Address review: split persistent/event schemas, drop event timestamps

- recipe.py: derive _VALID_ROLES/_VALID_STREAMS from MessageRole/MessageStream Literals
- dataset_metadata.py: keep CODEBASE_VERSION at v3.0
- language.py: remove RESERVED_STYLES; split arrow/feature schemas into
  persistent (with timestamp) and event (without timestamp); add docstrings
- language_render.py: events use frame-row timestamp implicitly; no
  per-event timestamp filtering or sorting
- converters.py: drop unused subtask_key passthrough
- add docstrings to new public APIs (recipe, render_messages_processor, collate)
- update tests for split schemas; revert uv.lock

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-04-27 13:38:23 +02:00
parent 8833d735a1
commit 2b71221194
10 changed files with 210 additions and 60 deletions
+20 -3
View File
@@ -19,7 +19,7 @@ from __future__ import annotations
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Literal from typing import Any, Literal, get_args
MessageRole = Literal["user", "assistant", "system", "tool"] MessageRole = Literal["user", "assistant", "system", "tool"]
MessageStream = Literal["high_level", "low_level"] MessageStream = Literal["high_level", "low_level"]
@@ -35,12 +35,21 @@ DEFAULT_BINDINGS = {
} }
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}") _PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
_VALID_ROLES = {"user", "assistant", "system", "tool"} _VALID_ROLES = frozenset(get_args(MessageRole))
_VALID_STREAMS = {"high_level", "low_level"} _VALID_STREAMS = frozenset(get_args(MessageStream))
@dataclass @dataclass
class MessageTurn: class MessageTurn:
"""A single chat-style turn in a recipe template.
``content`` may be a plain string, a list of HF-style multimodal blocks, or
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
training target, and ``if_present`` skips the turn when the named binding
resolves to ``None``.
"""
role: MessageRole role: MessageRole
content: str | list[dict[str, Any]] | None = None content: str | list[dict[str, Any]] | None = None
stream: MessageStream | None = None stream: MessageStream | None = None
@@ -71,6 +80,13 @@ class MessageTurn:
@dataclass @dataclass
class TrainingRecipe: class TrainingRecipe:
"""A recipe describing how to render training samples from language rows.
A recipe is either a *message recipe* (``messages`` plus optional
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
sub-recipes). ``weight`` is only meaningful inside a blend.
"""
messages: list[MessageTurn] | None = None messages: list[MessageTurn] | None = None
bindings: dict[str, str] | None = None bindings: dict[str, str] | None = None
blend: dict[str, TrainingRecipe] | None = None blend: dict[str, TrainingRecipe] | None = None
@@ -164,4 +180,5 @@ def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[
def load_recipe(path: str | Path) -> TrainingRecipe: def load_recipe(path: str | Path) -> TrainingRecipe:
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
return TrainingRecipe.from_yaml(path) return TrainingRecipe.from_yaml(path)
+1 -1
View File
@@ -51,7 +51,7 @@ from .utils import (
) )
from .video_utils import get_video_info from .video_utils import get_video_info
CODEBASE_VERSION = "v3.1" CODEBASE_VERSION = "v3.0"
class LeRobotDatasetMetadata: class LeRobotDatasetMetadata:
+11 -2
View File
@@ -22,7 +22,12 @@ from PIL import Image as PILImage
from lerobot.utils.constants import DEFAULT_FEATURES from lerobot.utils.constants import DEFAULT_FEATURES
from lerobot.utils.utils import is_valid_numpy_dtype_string from lerobot.utils.utils import is_valid_numpy_dtype_string
from .language import is_language_column, language_column_feature from .language import (
LANGUAGE_PERSISTENT,
is_language_column,
language_events_column_feature,
language_persistent_column_feature,
)
from .utils import ( from .utils import (
DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_FILE_SIZE_IN_MB,
@@ -47,7 +52,11 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features = {} hf_features = {}
for key, ft in features.items(): for key, ft in features.items():
if is_language_column(key): if is_language_column(key):
hf_features[key] = language_column_feature() hf_features[key] = (
language_persistent_column_feature()
if key == LANGUAGE_PERSISTENT
else language_events_column_feature()
)
elif ft["dtype"] == "video": elif ft["dtype"] == "video":
continue continue
elif ft["dtype"] == "image": elif ft["dtype"] == "image":
+67 -15
View File
@@ -24,12 +24,12 @@ import pyarrow as pa
LANGUAGE_PERSISTENT = "language_persistent" LANGUAGE_PERSISTENT = "language_persistent"
LANGUAGE_EVENTS = "language_events" LANGUAGE_EVENTS = "language_events"
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS) LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
LANGUAGE_ROW_FIELDS = ("role", "content", "style", "timestamp", "tool_calls") PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "tool_calls")
EVENT_ROW_FIELDS = ("role", "content", "style", "tool_calls")
CORE_STYLES = {"subtask", "plan", "memory", "interjection", "vqa"} CORE_STYLES = {"subtask", "plan", "memory", "interjection", "vqa"}
EXTENDED_STYLES = set() EXTENDED_STYLES = set()
RESERVED_STYLES = {"motion", "trace"} STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES | RESERVED_STYLES
PERSISTENT_STYLES = {"subtask", "plan", "memory"} PERSISTENT_STYLES = {"subtask", "plan", "memory"}
EVENT_ONLY_STYLES = {"interjection", "vqa"} EVENT_ONLY_STYLES = {"interjection", "vqa"}
@@ -37,43 +37,90 @@ EVENT_ONLY_STYLES = {"interjection", "vqa"}
LanguageColumn = Literal["language_persistent", "language_events"] LanguageColumn = Literal["language_persistent", "language_events"]
def language_row_arrow_type() -> pa.StructType: def _json_arrow_type() -> pa.DataType:
json_type = pa.json_() if hasattr(pa, "json_") else pa.string() return pa.json_() if hasattr(pa, "json_") else pa.string()
def _json_feature() -> object:
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
def language_persistent_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single persistent language row.
Persistent rows carry their own ``timestamp`` because they represent a state
that became active at a specific moment and remains active until superseded.
"""
return pa.struct( return pa.struct(
[ [
pa.field("role", pa.string(), nullable=False), pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True), pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True), pa.field("style", pa.string(), nullable=True),
pa.field("timestamp", pa.float64(), nullable=False), pa.field("timestamp", pa.float64(), nullable=False),
pa.field("tool_calls", pa.list_(json_type), nullable=True), pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
)
def language_event_row_arrow_type() -> pa.StructType:
"""Return the Arrow struct type for a single event language row.
Event rows have no ``timestamp`` field: each event is stored on the dataset
row whose frame timestamp is the event's firing time.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
] ]
) )
def language_persistent_arrow_type() -> pa.ListType: def language_persistent_arrow_type() -> pa.ListType:
return pa.list_(language_row_arrow_type()) """Return the Arrow list type for the ``language_persistent`` column."""
return pa.list_(language_persistent_row_arrow_type())
def language_events_arrow_type() -> pa.ListType: def language_events_arrow_type() -> pa.ListType:
return pa.list_(language_row_arrow_type()) """Return the Arrow list type for the ``language_events`` column."""
return pa.list_(language_event_row_arrow_type())
def language_row_feature() -> dict[str, object]: def language_persistent_row_feature() -> dict[str, object]:
json_feature = datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string") """Return the HF ``datasets`` feature mapping for a persistent language row."""
return { return {
"role": datasets.Value("string"), "role": datasets.Value("string"),
"content": datasets.Value("string"), "content": datasets.Value("string"),
"style": datasets.Value("string"), "style": datasets.Value("string"),
"timestamp": datasets.Value("float64"), "timestamp": datasets.Value("float64"),
"tool_calls": datasets.List(json_feature), "tool_calls": datasets.List(_json_feature()),
} }
def language_column_feature() -> datasets.List: def language_event_row_feature() -> dict[str, object]:
return datasets.List(language_row_feature()) """Return the HF ``datasets`` feature mapping for an event language row."""
return {
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()),
}
def language_persistent_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
return datasets.List(language_persistent_row_feature())
def language_events_column_feature() -> datasets.List:
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
return datasets.List(language_event_row_feature())
def language_feature_info() -> dict[str, dict]: def language_feature_info() -> dict[str, dict]:
"""Return the ``info["features"]`` entries for both language columns."""
return { return {
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None}, LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None}, LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
@@ -81,16 +128,21 @@ def language_feature_info() -> dict[str, dict]:
def is_language_column(key: str) -> bool: def is_language_column(key: str) -> bool:
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
return key in LANGUAGE_COLUMNS return key in LANGUAGE_COLUMNS
def column_for_style(style: str | None) -> LanguageColumn: def column_for_style(style: str | None) -> LanguageColumn:
"""Map a language style to the column where rows of that style are stored.
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
to :data:`LANGUAGE_EVENTS`.
"""
if style is None: if style is None:
return LANGUAGE_EVENTS return LANGUAGE_EVENTS
if style in PERSISTENT_STYLES: if style in PERSISTENT_STYLES:
return LANGUAGE_PERSISTENT return LANGUAGE_PERSISTENT
if style in EVENT_ONLY_STYLES: if style in EVENT_ONLY_STYLES:
return LANGUAGE_EVENTS return LANGUAGE_EVENTS
if style in RESERVED_STYLES:
raise ValueError(f"Style {style!r} is registered but has no storage column yet.")
raise ValueError(f"Unknown language style: {style!r}") raise ValueError(f"Unknown language style: {style!r}")
+59 -14
View File
@@ -47,6 +47,13 @@ def active_at(
role: str | None = None, role: str | None = None,
tool_name: str | None = None, tool_name: str | None = None,
) -> LanguageRow | None: ) -> LanguageRow | None:
"""Return the persistent row of ``style`` that is active at time ``t``.
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``
selector. ``events`` is accepted for resolver-signature uniformity but is
not consulted: only persistent styles are valid here.
"""
_validate_persistent_resolver("active_at", style) _validate_persistent_resolver("active_at", style)
matches = _matching_rows(persistent, style=style, role=role, tool_name=tool_name) matches = _matching_rows(persistent, style=style, role=role, tool_name=tool_name)
matches = [row for row in matches if _timestamp(row) <= t] matches = [row for row in matches if _timestamp(row) <= t]
@@ -62,14 +69,25 @@ def emitted_at(
role: str | None = None, role: str | None = None,
tool_name: str | None = None, tool_name: str | None = None,
) -> LanguageRow | None: ) -> LanguageRow | None:
"""Return the row of ``style`` emitted at exactly time ``t``.
For persistent styles, this matches persistent rows whose own ``timestamp``
equals ``t``. For event styles, the ``events`` list is assumed to come from
the dataset row at frame ``t`` (event rows carry no timestamp of their own),
so all matching event rows are considered emitted at ``t``.
"""
column = column_for_style(style) column = column_for_style(style)
rows = persistent if column == LANGUAGE_PERSISTENT else events if column == LANGUAGE_PERSISTENT:
matches = [ matches = [
row row
for row in _matching_rows(rows, style=style, role=role, tool_name=tool_name) for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name)
if _timestamp(row) == t if _timestamp(row) == t
] ]
return _select_exact(matches, style=style, role=role, tool_name=tool_name) return _select_one(
matches, style=style, role=role, tool_name=tool_name, sort_key=_persistent_sort_key
)
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name)
return _select_one(matches, style=style, role=role, tool_name=tool_name, sort_key=_event_sort_key)
def nth_prev( def nth_prev(
@@ -82,6 +100,12 @@ def nth_prev(
role: str | None = None, role: str | None = None,
tool_name: str | None = None, tool_name: str | None = None,
) -> LanguageRow | None: ) -> LanguageRow | None:
"""Return the persistent row that was active ``offset`` steps before ``t``.
Walks back through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``) and returns the one ``offset``
positions before the row active at ``t``. Only valid for persistent styles.
"""
return _nth_relative( return _nth_relative(
t, t,
persistent=persistent, persistent=persistent,
@@ -103,6 +127,12 @@ def nth_next(
role: str | None = None, role: str | None = None,
tool_name: str | None = None, tool_name: str | None = None,
) -> LanguageRow | None: ) -> LanguageRow | None:
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
Walks forward through chronologically sorted persistent rows of ``style``
(filtered by optional ``role``/``tool_name``) and returns the one ``offset``
positions after the row active at ``t``. Only valid for persistent styles.
"""
return _nth_relative( return _nth_relative(
t, t,
persistent=persistent, persistent=persistent,
@@ -124,6 +154,12 @@ def render_sample(
task: str | None = None, task: str | None = None,
dataset_ctx: Any | None = None, dataset_ctx: Any | None = None,
) -> RenderedMessages | None: ) -> RenderedMessages | None:
"""Render the chat-style messages for a single dataset sample.
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
at frame timestamp ``t``, then expands the recipe's message templates.
Returns ``None`` if the resolved sample contains no target message.
"""
persistent_rows = _normalize_rows(persistent or []) persistent_rows = _normalize_rows(persistent or [])
event_rows = _normalize_rows(events or []) event_rows = _normalize_rows(events or [])
selected_recipe = _select_recipe(recipe, sample_idx) selected_recipe = _select_recipe(recipe, sample_idx)
@@ -335,7 +371,10 @@ def _nth_relative(
if abs(offset) < 1: if abs(offset) < 1:
raise ValueError(f"{resolver_name} offset must be non-zero.") raise ValueError(f"{resolver_name} offset must be non-zero.")
rows = _sort_rows(_matching_rows(persistent, style=style, role=role, tool_name=tool_name)) rows = sorted(
_matching_rows(persistent, style=style, role=role, tool_name=tool_name),
key=_persistent_sort_key,
)
if not rows: if not rows:
return None return None
@@ -387,22 +426,24 @@ def _select_latest(
) -> LanguageRow | None: ) -> LanguageRow | None:
if not rows: if not rows:
return None return None
rows = _sort_rows(rows) rows = sorted(rows, key=_persistent_sort_key)
latest_ts = _timestamp(rows[-1]) latest_ts = _timestamp(rows[-1])
return _select_exact( return _select_one(
[row for row in rows if _timestamp(row) == latest_ts], [row for row in rows if _timestamp(row) == latest_ts],
style=style, style=style,
role=role, role=role,
tool_name=tool_name, tool_name=tool_name,
sort_key=_persistent_sort_key,
) )
def _select_exact( def _select_one(
rows: Sequence[LanguageRow], rows: Sequence[LanguageRow],
*, *,
style: str | None, style: str | None,
role: str | None, role: str | None,
tool_name: str | None, tool_name: str | None,
sort_key: Any,
) -> LanguageRow | None: ) -> LanguageRow | None:
if not rows: if not rows:
return None return None
@@ -410,11 +451,15 @@ def _select_exact(
raise ValueError( raise ValueError(
f"Ambiguous resolver for style={style!r}; add role=... or tool_name=... to disambiguate." f"Ambiguous resolver for style={style!r}; add role=... or tool_name=... to disambiguate."
) )
return _sort_rows(rows)[0] return sorted(rows, key=sort_key)[0]
def _sort_rows(rows: Sequence[LanguageRow]) -> list[LanguageRow]: def _persistent_sort_key(row: LanguageRow) -> tuple[float, str, str]:
return sorted(rows, key=lambda row: (_timestamp(row), row.get("style") or "", row.get("role") or "")) return (_timestamp(row), row.get("style") or "", row.get("role") or "")
def _event_sort_key(row: LanguageRow) -> tuple[str, str]:
return (row.get("style") or "", row.get("role") or "")
def _timestamp(row: LanguageRow) -> float: def _timestamp(row: LanguageRow) -> float:
-2
View File
@@ -167,7 +167,6 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
""" """
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {} task_key = {"task": batch["task"]} if "task" in batch else {}
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {} index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {} episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
@@ -187,7 +186,6 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
return { return {
**pad_keys, **pad_keys,
**task_key, **task_key,
**subtask_key,
**index_key, **index_key,
**task_index_key, **task_index_key,
**episode_index_key, **episode_index_key,
@@ -31,10 +31,19 @@ from .pipeline import ProcessorStep, ProcessorStepRegistry
@dataclass @dataclass
@ProcessorStepRegistry.register(name="render_messages_processor") @ProcessorStepRegistry.register(name="render_messages_processor")
class RenderMessagesStep(ProcessorStep): class RenderMessagesStep(ProcessorStep):
"""Processor step that turns raw language columns into rendered chat messages.
Reads ``language_persistent`` and ``language_events`` from the transition's
complementary data, renders them through ``recipe`` at the sample timestamp,
and replaces the raw columns with the resulting ``messages`` /
``message_streams`` / ``target_message_indices`` keys.
"""
recipe: TrainingRecipe recipe: TrainingRecipe
dataset_ctx: Any | None = None dataset_ctx: Any | None = None
def __call__(self, transition: EnvTransition) -> EnvTransition | None: def __call__(self, transition: EnvTransition) -> EnvTransition | None:
"""Render messages for a single transition; return ``None`` to drop it."""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or [] persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
events = complementary_data.get(LANGUAGE_EVENTS) or [] events = complementary_data.get(LANGUAGE_EVENTS) or []
+6
View File
@@ -26,6 +26,12 @@ _PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None: def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
"""Collate function that preserves Python-list and language fields as lists.
Drops ``None`` samples (e.g. recipes that yielded no target message), keeps
rendered-message and language fields as plain Python lists, and delegates
every other key to PyTorch's ``default_collate``.
"""
batch = [sample for sample in batch if sample is not None] batch = [sample for sample in batch if sample is not None]
if not batch: if not batch:
return None return None
+7 -5
View File
@@ -22,11 +22,14 @@ from lerobot.datasets.utils import DEFAULT_DATA_PATH
def test_language_arrow_schema_has_expected_fields(): def test_language_arrow_schema_has_expected_fields():
row_type = language_persistent_arrow_type().value_type persistent_row_type = language_persistent_arrow_type().value_type
event_row_type = language_events_arrow_type().value_type
assert isinstance(row_type, pa.StructType) assert isinstance(persistent_row_type, pa.StructType)
assert row_type.names == ["role", "content", "style", "timestamp", "tool_calls"] assert persistent_row_type.names == ["role", "content", "style", "timestamp", "tool_calls"]
assert language_events_arrow_type().value_type == row_type
assert isinstance(event_row_type, pa.StructType)
assert event_row_type.names == ["role", "content", "style", "tool_calls"]
def test_style_registry_routes_columns(): def test_style_registry_routes_columns():
@@ -72,7 +75,6 @@ def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot
"role": "user", "role": "user",
"content": "what is visible?", "content": "what is visible?",
"style": "vqa", "style": "vqa",
"timestamp": 0.0,
"tool_calls": None, "tool_calls": None,
} }
data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0) data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
+30 -18
View File
@@ -8,7 +8,7 @@ from lerobot.configs.recipe import MessageTurn, TrainingRecipe
from lerobot.datasets.language_render import active_at, emitted_at, nth_next, nth_prev, render_sample from lerobot.datasets.language_render import active_at, emitted_at, nth_next, nth_prev, render_sample
def row(role, content, style, timestamp, tool_calls=None): def persistent_row(role, content, style, timestamp, tool_calls=None):
return { return {
"role": role, "role": role,
"content": content, "content": content,
@@ -18,22 +18,32 @@ def row(role, content, style, timestamp, tool_calls=None):
} }
def event_row(role, content, style, tool_calls=None):
return {
"role": role,
"content": content,
"style": style,
"tool_calls": tool_calls,
}
PERSISTENT = [ PERSISTENT = [
row("assistant", "plan 0", "plan", 0.0), persistent_row("assistant", "plan 0", "plan", 0.0),
row("assistant", "memory 0", "memory", 0.0), persistent_row("assistant", "memory 0", "memory", 0.0),
row("assistant", "subtask 0", "subtask", 0.0), persistent_row("assistant", "subtask 0", "subtask", 0.0),
row("assistant", "memory 1", "memory", 1.0), persistent_row("assistant", "memory 1", "memory", 1.0),
row("assistant", "subtask 1", "subtask", 1.0), persistent_row("assistant", "subtask 1", "subtask", 1.0),
] ]
EVENTS = [ EVENTS_AT_1 = [
row("user", "what is visible?", "vqa", 1.0), event_row("user", "what is visible?", "vqa"),
row("assistant", '{"count": 2}', "vqa", 1.0), event_row("assistant", '{"count": 2}', "vqa"),
row("user", "skip wiping", "interjection", 2.0), ]
row( EVENTS_AT_2 = [
event_row("user", "skip wiping", "interjection"),
event_row(
"assistant", "assistant",
None, None,
None, None,
2.0,
[{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}], [{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}],
), ),
] ]
@@ -42,9 +52,9 @@ EVENTS = [
def test_resolver_temporal_semantics(): def test_resolver_temporal_semantics():
assert active_at(0.5, persistent=PERSISTENT, style="subtask")["content"] == "subtask 0" assert active_at(0.5, persistent=PERSISTENT, style="subtask")["content"] == "subtask 0"
assert active_at(1.0, persistent=PERSISTENT, style="subtask")["content"] == "subtask 1" assert active_at(1.0, persistent=PERSISTENT, style="subtask")["content"] == "subtask 1"
assert emitted_at(0.5, persistent=PERSISTENT, events=EVENTS, style="vqa", role="assistant") is None assert emitted_at(0.5, persistent=PERSISTENT, events=[], style="vqa", role="assistant") is None
assert ( assert (
emitted_at(1.0, persistent=PERSISTENT, events=EVENTS, style="vqa", role="assistant")["content"] emitted_at(1.0, persistent=PERSISTENT, events=EVENTS_AT_1, style="vqa", role="assistant")["content"]
== '{"count": 2}' == '{"count": 2}'
) )
@@ -87,7 +97,7 @@ def test_substitution_if_present_multimodal_and_tool_calls():
rendered = render_sample( rendered = render_sample(
recipe=recipe, recipe=recipe,
persistent=PERSISTENT, persistent=PERSISTENT,
events=EVENTS, events=EVENTS_AT_2,
t=2.0, t=2.0,
sample_idx=0, sample_idx=0,
task="clean kitchen", task="clean kitchen",
@@ -114,7 +124,9 @@ def test_exact_event_miss_returns_none_when_target_skips():
] ]
) )
assert render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS, t=0.0, sample_idx=0) is None assert (
render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=0) is None
)
def test_deterministic_blend_sampling(): def test_deterministic_blend_sampling():
@@ -138,10 +150,10 @@ def test_deterministic_blend_sampling():
) )
first = render_sample( first = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS, t=0.0, sample_idx=123, task="x" recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
) )
second = render_sample( second = render_sample(
recipe=recipe, persistent=PERSISTENT, events=EVENTS, t=0.0, sample_idx=123, task="x" recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
) )
assert first == second assert first == second