mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
review: dedupe regex, centralize column names, harden collate, more tests
* **#2 — dedupe `_PLACEHOLDER_RE`.** The same regex was compiled in
`recipe.py` and `language_render.py`. Promote to module-level
`PLACEHOLDER_RE` in `recipe.py` (its primary owner — declares
template syntax) and import from `language_render.py`.
* **#3 — centralize language column names.** `io_utils.py` had
hardcoded `{"language_persistent", "language_events"}` literals at
two sites. Replace with `LANGUAGE_COLUMNS` import so a future column
rename can't silently desync.
* **#4 — defensive collate preserved-keys.** `lerobot_collate_fn`
silently filtered language fields from samples that didn't have
them, which would hand downstream consumers a preserved list
shorter than the tensor batch. Now: if any sample carries a key,
every sample in the batch must carry it; otherwise raise a
`ValueError` so the upstream rendering bug surfaces at the boundary.
* **#5 — `_scalar` rejects non-singleton lists.** Previously a zero-
or multi-element list fell through and triggered confusing
`float([])` errors downstream. Now raises `ValueError` with the
actual length.
* **#6 — refactor `_extract_complementary_data`.** Replace 11 lines
of `key = {... if ... else {}}` plus an 11-line splat dict with a
single `_COMPLEMENTARY_KEYS` tuple iterated once.
* **#7 — document `EXTENDED_STYLES`.** Was an empty `set()` with no
comment. Add a docstring explaining it's an intentional extension
point: downstream modules append project-local styles before
`column_for_style` is called.
* **#9 — `tools.mdx` notes the runtime layer is future work.** The
page referenced `src/lerobot/tools/`, `registry.py`, and
`get_tools(meta)` — none exist in this PR. Added a callout at the
start of "How to add your own tool" plus a note on the
implementations paragraph.
* **#10 — tests for YAML round-trip, malformed rows, blend
validation.** `test_recipe.py` grew from 1 case to 12 covering:
blend-or-messages exclusivity, target-turn requirement, blend
emptiness, weight presence/positivity, nested-blend rejection,
`from_dict` with nested blends, `from_yaml` / `load_recipe`
agreement, top-level non-mapping rejection. Added a malformed-row
test for `_normalize_rows` that asserts non-dict entries raise
`TypeError`.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
+13
-3
@@ -66,9 +66,11 @@ prompt_str = tokenizer.apply_chat_template(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
**The implementations** — runnable Python — live under
|
**The implementations** — runnable Python — will live under
|
||||||
`src/lerobot/tools/`, one file per tool. The canonical `say`
|
`src/lerobot/tools/`, one file per tool. The runtime dispatcher and
|
||||||
implementation wraps Kyutai's pocket-tts model.
|
the canonical `say` implementation (wrapping Kyutai's pocket-tts) land
|
||||||
|
in a follow-up PR; this PR ships only the catalog storage and
|
||||||
|
fallback constant.
|
||||||
|
|
||||||
## Per-row tool _invocations_
|
## Per-row tool _invocations_
|
||||||
|
|
||||||
@@ -114,6 +116,14 @@ the matching implementation.
|
|||||||
|
|
||||||
## How to add your own tool
|
## How to add your own tool
|
||||||
|
|
||||||
|
> **Note:** Steps 2 and 3 below describe the runtime layer
|
||||||
|
> (`src/lerobot/tools/`, the `Tool` protocol, `TOOL_REGISTRY`,
|
||||||
|
> `get_tools(meta)`) which lands in a follow-up PR. Today (this PR
|
||||||
|
> only), Step 1 is enough to make the tool visible to the chat
|
||||||
|
> template via `meta.tools` so the model can learn to _generate_ the
|
||||||
|
> call. Executing the call at inference is what the follow-up PR
|
||||||
|
> wires up.
|
||||||
|
|
||||||
Three steps. Concrete example: a `record_observation` tool the policy
|
Three steps. Concrete example: a `record_observation` tool the policy
|
||||||
can call to capture an extra observation outside the regular control
|
can call to capture an extra observation outside the regular control
|
||||||
loop.
|
loop.
|
||||||
|
|||||||
@@ -34,7 +34,10 @@ DEFAULT_BINDINGS = {
|
|||||||
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
||||||
}
|
}
|
||||||
|
|
||||||
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||||
|
"""``${name}`` placeholder pattern used by both recipe binding-reference
|
||||||
|
discovery (here) and rendered-message substitution (in ``language_render``)."""
|
||||||
|
|
||||||
_VALID_ROLES = frozenset(get_args(MessageRole))
|
_VALID_ROLES = frozenset(get_args(MessageRole))
|
||||||
_VALID_STREAMS = frozenset(get_args(MessageStream))
|
_VALID_STREAMS = frozenset(get_args(MessageStream))
|
||||||
|
|
||||||
@@ -178,13 +181,13 @@ def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[
|
|||||||
if content is None:
|
if content is None:
|
||||||
return set()
|
return set()
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
return set(_PLACEHOLDER_RE.findall(content))
|
return set(PLACEHOLDER_RE.findall(content))
|
||||||
|
|
||||||
names: set[str] = set()
|
names: set[str] = set()
|
||||||
for block in content:
|
for block in content:
|
||||||
for value in block.values():
|
for value in block.values():
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
names.update(_PLACEHOLDER_RE.findall(value))
|
names.update(PLACEHOLDER_RE.findall(value))
|
||||||
return names
|
return names
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from torchvision import transforms
|
|||||||
from lerobot.utils.io_utils import load_json, write_json
|
from lerobot.utils.io_utils import load_json, write_json
|
||||||
from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict
|
from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict
|
||||||
|
|
||||||
|
from .language import LANGUAGE_COLUMNS
|
||||||
from .utils import (
|
from .utils import (
|
||||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_EPISODES_PATH,
|
DEFAULT_EPISODES_PATH,
|
||||||
@@ -256,7 +257,7 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
|||||||
dict: The batch with items converted to torch tensors.
|
dict: The batch with items converted to torch tensors.
|
||||||
"""
|
"""
|
||||||
for key in items_dict:
|
for key in items_dict:
|
||||||
if key in {"language_persistent", "language_events"}:
|
if key in LANGUAGE_COLUMNS:
|
||||||
continue
|
continue
|
||||||
first_item = items_dict[key][0]
|
first_item = items_dict[key][0]
|
||||||
if isinstance(first_item, PILImage.Image):
|
if isinstance(first_item, PILImage.Image):
|
||||||
@@ -297,12 +298,9 @@ def item_to_torch(item: dict) -> dict:
|
|||||||
Returns:
|
Returns:
|
||||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||||
"""
|
"""
|
||||||
|
skip_keys = {"task", *LANGUAGE_COLUMNS}
|
||||||
for key, val in item.items():
|
for key, val in item.items():
|
||||||
if isinstance(val, (np.ndarray | list)) and key not in [
|
if isinstance(val, (np.ndarray | list)) and key not in skip_keys:
|
||||||
"task",
|
|
||||||
"language_persistent",
|
|
||||||
"language_events",
|
|
||||||
]:
|
|
||||||
# Convert numpy arrays and lists to torch tensors
|
# Convert numpy arrays and lists to torch tensors
|
||||||
item[key] = torch.tensor(val)
|
item[key] = torch.tensor(val)
|
||||||
return item
|
return item
|
||||||
|
|||||||
@@ -37,7 +37,13 @@ CORE_STYLES = {
|
|||||||
"trace",
|
"trace",
|
||||||
"task_aug",
|
"task_aug",
|
||||||
}
|
}
|
||||||
EXTENDED_STYLES = set()
|
# Project-local styles can be registered at import time by appending to
|
||||||
|
# ``EXTENDED_STYLES`` before ``column_for_style`` is called. Anything added
|
||||||
|
# here is treated as a known style alongside ``CORE_STYLES`` for resolver
|
||||||
|
# validation. Empty by default — populate from a downstream module that
|
||||||
|
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
|
||||||
|
# the new style's column.
|
||||||
|
EXTENDED_STYLES: set[str] = set()
|
||||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||||
|
|
||||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import re
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe
|
from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe
|
||||||
|
|
||||||
from .language import LANGUAGE_PERSISTENT, column_for_style
|
from .language import LANGUAGE_PERSISTENT, column_for_style
|
||||||
|
|
||||||
@@ -30,7 +30,6 @@ LanguageRow = dict[str, Any]
|
|||||||
RenderedMessages = dict[str, list[Any]]
|
RenderedMessages = dict[str, list[Any]]
|
||||||
|
|
||||||
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
|
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
|
||||||
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
|
||||||
|
|
||||||
|
|
||||||
def active_at(
|
def active_at(
|
||||||
@@ -376,7 +375,7 @@ def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) ->
|
|||||||
return "" if content is None else str(content)
|
return "" if content is None else str(content)
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
return _PLACEHOLDER_RE.sub(replace, template)
|
return PLACEHOLDER_RE.sub(replace, template)
|
||||||
|
|
||||||
|
|
||||||
def _validate_rendered(rendered: RenderedMessages) -> None:
|
def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||||
|
|||||||
@@ -153,49 +153,30 @@ def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | An
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
_COMPLEMENTARY_KEYS = (
|
||||||
|
"task",
|
||||||
|
"index",
|
||||||
|
"task_index",
|
||||||
|
"episode_index",
|
||||||
|
"timestamp",
|
||||||
|
"language_persistent",
|
||||||
|
"language_events",
|
||||||
|
"messages",
|
||||||
|
"message_streams",
|
||||||
|
"target_message_indices",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""Extract complementary data from a batch dictionary.
|
||||||
Extract complementary data from a batch dictionary.
|
|
||||||
|
|
||||||
This includes padding flags, task description, and indices.
|
Includes padding flags (any key containing ``_is_pad``) plus the fixed
|
||||||
|
set of metadata / language keys defined in ``_COMPLEMENTARY_KEYS`` —
|
||||||
Args:
|
each only when present in ``batch``.
|
||||||
batch: The batch dictionary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary with the extracted complementary data.
|
|
||||||
"""
|
"""
|
||||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
extras = {k: batch[k] for k in _COMPLEMENTARY_KEYS if k in batch}
|
||||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
return {**pad_keys, **extras}
|
||||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
|
||||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
|
||||||
timestamp_key = {"timestamp": batch["timestamp"]} if "timestamp" in batch else {}
|
|
||||||
language_persistent_key = (
|
|
||||||
{"language_persistent": batch["language_persistent"]} if "language_persistent" in batch else {}
|
|
||||||
)
|
|
||||||
language_events_key = {"language_events": batch["language_events"]} if "language_events" in batch else {}
|
|
||||||
messages_key = {"messages": batch["messages"]} if "messages" in batch else {}
|
|
||||||
message_streams_key = {"message_streams": batch["message_streams"]} if "message_streams" in batch else {}
|
|
||||||
target_message_indices_key = (
|
|
||||||
{"target_message_indices": batch["target_message_indices"]}
|
|
||||||
if "target_message_indices" in batch
|
|
||||||
else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
**pad_keys,
|
|
||||||
**task_key,
|
|
||||||
**index_key,
|
|
||||||
**task_index_key,
|
|
||||||
**episode_index_key,
|
|
||||||
**timestamp_key,
|
|
||||||
**language_persistent_key,
|
|
||||||
**language_events_key,
|
|
||||||
**messages_key,
|
|
||||||
**message_streams_key,
|
|
||||||
**target_message_indices_key,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def create_transition(
|
def create_transition(
|
||||||
|
|||||||
@@ -87,6 +87,8 @@ def _scalar(value: Any) -> float | int:
|
|||||||
"""Unwrap a tensor/array/single-element list into a Python scalar."""
|
"""Unwrap a tensor/array/single-element list into a Python scalar."""
|
||||||
if hasattr(value, "item"):
|
if hasattr(value, "item"):
|
||||||
return value.item()
|
return value.item()
|
||||||
if isinstance(value, list) and len(value) == 1:
|
if isinstance(value, list):
|
||||||
|
if len(value) != 1:
|
||||||
|
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
|
||||||
return _scalar(value[0])
|
return _scalar(value[0])
|
||||||
return value
|
return value
|
||||||
|
|||||||
@@ -36,11 +36,22 @@ def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | N
|
|||||||
if not batch:
|
if not batch:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
preserved = {
|
# All-or-nothing per key: a partial-presence batch (e.g. half the samples
|
||||||
key: [sample[key] for sample in batch if key in sample]
|
# carry `messages` and half don't) is a real bug in the upstream
|
||||||
for key in _PYTHON_LIST_KEYS
|
# rendering step — silently filtering would hand downstream consumers a
|
||||||
if any(key in sample for sample in batch)
|
# preserved list shorter than the tensor batch. Raise instead so the
|
||||||
}
|
# mismatch surfaces at the boundary.
|
||||||
|
preserved: dict[str, list[Any]] = {}
|
||||||
|
for key in _PYTHON_LIST_KEYS:
|
||||||
|
presence = [key in sample for sample in batch]
|
||||||
|
if not any(presence):
|
||||||
|
continue
|
||||||
|
if not all(presence):
|
||||||
|
raise ValueError(
|
||||||
|
f"Inconsistent batch: {sum(presence)}/{len(batch)} samples carry {key!r}; "
|
||||||
|
f"every sample in a batch must agree."
|
||||||
|
)
|
||||||
|
preserved[key] = [sample[key] for sample in batch]
|
||||||
tensorizable = [
|
tensorizable = [
|
||||||
{
|
{
|
||||||
key: value
|
key: value
|
||||||
|
|||||||
@@ -1,8 +1,22 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from textwrap import dedent
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
from lerobot.configs.recipe import MessageTurn, TrainingRecipe, load_recipe
|
||||||
|
|
||||||
|
|
||||||
|
def _minimal_message_turn(content: str = "${task}") -> MessageTurn:
|
||||||
|
return MessageTurn(role="user", content=content, stream="high_level")
|
||||||
|
|
||||||
|
|
||||||
|
def _minimal_target_turn() -> MessageTurn:
|
||||||
|
return MessageTurn(role="assistant", content="ok", stream="high_level", target=True)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Message-recipe validation ────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def test_message_recipe_validates_unknown_binding():
|
def test_message_recipe_validates_unknown_binding():
|
||||||
@@ -10,6 +24,134 @@ def test_message_recipe_validates_unknown_binding():
|
|||||||
TrainingRecipe(
|
TrainingRecipe(
|
||||||
messages=[
|
messages=[
|
||||||
MessageTurn(role="user", content="${missing}", stream="high_level"),
|
MessageTurn(role="user", content="${missing}", stream="high_level"),
|
||||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
_minimal_target_turn(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_recipe_requires_at_least_one_target():
|
||||||
|
with pytest.raises(ValueError, match="target"):
|
||||||
|
TrainingRecipe(
|
||||||
|
messages=[
|
||||||
|
_minimal_message_turn(),
|
||||||
|
MessageTurn(role="assistant", content="no target", stream="high_level"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_recipe_rejects_both_messages_and_blend():
|
||||||
|
with pytest.raises(ValueError, match="only one"):
|
||||||
|
TrainingRecipe(
|
||||||
|
messages=[_minimal_message_turn(), _minimal_target_turn()],
|
||||||
|
blend={"a": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_recipe_rejects_neither_messages_nor_blend():
|
||||||
|
with pytest.raises(ValueError, match="must set one"):
|
||||||
|
TrainingRecipe()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Blend validation ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_blend_must_be_non_empty():
|
||||||
|
with pytest.raises(ValueError, match="at least one component"):
|
||||||
|
TrainingRecipe(blend={})
|
||||||
|
|
||||||
|
|
||||||
|
def test_blend_component_must_define_weight():
|
||||||
|
with pytest.raises(ValueError, match="weight"):
|
||||||
|
TrainingRecipe(blend={"a": TrainingRecipe(messages=[_minimal_target_turn()])})
|
||||||
|
|
||||||
|
|
||||||
|
def test_blend_component_weight_must_be_positive():
|
||||||
|
with pytest.raises(ValueError, match="positive weight"):
|
||||||
|
TrainingRecipe(blend={"a": TrainingRecipe(weight=0.0, messages=[_minimal_target_turn()])})
|
||||||
|
|
||||||
|
|
||||||
|
def test_blend_component_must_define_messages():
|
||||||
|
# A bare TrainingRecipe(weight=1.0) would itself raise; build it without
|
||||||
|
# going through __post_init__ to exercise the blend-level validator.
|
||||||
|
bad = TrainingRecipe.__new__(TrainingRecipe)
|
||||||
|
bad.messages = None
|
||||||
|
bad.bindings = None
|
||||||
|
bad.blend = None
|
||||||
|
bad.weight = 1.0
|
||||||
|
with pytest.raises(ValueError, match="must define messages"):
|
||||||
|
TrainingRecipe(blend={"a": bad})
|
||||||
|
|
||||||
|
|
||||||
|
def test_blend_components_cannot_themselves_define_a_blend():
|
||||||
|
inner = TrainingRecipe(blend={"x": TrainingRecipe(weight=1.0, messages=[_minimal_target_turn()])})
|
||||||
|
# Force-bypass the inner component's normal validation so the test
|
||||||
|
# exercises the outer blend's "no nested blends" rule directly.
|
||||||
|
nested = TrainingRecipe.__new__(TrainingRecipe)
|
||||||
|
nested.messages = None
|
||||||
|
nested.bindings = None
|
||||||
|
nested.blend = inner.blend
|
||||||
|
nested.weight = 1.0
|
||||||
|
with pytest.raises(ValueError, match="cannot itself define a blend"):
|
||||||
|
TrainingRecipe(blend={"outer": nested})
|
||||||
|
|
||||||
|
|
||||||
|
# ── from_dict / from_yaml round-trips ────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_dict_with_nested_blend():
|
||||||
|
recipe = TrainingRecipe.from_dict(
|
||||||
|
{
|
||||||
|
"blend": {
|
||||||
|
"a": {
|
||||||
|
"weight": 1.0,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "${task}", "stream": "high_level"},
|
||||||
|
{"role": "assistant", "content": "a", "stream": "high_level", "target": True},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"b": {
|
||||||
|
"weight": 2.0,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "${task}", "stream": "high_level"},
|
||||||
|
{"role": "assistant", "content": "b", "stream": "high_level", "target": True},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert recipe.blend is not None
|
||||||
|
assert set(recipe.blend) == {"a", "b"}
|
||||||
|
assert recipe.blend["b"].weight == 2.0
|
||||||
|
# Inner messages were promoted to MessageTurn instances.
|
||||||
|
assert isinstance(recipe.blend["a"].messages[0], MessageTurn)
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_yaml_round_trips_through_load_recipe(tmp_path: Path):
|
||||||
|
yaml_text = dedent(
|
||||||
|
"""
|
||||||
|
bindings:
|
||||||
|
custom: "active_at(t, style=subtask)"
|
||||||
|
messages:
|
||||||
|
- {role: user, content: "${task}: ${custom}", stream: high_level}
|
||||||
|
- {role: assistant, content: "ok", stream: high_level, target: true}
|
||||||
|
"""
|
||||||
|
).strip()
|
||||||
|
path = tmp_path / "recipe.yaml"
|
||||||
|
path.write_text(yaml_text)
|
||||||
|
|
||||||
|
via_classmethod = TrainingRecipe.from_yaml(path)
|
||||||
|
via_helper = load_recipe(path)
|
||||||
|
|
||||||
|
assert via_classmethod.bindings == {"custom": "active_at(t, style=subtask)"}
|
||||||
|
assert via_classmethod.messages[1].target is True
|
||||||
|
# ``load_recipe`` is just a wrapper, but assert the two paths agree
|
||||||
|
# on the structural result so a future divergence is caught here.
|
||||||
|
assert via_helper.bindings == via_classmethod.bindings
|
||||||
|
assert len(via_helper.messages) == len(via_classmethod.messages)
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_yaml_rejects_non_mapping(tmp_path: Path):
|
||||||
|
path = tmp_path / "bad.yaml"
|
||||||
|
path.write_text("- just\n- a\n- list\n")
|
||||||
|
with pytest.raises(ValueError, match="mapping at the top level"):
|
||||||
|
TrainingRecipe.from_yaml(path)
|
||||||
|
|||||||
@@ -342,6 +342,29 @@ def test_resolve_task_explicit_override_beats_rephrasings():
|
|||||||
assert rendered["messages"][0]["content"] == "explicit override wins"
|
assert rendered["messages"][0]["content"] == "explicit override wins"
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_sample_rejects_non_dict_language_rows():
|
||||||
|
"""``_normalize_rows`` must surface malformed inputs as TypeError.
|
||||||
|
|
||||||
|
A pipeline that hands the renderer a non-dict (e.g. a stray string)
|
||||||
|
is a real upstream bug — silent skipping would let it propagate.
|
||||||
|
"""
|
||||||
|
recipe = TrainingRecipe(
|
||||||
|
messages=[
|
||||||
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||||
|
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
with pytest.raises(TypeError, match="must be dictionaries"):
|
||||||
|
render_sample(
|
||||||
|
recipe=recipe,
|
||||||
|
persistent=["not a dict"],
|
||||||
|
events=[],
|
||||||
|
t=0.0,
|
||||||
|
sample_idx=0,
|
||||||
|
task="x",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_low_level_branch_renders_active_subtask():
|
def test_low_level_branch_renders_active_subtask():
|
||||||
low_level = TrainingRecipe(
|
low_level = TrainingRecipe(
|
||||||
blend={
|
blend={
|
||||||
|
|||||||
Reference in New Issue
Block a user