mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 14:39:43 +00:00
feat(policies): scaffold smolvla2 (smolvla + lm_head re-enabled)
PR 3 of the steerable-annotation plan retargeted from Pi0.5 to SmolVLA
because the recipe stack (PR 1 + PR 2) outputs HF/TRL-compatible chat
which a chat-pretrained backbone consumes natively. SmolVLA strips the
SmolVLM ``lm_head`` though, so it can only do flow-matching action
prediction. SmolVLA2 keeps the LM head so the same model can train on
the full Hi Robot / MEM / ECoT blend defined in the plan:
* action-only sub-recipes (low_level_execution) flow loss
* text-only sub-recipes (memory_update / ask_vqa / CE loss on
user_interjection_response) lm_head
* mixed sub-recipes both summed
This first commit lays down the structural scaffold:
- ``src/lerobot/policies/smolvla2/`` — new package with thin subclasses
of ``SmolVLAConfig`` / ``SmolVLAPolicy`` so we don't fork the 900-line
modeling code. ``SmolVLA2Config`` adds ``recipe_path``,
``apply_chat_template``, ``text_loss_weight``, ``flow_loss_weight``,
and ``unfreeze_lm_head``. ``SmolVLA2Policy`` unfreezes the SmolVLM
``lm_head`` (and the surrounding norm + last text-model layer SmolVLA
freezes) when ``unfreeze_lm_head=True`` and ``text_loss_weight>0``.
- ``factory.py`` registers ``smolvla2`` in ``get_policy_class``,
``make_policy_config``, and the pre/post-processor builder. Important:
the ``smolvla2`` branch lives BEFORE the ``isinstance(config,
SmolVLAConfig)`` check because ``SmolVLA2Config`` subclasses
``SmolVLAConfig`` — without the ordering, SmolVLA2 would silently
pick up SmolVLA's processor.
- ``configs/recipes/smolvla2_hirobot.yaml`` — canonical Hi Robot blend
for SmolVLA2. Same shape as ``pi05_hirobot.yaml`` (PR 1) so the
recipe stack stays uniform across policy backbones.
Behaviour today is identical to SmolVLA: the modeling forward
delegates to ``SmolVLAPolicy.forward`` and the processor delegates to
``make_smolvla_pre_post_processors``. The next commit on this branch
adds the chat-template processor + ``text_labels`` / ``predict_actions``
batch keys; the commit after that wires the actual text-loss path
through ``vlm.lm_head``.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,88 @@
|
||||
# SmolVLA2 canonical training recipe — Hi Robot / MEM / ECoT blend.
|
||||
#
|
||||
# Same blend shape as pi05_hirobot.yaml. SmolVLA2 differs from Pi0.5 in
|
||||
# how the renderer's output is consumed:
|
||||
#
|
||||
# - SmolVLA2 calls SmolVLM's tokenizer.apply_chat_template(messages,
|
||||
# tools=DEFAULT_TOOLS) on the rendered messages, since SmolVLM is a
|
||||
# chat-pretrained backbone.
|
||||
# - The processor builds a `text_labels` tensor that masks every token
|
||||
# except those belonging to messages whose index is in
|
||||
# `target_message_indices`. Cross-entropy on those positions trains
|
||||
# the LM head.
|
||||
# - `predict_actions = bool(targets_by_stream.get("low_level"))` —
|
||||
# same convention as Pi0.5. ``low_level_execution`` is the only
|
||||
# branch that runs the action expert / flow head.
|
||||
|
||||
blend:
|
||||
|
||||
memory_update:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
||||
current_memory: "emitted_at(t, style=memory)"
|
||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
||||
|
||||
user_interjection_response:
|
||||
weight: 0.16
|
||||
bindings:
|
||||
prior_plan: "nth_prev(style=plan, offset=1)"
|
||||
current_plan: "emitted_at(t, style=plan)"
|
||||
interjection: "emitted_at(t, style=interjection)"
|
||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||
messages:
|
||||
- {role: user, content: "${task}", stream: high_level}
|
||||
- {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan}
|
||||
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
|
||||
- {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech}
|
||||
|
||||
high_level_subtask:
|
||||
weight: 0.15
|
||||
bindings:
|
||||
next_subtask: "nth_next(style=subtask, offset=1)"
|
||||
messages:
|
||||
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
|
||||
- {role: user, content: "Current subtask: ${subtask}", stream: high_level, if_present: subtask}
|
||||
- {role: assistant, content: "${next_subtask}", stream: high_level, target: true}
|
||||
|
||||
low_level_execution:
|
||||
weight: 0.35
|
||||
messages:
|
||||
- {role: user, content: "${task}\nPlan: ${plan}\nMemory: ${memory}", stream: high_level}
|
||||
- {role: assistant, content: "${subtask}", stream: low_level, target: true}
|
||||
|
||||
# Per-camera VQA sub-recipes (PR 1's view-dependent style routing).
|
||||
# Adjust the camera keys (and add more sub-recipes) to match the
|
||||
# cameras present on your dataset.
|
||||
ask_vqa_top:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.top}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
|
||||
ask_vqa_wrist:
|
||||
weight: 0.10
|
||||
bindings:
|
||||
vqa_query: "emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
||||
vqa: "emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
||||
messages:
|
||||
- role: user
|
||||
stream: high_level
|
||||
if_present: vqa_query
|
||||
content:
|
||||
- {type: image, feature: observation.images.wrist}
|
||||
- {type: text, text: "${vqa_query}"}
|
||||
- {role: assistant, content: "${vqa}", stream: high_level, target: true, if_present: vqa}
|
||||
@@ -140,6 +140,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from .smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "smolvla2":
|
||||
from .smolvla2.modeling_smolvla2 import SmolVLA2Policy
|
||||
|
||||
return SmolVLA2Policy
|
||||
elif name == "sarm":
|
||||
from .sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
@@ -200,6 +204,10 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "smolvla2":
|
||||
from .smolvla2.configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
return SmolVLA2Config(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
@@ -386,6 +394,17 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif policy_cfg.type == "smolvla2":
|
||||
# NOTE: SmolVLA2Config subclasses SmolVLAConfig, so this branch
|
||||
# MUST come before the SmolVLAConfig isinstance check below
|
||||
# (otherwise SmolVLA2 would silently pick up SmolVLA's processor).
|
||||
from .smolvla2.processor_smolvla2 import make_smolvla2_pre_post_processors
|
||||
|
||||
processors = make_smolvla2_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SmolVLAConfig):
|
||||
from .smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 — SmolVLA with the SmolVLM language head re-enabled.
|
||||
|
||||
SmolVLA strips the LM head from the SmolVLM backbone because it only does
|
||||
flow-matching action prediction. SmolVLA2 keeps the LM head so the same
|
||||
model can train on the full Hi Robot / MEM / ECoT message blend defined in
|
||||
the steerable annotation plan (PR1 + PR2):
|
||||
|
||||
* action-only sub-recipes (e.g. ``low_level_execution``) → flow loss
|
||||
* text-only sub-recipes (e.g. ``memory_update``, ``ask_vqa``,
|
||||
``user_interjection_response``, ``high_level_subtask``) → CE loss on
|
||||
``lm_head`` over the recipe's target message tokens
|
||||
* mixed sub-recipes → both losses summed (weighted)
|
||||
|
||||
The ``predict_actions`` toggle follows the Pi0.5 convention from Section
|
||||
I.7 of the plan: ``True`` if any ``low_level`` target is present in the
|
||||
sample, else ``False``.
|
||||
|
||||
This package is a thin subclass of ``lerobot.policies.smolvla`` so most of
|
||||
the model code stays in one place — only the dual-loss path and the
|
||||
chat-template processor live here.
|
||||
"""
|
||||
|
||||
from .configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
__all__ = ["SmolVLA2Config"]
|
||||
@@ -0,0 +1,97 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs import PreTrainedConfig
|
||||
|
||||
from ..smolvla.configuration_smolvla import SmolVLAConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("smolvla2")
|
||||
@dataclass
|
||||
class SmolVLA2Config(SmolVLAConfig):
|
||||
"""SmolVLA2 — SmolVLA with the underlying SmolVLM language head re-enabled.
|
||||
|
||||
SmolVLA strips the LM head from the SmolVLM backbone because it only
|
||||
needs flow-matching action prediction. SmolVLA2 keeps the LM head so the
|
||||
same model can train on:
|
||||
|
||||
* **action-only sub-recipes** (e.g. ``low_level_execution``) — flow loss
|
||||
on the action expert, same as SmolVLA. ``predict_actions=True``.
|
||||
* **text-only sub-recipes** (e.g. ``memory_update`` / ``ask_vqa`` /
|
||||
``user_interjection_response`` / ``high_level_subtask``) — cross-
|
||||
entropy loss on the LM head over the recipe's target message tokens.
|
||||
Skips the flow head entirely. ``predict_actions=False``.
|
||||
* **mixed sub-recipes** — both heads run, losses summed (weighted).
|
||||
|
||||
The split is controlled by ``predict_actions = bool(targets_by_stream
|
||||
.get("low_level"))`` per the Pi0.5 convention in the steerable
|
||||
annotation plan (Section I.7), implemented inside the processor /
|
||||
forward path. Recipes drive it via ``stream`` + ``target`` metadata.
|
||||
|
||||
Compared to ``SmolVLAConfig`` this adds:
|
||||
|
||||
- ``recipe_path``: path to a ``TrainingRecipe`` YAML (loaded by the
|
||||
train script). When ``None``, SmolVLA2 falls back to the SmolVLA
|
||||
task-only path so unannotated datasets still work.
|
||||
- ``text_loss_weight`` / ``flow_loss_weight``: relative weights when
|
||||
both losses are active in a single sample.
|
||||
- ``unfreeze_lm_head``: must be ``True`` for the text head to learn —
|
||||
SmolVLA freezes ``lm_head`` to "avoid unused params issues" and we
|
||||
need to undo that for SmolVLA2.
|
||||
- ``train_expert_only=False`` by default, since the VLM body now also
|
||||
participates in text-target gradients.
|
||||
"""
|
||||
|
||||
# Recipe / language stack ---------------------------------------------
|
||||
recipe_path: str | None = "recipes/smolvla2_hirobot.yaml"
|
||||
"""Path (absolute or relative to ``src/lerobot/configs/``) to a
|
||||
``TrainingRecipe`` YAML. The default points at the canonical Hi Robot
|
||||
blend shipped alongside SmolVLA2. Set to ``None`` to disable recipe
|
||||
rendering and fall back to SmolVLA's single-task prompt path
|
||||
(unannotated datasets keep working that way)."""
|
||||
|
||||
apply_chat_template: bool = True
|
||||
"""Apply the SmolVLM tokenizer's chat template to the rendered messages
|
||||
before tokenizing. SmolVLM's backbone is chat-pretrained, so this
|
||||
matches its training distribution."""
|
||||
|
||||
# Loss weights --------------------------------------------------------
|
||||
text_loss_weight: float = 1.0
|
||||
"""Weight on the LM-head cross-entropy term. Set to ``0`` to disable
|
||||
text training entirely (reverts to flow-only / SmolVLA behaviour)."""
|
||||
|
||||
flow_loss_weight: float = 1.0
|
||||
"""Weight on the action-expert flow-matching term."""
|
||||
|
||||
# Backbone training ---------------------------------------------------
|
||||
unfreeze_lm_head: bool = True
|
||||
"""Whether to unfreeze the SmolVLM ``lm_head`` (and the immediately
|
||||
preceding norm + last text-model layer that SmolVLA freezes). Must be
|
||||
``True`` for the text head to learn. Setting this to ``False``
|
||||
effectively reduces SmolVLA2 back to SmolVLA's flow-only training,
|
||||
which is occasionally useful for ablations."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
# Backbone needs gradients flowing through its text path when the
|
||||
# LM head is producing supervised text. Override the SmolVLA
|
||||
# default (`train_expert_only=True`) unless the user explicitly
|
||||
# opts out of text training via `text_loss_weight=0`.
|
||||
if self.text_loss_weight > 0 and self.unfreeze_lm_head:
|
||||
# The user can still flip this back via CLI; this only
|
||||
# changes the *default* when SmolVLA2 is actually training a
|
||||
# text head.
|
||||
self.train_expert_only = False
|
||||
@@ -0,0 +1,119 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 modeling — dual-head subclass of SmolVLAPolicy.
|
||||
|
||||
This module defines :class:`SmolVLA2Policy`, which extends SmolVLA with:
|
||||
|
||||
* an unfrozen SmolVLM ``lm_head`` so language tokens can be supervised,
|
||||
* a forward path that routes to the flow head, the text head, or both,
|
||||
driven by ``batch["predict_actions"]`` and ``batch["text_labels"]``.
|
||||
|
||||
The text-head computation itself is NOT wired up in this scaffold commit
|
||||
(the processor doesn't yet produce ``text_labels`` either). This file is
|
||||
the structural placeholder that:
|
||||
|
||||
1. registers the ``SmolVLA2Policy`` class with the right config name so
|
||||
``policies/factory.py`` can build it,
|
||||
2. unfreezes ``lm_head`` at construction time when the config asks for it
|
||||
(otherwise SmolVLA's ``train_expert_only`` freezes it again on every
|
||||
``train()`` call),
|
||||
3. forwards to ``SmolVLAPolicy.forward`` so behaviour is identical to
|
||||
SmolVLA when no text labels are present — i.e. existing SmolVLA
|
||||
training scripts keep working.
|
||||
|
||||
The next commit on this branch fills in the actual text-loss path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from ..smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
from .configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
|
||||
class SmolVLA2Policy(SmolVLAPolicy):
|
||||
"""SmolVLA + re-enabled SmolVLM language head.
|
||||
|
||||
Compatible drop-in for ``SmolVLAPolicy`` from a checkpoint or factory
|
||||
perspective. Behaviourally identical to SmolVLA until the text-head
|
||||
code path lands in the next commit on this branch.
|
||||
"""
|
||||
|
||||
config_class = SmolVLA2Config
|
||||
name = "smolvla2"
|
||||
|
||||
def __init__(self, config: SmolVLA2Config, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
if not isinstance(config, SmolVLA2Config):
|
||||
# Allow loading a SmolVLA checkpoint into a SmolVLA2 model by
|
||||
# widening the config type — the new fields fall back to their
|
||||
# defaults, which preserves the existing SmolVLA behaviour.
|
||||
config = SmolVLA2Config(**{
|
||||
f.name: getattr(config, f.name)
|
||||
for f in config.__dataclass_fields__.values()
|
||||
if hasattr(config, f.name)
|
||||
})
|
||||
super().__init__(config, dataset_stats=dataset_stats)
|
||||
if config.unfreeze_lm_head and config.text_loss_weight > 0:
|
||||
self._unfreeze_lm_head()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Backbone surgery
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _unfreeze_lm_head(self) -> None:
|
||||
"""Re-enable gradients on the SmolVLM ``lm_head`` (and the bits of
|
||||
the text path SmolVLA freezes) so the text-loss can flow back.
|
||||
|
||||
SmolVLA's ``SmolVLMWithExpertModel.set_requires_grad`` freezes
|
||||
``lm_head``, ``text_model.model.norm.weight``, and the last
|
||||
``text_model.layers.<N-1>`` block. We undo that selectively when
|
||||
text training is enabled.
|
||||
"""
|
||||
vlm_with_expert = getattr(self.model, "vlm_with_expert", None)
|
||||
if vlm_with_expert is None:
|
||||
return
|
||||
vlm = getattr(vlm_with_expert, "vlm", None)
|
||||
if vlm is None:
|
||||
return
|
||||
for name, param in vlm.named_parameters():
|
||||
if (
|
||||
"lm_head" in name
|
||||
or "text_model.model.norm.weight" in name
|
||||
):
|
||||
param.requires_grad = True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Forward
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
noise: Tensor | None = None,
|
||||
time: Tensor | None = None,
|
||||
reduction: str = "mean",
|
||||
) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Forward pass with optional text-head loss.
|
||||
|
||||
SCAFFOLD: forwards directly to ``SmolVLAPolicy.forward``. The
|
||||
actual text-loss / dual-head routing lands in the next commit on
|
||||
this branch — it will read ``batch["text_labels"]`` and
|
||||
``batch["predict_actions"]`` (both produced by the SmolVLA2
|
||||
processor) to decide which head(s) to run.
|
||||
"""
|
||||
return super().forward(batch, noise=noise, time=time, reduction=reduction)
|
||||
@@ -0,0 +1,55 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SmolVLA2 processor pipelines.
|
||||
|
||||
SCAFFOLD: this currently delegates to SmolVLA's processor. The next
|
||||
commit on this branch replaces that with a chat-template aware pipeline:
|
||||
|
||||
RenderMessagesStep (PR1) → SmolVLA2ChatTokenizerStep → existing SmolVLA
|
||||
normalization / device steps.
|
||||
|
||||
The chat tokenizer step will:
|
||||
|
||||
* take ``messages`` / ``message_streams`` / ``target_message_indices``
|
||||
from the rendered sample,
|
||||
* call ``apply_chat_template(messages, tools=DEFAULT_TOOLS, ...)`` on the
|
||||
SmolVLM tokenizer,
|
||||
* tokenize the resulting prompt,
|
||||
* build a ``text_labels`` tensor with ``-100`` everywhere except the
|
||||
token positions belonging to messages whose index is in
|
||||
``target_message_indices``,
|
||||
* derive ``predict_actions = bool(targets_by_stream.get("low_level"))``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ..smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||
from .configuration_smolvla2 import SmolVLA2Config
|
||||
|
||||
|
||||
def make_smolvla2_pre_post_processors(
|
||||
config: SmolVLA2Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[Any, Any]:
|
||||
"""Build SmolVLA2's pre/post-processor pipelines.
|
||||
|
||||
SCAFFOLD: just delegates to ``make_smolvla_pre_post_processors`` so
|
||||
SmolVLA2 inherits SmolVLA's tokenization + normalization for now.
|
||||
The recipe-driven chat-template rendering arrives in the next commit.
|
||||
"""
|
||||
return make_smolvla_pre_post_processors(config, dataset_stats=dataset_stats)
|
||||
Reference in New Issue
Block a user