mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
refactor(language): unify resolver dispatch and prune redundant test scaffolding
* Drop the unused `events` kwarg from `active_at`/`nth_prev`/`nth_next`; only `emitted_at` actually consults events. The dispatcher in `_resolve_spec` now passes events conditionally. * Replace the dual `_persistent_sort_key`/`_event_sort_key` pair with a single `_row_sort_key` and drop the `sort_key` parameter from `_select_one`. Event rows lack `timestamp` (it is implicit in the frame) and now default to `0.0` for sort purposes — the `(style, role)` tiebreaker is unchanged. * Inline `_select_latest` into `active_at` (its only caller). * Collapse `emitted_at`'s dual-branch into one `_select_one` call. * Tighten `_validate_persistent_resolver` to a single `column_for_style(style) != LANGUAGE_PERSISTENT` check. * Parameterize `test_per_camera_blend_renders_both_views` over the two cameras and factor the sub-recipe builder into `_vqa_subrecipe` so the test no longer hand-rolls two near-identical recipe blocks. Net -98 LOC; behavior, public resolver names, and test expectations unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -24,12 +24,7 @@ from typing import Any
|
|||||||
|
|
||||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe
|
from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe
|
||||||
|
|
||||||
from .language import (
|
from .language import LANGUAGE_PERSISTENT, column_for_style
|
||||||
EVENT_ONLY_STYLES,
|
|
||||||
LANGUAGE_PERSISTENT,
|
|
||||||
PERSISTENT_STYLES,
|
|
||||||
column_for_style,
|
|
||||||
)
|
|
||||||
|
|
||||||
LanguageRow = dict[str, Any]
|
LanguageRow = dict[str, Any]
|
||||||
RenderedMessages = dict[str, list[Any]]
|
RenderedMessages = dict[str, list[Any]]
|
||||||
@@ -42,7 +37,6 @@ def active_at(
|
|||||||
t: float,
|
t: float,
|
||||||
*,
|
*,
|
||||||
persistent: Sequence[LanguageRow],
|
persistent: Sequence[LanguageRow],
|
||||||
events: Sequence[LanguageRow] | None = None,
|
|
||||||
style: str | None = None,
|
style: str | None = None,
|
||||||
role: str | None = None,
|
role: str | None = None,
|
||||||
tool_name: str | None = None,
|
tool_name: str | None = None,
|
||||||
@@ -52,16 +46,23 @@ def active_at(
|
|||||||
|
|
||||||
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
|
A persistent row is "active" at ``t`` when its own ``timestamp`` is the
|
||||||
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
|
most recent one ``<= t`` for the given ``style``/``role``/``tool_name``/
|
||||||
``camera`` selector. ``events`` is accepted for resolver-signature
|
``camera`` selector. Only valid for persistent styles.
|
||||||
uniformity but is not consulted: only persistent styles are valid here.
|
|
||||||
"""
|
"""
|
||||||
_validate_persistent_resolver("active_at", style)
|
_validate_persistent_resolver("active_at", style)
|
||||||
matches = _matching_rows(
|
matches = [
|
||||||
persistent, style=style, role=role, tool_name=tool_name, camera=camera
|
row
|
||||||
)
|
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||||
matches = [row for row in matches if _timestamp(row) <= t]
|
if _timestamp(row) <= t
|
||||||
return _select_latest(
|
]
|
||||||
matches, style=style, role=role, tool_name=tool_name, camera=camera
|
if not matches:
|
||||||
|
return None
|
||||||
|
latest_ts = max(_timestamp(row) for row in matches)
|
||||||
|
return _select_one(
|
||||||
|
[row for row in matches if _timestamp(row) == latest_ts],
|
||||||
|
style=style,
|
||||||
|
role=role,
|
||||||
|
tool_name=tool_name,
|
||||||
|
camera=camera,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -84,41 +85,21 @@ def emitted_at(
|
|||||||
filters by the row's ``camera`` field — required to disambiguate when
|
filters by the row's ``camera`` field — required to disambiguate when
|
||||||
multiple view-dependent rows share ``(t, role)`` across cameras.
|
multiple view-dependent rows share ``(t, role)`` across cameras.
|
||||||
"""
|
"""
|
||||||
column = column_for_style(style)
|
if column_for_style(style) == LANGUAGE_PERSISTENT:
|
||||||
if column == LANGUAGE_PERSISTENT:
|
|
||||||
matches = [
|
matches = [
|
||||||
row
|
row
|
||||||
for row in _matching_rows(
|
for row in _matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||||
persistent, style=style, role=role, tool_name=tool_name, camera=camera
|
|
||||||
)
|
|
||||||
if _timestamp(row) == t
|
if _timestamp(row) == t
|
||||||
]
|
]
|
||||||
return _select_one(
|
else:
|
||||||
matches,
|
matches = _matching_rows(events, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||||
style=style,
|
return _select_one(matches, style=style, role=role, tool_name=tool_name, camera=camera)
|
||||||
role=role,
|
|
||||||
tool_name=tool_name,
|
|
||||||
camera=camera,
|
|
||||||
sort_key=_persistent_sort_key,
|
|
||||||
)
|
|
||||||
matches = _matching_rows(
|
|
||||||
events, style=style, role=role, tool_name=tool_name, camera=camera
|
|
||||||
)
|
|
||||||
return _select_one(
|
|
||||||
matches,
|
|
||||||
style=style,
|
|
||||||
role=role,
|
|
||||||
tool_name=tool_name,
|
|
||||||
camera=camera,
|
|
||||||
sort_key=_event_sort_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def nth_prev(
|
def nth_prev(
|
||||||
t: float,
|
t: float,
|
||||||
*,
|
*,
|
||||||
persistent: Sequence[LanguageRow],
|
persistent: Sequence[LanguageRow],
|
||||||
events: Sequence[LanguageRow] | None = None,
|
|
||||||
style: str | None = None,
|
style: str | None = None,
|
||||||
offset: int = 1,
|
offset: int = 1,
|
||||||
role: str | None = None,
|
role: str | None = None,
|
||||||
@@ -132,23 +113,13 @@ def nth_prev(
|
|||||||
one ``offset`` positions before the row active at ``t``. Only valid for
|
one ``offset`` positions before the row active at ``t``. Only valid for
|
||||||
persistent styles.
|
persistent styles.
|
||||||
"""
|
"""
|
||||||
return _nth_relative(
|
return _nth_relative("nth_prev", t, persistent, style, -offset, role, tool_name, camera)
|
||||||
t,
|
|
||||||
persistent=persistent,
|
|
||||||
style=style,
|
|
||||||
offset=-offset,
|
|
||||||
role=role,
|
|
||||||
tool_name=tool_name,
|
|
||||||
camera=camera,
|
|
||||||
resolver_name="nth_prev",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def nth_next(
|
def nth_next(
|
||||||
t: float,
|
t: float,
|
||||||
*,
|
*,
|
||||||
persistent: Sequence[LanguageRow],
|
persistent: Sequence[LanguageRow],
|
||||||
events: Sequence[LanguageRow] | None = None,
|
|
||||||
style: str | None = None,
|
style: str | None = None,
|
||||||
offset: int = 1,
|
offset: int = 1,
|
||||||
role: str | None = None,
|
role: str | None = None,
|
||||||
@@ -162,16 +133,7 @@ def nth_next(
|
|||||||
one ``offset`` positions after the row active at ``t``. Only valid for
|
one ``offset`` positions after the row active at ``t``. Only valid for
|
||||||
persistent styles.
|
persistent styles.
|
||||||
"""
|
"""
|
||||||
return _nth_relative(
|
return _nth_relative("nth_next", t, persistent, style, offset, role, tool_name, camera)
|
||||||
t,
|
|
||||||
persistent=persistent,
|
|
||||||
style=style,
|
|
||||||
offset=offset,
|
|
||||||
role=role,
|
|
||||||
tool_name=tool_name,
|
|
||||||
camera=camera,
|
|
||||||
resolver_name="nth_next",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def render_sample(
|
def render_sample(
|
||||||
@@ -239,9 +201,7 @@ def _resolve_bindings(
|
|||||||
) -> dict[str, LanguageRow | str | None]:
|
) -> dict[str, LanguageRow | str | None]:
|
||||||
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
"""Resolve every binding in ``recipe`` (plus ``task``) at time ``t``."""
|
||||||
bindings: dict[str, LanguageRow | str | None] = {
|
bindings: dict[str, LanguageRow | str | None] = {
|
||||||
"task": _resolve_task(
|
"task": _resolve_task(task, dataset_ctx, persistent=persistent, sample_idx=sample_idx),
|
||||||
task, dataset_ctx, persistent=persistent, sample_idx=sample_idx
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
specs = {**DEFAULT_BINDINGS, **(recipe.bindings or {})}
|
||||||
for name, spec in specs.items():
|
for name, spec in specs.items():
|
||||||
@@ -275,18 +235,12 @@ def _resolve_task(
|
|||||||
if task is not None:
|
if task is not None:
|
||||||
return task
|
return task
|
||||||
|
|
||||||
aug_rows = [
|
aug_rows = [r for r in persistent if r.get("style") == "task_aug" and r.get("role") == "user"]
|
||||||
r
|
|
||||||
for r in persistent
|
|
||||||
if r.get("style") == "task_aug" and r.get("role") == "user"
|
|
||||||
]
|
|
||||||
if aug_rows:
|
if aug_rows:
|
||||||
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
# Deterministic, blake2b-based pick keyed on sample_idx so the
|
||||||
# rotation is reproducible across runs (Python's built-in ``hash``
|
# rotation is reproducible across runs (Python's built-in ``hash``
|
||||||
# is process-randomized).
|
# is process-randomized).
|
||||||
digest = hashlib.blake2b(
|
digest = hashlib.blake2b(f"task_aug:{sample_idx}".encode(), digest_size=8).digest()
|
||||||
f"task_aug:{sample_idx}".encode(), digest_size=8
|
|
||||||
).digest()
|
|
||||||
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
idx = int.from_bytes(digest, "big") % len(aug_rows)
|
||||||
chosen = aug_rows[idx].get("content")
|
chosen = aug_rows[idx].get("content")
|
||||||
if chosen:
|
if chosen:
|
||||||
@@ -314,15 +268,15 @@ def _resolve_spec(
|
|||||||
kwargs = _parse_resolver_args(match.group("args"))
|
kwargs = _parse_resolver_args(match.group("args"))
|
||||||
kwargs.pop("t_arg", None)
|
kwargs.pop("t_arg", None)
|
||||||
|
|
||||||
resolvers = {
|
if name == "emitted_at":
|
||||||
"active_at": active_at,
|
return emitted_at(t, persistent=persistent, events=events, **kwargs)
|
||||||
"emitted_at": emitted_at,
|
if name == "active_at":
|
||||||
"nth_prev": nth_prev,
|
return active_at(t, persistent=persistent, **kwargs)
|
||||||
"nth_next": nth_next,
|
if name == "nth_prev":
|
||||||
}
|
return nth_prev(t, persistent=persistent, **kwargs)
|
||||||
if name not in resolvers:
|
if name == "nth_next":
|
||||||
raise ValueError(f"Unknown language resolver: {name!r}")
|
return nth_next(t, persistent=persistent, **kwargs)
|
||||||
return resolvers[name](t, persistent=persistent, events=events, **kwargs)
|
raise ValueError(f"Unknown language resolver: {name!r}")
|
||||||
|
|
||||||
|
|
||||||
def _parse_resolver_args(args: str) -> dict[str, Any]:
|
def _parse_resolver_args(args: str) -> dict[str, Any]:
|
||||||
@@ -444,24 +398,23 @@ def _validate_rendered(rendered: RenderedMessages) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _nth_relative(
|
def _nth_relative(
|
||||||
|
name: str,
|
||||||
t: float,
|
t: float,
|
||||||
*,
|
|
||||||
persistent: Sequence[LanguageRow],
|
persistent: Sequence[LanguageRow],
|
||||||
style: str | None,
|
style: str | None,
|
||||||
offset: int,
|
offset: int,
|
||||||
role: str | None,
|
role: str | None,
|
||||||
tool_name: str | None,
|
tool_name: str | None,
|
||||||
camera: str | None,
|
camera: str | None,
|
||||||
resolver_name: str,
|
|
||||||
) -> LanguageRow | None:
|
) -> LanguageRow | None:
|
||||||
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
|
"""Shared body for ``nth_prev`` / ``nth_next`` with signed ``offset``."""
|
||||||
_validate_persistent_resolver(resolver_name, style)
|
_validate_persistent_resolver(name, style)
|
||||||
if abs(offset) < 1:
|
if abs(offset) < 1:
|
||||||
raise ValueError(f"{resolver_name} offset must be non-zero.")
|
raise ValueError(f"{name} offset must be non-zero.")
|
||||||
|
|
||||||
rows = sorted(
|
rows = sorted(
|
||||||
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
|
_matching_rows(persistent, style=style, role=role, tool_name=tool_name, camera=camera),
|
||||||
key=_persistent_sort_key,
|
key=_row_sort_key,
|
||||||
)
|
)
|
||||||
if not rows:
|
if not rows:
|
||||||
return None
|
return None
|
||||||
@@ -480,14 +433,12 @@ def _nth_relative(
|
|||||||
return rows[target_idx]
|
return rows[target_idx]
|
||||||
|
|
||||||
|
|
||||||
def _validate_persistent_resolver(resolver_name: str, style: str | None) -> None:
|
def _validate_persistent_resolver(name: str, style: str | None) -> None:
|
||||||
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
|
"""Reject calls with missing or event-only ``style`` for persistent resolvers."""
|
||||||
if style is None:
|
if style is None:
|
||||||
raise ValueError(f"{resolver_name} requires a persistent style.")
|
raise ValueError(f"{name} requires a persistent style.")
|
||||||
if style in EVENT_ONLY_STYLES:
|
if column_for_style(style) != LANGUAGE_PERSISTENT:
|
||||||
raise ValueError(f"{resolver_name} cannot be used with event-only style {style!r}.")
|
raise ValueError(f"{name} cannot be used with event-only style {style!r}.")
|
||||||
if style not in PERSISTENT_STYLES:
|
|
||||||
column_for_style(style)
|
|
||||||
|
|
||||||
|
|
||||||
def _matching_rows(
|
def _matching_rows(
|
||||||
@@ -509,29 +460,6 @@ def _matching_rows(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _select_latest(
|
|
||||||
rows: Sequence[LanguageRow],
|
|
||||||
*,
|
|
||||||
style: str | None,
|
|
||||||
role: str | None,
|
|
||||||
tool_name: str | None,
|
|
||||||
camera: str | None,
|
|
||||||
) -> LanguageRow | None:
|
|
||||||
"""Return the row tied for the latest ``timestamp`` (disambiguated by selectors)."""
|
|
||||||
if not rows:
|
|
||||||
return None
|
|
||||||
rows = sorted(rows, key=_persistent_sort_key)
|
|
||||||
latest_ts = _timestamp(rows[-1])
|
|
||||||
return _select_one(
|
|
||||||
[row for row in rows if _timestamp(row) == latest_ts],
|
|
||||||
style=style,
|
|
||||||
role=role,
|
|
||||||
tool_name=tool_name,
|
|
||||||
camera=camera,
|
|
||||||
sort_key=_persistent_sort_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _select_one(
|
def _select_one(
|
||||||
rows: Sequence[LanguageRow],
|
rows: Sequence[LanguageRow],
|
||||||
*,
|
*,
|
||||||
@@ -539,9 +467,13 @@ def _select_one(
|
|||||||
role: str | None,
|
role: str | None,
|
||||||
tool_name: str | None,
|
tool_name: str | None,
|
||||||
camera: str | None,
|
camera: str | None,
|
||||||
sort_key: Any,
|
|
||||||
) -> LanguageRow | None:
|
) -> LanguageRow | None:
|
||||||
"""Return the single matching row, or raise if the selectors are ambiguous."""
|
"""Return the single matching row, or raise if the selectors are ambiguous.
|
||||||
|
|
||||||
|
Ties are broken deterministically by ``_row_sort_key`` so that
|
||||||
|
multiple rows with identical ``(style, role, tool_name, camera)`` still
|
||||||
|
resolve to a stable choice.
|
||||||
|
"""
|
||||||
if not rows:
|
if not rows:
|
||||||
return None
|
return None
|
||||||
if len(rows) > 1 and role is None and tool_name is None and camera is None:
|
if len(rows) > 1 and role is None and tool_name is None and camera is None:
|
||||||
@@ -549,17 +481,21 @@ def _select_one(
|
|||||||
f"Ambiguous resolver for style={style!r}; add role=..., tool_name=..., "
|
f"Ambiguous resolver for style={style!r}; add role=..., tool_name=..., "
|
||||||
f"or camera=... to disambiguate."
|
f"or camera=... to disambiguate."
|
||||||
)
|
)
|
||||||
return sorted(rows, key=sort_key)[0]
|
return sorted(rows, key=_row_sort_key)[0]
|
||||||
|
|
||||||
|
|
||||||
def _persistent_sort_key(row: LanguageRow) -> tuple[float, str, str]:
|
def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]:
|
||||||
"""Sort key for persistent rows: ``(timestamp, style, role)``."""
|
"""Stable sort key for both persistent and event rows.
|
||||||
return (_timestamp(row), row.get("style") or "", row.get("role") or "")
|
|
||||||
|
|
||||||
|
Event rows lack ``timestamp`` (it is implicit in the frame), so default
|
||||||
def _event_sort_key(row: LanguageRow) -> tuple[str, str]:
|
to ``0.0`` — within a single frame all event rows share the same sort
|
||||||
"""Sort key for event rows: ``(style, role)`` (timestamp is implicit in the frame)."""
|
bucket and are tiebroken by ``(style, role)``.
|
||||||
return (row.get("style") or "", row.get("role") or "")
|
"""
|
||||||
|
timestamp = row.get("timestamp")
|
||||||
|
ts = (
|
||||||
|
float(timestamp.item() if hasattr(timestamp, "item") else timestamp) if timestamp is not None else 0.0
|
||||||
|
)
|
||||||
|
return (ts, row.get("style") or "", row.get("role") or "")
|
||||||
|
|
||||||
|
|
||||||
def _timestamp(row: LanguageRow) -> float:
|
def _timestamp(row: LanguageRow) -> float:
|
||||||
|
|||||||
@@ -199,84 +199,50 @@ def test_emitted_at_raises_on_ambiguous_per_camera_vqa():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_per_camera_blend_renders_both_views():
|
def _vqa_subrecipe(camera: str) -> TrainingRecipe:
|
||||||
recipe = TrainingRecipe(
|
return TrainingRecipe(
|
||||||
blend={
|
weight=1.0,
|
||||||
"top": TrainingRecipe(
|
bindings={
|
||||||
weight=1.0,
|
"vqa_query": f"emitted_at(t, style=vqa, role=user, camera={camera})",
|
||||||
bindings={
|
"vqa": f"emitted_at(t, style=vqa, role=assistant, camera={camera})",
|
||||||
"vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.top)"),
|
},
|
||||||
"vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"),
|
messages=[
|
||||||
},
|
MessageTurn(
|
||||||
messages=[
|
role="user",
|
||||||
MessageTurn(
|
content=[{"type": "image", "feature": camera}, {"type": "text", "text": "${vqa_query}"}],
|
||||||
role="user",
|
stream="high_level",
|
||||||
content=[
|
if_present="vqa_query",
|
||||||
{"type": "image", "feature": "observation.images.top"},
|
|
||||||
{"type": "text", "text": "${vqa_query}"},
|
|
||||||
],
|
|
||||||
stream="high_level",
|
|
||||||
if_present="vqa_query",
|
|
||||||
),
|
|
||||||
MessageTurn(
|
|
||||||
role="assistant",
|
|
||||||
content="${vqa}",
|
|
||||||
stream="high_level",
|
|
||||||
target=True,
|
|
||||||
if_present="vqa",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
"wrist": TrainingRecipe(
|
MessageTurn(
|
||||||
weight=1.0,
|
role="assistant",
|
||||||
bindings={
|
content="${vqa}",
|
||||||
"vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"),
|
stream="high_level",
|
||||||
"vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"),
|
target=True,
|
||||||
},
|
if_present="vqa",
|
||||||
messages=[
|
|
||||||
MessageTurn(
|
|
||||||
role="user",
|
|
||||||
content=[
|
|
||||||
{"type": "image", "feature": "observation.images.wrist"},
|
|
||||||
{"type": "text", "text": "${vqa_query}"},
|
|
||||||
],
|
|
||||||
stream="high_level",
|
|
||||||
if_present="vqa_query",
|
|
||||||
),
|
|
||||||
MessageTurn(
|
|
||||||
role="assistant",
|
|
||||||
content="${vqa}",
|
|
||||||
stream="high_level",
|
|
||||||
target=True,
|
|
||||||
if_present="vqa",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
}
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
rendered_top = render_sample(
|
|
||||||
recipe=recipe.blend["top"],
|
@pytest.mark.parametrize(
|
||||||
persistent=PERSISTENT,
|
("camera", "expected_query", "expected_answer"),
|
||||||
events=EVENTS_AT_3_TWO_CAMERAS,
|
[
|
||||||
t=3.0,
|
("observation.images.top", "how many cups (top)?", '{"count": 3}'),
|
||||||
sample_idx=0,
|
("observation.images.wrist", "how many cups (wrist)?", '{"count": 1}'),
|
||||||
)
|
],
|
||||||
rendered_wrist = render_sample(
|
)
|
||||||
recipe=recipe.blend["wrist"],
|
def test_per_camera_blend_renders_both_views(camera, expected_query, expected_answer):
|
||||||
|
rendered = render_sample(
|
||||||
|
recipe=_vqa_subrecipe(camera),
|
||||||
persistent=PERSISTENT,
|
persistent=PERSISTENT,
|
||||||
events=EVENTS_AT_3_TWO_CAMERAS,
|
events=EVENTS_AT_3_TWO_CAMERAS,
|
||||||
t=3.0,
|
t=3.0,
|
||||||
sample_idx=0,
|
sample_idx=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert rendered_top["messages"][0]["content"][0]["feature"] == "observation.images.top"
|
assert rendered["messages"][0]["content"][0]["feature"] == camera
|
||||||
assert rendered_top["messages"][0]["content"][1]["text"] == "how many cups (top)?"
|
assert rendered["messages"][0]["content"][1]["text"] == expected_query
|
||||||
assert rendered_top["messages"][1]["content"] == '{"count": 3}'
|
assert rendered["messages"][1]["content"] == expected_answer
|
||||||
|
|
||||||
assert rendered_wrist["messages"][0]["content"][0]["feature"] == "observation.images.wrist"
|
|
||||||
assert rendered_wrist["messages"][0]["content"][1]["text"] == "how many cups (wrist)?"
|
|
||||||
assert rendered_wrist["messages"][1]["content"] == '{"count": 1}'
|
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_task_picks_rephrasing_deterministically_per_sample():
|
def test_resolve_task_picks_rephrasing_deterministically_per_sample():
|
||||||
|
|||||||
Reference in New Issue
Block a user