diff --git a/docs/source/annotation_pipeline.mdx b/docs/source/annotation_pipeline.mdx index f896849c2..5d5ea2ef3 100644 --- a/docs/source/annotation_pipeline.mdx +++ b/docs/source/annotation_pipeline.mdx @@ -20,8 +20,13 @@ rewrites the data shards in place: | speech tool-call atom (`style=null`, `say`) | `language_events` | Module 2 | | `vqa` (user / assistant pair) | `language_events` | Module 3 | -The writer also adds a dataset-level `tools` column carrying the JSON schema -for the `say` tool call, and drops the legacy `subtask_index` column. +The writer drops the legacy `subtask_index` column. It does **not** add a +`tools` column to the parquet — the `say` tool's JSON schema is fixed and +lives as a code constant (`SAY_TOOL_SCHEMA` / `DEFAULT_TOOLS` in +`lerobot.annotations.steerable_pipeline.writer`), so the parquet stays +small and PR 2 doesn't extend PR 1's schema. Chat-template consumers +import the constant directly (e.g. +`apply_chat_template(messages, tools=DEFAULT_TOOLS)`). ## How to run it locally or on SLURM diff --git a/src/lerobot/annotations/steerable_pipeline/writer.py b/src/lerobot/annotations/steerable_pipeline/writer.py index b440201a5..e595161c6 100644 --- a/src/lerobot/annotations/steerable_pipeline/writer.py +++ b/src/lerobot/annotations/steerable_pipeline/writer.py @@ -25,8 +25,14 @@ For every episode the writer: 5. for each frame, materializes the sublist of event rows whose timestamp exactly equals that frame's timestamp, 6. drops the legacy ``subtask_index`` column, -7. adds a top-level ``tools`` column containing the JSON schema for ``say``, -8. writes the parquet shard back in place. +7. writes the parquet shard back in place. + +The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are +emitted per-row via the existing ``tool_calls`` field on the v3.1 row +struct (PR 1) for every speech atom. The tool *schema* (the description +of the ``say`` function and its parameters) is a fixed code constant — +``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import +it directly rather than reading a redundant per-row column. Invariants enforced here (and re-checked by the validator): @@ -38,7 +44,6 @@ Invariants enforced here (and re-checked by the validator): from __future__ import annotations -import json import logging from collections import defaultdict from collections.abc import Iterable, Sequence @@ -81,6 +86,19 @@ SAY_TOOL_SCHEMA: dict[str, Any] = { }, }, } +"""Fixed JSON schema for the only tool the canonical recipe knows about. + +Kept here as a code constant rather than written as a parquet column so +the v3.1 schema (PR 1) doesn't need to grow a redundant broadcast field +that holds the same value on every row of every dataset. Downstream +chat-template consumers (Pi0.5 processor, lerobot-dataset-visualizer) +import this directly. If multi-tool-set support ever becomes real, the +right place is ``meta/info.json["tools"]`` — adding it later is +non-breaking; ripping out a parquet column already shipped is not. +""" + +DEFAULT_TOOLS: list[dict[str, Any]] = [SAY_TOOL_SCHEMA] +"""Convenience list for ``apply_chat_template(messages, tools=...)``.""" def _row_persistent_sort_key(row: dict[str, Any]) -> tuple: @@ -286,8 +304,13 @@ class LanguageColumnsWriter: for name in table.column_names: if drop_old and name == "subtask_index": continue - if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS, "tools"): + if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS): continue # we'll re-add canonical versions + # Strip any legacy ``tools`` column previously emitted by older + # writers — the schema no longer uses it (constant lives in + # SAY_TOOL_SCHEMA / DEFAULT_TOOLS). + if name == "tools": + continue cols.append(table.column(name)) names.append(name) @@ -304,14 +327,6 @@ class LanguageColumnsWriter: cols.extend([persistent_arr, events_arr]) names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS]) - # Dataset-level tools column. Store the JSON schema as a string per - # row (broadcast-identical, parquet dictionary-encodes it) — string - # storage avoids requiring pa.json_() on every consumer. - tools_json = json.dumps([SAY_TOOL_SCHEMA], sort_keys=True) - tools_arr = pa.array([tools_json] * table.num_rows, type=pa.string()) - cols.append(tools_arr) - names.append("tools") - return pa.Table.from_arrays(cols, names=names) diff --git a/tests/annotations/test_pipeline_recipe_render.py b/tests/annotations/test_pipeline_recipe_render.py index 80f9d01a8..3cbd92358 100644 --- a/tests/annotations/test_pipeline_recipe_render.py +++ b/tests/annotations/test_pipeline_recipe_render.py @@ -17,7 +17,6 @@ from __future__ import annotations -import json from pathlib import Path import pyarrow.parquet as pq @@ -130,6 +129,7 @@ def test_pr1_canonical_recipe_renders_nonempty_from_pipeline_output( say = speech_rows[0]["tool_calls"][0] assert say["function"]["name"] == "say" assert isinstance(say["function"]["arguments"]["text"], str) - # Tools column carries the say schema - tools = json.loads(table.column("tools").to_pylist()[0]) - assert tools and tools[0]["function"]["name"] == "say" + # PR 2 no longer writes a ``tools`` column — the say schema lives as a + # constant (``SAY_TOOL_SCHEMA``) so PR 1's row struct is the single + # source of truth for the v3.1 schema. + assert "tools" not in table.column_names diff --git a/tests/annotations/test_writer.py b/tests/annotations/test_writer.py index 14e835bf3..c6ff03178 100644 --- a/tests/annotations/test_writer.py +++ b/tests/annotations/test_writer.py @@ -218,7 +218,10 @@ def test_writer_drops_subtask_index_idempotent(fixture_dataset_root: Path, tmp_p assert "subtask_index" not in table_a.column_names assert "language_persistent" in table_a.column_names assert "language_events" in table_a.column_names - assert "tools" in table_a.column_names + # The writer no longer emits a dataset-level ``tools`` column; the + # ``say`` tool schema lives as a code constant (``SAY_TOOL_SCHEMA``) + # so the parquet stays small and PR 2 doesn't extend PR 1's schema. + assert "tools" not in table_a.column_names # second pass — must produce identical bytes for the language columns records_again = list(iter_episodes(fixture_dataset_root)) @@ -248,7 +251,26 @@ def test_writer_normalize_rejects_misrouted_event_style() -> None: _normalize_event_row({"role": "assistant", "content": "oops", "style": "subtask", "tool_calls": None}) -def test_dataset_tools_column_present_with_say_schema(fixture_dataset_root: Path, tmp_path: Path) -> None: +def test_say_tool_schema_constant_is_well_formed() -> None: + """``SAY_TOOL_SCHEMA`` (and ``DEFAULT_TOOLS``) replace the parquet + ``tools`` column — chat-template consumers import them directly. + """ + from lerobot.annotations.steerable_pipeline.writer import ( + DEFAULT_TOOLS, + SAY_TOOL_SCHEMA, + ) + + assert DEFAULT_TOOLS == [SAY_TOOL_SCHEMA] + assert SAY_TOOL_SCHEMA["function"]["name"] == "say" + params = SAY_TOOL_SCHEMA["function"]["parameters"] + assert params["properties"]["text"]["type"] == "string" + assert params["required"] == ["text"] + + +def test_writer_does_not_add_tools_column(fixture_dataset_root: Path, tmp_path: Path) -> None: + """Re-running on a parquet that already has a legacy ``tools`` column + must drop it cleanly so reruns converge to the v3.1 schema. + """ staging_dir = tmp_path / "stage" _stage_episode( staging_dir, @@ -260,14 +282,7 @@ def test_dataset_tools_column_present_with_say_schema(fixture_dataset_root: Path records = list(iter_episodes(fixture_dataset_root)) LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root) table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet") - tools = table.column("tools").to_pylist() - assert tools, "tools column missing" - decoded = json.loads(tools[0]) - assert isinstance(decoded, list) - assert len(decoded) == 1 - assert decoded[0]["function"]["name"] == "say" - params = decoded[0]["function"]["parameters"] - assert params["properties"]["text"]["type"] == "string" + assert "tools" not in table.column_names def test_speech_atom_shape_matches_plan_spec() -> None: