mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
Address review: split persistent/event schemas, drop event timestamps
- recipe.py: derive _VALID_ROLES/_VALID_STREAMS from MessageRole/MessageStream Literals - dataset_metadata.py: keep CODEBASE_VERSION at v3.0 - language.py: remove RESERVED_STYLES; split arrow/feature schemas into persistent (with timestamp) and event (without timestamp); add docstrings - language_render.py: events use frame-row timestamp implicitly; no per-event timestamp filtering or sorting - converters.py: drop unused subtask_key passthrough - add docstrings to new public APIs (recipe, render_messages_processor, collate) - update tests for split schemas; revert uv.lock Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -19,7 +19,7 @@ from __future__ import annotations
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
MessageRole = Literal["user", "assistant", "system", "tool"]
|
||||
MessageStream = Literal["high_level", "low_level"]
|
||||
@@ -35,12 +35,21 @@ DEFAULT_BINDINGS = {
|
||||
}
|
||||
|
||||
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
_VALID_ROLES = {"user", "assistant", "system", "tool"}
|
||||
_VALID_STREAMS = {"high_level", "low_level"}
|
||||
_VALID_ROLES = frozenset(get_args(MessageRole))
|
||||
_VALID_STREAMS = frozenset(get_args(MessageStream))
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageTurn:
|
||||
"""A single chat-style turn in a recipe template.
|
||||
|
||||
``content`` may be a plain string, a list of HF-style multimodal blocks, or
|
||||
``None`` when ``tool_calls_from`` supplies tool-call payloads instead.
|
||||
``stream`` tags the turn for downstream filtering, ``target`` flags it as a
|
||||
training target, and ``if_present`` skips the turn when the named binding
|
||||
resolves to ``None``.
|
||||
"""
|
||||
|
||||
role: MessageRole
|
||||
content: str | list[dict[str, Any]] | None = None
|
||||
stream: MessageStream | None = None
|
||||
@@ -71,6 +80,13 @@ class MessageTurn:
|
||||
|
||||
@dataclass
|
||||
class TrainingRecipe:
|
||||
"""A recipe describing how to render training samples from language rows.
|
||||
|
||||
A recipe is either a *message recipe* (``messages`` plus optional
|
||||
``bindings``) or a *blend recipe* (``blend`` mapping names to weighted
|
||||
sub-recipes). ``weight`` is only meaningful inside a blend.
|
||||
"""
|
||||
|
||||
messages: list[MessageTurn] | None = None
|
||||
bindings: dict[str, str] | None = None
|
||||
blend: dict[str, TrainingRecipe] | None = None
|
||||
@@ -164,4 +180,5 @@ def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[
|
||||
|
||||
|
||||
def load_recipe(path: str | Path) -> TrainingRecipe:
|
||||
"""Load a :class:`TrainingRecipe` from a YAML file at ``path``."""
|
||||
return TrainingRecipe.from_yaml(path)
|
||||
|
||||
@@ -51,7 +51,7 @@ from .utils import (
|
||||
)
|
||||
from .video_utils import get_video_info
|
||||
|
||||
CODEBASE_VERSION = "v3.1"
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
|
||||
@@ -22,7 +22,12 @@ from PIL import Image as PILImage
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
||||
|
||||
from .language import is_language_column, language_column_feature
|
||||
from .language import (
|
||||
LANGUAGE_PERSISTENT,
|
||||
is_language_column,
|
||||
language_events_column_feature,
|
||||
language_persistent_column_feature,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
@@ -47,7 +52,11 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if is_language_column(key):
|
||||
hf_features[key] = language_column_feature()
|
||||
hf_features[key] = (
|
||||
language_persistent_column_feature()
|
||||
if key == LANGUAGE_PERSISTENT
|
||||
else language_events_column_feature()
|
||||
)
|
||||
elif ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
|
||||
@@ -24,12 +24,12 @@ import pyarrow as pa
|
||||
LANGUAGE_PERSISTENT = "language_persistent"
|
||||
LANGUAGE_EVENTS = "language_events"
|
||||
LANGUAGE_COLUMNS = (LANGUAGE_PERSISTENT, LANGUAGE_EVENTS)
|
||||
LANGUAGE_ROW_FIELDS = ("role", "content", "style", "timestamp", "tool_calls")
|
||||
PERSISTENT_ROW_FIELDS = ("role", "content", "style", "timestamp", "tool_calls")
|
||||
EVENT_ROW_FIELDS = ("role", "content", "style", "tool_calls")
|
||||
|
||||
CORE_STYLES = {"subtask", "plan", "memory", "interjection", "vqa"}
|
||||
EXTENDED_STYLES = set()
|
||||
RESERVED_STYLES = {"motion", "trace"}
|
||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES | RESERVED_STYLES
|
||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||
|
||||
PERSISTENT_STYLES = {"subtask", "plan", "memory"}
|
||||
EVENT_ONLY_STYLES = {"interjection", "vqa"}
|
||||
@@ -37,43 +37,90 @@ EVENT_ONLY_STYLES = {"interjection", "vqa"}
|
||||
LanguageColumn = Literal["language_persistent", "language_events"]
|
||||
|
||||
|
||||
def language_row_arrow_type() -> pa.StructType:
|
||||
json_type = pa.json_() if hasattr(pa, "json_") else pa.string()
|
||||
def _json_arrow_type() -> pa.DataType:
|
||||
return pa.json_() if hasattr(pa, "json_") else pa.string()
|
||||
|
||||
|
||||
def _json_feature() -> object:
|
||||
return datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
|
||||
|
||||
|
||||
def language_persistent_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single persistent language row.
|
||||
|
||||
Persistent rows carry their own ``timestamp`` because they represent a state
|
||||
that became active at a specific moment and remains active until superseded.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("timestamp", pa.float64(), nullable=False),
|
||||
pa.field("tool_calls", pa.list_(json_type), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_event_row_arrow_type() -> pa.StructType:
|
||||
"""Return the Arrow struct type for a single event language row.
|
||||
|
||||
Event rows have no ``timestamp`` field: each event is stored on the dataset
|
||||
row whose frame timestamp is the event's firing time.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def language_persistent_arrow_type() -> pa.ListType:
|
||||
return pa.list_(language_row_arrow_type())
|
||||
"""Return the Arrow list type for the ``language_persistent`` column."""
|
||||
return pa.list_(language_persistent_row_arrow_type())
|
||||
|
||||
|
||||
def language_events_arrow_type() -> pa.ListType:
|
||||
return pa.list_(language_row_arrow_type())
|
||||
"""Return the Arrow list type for the ``language_events`` column."""
|
||||
return pa.list_(language_event_row_arrow_type())
|
||||
|
||||
|
||||
def language_row_feature() -> dict[str, object]:
|
||||
json_feature = datasets.Json() if hasattr(datasets, "Json") else datasets.Value("string")
|
||||
def language_persistent_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for a persistent language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"timestamp": datasets.Value("float64"),
|
||||
"tool_calls": datasets.List(json_feature),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_column_feature() -> datasets.List:
|
||||
return datasets.List(language_row_feature())
|
||||
def language_event_row_feature() -> dict[str, object]:
|
||||
"""Return the HF ``datasets`` feature mapping for an event language row."""
|
||||
return {
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
|
||||
def language_persistent_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_persistent`` column."""
|
||||
return datasets.List(language_persistent_row_feature())
|
||||
|
||||
|
||||
def language_events_column_feature() -> datasets.List:
|
||||
"""Return the HF ``datasets`` feature for the ``language_events`` column."""
|
||||
return datasets.List(language_event_row_feature())
|
||||
|
||||
|
||||
def language_feature_info() -> dict[str, dict]:
|
||||
"""Return the ``info["features"]`` entries for both language columns."""
|
||||
return {
|
||||
LANGUAGE_PERSISTENT: {"dtype": "language", "shape": (1,), "names": None},
|
||||
LANGUAGE_EVENTS: {"dtype": "language", "shape": (1,), "names": None},
|
||||
@@ -81,16 +128,21 @@ def language_feature_info() -> dict[str, dict]:
|
||||
|
||||
|
||||
def is_language_column(key: str) -> bool:
|
||||
"""Return ``True`` if ``key`` is one of the dataset's language column names."""
|
||||
return key in LANGUAGE_COLUMNS
|
||||
|
||||
|
||||
def column_for_style(style: str | None) -> LanguageColumn:
|
||||
"""Map a language style to the column where rows of that style are stored.
|
||||
|
||||
Styles in :data:`PERSISTENT_STYLES` route to :data:`LANGUAGE_PERSISTENT`.
|
||||
Styles in :data:`EVENT_ONLY_STYLES` and the implicit ``None`` style route
|
||||
to :data:`LANGUAGE_EVENTS`.
|
||||
"""
|
||||
if style is None:
|
||||
return LANGUAGE_EVENTS
|
||||
if style in PERSISTENT_STYLES:
|
||||
return LANGUAGE_PERSISTENT
|
||||
if style in EVENT_ONLY_STYLES:
|
||||
return LANGUAGE_EVENTS
|
||||
if style in RESERVED_STYLES:
|
||||
raise ValueError(f"Style {style!r} is registered but has no storage column yet.")
|
||||
raise ValueError(f"Unknown language style: {style!r}")
|
||||
|
||||
@@ -47,6 +47,13 @@ def active_at(
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row of ``style`` that is active at time ``t``.
|
||||
|
||||
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
|
||||
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``
|
||||
selector. ``events`` is accepted for resolver-signature uniformity but is
|
||||
not consulted: only persistent styles are valid here.
|
||||
"""
|
||||
_validate_persistent_resolver("active_at", style)
|
||||
matches = _matching_rows(persistent, style=style, role=role, tool_name=tool_name)
|
||||
matches = [row for row in matches if _timestamp(row) <= t]
|
||||
@@ -62,14 +69,25 @@ def emitted_at(
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the row of ``style`` emitted at exactly time ``t``.
|
||||
|
||||
For persistent styles, this matches persistent rows whose own ``timestamp``
|
||||
equals ``t``. For event styles, the ``events`` list is assumed to come from
|
||||
the dataset row at frame ``t`` (event rows carry no timestamp of their own),
|
||||
so all matching event rows are considered emitted at ``t``.
|
||||
"""
|
||||
column = column_for_style(style)
|
||||
rows = persistent if column == LANGUAGE_PERSISTENT else events
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(rows, style=style, role=role, tool_name=tool_name)
|
||||
if _timestamp(row) == t
|
||||
]
|
||||
return _select_exact(matches, style=style, role=role, tool_name=tool_name)
|
||||
if column == LANGUAGE_PERSISTENT:
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name)
|
||||
if _timestamp(row) == t
|
||||
]
|
||||
return _select_one(
|
||||
matches, style=style, role=role, tool_name=tool_name, sort_key=_persistent_sort_key
|
||||
)
|
||||
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name)
|
||||
return _select_one(matches, style=style, role=role, tool_name=tool_name, sort_key=_event_sort_key)
|
||||
|
||||
|
||||
def nth_prev(
|
||||
@@ -82,6 +100,12 @@ def nth_prev(
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that was active ``offset`` steps before ``t``.
|
||||
|
||||
Walks back through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``) and returns the one ``offset``
|
||||
positions before the row active at ``t``. Only valid for persistent styles.
|
||||
"""
|
||||
return _nth_relative(
|
||||
t,
|
||||
persistent=persistent,
|
||||
@@ -103,6 +127,12 @@ def nth_next(
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the persistent row that becomes active ``offset`` steps after ``t``.
|
||||
|
||||
Walks forward through chronologically sorted persistent rows of ``style``
|
||||
(filtered by optional ``role``/``tool_name``) and returns the one ``offset``
|
||||
positions after the row active at ``t``. Only valid for persistent styles.
|
||||
"""
|
||||
return _nth_relative(
|
||||
t,
|
||||
persistent=persistent,
|
||||
@@ -124,6 +154,12 @@ def render_sample(
|
||||
task: str | None = None,
|
||||
dataset_ctx: Any | None = None,
|
||||
) -> RenderedMessages | None:
|
||||
"""Render the chat-style messages for a single dataset sample.
|
||||
|
||||
Resolves the recipe's bindings against ``persistent`` and ``events`` rows
|
||||
at frame timestamp ``t``, then expands the recipe's message templates.
|
||||
Returns ``None`` if the resolved sample contains no target message.
|
||||
"""
|
||||
persistent_rows = _normalize_rows(persistent or [])
|
||||
event_rows = _normalize_rows(events or [])
|
||||
selected_recipe = _select_recipe(recipe, sample_idx)
|
||||
@@ -335,7 +371,10 @@ def _nth_relative(
|
||||
if abs(offset) < 1:
|
||||
raise ValueError(f"{resolver_name} offset must be non-zero.")
|
||||
|
||||
rows = _sort_rows(_matching_rows(persistent, style=style, role=role, tool_name=tool_name))
|
||||
rows = sorted(
|
||||
_matching_rows(persistent, style=style, role=role, tool_name=tool_name),
|
||||
key=_persistent_sort_key,
|
||||
)
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
@@ -387,22 +426,24 @@ def _select_latest(
|
||||
) -> LanguageRow | None:
|
||||
if not rows:
|
||||
return None
|
||||
rows = _sort_rows(rows)
|
||||
rows = sorted(rows, key=_persistent_sort_key)
|
||||
latest_ts = _timestamp(rows[-1])
|
||||
return _select_exact(
|
||||
return _select_one(
|
||||
[row for row in rows if _timestamp(row) == latest_ts],
|
||||
style=style,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
sort_key=_persistent_sort_key,
|
||||
)
|
||||
|
||||
|
||||
def _select_exact(
|
||||
def _select_one(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
sort_key: Any,
|
||||
) -> LanguageRow | None:
|
||||
if not rows:
|
||||
return None
|
||||
@@ -410,11 +451,15 @@ def _select_exact(
|
||||
raise ValueError(
|
||||
f"Ambiguous resolver for style={style!r}; add role=... or tool_name=... to disambiguate."
|
||||
)
|
||||
return _sort_rows(rows)[0]
|
||||
return sorted(rows, key=sort_key)[0]
|
||||
|
||||
|
||||
def _sort_rows(rows: Sequence[LanguageRow]) -> list[LanguageRow]:
|
||||
return sorted(rows, key=lambda row: (_timestamp(row), row.get("style") or "", row.get("role") or ""))
|
||||
def _persistent_sort_key(row: LanguageRow) -> tuple[float, str, str]:
|
||||
return (_timestamp(row), row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _event_sort_key(row: LanguageRow) -> tuple[str, str]:
|
||||
return (row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _timestamp(row: LanguageRow) -> float:
|
||||
|
||||
@@ -167,7 +167,6 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
@@ -187,7 +186,6 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
**pad_keys,
|
||||
**task_key,
|
||||
**subtask_key,
|
||||
**index_key,
|
||||
**task_index_key,
|
||||
**episode_index_key,
|
||||
|
||||
@@ -31,10 +31,19 @@ from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="render_messages_processor")
|
||||
class RenderMessagesStep(ProcessorStep):
|
||||
"""Processor step that turns raw language columns into rendered chat messages.
|
||||
|
||||
Reads ``language_persistent`` and ``language_events`` from the transition's
|
||||
complementary data, renders them through ``recipe`` at the sample timestamp,
|
||||
and replaces the raw columns with the resulting ``messages`` /
|
||||
``message_streams`` / ``target_message_indices`` keys.
|
||||
"""
|
||||
|
||||
recipe: TrainingRecipe
|
||||
dataset_ctx: Any | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition | None:
|
||||
"""Render messages for a single transition; return ``None`` to drop it."""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
persistent = complementary_data.get(LANGUAGE_PERSISTENT) or []
|
||||
events = complementary_data.get(LANGUAGE_EVENTS) or []
|
||||
|
||||
@@ -26,6 +26,12 @@ _PYTHON_LIST_KEYS = {"messages", "message_streams", "target_message_indices"}
|
||||
|
||||
|
||||
def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | None:
|
||||
"""Collate function that preserves Python-list and language fields as lists.
|
||||
|
||||
Drops ``None`` samples (e.g. recipes that yielded no target message), keeps
|
||||
rendered-message and language fields as plain Python lists, and delegates
|
||||
every other key to PyTorch's ``default_collate``.
|
||||
"""
|
||||
batch = [sample for sample in batch if sample is not None]
|
||||
if not batch:
|
||||
return None
|
||||
|
||||
@@ -22,11 +22,14 @@ from lerobot.datasets.utils import DEFAULT_DATA_PATH
|
||||
|
||||
|
||||
def test_language_arrow_schema_has_expected_fields():
|
||||
row_type = language_persistent_arrow_type().value_type
|
||||
persistent_row_type = language_persistent_arrow_type().value_type
|
||||
event_row_type = language_events_arrow_type().value_type
|
||||
|
||||
assert isinstance(row_type, pa.StructType)
|
||||
assert row_type.names == ["role", "content", "style", "timestamp", "tool_calls"]
|
||||
assert language_events_arrow_type().value_type == row_type
|
||||
assert isinstance(persistent_row_type, pa.StructType)
|
||||
assert persistent_row_type.names == ["role", "content", "style", "timestamp", "tool_calls"]
|
||||
|
||||
assert isinstance(event_row_type, pa.StructType)
|
||||
assert event_row_type.names == ["role", "content", "style", "tool_calls"]
|
||||
|
||||
|
||||
def test_style_registry_routes_columns():
|
||||
@@ -72,7 +75,6 @@ def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot
|
||||
"role": "user",
|
||||
"content": "what is visible?",
|
||||
"style": "vqa",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
}
|
||||
data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
||||
|
||||
@@ -8,7 +8,7 @@ from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
||||
from lerobot.datasets.language_render import active_at, emitted_at, nth_next, nth_prev, render_sample
|
||||
|
||||
|
||||
def row(role, content, style, timestamp, tool_calls=None):
|
||||
def persistent_row(role, content, style, timestamp, tool_calls=None):
|
||||
return {
|
||||
"role": role,
|
||||
"content": content,
|
||||
@@ -18,22 +18,32 @@ def row(role, content, style, timestamp, tool_calls=None):
|
||||
}
|
||||
|
||||
|
||||
def event_row(role, content, style, tool_calls=None):
|
||||
return {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"style": style,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
|
||||
|
||||
PERSISTENT = [
|
||||
row("assistant", "plan 0", "plan", 0.0),
|
||||
row("assistant", "memory 0", "memory", 0.0),
|
||||
row("assistant", "subtask 0", "subtask", 0.0),
|
||||
row("assistant", "memory 1", "memory", 1.0),
|
||||
row("assistant", "subtask 1", "subtask", 1.0),
|
||||
persistent_row("assistant", "plan 0", "plan", 0.0),
|
||||
persistent_row("assistant", "memory 0", "memory", 0.0),
|
||||
persistent_row("assistant", "subtask 0", "subtask", 0.0),
|
||||
persistent_row("assistant", "memory 1", "memory", 1.0),
|
||||
persistent_row("assistant", "subtask 1", "subtask", 1.0),
|
||||
]
|
||||
EVENTS = [
|
||||
row("user", "what is visible?", "vqa", 1.0),
|
||||
row("assistant", '{"count": 2}', "vqa", 1.0),
|
||||
row("user", "skip wiping", "interjection", 2.0),
|
||||
row(
|
||||
EVENTS_AT_1 = [
|
||||
event_row("user", "what is visible?", "vqa"),
|
||||
event_row("assistant", '{"count": 2}', "vqa"),
|
||||
]
|
||||
EVENTS_AT_2 = [
|
||||
event_row("user", "skip wiping", "interjection"),
|
||||
event_row(
|
||||
"assistant",
|
||||
None,
|
||||
None,
|
||||
2.0,
|
||||
[{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}],
|
||||
),
|
||||
]
|
||||
@@ -42,9 +52,9 @@ EVENTS = [
|
||||
def test_resolver_temporal_semantics():
|
||||
assert active_at(0.5, persistent=PERSISTENT, style="subtask")["content"] == "subtask 0"
|
||||
assert active_at(1.0, persistent=PERSISTENT, style="subtask")["content"] == "subtask 1"
|
||||
assert emitted_at(0.5, persistent=PERSISTENT, events=EVENTS, style="vqa", role="assistant") is None
|
||||
assert emitted_at(0.5, persistent=PERSISTENT, events=[], style="vqa", role="assistant") is None
|
||||
assert (
|
||||
emitted_at(1.0, persistent=PERSISTENT, events=EVENTS, style="vqa", role="assistant")["content"]
|
||||
emitted_at(1.0, persistent=PERSISTENT, events=EVENTS_AT_1, style="vqa", role="assistant")["content"]
|
||||
== '{"count": 2}'
|
||||
)
|
||||
|
||||
@@ -87,7 +97,7 @@ def test_substitution_if_present_multimodal_and_tool_calls():
|
||||
rendered = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS,
|
||||
events=EVENTS_AT_2,
|
||||
t=2.0,
|
||||
sample_idx=0,
|
||||
task="clean kitchen",
|
||||
@@ -114,7 +124,9 @@ def test_exact_event_miss_returns_none_when_target_skips():
|
||||
]
|
||||
)
|
||||
|
||||
assert render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS, t=0.0, sample_idx=0) is None
|
||||
assert (
|
||||
render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=0) is None
|
||||
)
|
||||
|
||||
|
||||
def test_deterministic_blend_sampling():
|
||||
@@ -138,10 +150,10 @@ def test_deterministic_blend_sampling():
|
||||
)
|
||||
|
||||
first = render_sample(
|
||||
recipe=recipe, persistent=PERSISTENT, events=EVENTS, t=0.0, sample_idx=123, task="x"
|
||||
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
|
||||
)
|
||||
second = render_sample(
|
||||
recipe=recipe, persistent=PERSISTENT, events=EVENTS, t=0.0, sample_idx=123, task="x"
|
||||
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
|
||||
)
|
||||
assert first == second
|
||||
|
||||
|
||||
Reference in New Issue
Block a user