mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 14:09:47 +00:00
feat(smolvla2): chat-template processor + label mask + predict_actions
Wires PR 1's recipe stack into the SmolVLA2 pipeline so multi-target
sub-recipes (memory_update, ask_vqa, user_interjection_response,
high_level_subtask) carry meaningful supervision through to the model.
- New ``chat_processor_smolvla2.py`` with
``SmolVLA2ChatTokenizerStep``: reads ``messages`` /
``message_streams`` / ``target_message_indices`` from the rendered
sample (PR 1 ``RenderMessagesStep``), calls
``apply_chat_template(messages, tools=DEFAULT_TOOLS, ...)`` on the
SmolVLM tokenizer, and writes:
OBS_LANGUAGE_TOKENS / _ATTENTION_MASK ← chat-templated prompt
text_labels ← -100 except target msg tokens
predict_actions ← True iff any low_level target
Builds the label mask robustly by re-rendering the chat through
each target's prefix and reading off the prefix length — same
tokenizer, same tools, so the prefix tokens are guaranteed to be
a prefix of the full sequence. Image/video content blocks
(LeRobot ``feature``-keyed) are stripped before tokenizing; the
actual image tensors flow through SmolVLA's existing
``OBS_IMAGES_*`` channels and ``embed_prefix`` puts them before
the language embeddings, matching the chat-template-stripped
text order.
- ``processor_smolvla2.py``: when ``config.recipe_path`` is set,
build a new pipeline with ``RenderMessagesStep`` +
``SmolVLA2ChatTokenizerStep`` instead of SmolVLA's plain
``TokenizerProcessorStep``. When ``recipe_path`` is ``None``,
fall back to SmolVLA's pipeline so unannotated datasets still
work unchanged. Resolves recipe paths relative to
``src/lerobot/configs/`` so ``recipes/smolvla2_hirobot.yaml``
works directly.
The next commit on this branch picks up ``text_labels`` and
``predict_actions`` from the batch and routes them through the
SmolVLM ``lm_head`` for the actual dual-loss training.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,271 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""SmolVLA2's chat-template tokenization step.
|
||||||
|
|
||||||
|
Replaces SmolVLA's plain ``TokenizerProcessorStep`` for SmolVLA2 when a
|
||||||
|
``recipe_path`` is set. Reads the rendered messages produced by
|
||||||
|
``RenderMessagesStep`` (PR 1) and produces:
|
||||||
|
|
||||||
|
* ``OBS_LANGUAGE_TOKENS`` / ``OBS_LANGUAGE_ATTENTION_MASK`` —
|
||||||
|
the chat-templated prompt tokenized by SmolVLM's tokenizer, with
|
||||||
|
``tools=meta.tools`` (PR 1's catalog).
|
||||||
|
* ``text_labels`` — same shape as token ids, ``-100`` everywhere except
|
||||||
|
the positions belonging to messages whose index is in
|
||||||
|
``target_message_indices``. The next commit's modeling forward path
|
||||||
|
applies cross-entropy on those positions via the SmolVLM ``lm_head``.
|
||||||
|
* ``predict_actions`` — bool tensor, ``True`` iff any of the rendered
|
||||||
|
target messages has ``message_streams[i] == "low_level"``. The
|
||||||
|
modeling forward uses this to gate the flow head.
|
||||||
|
|
||||||
|
Image / video content blocks in the rendered messages are dropped
|
||||||
|
before tokenization — the chat template only handles text, and SmolVLA
|
||||||
|
already passes camera tensors out-of-band via the standard
|
||||||
|
``OBS_IMAGES_*`` features. This keeps the prefix layout unchanged
|
||||||
|
(``embed_prefix`` puts image embeddings before language embeddings,
|
||||||
|
matching the chat-template-stripped text order).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs import PipelineFeatureType, PolicyFeature
|
||||||
|
from lerobot.datasets.language import DEFAULT_TOOLS
|
||||||
|
from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry
|
||||||
|
from lerobot.types import EnvTransition, TransitionKey
|
||||||
|
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ProcessorStepRegistry.register(name="smolvla2_chat_tokenizer")
|
||||||
|
class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
||||||
|
"""Render messages → token ids + label mask + predict_actions flag.
|
||||||
|
|
||||||
|
This is the bridge between the recipe stack (PR 1's
|
||||||
|
``RenderMessagesStep`` outputs) and the SmolVLA2 modeling forward
|
||||||
|
(next commit, which reads ``text_labels`` / ``predict_actions``).
|
||||||
|
Pure-text turns and multi-stream targets are both handled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"
|
||||||
|
max_length: int = 2048
|
||||||
|
padding: str = "longest"
|
||||||
|
padding_side: str = "right"
|
||||||
|
tools: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
# Lazy: don't load the tokenizer until the step actually runs,
|
||||||
|
# so unit tests that import the module without transformers
|
||||||
|
# installed still pass.
|
||||||
|
self._tokenizer: Any = None
|
||||||
|
if self.tools is None:
|
||||||
|
# Default: ship the canonical ``say`` schema. Users who set
|
||||||
|
# ``meta.tools`` differently can override via
|
||||||
|
# ``with_tools(meta.tools)``.
|
||||||
|
self.tools = list(DEFAULT_TOOLS)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def with_tools(self, tools: list[dict[str, Any]]) -> "SmolVLA2ChatTokenizerStep":
|
||||||
|
"""Override the tools catalog rendered into the system prompt."""
|
||||||
|
self.tools = list(tools)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
|
||||||
|
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||||
|
messages = comp.get("messages")
|
||||||
|
if not messages:
|
||||||
|
# No recipe rendering happened — nothing to do; downstream
|
||||||
|
# falls back to whatever ``task`` is in the transition.
|
||||||
|
return transition
|
||||||
|
|
||||||
|
message_streams: list[str | None] = list(comp.get("message_streams") or [])
|
||||||
|
target_indices: list[int] = sorted(
|
||||||
|
int(i) for i in (comp.get("target_message_indices") or [])
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = self._get_tokenizer()
|
||||||
|
text_messages = [_strip_lerobot_blocks(m) for m in messages]
|
||||||
|
|
||||||
|
# Tokenize the full chat once.
|
||||||
|
full_ids = tokenizer.apply_chat_template(
|
||||||
|
text_messages,
|
||||||
|
tools=self.tools,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
tokenize=True,
|
||||||
|
return_tensors=None,
|
||||||
|
)
|
||||||
|
if isinstance(full_ids, list) and full_ids and isinstance(full_ids[0], list):
|
||||||
|
full_ids = full_ids[0]
|
||||||
|
|
||||||
|
# Build the label mask by re-rendering progressively up to each
|
||||||
|
# target message and reading off the prefix length. This is the
|
||||||
|
# robust way to get exact token boundaries: we use the same
|
||||||
|
# tokenizer, the same ``tools=`` argument, and the same chat
|
||||||
|
# template — so the prefix tokens are guaranteed to be a prefix
|
||||||
|
# of the full sequence.
|
||||||
|
labels = [-100] * len(full_ids)
|
||||||
|
for tgt in target_indices:
|
||||||
|
prefix_ids = tokenizer.apply_chat_template(
|
||||||
|
text_messages[:tgt],
|
||||||
|
tools=self.tools,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
tokenize=True,
|
||||||
|
return_tensors=None,
|
||||||
|
)
|
||||||
|
full_through_target = tokenizer.apply_chat_template(
|
||||||
|
text_messages[: tgt + 1],
|
||||||
|
tools=self.tools,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
tokenize=True,
|
||||||
|
return_tensors=None,
|
||||||
|
)
|
||||||
|
if isinstance(prefix_ids, list) and prefix_ids and isinstance(prefix_ids[0], list):
|
||||||
|
prefix_ids = prefix_ids[0]
|
||||||
|
if (
|
||||||
|
isinstance(full_through_target, list)
|
||||||
|
and full_through_target
|
||||||
|
and isinstance(full_through_target[0], list)
|
||||||
|
):
|
||||||
|
full_through_target = full_through_target[0]
|
||||||
|
start = len(prefix_ids)
|
||||||
|
end = min(len(full_through_target), len(full_ids))
|
||||||
|
for pos in range(start, end):
|
||||||
|
labels[pos] = int(full_ids[pos])
|
||||||
|
|
||||||
|
# Truncate / pad to ``max_length`` so batches collate cleanly.
|
||||||
|
# The SmolVLA pipeline downstream relies on a fixed length
|
||||||
|
# behaviour ("longest" or "max_length") — we mirror it here.
|
||||||
|
if len(full_ids) > self.max_length:
|
||||||
|
full_ids = full_ids[: self.max_length]
|
||||||
|
labels = labels[: self.max_length]
|
||||||
|
attn = [1] * len(full_ids)
|
||||||
|
if self.padding == "max_length" and len(full_ids) < self.max_length:
|
||||||
|
pad_id = (
|
||||||
|
tokenizer.pad_token_id
|
||||||
|
if tokenizer.pad_token_id is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
n_pad = self.max_length - len(full_ids)
|
||||||
|
full_ids = full_ids + [pad_id] * n_pad
|
||||||
|
labels = labels + [-100] * n_pad
|
||||||
|
attn = attn + [0] * n_pad
|
||||||
|
|
||||||
|
ids_t = torch.tensor(full_ids, dtype=torch.long)
|
||||||
|
attn_t = torch.tensor(attn, dtype=torch.bool)
|
||||||
|
labels_t = torch.tensor(labels, dtype=torch.long)
|
||||||
|
predict_actions = any(
|
||||||
|
i < len(message_streams) and message_streams[i] == "low_level"
|
||||||
|
for i in target_indices
|
||||||
|
)
|
||||||
|
|
||||||
|
new_complementary = dict(comp)
|
||||||
|
# Drop the per-recipe sidecar keys; everything downstream needs
|
||||||
|
# is now in the tokenized form.
|
||||||
|
new_complementary.pop("messages", None)
|
||||||
|
new_complementary.pop("message_streams", None)
|
||||||
|
new_complementary.pop("target_message_indices", None)
|
||||||
|
# SmolVLA's pipeline expects ``OBS_LANGUAGE_TOKENS`` /
|
||||||
|
# ``OBS_LANGUAGE_ATTENTION_MASK`` on the OBSERVATION key. Place
|
||||||
|
# them there — and drop ``task`` so the upstream
|
||||||
|
# ``TokenizerProcessorStep`` (which we replace) doesn't double-
|
||||||
|
# tokenize.
|
||||||
|
observation = dict(transition.get(TransitionKey.OBSERVATION) or {})
|
||||||
|
observation[OBS_LANGUAGE_TOKENS] = ids_t
|
||||||
|
observation[OBS_LANGUAGE_ATTENTION_MASK] = attn_t
|
||||||
|
new_complementary["text_labels"] = labels_t
|
||||||
|
new_complementary["predict_actions"] = torch.tensor(predict_actions, dtype=torch.bool)
|
||||||
|
new_complementary.pop("task", None)
|
||||||
|
|
||||||
|
new_transition = dict(transition)
|
||||||
|
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary
|
||||||
|
new_transition[TransitionKey.OBSERVATION] = observation
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
"""Pass-through; this step writes runtime tensors not features."""
|
||||||
|
return features
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _get_tokenizer(self): # noqa: ANN202
|
||||||
|
if self._tokenizer is not None:
|
||||||
|
return self._tokenizer
|
||||||
|
try:
|
||||||
|
from transformers import AutoTokenizer # noqa: PLC0415
|
||||||
|
except ImportError as exc: # pragma: no cover
|
||||||
|
raise ImportError(
|
||||||
|
"SmolVLA2ChatTokenizerStep requires transformers. "
|
||||||
|
"`pip install lerobot[transformers-dep]`."
|
||||||
|
) from exc
|
||||||
|
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
|
||||||
|
if self._tokenizer.pad_token_id is None and self._tokenizer.eos_token_id is not None:
|
||||||
|
self._tokenizer.pad_token = self._tokenizer.eos_token
|
||||||
|
return self._tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_lerobot_blocks(message: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Remove LeRobot-specific multimodal blocks from ``message`` content.
|
||||||
|
|
||||||
|
The recipe DSL allows authors to write multimodal content like
|
||||||
|
``{"type": "image", "feature": "observation.images.top"}``. SmolVLM's
|
||||||
|
tokenizer doesn't know that ``feature`` key (it expects ``url`` or
|
||||||
|
``path``). The actual image tensor flows through SmolVLA's
|
||||||
|
``OBS_IMAGES_*`` channels separately; the chat template only needs
|
||||||
|
the text. So we strip non-text blocks before tokenizing.
|
||||||
|
"""
|
||||||
|
new = dict(message)
|
||||||
|
content = new.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
text_parts: list[dict[str, Any]] = []
|
||||||
|
for block in content:
|
||||||
|
if not isinstance(block, dict):
|
||||||
|
continue
|
||||||
|
if block.get("type") == "text":
|
||||||
|
text_parts.append({"type": "text", "text": str(block.get("text", ""))})
|
||||||
|
# If only one text block survives, flatten to a string for
|
||||||
|
# template friendliness; some chat templates choke on a single-
|
||||||
|
# element list.
|
||||||
|
if len(text_parts) == 1:
|
||||||
|
new["content"] = text_parts[0]["text"]
|
||||||
|
elif text_parts:
|
||||||
|
new["content"] = text_parts
|
||||||
|
else:
|
||||||
|
new["content"] = ""
|
||||||
|
if "tool_calls" in new and not new["tool_calls"]:
|
||||||
|
# Drop empty tool_calls — some templates render them as a
|
||||||
|
# spurious empty marker.
|
||||||
|
new.pop("tool_calls")
|
||||||
|
# ``stream`` and ``target`` were recipe metadata; templates don't
|
||||||
|
# know them and may warn or crash.
|
||||||
|
new.pop("stream", None)
|
||||||
|
new.pop("target", None)
|
||||||
|
return new
|
||||||
|
|
||||||
|
|
||||||
|
# Re-export for tests / introspection
|
||||||
|
strip_lerobot_blocks = _strip_lerobot_blocks
|
||||||
@@ -13,43 +13,119 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""SmolVLA2 processor pipelines.
|
"""SmolVLA2 processor pipelines.
|
||||||
|
|
||||||
SCAFFOLD: this currently delegates to SmolVLA's processor. The next
|
When ``config.recipe_path`` is set, the pre-processor pipeline becomes:
|
||||||
commit on this branch replaces that with a chat-template aware pipeline:
|
|
||||||
|
|
||||||
RenderMessagesStep (PR1) → SmolVLA2ChatTokenizerStep → existing SmolVLA
|
rename observations
|
||||||
normalization / device steps.
|
add batch dim
|
||||||
|
RenderMessagesStep(recipe) # PR 1: language_* → messages
|
||||||
|
SmolVLA2ChatTokenizerStep(...) # chat template + label mask + predict_actions
|
||||||
|
DeviceProcessorStep
|
||||||
|
NormalizerProcessorStep
|
||||||
|
|
||||||
The chat tokenizer step will:
|
When ``config.recipe_path`` is ``None``, we delegate to SmolVLA's
|
||||||
|
plain task-string pipeline so unannotated datasets still work.
|
||||||
|
|
||||||
* take ``messages`` / ``message_streams`` / ``target_message_indices``
|
Post-processor is unchanged from SmolVLA.
|
||||||
from the rendered sample,
|
|
||||||
* call ``apply_chat_template(messages, tools=DEFAULT_TOOLS, ...)`` on the
|
|
||||||
SmolVLM tokenizer,
|
|
||||||
* tokenize the resulting prompt,
|
|
||||||
* build a ``text_labels`` tensor with ``-100`` everywhere except the
|
|
||||||
token positions belonging to messages whose index is in
|
|
||||||
``target_message_indices``,
|
|
||||||
* derive ``predict_actions = bool(targets_by_stream.get("low_level"))``.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.recipe import TrainingRecipe
|
||||||
|
from lerobot.processor import (
|
||||||
|
AddBatchDimensionProcessorStep,
|
||||||
|
DeviceProcessorStep,
|
||||||
|
NormalizerProcessorStep,
|
||||||
|
PolicyAction,
|
||||||
|
PolicyProcessorPipeline,
|
||||||
|
RenameObservationsProcessorStep,
|
||||||
|
RenderMessagesStep,
|
||||||
|
UnnormalizerProcessorStep,
|
||||||
|
policy_action_to_transition,
|
||||||
|
transition_to_policy_action,
|
||||||
|
)
|
||||||
|
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
|
|
||||||
from ..smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
from ..smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||||
|
from .chat_processor_smolvla2 import SmolVLA2ChatTokenizerStep
|
||||||
from .configuration_smolvla2 import SmolVLA2Config
|
from .configuration_smolvla2 import SmolVLA2Config
|
||||||
|
|
||||||
|
|
||||||
def make_smolvla2_pre_post_processors(
|
def make_smolvla2_pre_post_processors(
|
||||||
config: SmolVLA2Config,
|
config: SmolVLA2Config,
|
||||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
) -> tuple[Any, Any]:
|
) -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
"""Build SmolVLA2's pre/post-processor pipelines.
|
"""Build SmolVLA2's pre/post-processor pipelines.
|
||||||
|
|
||||||
SCAFFOLD: just delegates to ``make_smolvla_pre_post_processors`` so
|
With ``recipe_path`` set, inserts the recipe-rendering step and the
|
||||||
SmolVLA2 inherits SmolVLA's tokenization + normalization for now.
|
chat-template tokenizer that emits ``text_labels`` and
|
||||||
The recipe-driven chat-template rendering arrives in the next commit.
|
``predict_actions`` for the dual-loss path. Without it, falls back
|
||||||
|
to SmolVLA's plain task-string pipeline so unannotated datasets
|
||||||
|
keep working unchanged.
|
||||||
"""
|
"""
|
||||||
|
if not config.recipe_path:
|
||||||
return make_smolvla_pre_post_processors(config, dataset_stats=dataset_stats)
|
return make_smolvla_pre_post_processors(config, dataset_stats=dataset_stats)
|
||||||
|
|
||||||
|
recipe = _load_recipe(config.recipe_path)
|
||||||
|
|
||||||
|
input_steps = [
|
||||||
|
RenameObservationsProcessorStep(rename_map={}),
|
||||||
|
AddBatchDimensionProcessorStep(),
|
||||||
|
RenderMessagesStep(recipe=recipe),
|
||||||
|
SmolVLA2ChatTokenizerStep(
|
||||||
|
tokenizer_name=config.vlm_model_name,
|
||||||
|
max_length=config.tokenizer_max_length,
|
||||||
|
padding=config.pad_language_to,
|
||||||
|
),
|
||||||
|
DeviceProcessorStep(device=config.device),
|
||||||
|
NormalizerProcessorStep(
|
||||||
|
features={**config.input_features, **config.output_features},
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
output_steps = [
|
||||||
|
UnnormalizerProcessorStep(
|
||||||
|
features=config.output_features,
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
),
|
||||||
|
DeviceProcessorStep(device="cpu"),
|
||||||
|
]
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||||
|
steps=input_steps,
|
||||||
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
),
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||||
|
steps=output_steps,
|
||||||
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
to_transition=policy_action_to_transition,
|
||||||
|
to_output=transition_to_policy_action,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_recipe(path_str: str) -> TrainingRecipe:
|
||||||
|
"""Resolve ``path_str`` to a ``TrainingRecipe``.
|
||||||
|
|
||||||
|
Accepts an absolute path or a path relative to
|
||||||
|
``src/lerobot/configs/`` so recipe authors can write
|
||||||
|
``--policy.recipe_path=recipes/smolvla2_hirobot.yaml``.
|
||||||
|
"""
|
||||||
|
p = Path(path_str)
|
||||||
|
if not p.is_absolute() and not p.exists():
|
||||||
|
from lerobot.configs import recipe as _recipe_module # noqa: PLC0415
|
||||||
|
|
||||||
|
configs_dir = Path(_recipe_module.__file__).resolve().parent
|
||||||
|
candidate = configs_dir / path_str
|
||||||
|
if candidate.exists():
|
||||||
|
p = candidate
|
||||||
|
return TrainingRecipe.from_yaml(p)
|
||||||
|
|||||||
Reference in New Issue
Block a user