mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 15:09:51 +00:00
review: dedupe regex, centralize column names, harden collate, more tests
* **#2 — dedupe `_PLACEHOLDER_RE`.** The same regex was compiled in
`recipe.py` and `language_render.py`. Promote to module-level
`PLACEHOLDER_RE` in `recipe.py` (its primary owner — declares
template syntax) and import from `language_render.py`.
* **#3 — centralize language column names.** `io_utils.py` had
hardcoded `{"language_persistent", "language_events"}` literals at
two sites. Replace with `LANGUAGE_COLUMNS` import so a future column
rename can't silently desync.
* **#4 — defensive collate preserved-keys.** `lerobot_collate_fn`
silently filtered language fields from samples that didn't have
them, which would hand downstream consumers a preserved list
shorter than the tensor batch. Now: if any sample carries a key,
every sample in the batch must carry it; otherwise raise a
`ValueError` so the upstream rendering bug surfaces at the boundary.
* **#5 — `_scalar` rejects non-singleton lists.** Previously a zero-
or multi-element list fell through and triggered confusing
`float([])` errors downstream. Now raises `ValueError` with the
actual length.
* **#6 — refactor `_extract_complementary_data`.** Replace 11 lines
of `key = {... if ... else {}}` plus an 11-line splat dict with a
single `_COMPLEMENTARY_KEYS` tuple iterated once.
* **#7 — document `EXTENDED_STYLES`.** Was an empty `set()` with no
comment. Add a docstring explaining it's an intentional extension
point: downstream modules append project-local styles before
`column_for_style` is called.
* **#9 — `tools.mdx` notes the runtime layer is future work.** The
page referenced `src/lerobot/tools/`, `registry.py`, and
`get_tools(meta)` — none exist in this PR. Added a callout at the
start of "How to add your own tool" plus a note on the
implementations paragraph.
* **#10 — tests for YAML round-trip, malformed rows, blend
validation.** `test_recipe.py` grew from 1 case to 12 covering:
blend-or-messages exclusivity, target-turn requirement, blend
emptiness, weight presence/positivity, nested-blend rejection,
`from_dict` with nested blends, `from_yaml` / `load_recipe`
agreement, top-level non-mapping rejection. Added a malformed-row
test for `_normalize_rows` that asserts non-dict entries raise
`TypeError`.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -34,7 +34,10 @@ DEFAULT_BINDINGS = {
|
||||
"vqa_query": "emitted_at(t, style=vqa, role=user)",
|
||||
}
|
||||
|
||||
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
"""``${name}`` placeholder pattern used by both recipe binding-reference
|
||||
discovery (here) and rendered-message substitution (in ``language_render``)."""
|
||||
|
||||
_VALID_ROLES = frozenset(get_args(MessageRole))
|
||||
_VALID_STREAMS = frozenset(get_args(MessageStream))
|
||||
|
||||
@@ -178,13 +181,13 @@ def _placeholders_in_content(content: str | list[dict[str, Any]] | None) -> set[
|
||||
if content is None:
|
||||
return set()
|
||||
if isinstance(content, str):
|
||||
return set(_PLACEHOLDER_RE.findall(content))
|
||||
return set(PLACEHOLDER_RE.findall(content))
|
||||
|
||||
names: set[str] = set()
|
||||
for block in content:
|
||||
for value in block.values():
|
||||
if isinstance(value, str):
|
||||
names.update(_PLACEHOLDER_RE.findall(value))
|
||||
names.update(PLACEHOLDER_RE.findall(value))
|
||||
return names
|
||||
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ from torchvision import transforms
|
||||
from lerobot.utils.io_utils import load_json, write_json
|
||||
from lerobot.utils.utils import SuppressProgressBars, flatten_dict, unflatten_dict
|
||||
|
||||
from .language import LANGUAGE_COLUMNS
|
||||
from .utils import (
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
@@ -256,7 +257,7 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
||||
dict: The batch with items converted to torch tensors.
|
||||
"""
|
||||
for key in items_dict:
|
||||
if key in {"language_persistent", "language_events"}:
|
||||
if key in LANGUAGE_COLUMNS:
|
||||
continue
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
@@ -297,12 +298,9 @@ def item_to_torch(item: dict) -> dict:
|
||||
Returns:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
skip_keys = {"task", *LANGUAGE_COLUMNS}
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray | list)) and key not in [
|
||||
"task",
|
||||
"language_persistent",
|
||||
"language_events",
|
||||
]:
|
||||
if isinstance(val, (np.ndarray | list)) and key not in skip_keys:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
|
||||
@@ -37,7 +37,13 @@ CORE_STYLES = {
|
||||
"trace",
|
||||
"task_aug",
|
||||
}
|
||||
EXTENDED_STYLES = set()
|
||||
# Project-local styles can be registered at import time by appending to
|
||||
# ``EXTENDED_STYLES`` before ``column_for_style`` is called. Anything added
|
||||
# here is treated as a known style alongside ``CORE_STYLES`` for resolver
|
||||
# validation. Empty by default — populate from a downstream module that
|
||||
# also extends ``PERSISTENT_STYLES`` or ``EVENT_ONLY_STYLES`` to declare
|
||||
# the new style's column.
|
||||
EXTENDED_STYLES: set[str] = set()
|
||||
STYLE_REGISTRY = CORE_STYLES | EXTENDED_STYLES
|
||||
|
||||
PERSISTENT_STYLES = {"subtask", "plan", "memory", "motion", "task_aug"}
|
||||
|
||||
@@ -22,7 +22,7 @@ import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, TrainingRecipe
|
||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe
|
||||
|
||||
from .language import LANGUAGE_PERSISTENT, column_for_style
|
||||
|
||||
@@ -30,7 +30,6 @@ LanguageRow = dict[str, Any]
|
||||
RenderedMessages = dict[str, list[Any]]
|
||||
|
||||
_RESOLVER_RE = re.compile(r"^(?P<name>[A-Za-z_][A-Za-z0-9_]*)\((?P<args>.*)\)$")
|
||||
_PLACEHOLDER_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
|
||||
|
||||
def active_at(
|
||||
@@ -376,7 +375,7 @@ def _substitute(template: str, bindings: dict[str, LanguageRow | str | None]) ->
|
||||
return "" if content is None else str(content)
|
||||
return str(value)
|
||||
|
||||
return _PLACEHOLDER_RE.sub(replace, template)
|
||||
return PLACEHOLDER_RE.sub(replace, template)
|
||||
|
||||
|
||||
def _validate_rendered(rendered: RenderedMessages) -> None:
|
||||
|
||||
@@ -153,49 +153,30 @@ def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | An
|
||||
return x
|
||||
|
||||
|
||||
_COMPLEMENTARY_KEYS = (
|
||||
"task",
|
||||
"index",
|
||||
"task_index",
|
||||
"episode_index",
|
||||
"timestamp",
|
||||
"language_persistent",
|
||||
"language_events",
|
||||
"messages",
|
||||
"message_streams",
|
||||
"target_message_indices",
|
||||
)
|
||||
|
||||
|
||||
def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Extract complementary data from a batch dictionary.
|
||||
"""Extract complementary data from a batch dictionary.
|
||||
|
||||
This includes padding flags, task description, and indices.
|
||||
|
||||
Args:
|
||||
batch: The batch dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary with the extracted complementary data.
|
||||
Includes padding flags (any key containing ``_is_pad``) plus the fixed
|
||||
set of metadata / language keys defined in ``_COMPLEMENTARY_KEYS`` —
|
||||
each only when present in ``batch``.
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
timestamp_key = {"timestamp": batch["timestamp"]} if "timestamp" in batch else {}
|
||||
language_persistent_key = (
|
||||
{"language_persistent": batch["language_persistent"]} if "language_persistent" in batch else {}
|
||||
)
|
||||
language_events_key = {"language_events": batch["language_events"]} if "language_events" in batch else {}
|
||||
messages_key = {"messages": batch["messages"]} if "messages" in batch else {}
|
||||
message_streams_key = {"message_streams": batch["message_streams"]} if "message_streams" in batch else {}
|
||||
target_message_indices_key = (
|
||||
{"target_message_indices": batch["target_message_indices"]}
|
||||
if "target_message_indices" in batch
|
||||
else {}
|
||||
)
|
||||
|
||||
return {
|
||||
**pad_keys,
|
||||
**task_key,
|
||||
**index_key,
|
||||
**task_index_key,
|
||||
**episode_index_key,
|
||||
**timestamp_key,
|
||||
**language_persistent_key,
|
||||
**language_events_key,
|
||||
**messages_key,
|
||||
**message_streams_key,
|
||||
**target_message_indices_key,
|
||||
}
|
||||
extras = {k: batch[k] for k in _COMPLEMENTARY_KEYS if k in batch}
|
||||
return {**pad_keys, **extras}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -87,6 +87,8 @@ def _scalar(value: Any) -> float | int:
|
||||
"""Unwrap a tensor/array/single-element list into a Python scalar."""
|
||||
if hasattr(value, "item"):
|
||||
return value.item()
|
||||
if isinstance(value, list) and len(value) == 1:
|
||||
if isinstance(value, list):
|
||||
if len(value) != 1:
|
||||
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
|
||||
return _scalar(value[0])
|
||||
return value
|
||||
|
||||
@@ -36,11 +36,22 @@ def lerobot_collate_fn(batch: list[dict[str, Any] | None]) -> dict[str, Any] | N
|
||||
if not batch:
|
||||
return None
|
||||
|
||||
preserved = {
|
||||
key: [sample[key] for sample in batch if key in sample]
|
||||
for key in _PYTHON_LIST_KEYS
|
||||
if any(key in sample for sample in batch)
|
||||
}
|
||||
# All-or-nothing per key: a partial-presence batch (e.g. half the samples
|
||||
# carry `messages` and half don't) is a real bug in the upstream
|
||||
# rendering step — silently filtering would hand downstream consumers a
|
||||
# preserved list shorter than the tensor batch. Raise instead so the
|
||||
# mismatch surfaces at the boundary.
|
||||
preserved: dict[str, list[Any]] = {}
|
||||
for key in _PYTHON_LIST_KEYS:
|
||||
presence = [key in sample for sample in batch]
|
||||
if not any(presence):
|
||||
continue
|
||||
if not all(presence):
|
||||
raise ValueError(
|
||||
f"Inconsistent batch: {sum(presence)}/{len(batch)} samples carry {key!r}; "
|
||||
f"every sample in a batch must agree."
|
||||
)
|
||||
preserved[key] = [sample[key] for sample in batch]
|
||||
tensorizable = [
|
||||
{
|
||||
key: value
|
||||
|
||||
Reference in New Issue
Block a user