refactor(annotate): drop dataset-level `tools` parquet column

PR 2 used to write a top-level ``tools`` column on every parquet shard
holding the JSON schema for the ``say`` tool, broadcast identically
across every row. That extends PR 1's schema for no real information
gain — the schema is a fixed code constant, parquet's RLE/dict encoding
collapses it on disk anyway, and HF/TRL chat-template consumers can
just import the constant directly.

PR 2 should fill in PR 1's existing schema, not add to it. So:

- ``writer.py``: stop emitting the ``tools`` column. Strip any legacy
  ``tools`` column from older shards on rerun so the schema converges to
  v3.1. ``SAY_TOOL_SCHEMA`` stays as a public constant (now joined by
  ``DEFAULT_TOOLS = [SAY_TOOL_SCHEMA]``); chat-template policies and the
  visualizer import them directly.
- ``test_writer.py``: replace the "tools column present" assertion with
  one that explicitly checks the column is absent, plus a new test
  asserting the constant's shape.
- ``test_pipeline_recipe_render.py``: drop the tools-column read; assert
  it's not present in the rewritten parquet.
- ``annotation_pipeline.mdx``: update the writer description to note the
  parquet stays small and the schema lives as a code constant.

If multi-tool-set support ever becomes real (datasets with different
tool inventories), the right home is ``meta/info.json["tools"]`` —
adding it later is non-breaking; ripping out a parquet column already
shipped is not.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-30 15:54:37 +02:00
parent 0f6e3230df
commit b71e10da6b
4 changed files with 63 additions and 28 deletions
+7 -2
View File
@@ -20,8 +20,13 @@ rewrites the data shards in place:
| speech tool-call atom (`style=null`, `say`) | `language_events` | Module 2 | | speech tool-call atom (`style=null`, `say`) | `language_events` | Module 2 |
| `vqa` (user / assistant pair) | `language_events` | Module 3 | | `vqa` (user / assistant pair) | `language_events` | Module 3 |
The writer also adds a dataset-level `tools` column carrying the JSON schema The writer drops the legacy `subtask_index` column. It does **not** add a
for the `say` tool call, and drops the legacy `subtask_index` column. `tools` column to the parquet — the `say` tool's JSON schema is fixed and
lives as a code constant (`SAY_TOOL_SCHEMA` / `DEFAULT_TOOLS` in
`lerobot.annotations.steerable_pipeline.writer`), so the parquet stays
small and PR 2 doesn't extend PR 1's schema. Chat-template consumers
import the constant directly (e.g.
`apply_chat_template(messages, tools=DEFAULT_TOOLS)`).
## How to run it locally or on SLURM ## How to run it locally or on SLURM
@@ -25,8 +25,14 @@ For every episode the writer:
5. for each frame, materializes the sublist of event rows whose timestamp 5. for each frame, materializes the sublist of event rows whose timestamp
exactly equals that frame's timestamp, exactly equals that frame's timestamp,
6. drops the legacy ``subtask_index`` column, 6. drops the legacy ``subtask_index`` column,
7. adds a top-level ``tools`` column containing the JSON schema for ``say``, 7. writes the parquet shard back in place.
8. writes the parquet shard back in place.
The writer does NOT add a dataset-level ``tools`` column. Tool *calls* are
emitted per-row via the existing ``tool_calls`` field on the v3.1 row
struct (PR 1) for every speech atom. The tool *schema* (the description
of the ``say`` function and its parameters) is a fixed code constant —
``SAY_TOOL_SCHEMA`` below — and downstream chat-template consumers import
it directly rather than reading a redundant per-row column.
Invariants enforced here (and re-checked by the validator): Invariants enforced here (and re-checked by the validator):
@@ -38,7 +44,6 @@ Invariants enforced here (and re-checked by the validator):
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
@@ -81,6 +86,19 @@ SAY_TOOL_SCHEMA: dict[str, Any] = {
}, },
}, },
} }
"""Fixed JSON schema for the only tool the canonical recipe knows about.
Kept here as a code constant rather than written as a parquet column so
the v3.1 schema (PR 1) doesn't need to grow a redundant broadcast field
that holds the same value on every row of every dataset. Downstream
chat-template consumers (Pi0.5 processor, lerobot-dataset-visualizer)
import this directly. If multi-tool-set support ever becomes real, the
right place is ``meta/info.json["tools"]`` — adding it later is
non-breaking; ripping out a parquet column already shipped is not.
"""
DEFAULT_TOOLS: list[dict[str, Any]] = [SAY_TOOL_SCHEMA]
"""Convenience list for ``apply_chat_template(messages, tools=...)``."""
def _row_persistent_sort_key(row: dict[str, Any]) -> tuple: def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
@@ -286,8 +304,13 @@ class LanguageColumnsWriter:
for name in table.column_names: for name in table.column_names:
if drop_old and name == "subtask_index": if drop_old and name == "subtask_index":
continue continue
if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS, "tools"): if name in (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS):
continue # we'll re-add canonical versions continue # we'll re-add canonical versions
# Strip any legacy ``tools`` column previously emitted by older
# writers — the schema no longer uses it (constant lives in
# SAY_TOOL_SCHEMA / DEFAULT_TOOLS).
if name == "tools":
continue
cols.append(table.column(name)) cols.append(table.column(name))
names.append(name) names.append(name)
@@ -304,14 +327,6 @@ class LanguageColumnsWriter:
cols.extend([persistent_arr, events_arr]) cols.extend([persistent_arr, events_arr])
names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS]) names.extend([LANGUAGE_PERSISTENT, LANGUAGE_EVENTS])
# Dataset-level tools column. Store the JSON schema as a string per
# row (broadcast-identical, parquet dictionary-encodes it) — string
# storage avoids requiring pa.json_() on every consumer.
tools_json = json.dumps([SAY_TOOL_SCHEMA], sort_keys=True)
tools_arr = pa.array([tools_json] * table.num_rows, type=pa.string())
cols.append(tools_arr)
names.append("tools")
return pa.Table.from_arrays(cols, names=names) return pa.Table.from_arrays(cols, names=names)
@@ -17,7 +17,6 @@
from __future__ import annotations from __future__ import annotations
import json
from pathlib import Path from pathlib import Path
import pyarrow.parquet as pq import pyarrow.parquet as pq
@@ -130,6 +129,7 @@ def test_pr1_canonical_recipe_renders_nonempty_from_pipeline_output(
say = speech_rows[0]["tool_calls"][0] say = speech_rows[0]["tool_calls"][0]
assert say["function"]["name"] == "say" assert say["function"]["name"] == "say"
assert isinstance(say["function"]["arguments"]["text"], str) assert isinstance(say["function"]["arguments"]["text"], str)
# Tools column carries the say schema # PR 2 no longer writes a ``tools`` column the say schema lives as a
tools = json.loads(table.column("tools").to_pylist()[0]) # constant (``SAY_TOOL_SCHEMA``) so PR 1's row struct is the single
assert tools and tools[0]["function"]["name"] == "say" # source of truth for the v3.1 schema.
assert "tools" not in table.column_names
+25 -10
View File
@@ -218,7 +218,10 @@ def test_writer_drops_subtask_index_idempotent(fixture_dataset_root: Path, tmp_p
assert "subtask_index" not in table_a.column_names assert "subtask_index" not in table_a.column_names
assert "language_persistent" in table_a.column_names assert "language_persistent" in table_a.column_names
assert "language_events" in table_a.column_names assert "language_events" in table_a.column_names
assert "tools" in table_a.column_names # The writer no longer emits a dataset-level ``tools`` column; the
# ``say`` tool schema lives as a code constant (``SAY_TOOL_SCHEMA``)
# so the parquet stays small and PR 2 doesn't extend PR 1's schema.
assert "tools" not in table_a.column_names
# second pass — must produce identical bytes for the language columns # second pass — must produce identical bytes for the language columns
records_again = list(iter_episodes(fixture_dataset_root)) records_again = list(iter_episodes(fixture_dataset_root))
@@ -248,7 +251,26 @@ def test_writer_normalize_rejects_misrouted_event_style() -> None:
_normalize_event_row({"role": "assistant", "content": "oops", "style": "subtask", "tool_calls": None}) _normalize_event_row({"role": "assistant", "content": "oops", "style": "subtask", "tool_calls": None})
def test_dataset_tools_column_present_with_say_schema(fixture_dataset_root: Path, tmp_path: Path) -> None: def test_say_tool_schema_constant_is_well_formed() -> None:
"""``SAY_TOOL_SCHEMA`` (and ``DEFAULT_TOOLS``) replace the parquet
``tools`` column — chat-template consumers import them directly.
"""
from lerobot.annotations.steerable_pipeline.writer import (
DEFAULT_TOOLS,
SAY_TOOL_SCHEMA,
)
assert DEFAULT_TOOLS == [SAY_TOOL_SCHEMA]
assert SAY_TOOL_SCHEMA["function"]["name"] == "say"
params = SAY_TOOL_SCHEMA["function"]["parameters"]
assert params["properties"]["text"]["type"] == "string"
assert params["required"] == ["text"]
def test_writer_does_not_add_tools_column(fixture_dataset_root: Path, tmp_path: Path) -> None:
"""Re-running on a parquet that already has a legacy ``tools`` column
must drop it cleanly so reruns converge to the v3.1 schema.
"""
staging_dir = tmp_path / "stage" staging_dir = tmp_path / "stage"
_stage_episode( _stage_episode(
staging_dir, staging_dir,
@@ -260,14 +282,7 @@ def test_dataset_tools_column_present_with_say_schema(fixture_dataset_root: Path
records = list(iter_episodes(fixture_dataset_root)) records = list(iter_episodes(fixture_dataset_root))
LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root) LanguageColumnsWriter().write_all(records, staging_dir, fixture_dataset_root)
table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet") table = pq.read_table(fixture_dataset_root / "data" / "chunk-000" / "file-000.parquet")
tools = table.column("tools").to_pylist() assert "tools" not in table.column_names
assert tools, "tools column missing"
decoded = json.loads(tools[0])
assert isinstance(decoded, list)
assert len(decoded) == 1
assert decoded[0]["function"]["name"] == "say"
params = decoded[0]["function"]["parameters"]
assert params["properties"]["text"]["type"] == "string"
def test_speech_atom_shape_matches_plan_spec() -> None: def test_speech_atom_shape_matches_plan_spec() -> None: