mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
refactor(language): unify resolver dispatch and prune redundant test scaffolding
* Drop the unused `events` kwarg from `active_at`/`nth_prev`/`nth_next`; only `emitted_at` actually consults events. The dispatcher in `_resolve_spec` now passes events conditionally. * Replace the dual `_persistent_sort_key`/`_event_sort_key` pair with a single `_row_sort_key` and drop the `sort_key` parameter from `_select_one`. Event rows lack `timestamp` (it is implicit in the frame) and now default to `0.0` for sort purposes — the `(style, role)` tiebreaker is unchanged. * Inline `_select_latest` into `active_at` (its only caller). * Collapse `emitted_at`'s dual-branch into one `_select_one` call. * Tighten `_validate_persistent_resolver` to a single `column_for_style(style) != LANGUAGE_PERSISTENT` check. * Parameterize `test_per_camera_blend_renders_both_views` over the two cameras and factor the sub-recipe builder into `_vqa_subrecipe` so the test no longer hand-rolls two near-identical recipe blocks. Net -98 LOC; behavior, public resolver names, and test expectations unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -24,12 +24,7 @@ from typing import Any
|
||||
|
||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe
|
||||
|
||||
from .language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
column_for_style,
|
||||
)
|
||||
from .language import LANGUAGE_PERSISTENT, column_for_style
|
||||
|
||||
LanguageRow = dict[str, Any]
|
||||
RenderedMessages = dict[str, list[Any]]
|
||||
@@ -42,7 +37,6 @@ def active_at(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow] | None = None,
|
||||
style: str | None = None,
|
||||
role: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
@@ -52,16 +46,23 @@ def active_at(
|
||||
|
||||
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
|
||||
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
|
||||
``camera`` selector. ``events`` is accepted for resolver-signature
|
||||
uniformity but is not consulted: only persistent styles are valid here.
|
||||
``camera`` selector. Only valid for persistent styles.
|
||||
"""
|
||||
_validate_persistent_resolver("active_at", style)
|
||||
matches = _matching_rows(
|
||||
persistent, style=style, role=role, tool_name=tool_name, camera=camera
|
||||
)
|
||||
matches = [row for row in matches if _timestamp(row) <= t]
|
||||
return _select_latest(
|
||||
matches, style=style, role=role, tool_name=tool_name, camera=camera
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if _timestamp(row) <= t
|
||||
]
|
||||
if not matches:
|
||||
return None
|
||||
latest_ts = max(_timestamp(row) for row in matches)
|
||||
return _select_one(
|
||||
[row for row in matches if _timestamp(row) == latest_ts],
|
||||
style=style,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
)
|
||||
|
||||
|
||||
@@ -84,41 +85,21 @@ def emitted_at(
|
||||
filters by the row's ``camera`` field — required to disambiguate when
|
||||
multiple view-dependent rows share ``(t, role)`` across cameras.
|
||||
"""
|
||||
column = column_for_style(style)
|
||||
if column == LANGUAGE_PERSISTENT:
|
||||
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||
matches = [
|
||||
row
|
||||
for row in _matching_rows(
|
||||
persistent, style=style, role=role, tool_name=tool_name, camera=camera
|
||||
)
|
||||
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
if _timestamp(row) == t
|
||||
]
|
||||
return _select_one(
|
||||
matches,
|
||||
style=style,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
sort_key=_persistent_sort_key,
|
||||
)
|
||||
matches = _matching_rows(
|
||||
events, style=style, role=role, tool_name=tool_name, camera=camera
|
||||
)
|
||||
return _select_one(
|
||||
matches,
|
||||
style=style,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
sort_key=_event_sort_key,
|
||||
)
|
||||
else:
|
||||
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
return _select_one(matches, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||
|
||||
|
||||
def nth_prev(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow] | None = None,
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
@@ -132,23 +113,13 @@ def nth_prev(
|
||||
one ``offset`` positions before the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative(
|
||||
t,
|
||||
persistent=persistent,
|
||||
style=style,
|
||||
offset=-offset,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
resolver_name="nth_prev",
|
||||
)
|
||||
return _nth_relative("nth_prev", t, persistent, style, -offset, role, tool_name, camera)
|
||||
|
||||
|
||||
def nth_next(
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
events: Sequence[LanguageRow] | None = None,
|
||||
style: str | None = None,
|
||||
offset: int = 1,
|
||||
role: str | None = None,
|
||||
@@ -162,16 +133,7 @@ def nth_next(
|
||||
one ``offset`` positions after the row active at ``t``. Only valid for
|
||||
persistent styles.
|
||||
"""
|
||||
return _nth_relative(
|
||||
t,
|
||||
persistent=persistent,
|
||||
style=style,
|
||||
offset=offset,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
resolver_name="nth_next",
|
||||
)
|
||||
return _nth_relative("nth_next", t, persistent, style, offset, role, tool_name, camera)
|
||||
|
||||
|
||||
def render_sample(
|
||||
@@ -239,9 +201,7 @@ def _resolve_bindings(
|
||||
) -> dict[str, LanguageRow | str | None]:
|
||||
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
||||
bindings: dict[str, LanguageRow | str | None] = {
|
||||
"task": _resolve_task(
|
||||
task, dataset_ctx, persistent=persistent, sample_idx=sample_idx
|
||||
),
|
||||
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
|
||||
}
|
||||
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
||||
for name, spec in specs.items():
|
||||
@@ -275,18 +235,12 @@ def _resolve_task(
|
||||
if task is not None:
|
||||
return task
|
||||
|
||||
aug_rows = [
|
||||
r
|
||||
for r in persistent
|
||||
if r.get("style") == "task_aug" and r.get("role") == "user"
|
||||
]
|
||||
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
|
||||
if aug_rows:
|
||||
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
||||
# rotation is reproducible across runs (Python's built-in ``hash``
|
||||
# is process-randomized).
|
||||
digest = hashlib.blake2b(
|
||||
f"task_aug:{sample_idx}".encode(), digest_size=8
|
||||
).digest()
|
||||
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
|
||||
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
||||
chosen = aug_rows[idx].get("content")
|
||||
if chosen:
|
||||
@@ -314,15 +268,15 @@ def _resolve_spec(
|
||||
kwargs = _parse_resolver_args(match.group("args"))
|
||||
kwargs.pop("t_arg", None)
|
||||
|
||||
resolvers = {
|
||||
"active_at": active_at,
|
||||
"emitted_at": emitted_at,
|
||||
"nth_prev": nth_prev,
|
||||
"nth_next": nth_next,
|
||||
}
|
||||
if name not in resolvers:
|
||||
raise ValueError(f"Unknown language resolver: {name!r}")
|
||||
return resolvers[name](t, persistent=persistent, events=events, **kwargs)
|
||||
if name == "emitted_at":
|
||||
return emitted_at(t, persistent=persistent, events=events, **kwargs)
|
||||
if name == "active_at":
|
||||
return active_at(t, persistent=persistent, **kwargs)
|
||||
if name == "nth_prev":
|
||||
return nth_prev(t, persistent=persistent, **kwargs)
|
||||
if name == "nth_next":
|
||||
return nth_next(t, persistent=persistent, **kwargs)
|
||||
raise ValueError(f"Unknown language resolver: {name!r}")
|
||||
|
||||
|
||||
def _parse_resolver_args(args: str) -> dict[str, Any]:
|
||||
@@ -444,24 +398,23 @@ def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||
|
||||
|
||||
def _nth_relative(
|
||||
name: str,
|
||||
t: float,
|
||||
*,
|
||||
persistent: Sequence[LanguageRow],
|
||||
style: str | None,
|
||||
offset: int,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
resolver_name: str,
|
||||
) -> LanguageRow | None:
|
||||
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
|
||||
_validate_persistent_resolver(resolver_name, style)
|
||||
_validate_persistent_resolver(name, style)
|
||||
if abs(offset) < 1:
|
||||
raise ValueError(f"{resolver_name} offset must be non-zero.")
|
||||
raise ValueError(f"{name} offset must be non-zero.")
|
||||
|
||||
rows = sorted(
|
||||
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
|
||||
key=_persistent_sort_key,
|
||||
key=_row_sort_key,
|
||||
)
|
||||
if not rows:
|
||||
return None
|
||||
@@ -480,14 +433,12 @@ def _nth_relative(
|
||||
return rows[target_idx]
|
||||
|
||||
|
||||
def _validate_persistent_resolver(resolver_name: str, style: str | None) -> None:
|
||||
def _validate_persistent_resolver(name: str, style: str | None) -> None:
|
||||
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
|
||||
if style is None:
|
||||
raise ValueError(f"{resolver_name} requires a persistent style.")
|
||||
if style in EVENT_ONLY_STYLES:
|
||||
raise ValueError(f"{resolver_name} cannot be used with event-only style {style!r}.")
|
||||
if style not in PERSISTENT_STYLES:
|
||||
column_for_style(style)
|
||||
raise ValueError(f"{name} requires a persistent style.")
|
||||
if column_for_style(style) != LANGUAGE_PERSISTENT:
|
||||
raise ValueError(f"{name} cannot be used with event-only style {style!r}.")
|
||||
|
||||
|
||||
def _matching_rows(
|
||||
@@ -509,29 +460,6 @@ def _matching_rows(
|
||||
]
|
||||
|
||||
|
||||
def _select_latest(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
style: str | None,
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the row tied for the latest ``timestamp`` (disambiguated by selectors)."""
|
||||
if not rows:
|
||||
return None
|
||||
rows = sorted(rows, key=_persistent_sort_key)
|
||||
latest_ts = _timestamp(rows[-1])
|
||||
return _select_one(
|
||||
[row for row in rows if _timestamp(row) == latest_ts],
|
||||
style=style,
|
||||
role=role,
|
||||
tool_name=tool_name,
|
||||
camera=camera,
|
||||
sort_key=_persistent_sort_key,
|
||||
)
|
||||
|
||||
|
||||
def _select_one(
|
||||
rows: Sequence[LanguageRow],
|
||||
*,
|
||||
@@ -539,9 +467,13 @@ def _select_one(
|
||||
role: str | None,
|
||||
tool_name: str | None,
|
||||
camera: str | None,
|
||||
sort_key: Any,
|
||||
) -> LanguageRow | None:
|
||||
"""Return the single matching row, or raise if the selectors are ambiguous."""
|
||||
"""Return the single matching row, or raise if the selectors are ambiguous.
|
||||
|
||||
Ties are broken deterministically by ``_row_sort_key`` so that
|
||||
multiple rows with identical ``(style, role, tool_name, camera)`` still
|
||||
resolve to a stable choice.
|
||||
"""
|
||||
if not rows:
|
||||
return None
|
||||
if len(rows) > 1 and role is None and tool_name is None and camera is None:
|
||||
@@ -549,17 +481,21 @@ def _select_one(
|
||||
f"Ambiguous resolver for style={style!r}; add role=..., tool_name=..., "
|
||||
f"or camera=... to disambiguate."
|
||||
)
|
||||
return sorted(rows, key=sort_key)[0]
|
||||
return sorted(rows, key=_row_sort_key)[0]
|
||||
|
||||
|
||||
def _persistent_sort_key(row: LanguageRow) -> tuple[float, str, str]:
|
||||
"""Sort key for persistent rows: ``(timestamp, style, role)``."""
|
||||
return (_timestamp(row), row.get("style") or "", row.get("role") or "")
|
||||
def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]:
|
||||
"""Stable sort key for both persistent and event rows.
|
||||
|
||||
|
||||
def _event_sort_key(row: LanguageRow) -> tuple[str, str]:
|
||||
"""Sort key for event rows: ``(style, role)`` (timestamp is implicit in the frame)."""
|
||||
return (row.get("style") or "", row.get("role") or "")
|
||||
Event rows lack ``timestamp`` (it is implicit in the frame), so default
|
||||
to ``0.0`` — within a single frame all event rows share the same sort
|
||||
bucket and are tiebroken by ``(style, role)``.
|
||||
"""
|
||||
timestamp = row.get("timestamp")
|
||||
ts = (
|
||||
float(timestamp.item() if hasattr(timestamp, "item") else timestamp) if timestamp is not None else 0.0
|
||||
)
|
||||
return (ts, row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _timestamp(row: LanguageRow) -> float:
|
||||
|
||||
@@ -199,84 +199,50 @@ def test_emitted_at_raises_on_ambiguous_per_camera_vqa():
|
||||
)
|
||||
|
||||
|
||||
def test_per_camera_blend_renders_both_views():
|
||||
recipe = TrainingRecipe(
|
||||
blend={
|
||||
"top": TrainingRecipe(
|
||||
weight=1.0,
|
||||
bindings={
|
||||
"vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.top)"),
|
||||
"vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"),
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "image", "feature": "observation.images.top"},
|
||||
{"type": "text", "text": "${vqa_query}"},
|
||||
],
|
||||
stream="high_level",
|
||||
if_present="vqa_query",
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${vqa}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="vqa",
|
||||
),
|
||||
],
|
||||
def _vqa_subrecipe(camera: str) -> TrainingRecipe:
|
||||
return TrainingRecipe(
|
||||
weight=1.0,
|
||||
bindings={
|
||||
"vqa_query": f"emitted_at(t, style=vqa, role=user, camera={camera})",
|
||||
"vqa": f"emitted_at(t, style=vqa, role=assistant, camera={camera})",
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content=[{"type": "image", "feature": camera}, {"type": "text", "text": "${vqa_query}"}],
|
||||
stream="high_level",
|
||||
if_present="vqa_query",
|
||||
),
|
||||
"wrist": TrainingRecipe(
|
||||
weight=1.0,
|
||||
bindings={
|
||||
"vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"),
|
||||
"vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"),
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "image", "feature": "observation.images.wrist"},
|
||||
{"type": "text", "text": "${vqa_query}"},
|
||||
],
|
||||
stream="high_level",
|
||||
if_present="vqa_query",
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${vqa}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="vqa",
|
||||
),
|
||||
],
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${vqa}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="vqa",
|
||||
),
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
rendered_top = render_sample(
|
||||
recipe=recipe.blend["top"],
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS_AT_3_TWO_CAMERAS,
|
||||
t=3.0,
|
||||
sample_idx=0,
|
||||
)
|
||||
rendered_wrist = render_sample(
|
||||
recipe=recipe.blend["wrist"],
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("camera", "expected_query", "expected_answer"),
|
||||
[
|
||||
("observation.images.top", "how many cups (top)?", '{"count": 3}'),
|
||||
("observation.images.wrist", "how many cups (wrist)?", '{"count": 1}'),
|
||||
],
|
||||
)
|
||||
def test_per_camera_blend_renders_both_views(camera, expected_query, expected_answer):
|
||||
rendered = render_sample(
|
||||
recipe=_vqa_subrecipe(camera),
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS_AT_3_TWO_CAMERAS,
|
||||
t=3.0,
|
||||
sample_idx=0,
|
||||
)
|
||||
|
||||
assert rendered_top["messages"][0]["content"][0]["feature"] == "observation.images.top"
|
||||
assert rendered_top["messages"][0]["content"][1]["text"] == "how many cups (top)?"
|
||||
assert rendered_top["messages"][1]["content"] == '{"count": 3}'
|
||||
|
||||
assert rendered_wrist["messages"][0]["content"][0]["feature"] == "observation.images.wrist"
|
||||
assert rendered_wrist["messages"][0]["content"][1]["text"] == "how many cups (wrist)?"
|
||||
assert rendered_wrist["messages"][1]["content"] == '{"count": 1}'
|
||||
assert rendered["messages"][0]["content"][0]["feature"] == camera
|
||||
assert rendered["messages"][0]["content"][1]["text"] == expected_query
|
||||
assert rendered["messages"][1]["content"] == expected_answer
|
||||
|
||||
|
||||
def test_resolve_task_picks_rephrasing_deterministically_per_sample():
|
||||
|
||||
Reference in New Issue
Block a user