mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 14:39:43 +00:00
Compare commits
2 Commits
1ca38d9748
...
e3e9374e2c
| Author | SHA1 | Date | |
|---|---|---|---|
| e3e9374e2c | |||
| c1a0c601e2 |
@@ -33,6 +33,8 @@
|
|||||||
title: Using the Dataset Tools
|
title: Using the Dataset Tools
|
||||||
- local: language_and_recipes
|
- local: language_and_recipes
|
||||||
title: Language Columns and Recipes
|
title: Language Columns and Recipes
|
||||||
|
- local: tools
|
||||||
|
title: Tools
|
||||||
- local: streaming_video_encoding
|
- local: streaming_video_encoding
|
||||||
title: Streaming Video Encoding
|
title: Streaming Video Encoding
|
||||||
title: "Datasets"
|
title: "Datasets"
|
||||||
|
|||||||
@@ -0,0 +1,198 @@
|
|||||||
|
# Tools
|
||||||
|
|
||||||
|
LeRobot v3.1 supports **tool calls** in policies — assistant messages can
|
||||||
|
emit structured invocations like `say(text="OK, starting now")` that the
|
||||||
|
runtime dispatches to a real implementation (TTS, controller, logger, …).
|
||||||
|
|
||||||
|
This page covers:
|
||||||
|
|
||||||
|
1. Where the tool catalog lives (PR 1).
|
||||||
|
2. How the annotation pipeline produces tool-call atoms (PR 2).
|
||||||
|
3. How to add your own tool (PR 3).
|
||||||
|
|
||||||
|
## Where tools are declared
|
||||||
|
|
||||||
|
Two layers.
|
||||||
|
|
||||||
|
**The catalog** — a list of OpenAI-style function schemas — lives at
|
||||||
|
`meta/info.json["tools"]` on each dataset. Example:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"features": { "...": "..." },
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "say",
|
||||||
|
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": { "type": "string", "description": "The verbatim text to speak." }
|
||||||
|
},
|
||||||
|
"required": ["text"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Read it via the dataset metadata accessor:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||||
|
|
||||||
|
meta = LeRobotDatasetMetadata(repo_id="pepijn/super_poulain_final_annotations")
|
||||||
|
tools = meta.tools # list[dict] — OpenAI tool schemas
|
||||||
|
```
|
||||||
|
|
||||||
|
If the dataset's `info.json` doesn't declare any tools, `meta.tools`
|
||||||
|
returns `DEFAULT_TOOLS` from `lerobot.datasets.language` — currently a
|
||||||
|
single-entry list with the canonical `say` schema. So unannotated
|
||||||
|
datasets and chat-template consumers keep working without any
|
||||||
|
configuration:
|
||||||
|
|
||||||
|
```python
|
||||||
|
prompt_str = tokenizer.apply_chat_template(
|
||||||
|
sample["messages"],
|
||||||
|
tools=meta.tools, # works either way
|
||||||
|
add_generation_prompt=False,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**The implementations** — runnable Python — live under
|
||||||
|
`src/lerobot/tools/`, one file per tool. The `say` implementation
|
||||||
|
arrives in PR 3 and wraps Kyutai's pocket-tts model.
|
||||||
|
|
||||||
|
## Per-row tool *invocations*
|
||||||
|
|
||||||
|
The catalog above describes *what can be called*. The actual *call* — the
|
||||||
|
function name plus the argument values — is stored per-row, on the
|
||||||
|
assistant atoms in `language_events`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": null,
|
||||||
|
"style": null,
|
||||||
|
"timestamp": 12.4,
|
||||||
|
"camera": null,
|
||||||
|
"tool_calls": [
|
||||||
|
{ "type": "function",
|
||||||
|
"function": { "name": "say", "arguments": { "text": "On it." } } }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Recipes splice these into rendered messages via `tool_calls_from`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
user_interjection_response:
|
||||||
|
bindings:
|
||||||
|
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
||||||
|
messages:
|
||||||
|
- { role: user, content: "${task}", stream: high_level }
|
||||||
|
- { role: assistant, content: "${current_plan}", stream: high_level,
|
||||||
|
target: true, tool_calls_from: speech }
|
||||||
|
```
|
||||||
|
|
||||||
|
The model's training target is one assistant turn that carries both the
|
||||||
|
plan text *and* the `say` tool call. At inference, the runtime parses
|
||||||
|
the generated text back into structured `tool_calls` and dispatches to
|
||||||
|
the matching implementation.
|
||||||
|
|
||||||
|
## How to add your own tool
|
||||||
|
|
||||||
|
Three steps. Concrete example: a `record_observation` tool the policy
|
||||||
|
can call to capture an extra observation outside the regular control
|
||||||
|
loop.
|
||||||
|
|
||||||
|
### Step 1 — declare the schema
|
||||||
|
|
||||||
|
Add an entry under `meta/info.json["tools"]`. Either edit the file
|
||||||
|
directly on disk *before* running the annotation pipeline (it'll be
|
||||||
|
preserved) or hand it to `lerobot-annotate` via a config flag (PR 2 —
|
||||||
|
exact CLI lands with the pipeline change).
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"tools": [
|
||||||
|
{ "type": "function", "function": { "name": "say", "...": "..." } },
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "record_observation",
|
||||||
|
"description": "Capture a high-resolution still image for the user.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"label": { "type": "string", "description": "Short label for the saved image." }
|
||||||
|
},
|
||||||
|
"required": ["label"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The schema follows OpenAI's function-calling convention exactly, so the
|
||||||
|
chat template can render it natively.
|
||||||
|
|
||||||
|
### Step 2 — implement the call
|
||||||
|
|
||||||
|
Create `src/lerobot/tools/record_observation.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from .base import Tool
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
RECORD_OBSERVATION_SCHEMA: dict[str, Any] = { "...": "..." } # mirrors the JSON above
|
||||||
|
|
||||||
|
|
||||||
|
class RecordObservationTool:
|
||||||
|
name = "record_observation"
|
||||||
|
schema = RECORD_OBSERVATION_SCHEMA
|
||||||
|
|
||||||
|
def __init__(self, schema: dict | None = None, output_dir: str = "."):
|
||||||
|
self.output_dir = output_dir
|
||||||
|
|
||||||
|
def call(self, arguments: dict) -> str:
|
||||||
|
label = arguments["label"]
|
||||||
|
# ... save the latest camera frame to <output_dir>/<label>.png ...
|
||||||
|
return f"saved {label}.png"
|
||||||
|
```
|
||||||
|
|
||||||
|
One file per tool keeps dependencies isolated — `record_observation`
|
||||||
|
might pull `pillow`, while `say` (PR 3) pulls `pocket-tts`. Users
|
||||||
|
installing only the tools they need avoid heavy transitive deps.
|
||||||
|
|
||||||
|
### Step 3 — register it
|
||||||
|
|
||||||
|
Add to `src/lerobot/tools/registry.py` (PR 3):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from .record_observation import RecordObservationTool
|
||||||
|
|
||||||
|
TOOL_REGISTRY["record_observation"] = RecordObservationTool
|
||||||
|
```
|
||||||
|
|
||||||
|
That's it. At runtime `get_tools(meta)` looks up each schema in
|
||||||
|
`meta.tools`, instantiates the matching registered class, and returns
|
||||||
|
a name → instance dict the dispatcher can route into.
|
||||||
|
|
||||||
|
## Where this fits in the three-PR stack
|
||||||
|
|
||||||
|
| Layer | PR | What lands |
|
||||||
|
|---|---|---|
|
||||||
|
| Catalog storage in `meta/info.json` + `meta.tools` accessor | PR 1 | This page; `SAY_TOOL_SCHEMA`, `DEFAULT_TOOLS` constants in `lerobot.datasets.language`; `LeRobotDatasetMetadata.tools` property |
|
||||||
|
| Annotation pipeline writes `tools` to meta after a run; honors anything users pre-populated | PR 2 | `lerobot-annotate` ensures `meta/info.json["tools"]` includes the canonical `say` and merges any user-declared tools |
|
||||||
|
| Runnable implementations under `src/lerobot/tools/`; runtime dispatcher; `say.py` wired to Kyutai's pocket-tts | PR 3 | One file per tool; `Tool` protocol; `TOOL_REGISTRY`; optional `[tools]` extra in `pyproject.toml` |
|
||||||
|
|
||||||
|
If you want to use a tool *without* writing an implementation (e.g. for
|
||||||
|
training-time chat-template formatting only), step 1 alone is enough —
|
||||||
|
the model still learns to *generate* the call. Steps 2 and 3 are only
|
||||||
|
needed to actually *execute* it at inference.
|
||||||
@@ -318,6 +318,28 @@ class LeRobotDatasetMetadata:
|
|||||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tools(self) -> list[dict]:
|
||||||
|
"""OpenAI-style tool schemas declared by this dataset.
|
||||||
|
|
||||||
|
Read from ``meta/info.json["tools"]``. Returns a copy, so callers
|
||||||
|
can mutate the result safely. Falls back to
|
||||||
|
:data:`lerobot.datasets.language.DEFAULT_TOOLS` (the canonical
|
||||||
|
``say`` schema) when the dataset doesn't declare any — that way
|
||||||
|
unannotated datasets and chat-template consumers
|
||||||
|
(``apply_chat_template(messages, tools=meta.tools)``) keep
|
||||||
|
working out of the box.
|
||||||
|
|
||||||
|
Implementations live under :mod:`lerobot.tools` (one file per
|
||||||
|
tool); see ``docs/source/tools.mdx`` for the authoring guide.
|
||||||
|
"""
|
||||||
|
from .language import DEFAULT_TOOLS # noqa: PLC0415 (avoid circular import)
|
||||||
|
|
||||||
|
declared = self.info.get("tools")
|
||||||
|
if isinstance(declared, list) and declared:
|
||||||
|
return [dict(t) for t in declared]
|
||||||
|
return [dict(t) for t in DEFAULT_TOOLS]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self) -> dict[str, list | dict]:
|
def names(self) -> dict[str, list | dict]:
|
||||||
"""Names of the various dimensions of vector modalities."""
|
"""Names of the various dimensions of vector modalities."""
|
||||||
|
|||||||
@@ -27,11 +27,20 @@ LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
|
|||||||
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
|
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "camera", "tool_calls")
|
||||||
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
|
EVENT_ROW_FIELDS = ("role", "content", "style", "camera", "tool_calls")
|
||||||
|
|
||||||
CORE_STYLES = {"subtask", "plan", "memory", "motion", "interjection", "vqa", "trace"}
|
CORE_STYLES = {
|
||||||
|
"subtask",
|
||||||
|
"plan",
|
||||||
|
"memory",
|
||||||
|
"motion",
|
||||||
|
"interjection",
|
||||||
|
"vqa",
|
||||||
|
"trace",
|
||||||
|
"task_aug",
|
||||||
|
}
|
||||||
EXTENDED_STYLES = set()
|
EXTENDED_STYLES = set()
|
||||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||||
|
|
||||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion"}
|
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
||||||
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
EVENT_ONLY_STYLES = {"interjection", "vqa", "trace"}
|
||||||
|
|
||||||
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
|
# Styles whose ``content`` is grounded in a specific camera view. Rows of these
|
||||||
@@ -174,6 +183,43 @@ def validate_camera_field(style: str | None, camera: str | None) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Tool registry --------------------------------------------------------
|
||||||
|
# Tools declared on a dataset live in ``meta/info.json["tools"]`` as a list
|
||||||
|
# of OpenAI-style function schemas. The runtime / training stack reads them
|
||||||
|
# through :class:`LeRobotDatasetMetadata.tools` (with these constants as
|
||||||
|
# fallback when the dataset doesn't declare any). Implementations live
|
||||||
|
# under :mod:`lerobot.tools` (one file per tool); see
|
||||||
|
# ``docs/source/tools.mdx`` for the authoring guide.
|
||||||
|
|
||||||
|
SAY_TOOL_SCHEMA: dict = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "say",
|
||||||
|
"description": "Speak a short utterance to the user via the TTS executor.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The verbatim text to speak.",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["text"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
"""Canonical schema for the ``say`` tool emitted by the steerable
|
||||||
|
annotation pipeline (PR 2 Module 2). Single source of truth — PR 2's
|
||||||
|
writer, PR 3's runtime tool registry, and the dataset visualizer all
|
||||||
|
import this constant rather than duplicating the dict."""
|
||||||
|
|
||||||
|
DEFAULT_TOOLS: list[dict] = [SAY_TOOL_SCHEMA]
|
||||||
|
"""Fallback tools list. Returned by ``LeRobotDatasetMetadata.tools``
|
||||||
|
when ``meta/info.json["tools"]`` is unset, so unannotated datasets and
|
||||||
|
chat-template consumers (``apply_chat_template(messages, tools=...)``)
|
||||||
|
keep working out of the box."""
|
||||||
|
|
||||||
|
|
||||||
def column_for_style(style: str | None) -> LanguageColumn:
|
def column_for_style(style: str | None) -> LanguageColumn:
|
||||||
"""Map a language style to the column where rows of that style are stored.
|
"""Map a language style to the column where rows of that style are stored.
|
||||||
|
|
||||||
|
|||||||
@@ -198,6 +198,7 @@ def render_sample(
|
|||||||
persistent=persistent_rows,
|
persistent=persistent_rows,
|
||||||
events=event_rows,
|
events=event_rows,
|
||||||
t=t,
|
t=t,
|
||||||
|
sample_idx=sample_idx,
|
||||||
task=task,
|
task=task,
|
||||||
dataset_ctx=dataset_ctx,
|
dataset_ctx=dataset_ctx,
|
||||||
)
|
)
|
||||||
@@ -232,21 +233,65 @@ def _resolve_bindings(
|
|||||||
persistent: Sequence[LanguageRow],
|
persistent: Sequence[LanguageRow],
|
||||||
events: Sequence[LanguageRow],
|
events: Sequence[LanguageRow],
|
||||||
t: float,
|
t: float,
|
||||||
|
sample_idx: int,
|
||||||
task: str | None,
|
task: str | None,
|
||||||
dataset_ctx: Any | None,
|
dataset_ctx: Any | None,
|
||||||
) -> dict[str, LanguageRow | str | None]:
|
) -> dict[str, LanguageRow | str | None]:
|
||||||
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
||||||
bindings: dict[str, LanguageRow | str | None] = {"task": _resolve_task(task, dataset_ctx)}
|
bindings: dict[str, LanguageRow | str | None] = {
|
||||||
|
"task": _resolve_task(
|
||||||
|
task, dataset_ctx, persistent=persistent, sample_idx=sample_idx
|
||||||
|
),
|
||||||
|
}
|
||||||
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
||||||
for name, spec in specs.items():
|
for name, spec in specs.items():
|
||||||
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
|
bindings[name] = _resolve_spec(spec, persistent=persistent, events=events, t=t)
|
||||||
return bindings
|
return bindings
|
||||||
|
|
||||||
|
|
||||||
def _resolve_task(task: str | None, dataset_ctx: Any | None) -> str | None:
|
def _resolve_task(
|
||||||
"""Return ``task`` if set, otherwise look it up on ``dataset_ctx``."""
|
task: str | None,
|
||||||
|
dataset_ctx: Any | None,
|
||||||
|
*,
|
||||||
|
persistent: Sequence[LanguageRow] = (),
|
||||||
|
sample_idx: int = 0,
|
||||||
|
) -> str | None:
|
||||||
|
"""Return the task string for ``sample_idx``.
|
||||||
|
|
||||||
|
Resolution order:
|
||||||
|
|
||||||
|
1. Explicit ``task`` override (caller-supplied) wins.
|
||||||
|
2. If ``persistent`` contains rows of style ``task_aug`` (role=user),
|
||||||
|
deterministically pick one by ``sample_idx`` so each frame of an
|
||||||
|
episode rotates through the available rephrasings across an epoch.
|
||||||
|
This realizes Xiao 2022 / CAST-style task-prompt diversity without
|
||||||
|
changing ``meta/tasks.parquet`` and without forcing recipes to opt
|
||||||
|
in: ``${task}`` automatically picks a rephrasing when one exists,
|
||||||
|
and falls back to the canonical task otherwise. Recipes that want
|
||||||
|
the literal canonical task can override the binding.
|
||||||
|
3. Otherwise read the canonical task from ``dataset_ctx`` (which is
|
||||||
|
backed by ``meta/tasks.parquet``).
|
||||||
|
"""
|
||||||
if task is not None:
|
if task is not None:
|
||||||
return task
|
return task
|
||||||
|
|
||||||
|
aug_rows = [
|
||||||
|
r
|
||||||
|
for r in persistent
|
||||||
|
if r.get("style") == "task_aug" and r.get("role") == "user"
|
||||||
|
]
|
||||||
|
if aug_rows:
|
||||||
|
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
||||||
|
# rotation is reproducible across runs (Python's built-in ``hash``
|
||||||
|
# is process-randomized).
|
||||||
|
digest = hashlib.blake2b(
|
||||||
|
f"task_aug:{sample_idx}".encode(), digest_size=8
|
||||||
|
).digest()
|
||||||
|
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
||||||
|
chosen = aug_rows[idx].get("content")
|
||||||
|
if chosen:
|
||||||
|
return str(chosen)
|
||||||
|
|
||||||
if dataset_ctx is None:
|
if dataset_ctx is None:
|
||||||
return None
|
return None
|
||||||
if isinstance(dataset_ctx, dict):
|
if isinstance(dataset_ctx, dict):
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ def test_language_arrow_schema_has_expected_fields():
|
|||||||
|
|
||||||
|
|
||||||
def test_style_registry_routes_columns():
|
def test_style_registry_routes_columns():
|
||||||
assert {"subtask", "plan", "memory", "motion"} == PERSISTENT_STYLES
|
assert {"subtask", "plan", "memory", "motion", "task_aug"} == PERSISTENT_STYLES
|
||||||
assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES
|
assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES
|
||||||
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
|
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
|
||||||
|
|
||||||
@@ -51,6 +51,7 @@ def test_style_registry_routes_columns():
|
|||||||
assert column_for_style("plan") == LANGUAGE_PERSISTENT
|
assert column_for_style("plan") == LANGUAGE_PERSISTENT
|
||||||
assert column_for_style("memory") == LANGUAGE_PERSISTENT
|
assert column_for_style("memory") == LANGUAGE_PERSISTENT
|
||||||
assert column_for_style("motion") == LANGUAGE_PERSISTENT
|
assert column_for_style("motion") == LANGUAGE_PERSISTENT
|
||||||
|
assert column_for_style("task_aug") == LANGUAGE_PERSISTENT
|
||||||
assert column_for_style("interjection") == LANGUAGE_EVENTS
|
assert column_for_style("interjection") == LANGUAGE_EVENTS
|
||||||
assert column_for_style("vqa") == LANGUAGE_EVENTS
|
assert column_for_style("vqa") == LANGUAGE_EVENTS
|
||||||
assert column_for_style("trace") == LANGUAGE_EVENTS
|
assert column_for_style("trace") == LANGUAGE_EVENTS
|
||||||
|
|||||||
@@ -289,6 +289,87 @@ def test_per_camera_blend_renders_both_views():
|
|||||||
assert rendered_wrist["messages"][1]["content"] == '{"count": 1}'
|
assert rendered_wrist["messages"][1]["content"] == '{"count": 1}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_task_picks_rephrasing_deterministically_per_sample():
|
||||||
|
rephrasings = [
|
||||||
|
persistent_row("user", "tidy the kitchen", "task_aug", 0.0),
|
||||||
|
persistent_row("user", "please clean up the kitchen", "task_aug", 0.0),
|
||||||
|
persistent_row("user", "kitchen needs tidying", "task_aug", 0.0),
|
||||||
|
persistent_row("user", "make the kitchen clean", "task_aug", 0.0),
|
||||||
|
]
|
||||||
|
recipe = TrainingRecipe(
|
||||||
|
messages=[
|
||||||
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||||
|
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# No explicit task override → resolver consults persistent rows.
|
||||||
|
seen: set[str] = set()
|
||||||
|
for sample_idx in range(64):
|
||||||
|
rendered = render_sample(
|
||||||
|
recipe=recipe,
|
||||||
|
persistent=rephrasings,
|
||||||
|
events=[],
|
||||||
|
t=0.0,
|
||||||
|
sample_idx=sample_idx,
|
||||||
|
dataset_ctx={"task": "canonical kitchen task"},
|
||||||
|
)
|
||||||
|
seen.add(rendered["messages"][0]["content"])
|
||||||
|
# Every rephrasing should be reachable across enough samples.
|
||||||
|
assert seen == {r["content"] for r in rephrasings}
|
||||||
|
# Same sample_idx → same pick (determinism).
|
||||||
|
a = render_sample(
|
||||||
|
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
|
||||||
|
dataset_ctx={"task": "canonical"},
|
||||||
|
)
|
||||||
|
b = render_sample(
|
||||||
|
recipe=recipe, persistent=rephrasings, events=[], t=0.0, sample_idx=42,
|
||||||
|
dataset_ctx={"task": "canonical"},
|
||||||
|
)
|
||||||
|
assert a["messages"][0]["content"] == b["messages"][0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_task_falls_back_to_canonical_without_rephrasings():
|
||||||
|
recipe = TrainingRecipe(
|
||||||
|
messages=[
|
||||||
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||||
|
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
rendered = render_sample(
|
||||||
|
recipe=recipe,
|
||||||
|
persistent=PERSISTENT, # no task_aug rows
|
||||||
|
events=[],
|
||||||
|
t=0.0,
|
||||||
|
sample_idx=0,
|
||||||
|
dataset_ctx={"task": "clean the kitchen"},
|
||||||
|
)
|
||||||
|
assert rendered["messages"][0]["content"] == "clean the kitchen"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_task_explicit_override_beats_rephrasings():
|
||||||
|
rephrasings = [
|
||||||
|
persistent_row("user", "rephrased one", "task_aug", 0.0),
|
||||||
|
persistent_row("user", "rephrased two", "task_aug", 0.0),
|
||||||
|
]
|
||||||
|
recipe = TrainingRecipe(
|
||||||
|
messages=[
|
||||||
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||||
|
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
rendered = render_sample(
|
||||||
|
recipe=recipe,
|
||||||
|
persistent=rephrasings,
|
||||||
|
events=[],
|
||||||
|
t=0.0,
|
||||||
|
sample_idx=0,
|
||||||
|
task="explicit override wins",
|
||||||
|
dataset_ctx={"task": "canonical"},
|
||||||
|
)
|
||||||
|
assert rendered["messages"][0]["content"] == "explicit override wins"
|
||||||
|
|
||||||
|
|
||||||
def test_canonical_recipe_can_render_low_level_branch():
|
def test_canonical_recipe_can_render_low_level_branch():
|
||||||
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
|
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
|
||||||
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
|
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
|
||||||
|
|||||||
Reference in New Issue
Block a user