Address review: split persistent/event schemas, drop event timestamps

- recipe.py: derive _VALID_ROLES/_VALID_STREAMS from MessageRole/MessageStream Literals
- dataset_metadata.py: keep CODEBASE_VERSION at v3.0
- language.py: remove RESERVED_STYLES; split arrow/feature schemas into
  persistent (with timestamp) and event (without timestamp); add docstrings
- language_render.py: events use frame-row timestamp implicitly; no
  per-event timestamp filtering or sorting
- converters.py: drop unused subtask_key passthrough
- add docstrings to new public APIs (recipe, render_messages_processor, collate)
- update tests for split schemas; revert uv.lock

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-27 13:38:23 +02:00
parent 8833d735a1
commit 2b71221194
10 changed files with 210 additions and 60 deletions
+20 -3
View File
@@ -19,7 +19,7 @@ from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal
from typing import Any, Literal, get_args
MessageRole = Literal["user", "assistant", "system", "tool"]
MessageStream = Literal["high_level", "low_level"]
@@ -35,12 +35,21 @@ DEFAULT_BINDINGS = {
}
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
_VALID_ROLES = {"user", "assistant", "system", "tool"}
_VALID_STREAMS = {"high_level", "low_level"}
_VALID_ROLES = frozenset(get_args(MessageRole))
_VALID_STREAMS = frozenset(get_args(MessageStream))
@dataclass
class MessageTurn:
"""A single chat-style turn in a recipe template.
``content`` may be a plain string, a list of HF-style multimodal blocks, or
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
training target, and ``if_present`` skips the turn when the named binding
resolves to ``None``.
"""
role: MessageRole
content: str | list[dict[str, Any]] | None = None
stream: MessageStream | None = None
@@ -71,6 +80,13 @@ class MessageTurn:
@dataclass
class TrainingRecipe:
"""A recipe describing how to render training samples from language rows.
A recipe is either a *message recipe* (``messages`` plus optional
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
sub-recipes). ``weight`` is only meaningful inside a blend.
"""
messages: list[MessageTurn] | None = None
bindings: dict[str, str] | None = None
blend: dict[str, TrainingRecipe] | None = None
@@ -164,4 +180,5 @@ def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[
def load_recipe(path: str | Path) -> TrainingRecipe:
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
return TrainingRecipe.from_yaml(path)
+1 -1
View File
@@ -51,7 +51,7 @@ from .utils import (
)
from .video_utils import get_video_info
CODEBASE_VERSION = "v3.1"
CODEBASE_VERSION = "v3.0"
class LeRobotDatasetMetadata:
+11 -2
View File
@@ -22,7 +22,12 @@ from PIL import Image as PILImage
from lerobot.utils.constants import DEFAULT_FEATURES
from lerobot.utils.utils import is_valid_numpy_dtype_string
from .language import is_language_column, language_column_feature
from .language import (
LANGUAGE_PERSISTENT,
is_language_column,
language_events_column_feature,
language_persistent_column_feature,
)
from .utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -47,7 +52,11 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features = {}
for key, ft in features.items():
if is_language_column(key):
hf_features[key] = language_column_feature()
hf_features[key] = (
language_persistent_column_feature()
if key == LANGUAGE_PERSISTENT
else language_events_column_feature()
)
elif ft["dtype"] == "video":
continue
elif ft["dtype"] == "image":
+67 -15
View File
@@ -24,12 +24,12 @@ import pyarrow as pa
LANGUAGE_PERSISTENT = "language_persistent"
LANGUAGE_EVENTS = "language_events"
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
LANGUAGE_ROW_FIELDS = ("role", "content", "style", "timestamp", "tool_calls")
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "tool_calls")
EVENT_ROW_FIELDS = ("role", "content", "style", "tool_calls")
CORE_STYLES = {"subtask", "plan", "memory", "interjection", "vqa"}
EXTENDED_STYLES = set()
RESERVED_STYLES = {"motion", "trace"}
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES | RESERVED_STYLES
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
PERSISTENT_STYLES = {"subtask", "plan", "memory"}
EVENT_ONLY_STYLES = {"interjection", "vqa"}
@@ -37,43 +37,90 @@ EVENT_ONLY_STYLES = {"interjection", "vqa"}
LanguageColumn = Literal["language_persistent", "language_events"]
def language_row_arrow_type() -> pa.StructType:
json_type = pa.json_() if hasattr(pa, "json_") else pa.string()
def _json_arrow_type() -> pa.DataType:
return pa.json_() if hasattr(pa, "json_") else pa.string()
def _json_feature() -> object:
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
def language_persistent_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single persistent language row.
Persistent rows carry their own ``timestamp`` because they represent a state
that became active at a specific moment and remains active until superseded.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("timestamp", pa.float64(), nullable=False),
pa.field("tool_calls", pa.list_(json_type), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_event_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single event language row.
Event rows have no ``timestamp`` field: each event is stored on the dataset
row whose frame timestamp is the event's firing time.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_persistent_arrow_type() -> pa.ListType:
return pa.list_(language_row_arrow_type())
"""Return the Arrow list type for the ``language_persistent`` column."""
return pa.list_(language_persistent_row_arrow_type())
def language_events_arrow_type() -> pa.ListType:
return pa.list_(language_row_arrow_type())
"""Return the Arrow list type for the ``language_events`` column."""
return pa.list_(language_event_row_arrow_type())
def language_row_feature() -> dict[str, object]:
json_feature = datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
def language_persistent_row_feature() -> dict[str, object]:
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"timestamp": datasets.Value("float64"),
"tool_calls": datasets.List(json_feature),
"tool_calls": datasets.List(_json_feature()),
}
def language_column_feature() -> datasets.List:
return datasets.List(language_row_feature())
def language_event_row_feature() -> dict[str, object]:
"""Return the HF ``datasets`` feature mapping for an event language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()),
}
def language_persistent_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
return datasets.List(language_persistent_row_feature())
def language_events_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
return datasets.List(language_event_row_feature())
def language_feature_info() -> dict[str, dict]:
"""Return the ``info["features"]`` entries for both language columns."""
return {
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
@@ -81,16 +128,21 @@ def language_feature_info() -> dict[str, dict]:
def is_language_column(key: str) -> bool:
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
return key in LANGUAGE_COLUMNS
def column_for_style(style: str | None) -> LanguageColumn:
"""Map a language style to the column where rows of that style are stored.
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
to :data:`LANGUAGE_EVENTS`.
"""
if style is None:
return LANGUAGE_EVENTS
if style in PERSISTENT_STYLES:
return LANGUAGE_PERSISTENT
if style in EVENT_ONLY_STYLES:
return LANGUAGE_EVENTS
if style in RESERVED_STYLES:
raise ValueError(f"Style {style!r} is registered but has no storage column yet.")
raise ValueError(f"Unknown language style: {style!r}")
+59 -14
View File
@@ -47,6 +47,13 @@ def active_at(
role: str | None = None,
tool_name: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row of ``style`` that is active at time ``t``.
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``
selector. ``events`` is accepted for resolver-signature uniformity but is
not consulted: only persistent styles are valid here.
"""
_validate_persistent_resolver("active_at", style)
matches = _matching_rows(persistent, style=style, role=role, tool_name=tool_name)
matches = [row for row in matches if _timestamp(row) <= t]
@@ -62,14 +69,25 @@ def emitted_at(
role: str | None = None,
tool_name: str | None = None,
) -> LanguageRow | None:
"""Return the row of ``style`` emitted at exactly time ``t``.
For persistent styles, this matches persistent rows whose own ``timestamp``
equals ``t``. For event styles, the ``events`` list is assumed to come from
the dataset row at frame ``t`` (event rows carry no timestamp of their own),
so all matching event rows are considered emitted at ``t``.
"""
column = column_for_style(style)
rows = persistent if column == LANGUAGE_PERSISTENT else events
matches = [
row
for row in _matching_rows(rows, style=style, role=role, tool_name=tool_name)
if _timestamp(row) == t
]
return _select_exact(matches, style=style, role=role, tool_name=tool_name)
if column == LANGUAGE_PERSISTENT:
matches = [
row
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name)
if _timestamp(row) == t
]
return _select_one(
matches, style=style, role=role, tool_name=tool_name, sort_key=_persistent_sort_key
)
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name)
return _select_one(matches, style=style, role=role, tool_name=tool_name, sort_key=_event_sort_key)
def nth_prev(
@@ -82,6 +100,12 @@ def nth_prev(
role: str | None = None,
tool_name: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row that was active ``offset`` steps before ``t``.
Walks back through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``) and returns the one ``offset``
positions before the row active at ``t``. Only valid for persistent styles.
"""
return _nth_relative(
t,
persistent=persistent,
@@ -103,6 +127,12 @@ def nth_next(
role: str | None = None,
tool_name: str | None = None,
) -> LanguageRow | None:
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
Walks forward through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``) and returns the one ``offset``
positions after the row active at ``t``. Only valid for persistent styles.
"""
return _nth_relative(
t,
persistent=persistent,
@@ -124,6 +154,12 @@ def render_sample(
task: str | None = None,
dataset_ctx: Any | None = None,
) -> RenderedMessages | None:
"""Render the chat-style messages for a single dataset sample.
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
at frame timestamp ``t``, then expands the recipe's message templates.
Returns ``None`` if the resolved sample contains no target message.
"""
persistent_rows = _normalize_rows(persistent or [])
event_rows = _normalize_rows(events or [])
selected_recipe = _select_recipe(recipe, sample_idx)
@@ -335,7 +371,10 @@ def _nth_relative(
if abs(offset) < 1:
raise ValueError(f"{resolver_name} offset must be non-zero.")
rows = _sort_rows(_matching_rows(persistent, style=style, role=role, tool_name=tool_name))
rows = sorted(
_matching_rows(persistent, style=style, role=role, tool_name=tool_name),
key=_persistent_sort_key,
)
if not rows:
return None
@@ -387,22 +426,24 @@ def _select_latest(
) -> LanguageRow | None:
if not rows:
return None
rows = _sort_rows(rows)
rows = sorted(rows, key=_persistent_sort_key)
latest_ts = _timestamp(rows[-1])
return _select_exact(
return _select_one(
[row for row in rows if _timestamp(row) == latest_ts],
style=style,
role=role,
tool_name=tool_name,
sort_key=_persistent_sort_key,
)
def _select_exact(
def _select_one(
rows: Sequence[LanguageRow],
*,
style: str | None,
role: str | None,
tool_name: str | None,
sort_key: Any,
) -> LanguageRow | None:
if not rows:
return None
@@ -410,11 +451,15 @@ def _select_exact(
raise ValueError(
f"Ambiguous resolver for style={style!r}; add role=... or tool_name=... to disambiguate."
)
return _sort_rows(rows)[0]
return sorted(rows, key=sort_key)[0]
def _sort_rows(rows: Sequence[LanguageRow]) -> list[LanguageRow]:
return sorted(rows, key=lambda row: (_timestamp(row), row.get("style") or "", row.get("role") or ""))
def _persistent_sort_key(row: LanguageRow) -> tuple[float, str, str]:
return (_timestamp(row), row.get("style") or "", row.get("role") or "")
def _event_sort_key(row: LanguageRow) -> tuple[str, str]:
return (row.get("style") or "", row.get("role") or "")
def _timestamp(row: LanguageRow) -> float:
-2
View File
@@ -167,7 +167,6 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
@@ -187,7 +186,6 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
return {
**pad_keys,
**task_key,
**subtask_key,
**index_key,
**task_index_key,
**episode_index_key,
@@ -31,10 +31,19 @@ from .pipeline import ProcessorStep, ProcessorStepRegistry
@dataclass
@ProcessorStepRegistry.register(name="render_messages_processor")
class RenderMessagesStep(ProcessorStep):
"""Processor step that turns raw language columns into rendered chat messages.
Reads ``language_persistent`` and ``language_events`` from the transition's
complementary data, renders them through ``recipe`` at the sample timestamp,
and replaces the raw columns with the resulting ``messages`` /
``message_streams`` / ``target_message_indices`` keys.
"""
recipe: TrainingRecipe
dataset_ctx: Any | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
"""Render messages for a single transition; return ``None`` to drop it."""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
events = complementary_data.get(LANGUAGE_EVENTS) or []
+6
View File
@@ -26,6 +26,12 @@ _PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
"""Collate function that preserves Python-list and language fields as lists.
Drops ``None`` samples (e.g. recipes that yielded no target message), keeps
rendered-message and language fields as plain Python lists, and delegates
every other key to PyTorch's ``default_collate``.
"""
batch = [sample for sample in batch if sample is not None]
if not batch:
return None
+7 -5
View File
@@ -22,11 +22,14 @@ from lerobot.datasets.utils import DEFAULT_DATA_PATH
def test_language_arrow_schema_has_expected_fields():
row_type = language_persistent_arrow_type().value_type
persistent_row_type = language_persistent_arrow_type().value_type
event_row_type = language_events_arrow_type().value_type
assert isinstance(row_type, pa.StructType)
assert row_type.names == ["role", "content", "style", "timestamp", "tool_calls"]
assert language_events_arrow_type().value_type == row_type
assert isinstance(persistent_row_type, pa.StructType)
assert persistent_row_type.names == ["role", "content", "style", "timestamp", "tool_calls"]
assert isinstance(event_row_type, pa.StructType)
assert event_row_type.names == ["role", "content", "style", "tool_calls"]
def test_style_registry_routes_columns():
@@ -72,7 +75,6 @@ def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot
"role": "user",
"content": "what is visible?",
"style": "vqa",
"timestamp": 0.0,
"tool_calls": None,
}
data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
+30 -18
View File
@@ -8,7 +8,7 @@ from lerobot.configs.recipe import MessageTurn, TrainingRecipe
from lerobot.datasets.language_render import active_at, emitted_at, nth_next, nth_prev, render_sample
def row(role, content, style, timestamp, tool_calls=None):
def persistent_row(role, content, style, timestamp, tool_calls=None):
return {
"role": role,
"content": content,
@@ -18,22 +18,32 @@ def row(role, content, style, timestamp, tool_calls=None):
}
def event_row(role, content, style, tool_calls=None):
return {
"role": role,
"content": content,
"style": style,
"tool_calls": tool_calls,
}
PERSISTENT = [
row("assistant", "plan 0", "plan", 0.0),
row("assistant", "memory 0", "memory", 0.0),
row("assistant", "subtask 0", "subtask", 0.0),
row("assistant", "memory 1", "memory", 1.0),
row("assistant", "subtask 1", "subtask", 1.0),
persistent_row("assistant", "plan 0", "plan", 0.0),
persistent_row("assistant", "memory 0", "memory", 0.0),
persistent_row("assistant", "subtask 0", "subtask", 0.0),
persistent_row("assistant", "memory 1", "memory", 1.0),
persistent_row("assistant", "subtask 1", "subtask", 1.0),
]
EVENTS = [
row("user", "what is visible?", "vqa", 1.0),
row("assistant", '{"count": 2}', "vqa", 1.0),
row("user", "skip wiping", "interjection", 2.0),
row(
EVENTS_AT_1 = [
event_row("user", "what is visible?", "vqa"),
event_row("assistant", '{"count": 2}', "vqa"),
]
EVENTS_AT_2 = [
event_row("user", "skip wiping", "interjection"),
event_row(
"assistant",
None,
None,
2.0,
[{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}],
),
]
@@ -42,9 +52,9 @@ EVENTS = [
def test_resolver_temporal_semantics():
assert active_at(0.5, persistent=PERSISTENT, style="subtask")["content"] == "subtask 0"
assert active_at(1.0, persistent=PERSISTENT, style="subtask")["content"] == "subtask 1"
assert emitted_at(0.5, persistent=PERSISTENT, events=EVENTS, style="vqa", role="assistant") is None
assert emitted_at(0.5, persistent=PERSISTENT, events=[], style="vqa", role="assistant") is None
assert (
emitted_at(1.0, persistent=PERSISTENT, events=EVENTS, style="vqa", role="assistant")["content"]
emitted_at(1.0, persistent=PERSISTENT, events=EVENTS_AT_1, style="vqa", role="assistant")["content"]
== '{"count": 2}'
)
@@ -87,7 +97,7 @@ def test_substitution_if_present_multimodal_and_tool_calls():
rendered = render_sample(
recipe=recipe,
persistent=PERSISTENT,
events=EVENTS,
events=EVENTS_AT_2,
t=2.0,
sample_idx=0,
task="clean kitchen",
@@ -114,7 +124,9 @@ def test_exact_event_miss_returns_none_when_target_skips():
]
)
assert render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS, t=0.0, sample_idx=0) is None
assert (
render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=0) is None
)
def test_deterministic_blend_sampling():
@@ -138,10 +150,10 @@ def test_deterministic_blend_sampling():
)
first = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS, t=0.0, sample_idx=123, task="x"
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
)
second = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS, t=0.0, sample_idx=123, task="x"
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
)
assert first == second