From b2aa372fcf48b92af75c17786121d8399fe14b69 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Wed, 13 May 2026 12:51:09 +0200 Subject: [PATCH] refactor(recipes): fold memory into action_execution, drop interjection, fuse smolvla2 forward MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Recipe changes: * action_execution now bundles the memory update as a second assistant target gated on a new ``new_memory`` binding (fires only at subtask-boundary frames). No "Completed subtask: X" filler — the model emits the new subtask AND the updated memory back-to-back in one prefix. * user_interjection_response sub-recipe removed (current datasets don't have interjection / say() annotations). * Standalone memory_update sub-recipe removed (folded above). * Weights rebalanced: action_execution 0.85, ask_vqa_top/wrist 0.075 each (sums to 1.0). Runtime ``_msgs_for_memory`` updated to match the new boundary-frame prompt layout. Modeling: * SmolVLA2Policy now fuses the flow + text losses into a SINGLE backbone forward via ``_compute_fused_loss`` (one vlm_with_expert pass with [prefix, suffix] embeds, then both lm_head CE on lang slice + action_out_proj MSE on suffix). Mirrors pi052's existing ``_compute_all_losses_fused`` — saves one backbone pass per training step. Examples: * Removed the two training SLURM scaffolds; they were out-of-date with the recipe refactor. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/training/pi052_hirobot.slurm | 75 -------- examples/training/smolvla2_hirobot.slurm | 82 --------- .../configs/recipes/pi052_hirobot.yaml | 51 +++--- .../configs/recipes/smolvla2_hirobot.yaml | 55 +++--- .../policies/smolvla2/inference/steps.py | 38 ++-- .../policies/smolvla2/modeling_smolvla2.py | 170 ++++++++++++++++-- 6 files changed, 228 insertions(+), 243 deletions(-) delete mode 100644 examples/training/pi052_hirobot.slurm delete mode 100644 examples/training/smolvla2_hirobot.slurm diff --git a/examples/training/pi052_hirobot.slurm b/examples/training/pi052_hirobot.slurm deleted file mode 100644 index e0a902177..000000000 --- a/examples/training/pi052_hirobot.slurm +++ /dev/null @@ -1,75 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=pi052-hirobot -#SBATCH --partition=hopper-prod -#SBATCH --qos=high -#SBATCH --time=48:00:00 -#SBATCH --ntasks=1 -#SBATCH --gpus-per-task=8 - -# π0.5 v2 training — reproduces the π0.5 paper's hierarchical recipe. -# -# Same recipe blend as the SmolVLA2 stack (recipes/pi052_hirobot.yaml), -# just on the PaliGemma 2B + Gemma-300m action-expert backbone the -# paper uses. The text head learns subtask prediction via cross- -# entropy on supervised spans; the action expert learns the flow -# field. Paper §IV.D mixes the two losses with α=10, which we encode -# as flow_loss_weight=10 / text_loss_weight=1. - -set -euo pipefail - -cd "${LEROBOT_ROOT:-$HOME/lerobot}" - -export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH" -export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}" -export NCCL_TIMEOUT="${NCCL_TIMEOUT:-1800}" -export HF_HUB_DOWNLOAD_TIMEOUT="${HF_HUB_DOWNLOAD_TIMEOUT:-120}" -export WANDB_INIT_TIMEOUT="${WANDB_INIT_TIMEOUT:-300}" - -DATASET="${DATASET:-pepijn223/super_poulain_full_tool3}" -POLICY_REPO_ID="${POLICY_REPO_ID:-pepijn223/pi052_hirobot_super_poulain}" -JOB_NAME="${JOB_NAME:-pi052-hirobot-super-poulain}" -NUM_PROCESSES="${NUM_PROCESSES:-8}" -BATCH_SIZE="${BATCH_SIZE:-32}" -STEPS="${STEPS:-15000}" -RUN_ID="${SLURM_JOB_ID:-$(date +%Y%m%d_%H%M%S)}" -OUTPUT_DIR="${OUTPUT_DIR:-/fsx/pepijn/outputs/train/pi052_hirobot_${STEPS}_${RUN_ID}}" - -echo "Training pi052 on $DATASET" -echo " GPUs: $NUM_PROCESSES" -echo " batch: $BATCH_SIZE / GPU (global=$((NUM_PROCESSES * BATCH_SIZE)))" -echo " steps: $STEPS" -echo " output: $OUTPUT_DIR" -echo " loss mix: flow_loss_weight=10 (paper α), text_loss_weight=1" -echo " augmentation: image_transforms ON, prompt dropout {plan:0.30 memory:0.30 subtask:0.20}" - -accelerate launch --multi_gpu --num_processes="$NUM_PROCESSES" \ - -m lerobot.scripts.lerobot_train \ - --policy.type=pi052 \ - --policy.recipe_path=recipes/pi052_hirobot.yaml \ - --dataset.repo_id="$DATASET" \ - --dataset.revision=main \ - --dataset.video_backend=pyav \ - --output_dir="$OUTPUT_DIR" \ - --job_name="$JOB_NAME" \ - --policy.repo_id="$POLICY_REPO_ID" \ - --policy.compile_model=false \ - --policy.device=cuda \ - --policy.tokenizer_max_length=512 \ - --policy.text_loss_weight=1.0 \ - --policy.flow_loss_weight=10.0 \ - --policy.unfreeze_lm_head=true \ - --steps="$STEPS" \ - --policy.scheduler_decay_steps="$STEPS" \ - --batch_size="$BATCH_SIZE" \ - --wandb.enable=true \ - --wandb.disable_artifact=true \ - --wandb.project=hirobot \ - --log_freq=100 \ - --save_freq="$STEPS" \ - --num_workers=0 \ - --dataset.image_transforms.enable=true \ - --dataset.image_transforms.max_num_transforms=3 \ - --dataset.image_transforms.random_order=true \ - --policy.plan_dropout_prob=0.30 \ - --policy.memory_dropout_prob=0.30 \ - --policy.subtask_dropout_prob=0.20 diff --git a/examples/training/smolvla2_hirobot.slurm b/examples/training/smolvla2_hirobot.slurm deleted file mode 100644 index 2a3eac1f8..000000000 --- a/examples/training/smolvla2_hirobot.slurm +++ /dev/null @@ -1,82 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=smolvla2-hirobot -#SBATCH --partition=hopper-prod -#SBATCH --qos=high -#SBATCH --time=48:00:00 -#SBATCH --ntasks=1 -#SBATCH --gpus-per-task=8 - -# SmolVLA2 training on an annotated dataset. -# -# The high_level_subtask recipe (recipes/smolvla2_hirobot.yaml) was -# fixed in PR3 to supervise the LM head with the *current* active -# subtask span at every frame, not the next-span target which is -# empty on stable phases. With the old recipe the head learned to -# emit ``\n`` on every chunk boundary; the new one supervises a -# real, scene-grounded string at every frame. -# -# Two regularisers are still on: -# -# * --dataset.image_transforms.enable=true: torchvision-v2 -# ColorJitter + SharpnessJitter + RandomAffine per frame; default -# envelope (brightness ±20% etc). -# * --policy.{plan,memory,subtask}_dropout_prob: randomly drop the -# context messages carrying the named recipe binding so the model -# handles missing/stale context. Mirrors Pi0.7 §V.E. - -set -euo pipefail - -cd "${LEROBOT_ROOT:-$HOME/lerobot}" - -export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH" -export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}" -export NCCL_TIMEOUT="${NCCL_TIMEOUT:-1800}" -export HF_HUB_DOWNLOAD_TIMEOUT="${HF_HUB_DOWNLOAD_TIMEOUT:-120}" -export WANDB_INIT_TIMEOUT="${WANDB_INIT_TIMEOUT:-300}" - -DATASET="${DATASET:-pepijn223/super_poulain_full_tool3}" -POLICY_REPO_ID="${POLICY_REPO_ID:-pepijn223/smolvla2_hirobot_super_poulain_tool6}" -JOB_NAME="${JOB_NAME:-smolvla2-hirobot-super-poulain-tool6}" -NUM_PROCESSES="${NUM_PROCESSES:-8}" -BATCH_SIZE="${BATCH_SIZE:-32}" -STEPS="${STEPS:-15000}" -RUN_ID="${SLURM_JOB_ID:-$(date +%Y%m%d_%H%M%S)}" -OUTPUT_DIR="${OUTPUT_DIR:-/fsx/pepijn/outputs/train/smolvla2_hirobot_super_poulain_tool3_${STEPS}_${RUN_ID}}" - -echo "Training smolvla2 on $DATASET" -echo " GPUs: $NUM_PROCESSES" -echo " batch: $BATCH_SIZE / GPU (global=$((NUM_PROCESSES * BATCH_SIZE)))" -echo " steps: $STEPS" -echo " output: $OUTPUT_DIR" -echo " augmentation: image_transforms ON, prompt dropout {plan:0.30 memory:0.30 subtask:0.20}" - -accelerate launch --multi_gpu --num_processes="$NUM_PROCESSES" \ - -m lerobot.scripts.lerobot_train \ - --policy.type=smolvla2 \ - --policy.recipe_path=recipes/smolvla2_hirobot.yaml \ - --dataset.repo_id="$DATASET" \ - --dataset.revision=main \ - --dataset.video_backend=pyav \ - --output_dir="$OUTPUT_DIR" \ - --job_name="$JOB_NAME" \ - --policy.repo_id="$POLICY_REPO_ID" \ - --policy.compile_model=false \ - --policy.device=cuda \ - --policy.tokenizer_max_length=512 \ - --policy.text_loss_weight=1.0 \ - --policy.flow_loss_weight=10.0 \ - --steps="$STEPS" \ - --policy.scheduler_decay_steps="$STEPS" \ - --batch_size="$BATCH_SIZE" \ - --wandb.enable=true \ - --wandb.disable_artifact=true \ - --wandb.project=hirobot \ - --log_freq=100 \ - --save_freq="$STEPS" \ - --num_workers=0 \ - --dataset.image_transforms.enable=true \ - --dataset.image_transforms.max_num_transforms=3 \ - --dataset.image_transforms.random_order=true \ - --policy.plan_dropout_prob=0.30 \ - --policy.memory_dropout_prob=0.30 \ - --policy.subtask_dropout_prob=0.20 diff --git a/src/lerobot/configs/recipes/pi052_hirobot.yaml b/src/lerobot/configs/recipes/pi052_hirobot.yaml index 4968ee80a..40a20387d 100644 --- a/src/lerobot/configs/recipes/pi052_hirobot.yaml +++ b/src/lerobot/configs/recipes/pi052_hirobot.yaml @@ -22,53 +22,50 @@ # Pi 0.7 §V.A — subtask in the prompt + flow on actions. # # Flavor 2 — event-driven text-only recipes -# ``memory_update``, ``user_interjection_response``, # ``ask_vqa_*``. Each handles a specific high-level event # with a TEXT output. ``if_present`` guards keep them from # firing on frames without the relevant annotation. +# +# Memory updates are folded INTO ``action_execution`` as a +# conditional second target gated on boundary frames — see +# ``smolvla2_hirobot.yaml`` for the rationale. The +# ``user_interjection_response`` recipe was dropped — the +# current datasets don't include interjection / say() annotations. blend: # ---------------------------------------------------------- # FLAVOR 1: action_execution (main path) + # + # Bundles memory updates inline. On most frames the binding + # ``new_memory: emitted_at(t, style=memory)`` returns None and + # only the subtask is supervised. On *boundary* frames (the + # exact timestamp a new memory was annotated — i.e. when a + # subtask just completed) the binding fires and the recipe + # supervises the new memory as a follow-up assistant turn, + # with a "Completed subtask: …" user message in between to + # separate the two outputs in the rendered prefix. # ---------------------------------------------------------- action_execution: - weight: 0.60 + weight: 0.85 + bindings: + new_memory: "emitted_at(t, style=memory)" messages: - role: user stream: high_level content: "${task}\nPlan: ${plan}\nMemory: ${memory}" - {role: assistant, content: "${subtask}", stream: low_level, target: true, if_present: subtask} + # Memory-update tail — only renders at boundary frames where + # ``new_memory`` fires. The new memory is appended as a second + # assistant turn right after the subtask, with no intervening + # user filler: at a subtask boundary the model emits the new + # subtask AND the updated memory in one forward pass. + - {role: assistant, content: "${new_memory}", stream: high_level, target: true, if_present: new_memory} # ---------------------------------------------------------- # FLAVOR 2: event-driven text-only paths # ---------------------------------------------------------- - memory_update: - weight: 0.10 - bindings: - prior_memory: "nth_prev(style=memory, offset=1)" - current_memory: "emitted_at(t, style=memory)" - completed_subtask: "nth_prev(style=subtask, offset=1)" - messages: - - {role: user, content: "${task}", stream: high_level} - - {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory} - - {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask} - - {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory} - - user_interjection_response: - weight: 0.15 - bindings: - prior_plan: "nth_prev(style=plan, offset=1)" - current_plan: "emitted_at(t, style=plan)" - interjection: "emitted_at(t, style=interjection)" - speech: "emitted_at(t, role=assistant, tool_name=say)" - messages: - - {role: user, content: "${task}", stream: high_level} - - {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan} - - {role: user, content: "${interjection}", stream: high_level, if_present: interjection} - - {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech} - ask_vqa_top: weight: 0.075 bindings: diff --git a/src/lerobot/configs/recipes/smolvla2_hirobot.yaml b/src/lerobot/configs/recipes/smolvla2_hirobot.yaml index 97c786fee..d96bd168d 100644 --- a/src/lerobot/configs/recipes/smolvla2_hirobot.yaml +++ b/src/lerobot/configs/recipes/smolvla2_hirobot.yaml @@ -28,16 +28,17 @@ # Each handles a specific high-level event with a TEXT # output (no action supervision). They fire when the # binding for the event resolves to non-None: -# * ``memory_update``: at subtask boundary, predict new -# memory from task + prior memory + completed subtask. -# * ``user_interjection_response``: on user input, predict -# new plan + paired ``say()`` tool call. # * ``ask_vqa_top`` / ``ask_vqa_wrist``: answer a # camera-grounded visual question. # All use ``stream: high_level`` (no flow loss) and rely on # ``if_present`` guards so they only fire on frames where # the relevant event annotation is present. # +# ``memory_update`` is folded into Flavor 1 (gated on the +# ``new_memory`` binding at boundary frames). +# ``user_interjection_response`` was dropped — the current +# datasets don't include interjection / say() annotations. +# # How the chat tokenizer interprets the flavor split # --------------------------------------------------- # * predict_actions = bool(targets_by_stream.get("low_level")) @@ -50,44 +51,38 @@ blend: # ---------------------------------------------------------- # FLAVOR 1: action_execution (main path) + # + # Bundles memory updates inline. On most frames the binding + # ``new_memory: emitted_at(t, style=memory)`` returns None and + # only the subtask is supervised. On *boundary* frames (the + # exact timestamp a new memory was annotated — i.e. when a + # subtask just completed) the binding fires and the recipe + # supervises the new memory as a follow-up assistant turn, + # with a "Completed subtask: …" user message in between to + # separate the two outputs in the chat sequence. Mirrors the + # behaviour of the old standalone ``memory_update`` recipe + # but keeps everything inside the unified action_execution. # ---------------------------------------------------------- action_execution: - weight: 0.60 + weight: 0.85 + bindings: + new_memory: "emitted_at(t, style=memory)" messages: - role: user stream: high_level content: "${task}\nPlan: ${plan}\nMemory: ${memory}" - {role: assistant, content: "${subtask}", stream: low_level, target: true, if_present: subtask} + # Memory-update tail — only renders at boundary frames where + # ``new_memory`` fires. The new memory is appended as a second + # assistant turn right after the subtask, with no intervening + # user filler: at a subtask boundary the model emits the new + # subtask AND the updated memory in one forward pass. + - {role: assistant, content: "${new_memory}", stream: high_level, target: true, if_present: new_memory} # ---------------------------------------------------------- # FLAVOR 2: event-driven text-only paths # ---------------------------------------------------------- - memory_update: - weight: 0.10 - bindings: - prior_memory: "nth_prev(style=memory, offset=1)" - current_memory: "emitted_at(t, style=memory)" - completed_subtask: "nth_prev(style=subtask, offset=1)" - messages: - - {role: user, content: "${task}", stream: high_level} - - {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory} - - {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask} - - {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory} - - user_interjection_response: - weight: 0.15 - bindings: - prior_plan: "nth_prev(style=plan, offset=1)" - current_plan: "emitted_at(t, style=plan)" - interjection: "emitted_at(t, style=interjection)" - speech: "emitted_at(t, role=assistant, tool_name=say)" - messages: - - {role: user, content: "${task}", stream: high_level} - - {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan} - - {role: user, content: "${interjection}", stream: high_level, if_present: interjection} - - {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech} - ask_vqa_top: weight: 0.075 bindings: diff --git a/src/lerobot/policies/smolvla2/inference/steps.py b/src/lerobot/policies/smolvla2/inference/steps.py index 0eef3b493..e49ef3355 100644 --- a/src/lerobot/policies/smolvla2/inference/steps.py +++ b/src/lerobot/policies/smolvla2/inference/steps.py @@ -738,24 +738,30 @@ def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]: def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]: - """``memory_update`` recipe layout.""" - msgs: list[dict[str, Any]] = [ - {"role": "user", "content": state.get("task") or ""} - ] + """Memory-update prompt — matches the boundary-frame tail of + ``action_execution`` in the v2 recipes. + + Recipe layout on a boundary frame: + user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}" + assistant: "${subtask}" + assistant: → predicts new memory + + At inference we fire this when the runtime detects a subtask + transition; the freshly-predicted subtask lives in + ``state['current_subtask']``. No "Completed subtask: X" user + filler — the second assistant turn is generated immediately + after the subtask turn. + """ + head_parts = [state.get("task") or ""] + if state.get("current_plan"): + head_parts.append(f"Plan: {state['current_plan']}") if state.get("current_memory"): - msgs.append( - { - "role": "assistant", - "content": f"Previous memory: {state['current_memory']}", - } - ) + head_parts.append(f"Memory: {state['current_memory']}") + msgs: list[dict[str, Any]] = [ + {"role": "user", "content": "\n".join(head_parts)}, + ] if state.get("current_subtask"): - msgs.append( - { - "role": "user", - "content": f"Completed subtask: {state['current_subtask']}", - } - ) + msgs.append({"role": "assistant", "content": state["current_subtask"]}) return msgs diff --git a/src/lerobot/policies/smolvla2/modeling_smolvla2.py b/src/lerobot/policies/smolvla2/modeling_smolvla2.py index e0235e53f..cb7f3d8a3 100644 --- a/src/lerobot/policies/smolvla2/modeling_smolvla2.py +++ b/src/lerobot/policies/smolvla2/modeling_smolvla2.py @@ -133,18 +133,29 @@ class SmolVLA2Policy(SmolVLAPolicy): device = batch[OBS_STATE].device total = torch.zeros((), device=device, dtype=torch.float32) - # ------------------------------------------------------------ - # Flow loss path — only when at least one sample wants actions. - # ------------------------------------------------------------ - run_flow = self.config.flow_loss_weight > 0 and ( - not has_per_sample_routing or bool(predict_actions_t.any().item()) + run_flow = ( + self.config.flow_loss_weight > 0 + and ACTION in batch + and (not has_per_sample_routing or bool(predict_actions_t.any().item())) ) - if run_flow and ACTION in batch: + + # ------------------------------------------------------------ + # Fused path — one backbone forward for flow + text together. + # ------------------------------------------------------------ + if run_flow and has_text_data: + flow_loss, text_loss, flow_diag = self._compute_fused_loss( + batch, text_labels, predict_actions_t, noise=noise, time=time + ) + total = total + self.config.flow_loss_weight * flow_loss + total = total + self.config.text_loss_weight * text_loss + loss_dict["flow_loss"] = float(flow_loss.detach().item()) + loss_dict["text_loss"] = float(text_loss.detach().item()) + for k, v in flow_diag.items(): + loss_dict[f"flow_{k}"] = v + elif run_flow: per_sample_flow, flow_diag = super().forward( batch, noise=noise, time=time, reduction="none" ) - # ``per_sample_flow`` has shape (B,) from the SmolVLA - # reduction="none" branch. if has_per_sample_routing: mask = predict_actions_t.to(per_sample_flow.dtype) masked = per_sample_flow * mask @@ -156,11 +167,7 @@ class SmolVLA2Policy(SmolVLAPolicy): loss_dict["flow_loss"] = float(flow_loss.detach().item()) for k, v in flow_diag.items(): loss_dict[f"flow_{k}"] = v - - # ------------------------------------------------------------ - # Text loss path — prefix-only forward → lm_head → CE. - # ------------------------------------------------------------ - if has_text_data: + elif has_text_data: text_loss = self._compute_text_loss(batch, text_labels) total = total + self.config.text_loss_weight * text_loss loss_dict["text_loss"] = float(text_loss.detach().item()) @@ -253,6 +260,143 @@ class SmolVLA2Policy(SmolVLAPolicy): ) return loss / valid_labels.sum().clamp(min=1) + # ------------------------------------------------------------------ + # Fused flow + text loss (single backbone forward) + # ------------------------------------------------------------------ + + def _compute_fused_loss( + self, + batch: dict[str, Tensor], + text_labels: Tensor, + predict_actions_t: Tensor | None, + noise: Tensor | None = None, + time: Tensor | None = None, + ) -> tuple[Tensor, Tensor, dict[str, Any]]: + """One backbone forward → both flow MSE and text CE. + + Mirrors ``SmolVLAModel.forward`` (prefix + suffix concat, one + ``vlm_with_expert`` call) but captures **both** outputs: + + * ``prefix_out[:, lang_start:lang_end]`` → ``lm_head`` → CE on + ``text_labels`` (same slicing as ``_compute_text_loss``). + * ``suffix_out[:, -chunk_size:]`` → ``action_out_proj`` → flow + MSE against ``noise - actions`` (same as the parent forward). + + Saves one backbone pass per training step vs. running the flow + and text paths separately — same trick PI052Policy uses in + ``_compute_all_losses_fused``. + """ + from ..smolvla.modeling_smolvla import resize_with_pad # noqa: F401 (kept for parity) + + cfg = self.config + if cfg.adapt_to_pi_aloha: + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) + batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) + + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) + lang_tokens = batch[OBS_LANGUAGE_TOKENS] + lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK] + actions = self.prepare_action(batch) + + inner = self.model + if noise is None: + noise = inner.sample_noise(actions.shape, actions.device) + if time is None: + time = inner.sample_time(actions.shape[0], actions.device) + + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + + prefix_embs, prefix_pad_masks, prefix_att_masks = inner.embed_prefix( + images, img_masks, lang_tokens, lang_masks, state=state + ) + suffix_embs, suffix_pad_masks, suffix_att_masks = inner.embed_suffix(x_t, time) + + pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) + att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) + att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + + out_pair, _ = inner.vlm_with_expert.forward( + attention_mask=att_2d_masks, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + fill_kv_cache=False, + ) + prefix_out, suffix_out = out_pair[0], out_pair[1] + if prefix_out is None or suffix_out is None: + raise RuntimeError( + "SmolVLA2: fused forward expected both prefix and suffix " + "hidden states from vlm_with_expert." + ) + + # ---------------- flow loss (per-sample maskable) ---------------- + chunk = cfg.chunk_size + suffix_chunk = suffix_out[:, -chunk:].to(torch.float32) + v_t = inner.action_out_proj(suffix_chunk) + losses = F.mse_loss(u_t, v_t, reduction="none") + + original_action_dim = cfg.action_feature.shape[0] + losses = losses[:, :, :original_action_dim] + flow_diag = {"losses_after_forward": float(losses.detach().mean().item())} + + actions_is_pad = batch.get("action_is_pad") + if actions_is_pad is not None: + in_episode = ~actions_is_pad + losses = losses * in_episode.unsqueeze(-1) + flow_diag["losses_after_in_ep_bound"] = float(losses.detach().mean().item()) + + losses = losses[:, :, : cfg.max_action_dim] + flow_diag["losses_after_rm_padding"] = float(losses.detach().mean().item()) + + per_sample_flow = losses.mean(dim=(1, 2)) + if predict_actions_t is not None: + mask = predict_actions_t.to(per_sample_flow.dtype) + flow_loss = (per_sample_flow * mask).sum() / mask.sum().clamp(min=1.0) + else: + flow_loss = per_sample_flow.mean() + + # ---------------- text loss (lang slice of prefix) --------------- + num_lang = lang_tokens.shape[1] + state_for_dim = state if state.ndim >= 2 else state[:, None] + num_state = state_for_dim.shape[1] if state_for_dim.ndim >= 2 else 1 + if num_state < 1: + num_state = 1 + prefix_len = prefix_out.shape[1] + lang_end = prefix_len - num_state + lang_start = lang_end - num_lang + if lang_start < 0 or lang_end > prefix_len: + raise RuntimeError( + f"SmolVLA2: fused forward could not locate lang range " + f"(prefix_len={prefix_len}, num_lang={num_lang}, " + f"num_state={num_state})." + ) + vlm = inner.vlm_with_expert.vlm + lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype) + logits = vlm.lm_head(lang_hidden) + + if text_labels.shape[1] != num_lang: + common = min(text_labels.shape[1], num_lang) + logits = logits[:, :common] + text_labels = text_labels[:, :common] + + shift_logits = logits[:, :-1, :].contiguous() + shift_labels = text_labels[:, 1:].contiguous().long() + valid_labels = shift_labels != -100 + ce = F.cross_entropy( + shift_logits.reshape(-1, shift_logits.shape[-1]), + shift_labels.reshape(-1), + ignore_index=-100, + reduction="sum", + ) + text_loss = ce / valid_labels.sum().clamp(min=1) + + return flow_loss, text_loss, flow_diag + # ------------------------------------------------------------------ # Inference: text generation # ------------------------------------------------------------------