review: address CarolinePascal feedback

- language timestamps: float64 -> float32 to match LeRobotDataset frame
  timestamps (Arrow struct + HF feature)
- dataset_metadata: hoist `.language` imports to module top — language.py
  has no lerobot imports, so there is no circular-import risk
- dataset_metadata: add a `meta.tools` setter that persists the catalog to
  info.json and reloads `meta.info`
- feature_utils: validate the `language` dtype instead of returning "" —
  warn (non-fatal) when a non-empty value is written at record time
- centralize the scalar-unwrap helper as `lerobot.utils.utils.unwrap_scalar`,
  shared by render_messages_processor and language_render
- docs: move `## Layer 2 — recipe anatomy` ahead of the resolver sections,
  which describe recipe bindings rather than dataset layout
- language_render: note in EMITTED_AT_TOLERANCE_S that persistent rows change
  on a human-action timescale, not the camera frame rate

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-18 11:04:55 +02:00
parent bce5387e04
commit 949a0505a1
9 changed files with 168 additions and 46 deletions
+18 -18
View File
@@ -40,7 +40,7 @@ frame the row sits on already provides it):
role: string
content: string | null
style: string | null
timestamp: float64 # persistent rows only
timestamp: float32 # persistent rows only
camera: string | null # observation.images.* feature key, view-dependent rows only
tool_calls: list[Json] | null
```
@@ -64,6 +64,23 @@ The language stack itself has three internal modules backing layer 1:
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
## Layer 2 — recipe anatomy
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They
declare which annotation rows to pull (via `bindings`) and how to compose them
into chat turns (`messages`).
```yaml
messages:
- { role: user, content: "${task}", stream: high_level }
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
```
A recipe can also branch into a weighted **blend** of sub-recipes. At sample
time, exactly one branch is selected deterministically from the sample index,
so different frames train different objectives (e.g. memory updates vs.
low-level execution vs. VQA) without any Python wiring.
### Temporal semantics
Persistent styles are active after emission until replaced:
@@ -112,23 +129,6 @@ ask_vqa_top:
Add one such sub-recipe per camera the dataset records.
## Layer 2 — recipe anatomy
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They
declare which annotation rows to pull (via `bindings`) and how to compose them
into chat turns (`messages`).
```yaml
messages:
- { role: user, content: "${task}", stream: high_level }
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
```
A recipe can also branch into a weighted **blend** of sub-recipes. At sample
time, exactly one branch is selected deterministically from the sample index,
so different frames train different objectives (e.g. memory updates vs.
low-level execution vs. VQA) without any Python wiring.
## Layer 3 — training format
Rendered samples use HF-style chat messages plus LeRobot sidecars:
+15 -4
View File
@@ -39,6 +39,7 @@ from .io_utils import (
write_stats,
write_tasks,
)
from .language import DEFAULT_TOOLS, LANGUAGE_COLUMNS
from .utils import (
DEFAULT_EPISODES_PATH,
check_version_compatibility,
@@ -323,8 +324,6 @@ class LeRobotDatasetMetadata:
Used to gate language-aware code paths (collate, render step) so
unannotated datasets keep PyTorch's default collate behavior.
"""
from .language import LANGUAGE_COLUMNS # noqa: PLC0415 (avoid circular import)
return any(col in self.features for col in LANGUAGE_COLUMNS)
@property
@@ -342,13 +341,25 @@ class LeRobotDatasetMetadata:
Implementations live under :mod:`lerobot.tools` (one file per
tool); see ``docs/source/tools.mdx`` for the authoring guide.
"""
from .language import DEFAULT_TOOLS # noqa: PLC0415 (avoid circular import)
declared = self.info.tools
if declared:
return [dict(t) for t in declared]
return [dict(t) for t in DEFAULT_TOOLS]
@tools.setter
def tools(self, value: list[dict] | None) -> None:
"""Persist a tool catalog to ``meta/info.json`` and reload metadata.
Writes ``value`` into the on-disk ``info.json`` (or clears the
``tools`` key when ``value`` is ``None`` or empty), then reloads
``self.info`` so the in-memory metadata matches what's on disk.
Saves callers from hand-editing ``info.json`` and re-instantiating
the metadata object.
"""
self.info.tools = [dict(t) for t in value] if value else None
write_info(self.info, self.root)
self.info = load_info(self.root)
@property
def names(self) -> dict[str, list | dict]:
"""Names of the various dimensions of vector modalities."""
+26 -1
View File
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pprint import pformat
import datasets
@@ -255,7 +256,7 @@ def validate_feature_dtype_and_shape(
elif expected_dtype == "string":
return validate_feature_string(name, value)
elif expected_dtype == "language":
return ""
return validate_feature_language(name, value)
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
@@ -335,6 +336,30 @@ def validate_feature_string(name: str, value: str) -> str:
return ""
def validate_feature_language(name: str, value) -> str:
"""Validate a feature that is expected to hold language annotations.
Language columns (``language_persistent`` / ``language_events``) are
populated after recording by the annotation pipeline, not at record time.
Any value supplied here is dropped before the frame is written, so a
non-empty value almost certainly signals a mistake. We warn rather than
fail to keep recording resilient.
Args:
name (str): The name of the feature.
value: The value to validate.
Returns:
str: Always an empty string language values are non-fatal.
"""
if value is not None:
logging.warning(
f"The feature '{name}' is a 'language' column populated by the annotation pipeline, "
f"not at record time. The provided value will be dropped."
)
return ""
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
"""Validate the episode buffer before it's written to disk.
+4 -2
View File
@@ -79,13 +79,15 @@ def language_persistent_row_arrow_type() -> pa.StructType:
Persistent rows carry their own ``timestamp`` because they represent a state
that became active at a specific moment and remains active until superseded.
``timestamp`` is ``float32`` to match the timestamp dtype LeRobotDataset
uses for frame data.
"""
return pa.struct(
[
pa.field("role", pa.string(), nullable=False),
pa.field("content", pa.string(), nullable=True),
pa.field("style", pa.string(), nullable=True),
pa.field("timestamp", pa.float64(), nullable=False),
pa.field("timestamp", pa.float32(), nullable=False),
pa.field("camera", pa.string(), nullable=True),
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
]
@@ -125,7 +127,7 @@ def language_persistent_row_feature() -> dict[str, object]:
"role": datasets.Value("string"),
"content": datasets.Value("string"),
"style": datasets.Value("string"),
"timestamp": datasets.Value("float64"),
"timestamp": datasets.Value("float32"),
"camera": datasets.Value("string"),
"tool_calls": datasets.List(_json_feature()),
}
+10 -8
View File
@@ -23,6 +23,7 @@ from collections.abc import Sequence
from typing import Any
from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe
from lerobot.utils.utils import unwrap_scalar
from .language import LANGUAGE_PERSISTENT, column_for_style
@@ -67,12 +68,16 @@ def active_at(
EMITTED_AT_TOLERANCE_S = 0.1
"""Half-window for matching persistent rows to a frame timestamp in
``emitted_at``. Persistent timestamps come from parquet (float64) and ``t``
is also a float64 from parquet, so in the ideal hot path an exact match
``emitted_at``. Persistent timestamps come from parquet (float32) and ``t``
is also a float32 from parquet, so in the ideal hot path an exact match
would suffice but any caller that derives ``t`` arithmetically (e.g.
``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers
common arithmetic drift without admitting frames that are visibly far
apart at typical control rates (30100 Hz)."""
apart at typical control rates (30100 Hz). This does mean two persistent
rows of the same selector emitted within 0.1 s of each other cannot be
told apart by ``emitted_at`` acceptable because persistent annotations
(subtask / plan / memory transitions) change on a human-action timescale,
not at the camera frame rate."""
def emitted_at(
@@ -506,16 +511,13 @@ def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]:
bucket and are tiebroken by ``(style, role)``.
"""
timestamp = row.get("timestamp")
ts = (
float(timestamp.item() if hasattr(timestamp, "item") else timestamp) if timestamp is not None else 0.0
)
ts = float(unwrap_scalar(timestamp)) if timestamp is not None else 0.0
return (ts, row.get("style") or "", row.get("role") or "")
def _timestamp(row: LanguageRow) -> float:
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
value = row["timestamp"]
return float(value.item() if hasattr(value, "item") else value)
return float(unwrap_scalar(row["timestamp"]))
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
@@ -24,6 +24,7 @@ from lerobot.configs.recipe import TrainingRecipe
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
from lerobot.datasets.language_render import render_sample
from lerobot.types import EnvTransition, TransitionKey
from lerobot.utils.utils import unwrap_scalar
from .pipeline import ProcessorStep, ProcessorStepRegistry
@@ -60,8 +61,8 @@ class RenderMessagesStep(ProcessorStep):
recipe=self.recipe,
persistent=persistent,
events=events,
t=_scalar(timestamp),
sample_idx=int(_scalar(sample_idx)),
t=unwrap_scalar(timestamp),
sample_idx=int(unwrap_scalar(sample_idx)),
task=complementary_data.get("task"),
dataset_ctx=self.dataset_ctx,
)
@@ -81,14 +82,3 @@ class RenderMessagesStep(ProcessorStep):
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Pass features through unchanged; rendering only touches complementary data."""
return features
def _scalar(value: Any) -> float | int:
"""Unwrap a tensor/array/single-element list into a Python scalar."""
if hasattr(value, "item"):
return value.item()
if isinstance(value, list):
if len(value) != 1:
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
return _scalar(value[0])
return value
+19
View File
@@ -160,6 +160,25 @@ def has_method(cls: object, method_name: str) -> bool:
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
def unwrap_scalar(value: Any) -> Any:
"""Unwrap a tensor / numpy scalar / single-element list into a Python scalar.
Tensors and numpy scalars expose ``.item()``; single-element lists are
unwrapped recursively. Anything else is returned unchanged. Centralized
here so the language renderer and processor steps share one definition.
Raises:
ValueError: If ``value`` is a list with zero or multiple elements.
"""
if hasattr(value, "item"):
return value.item()
if isinstance(value, list):
if len(value) != 1:
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
return unwrap_scalar(value[0])
return value
def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
"""
Return True if a given string can be converted to a numpy dtype.
+56
View File
@@ -466,3 +466,59 @@ def test_tools_round_trip_through_dataset_info(tmp_path):
info = DatasetInfo.from_dict(raw)
assert info.tools == raw["tools"]
assert info.to_dict()["tools"] == raw["tools"]
def test_tools_setter_persists_to_info_json_and_reloads(tmp_path):
"""Assigning meta.tools writes info.json and reloads meta.info."""
from lerobot.datasets.io_utils import load_info
root = tmp_path / "set_tools"
meta = LeRobotDatasetMetadata.create(
repo_id="test/set_tools",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=root,
use_videos=False,
)
custom_tool = {
"type": "function",
"function": {
"name": "record_observation",
"description": "Capture a still image.",
"parameters": {
"type": "object",
"properties": {"label": {"type": "string"}},
"required": ["label"],
},
},
}
meta.tools = [custom_tool]
# In-memory metadata reflects the new catalog ...
assert meta.tools == [custom_tool]
assert meta.info.tools == [custom_tool]
# ... and a fresh read from disk agrees.
assert load_info(root).tools == [custom_tool]
def test_tools_setter_clears_key_when_set_to_none(tmp_path):
"""Setting meta.tools back to None drops the key and restores the default."""
from lerobot.datasets.language import DEFAULT_TOOLS
root = tmp_path / "clear_tools"
meta = LeRobotDatasetMetadata.create(
repo_id="test/clear_tools",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=root,
use_videos=False,
)
meta.tools = [{"type": "function", "function": {"name": "say"}}]
meta.tools = None
assert meta.tools == DEFAULT_TOOLS
with open(root / INFO_PATH) as f:
info_on_disk = json.load(f)
assert "tools" not in info_on_disk
+17
View File
@@ -45,6 +45,23 @@ def test_language_arrow_schema_has_expected_fields():
assert isinstance(event_row_type, pa.StructType)
assert event_row_type.names == ["role", "content", "style", "camera", "tool_calls"]
# Persistent-row timestamps use float32, matching LeRobotDataset frame timestamps.
assert persistent_row_type.field("timestamp").type == pa.float32()
def test_validate_feature_language_warns_only_on_non_empty_value(caplog):
from lerobot.datasets.feature_utils import validate_feature_language
# None (the expected record-time value) is silent and non-fatal.
with caplog.at_level("WARNING"):
assert validate_feature_language("language_persistent", None) == ""
assert caplog.records == []
# A stray non-empty value is dropped later, so we warn rather than fail.
with caplog.at_level("WARNING"):
assert validate_feature_language("language_persistent", [{"role": "user"}]) == ""
assert any("language_persistent" in r.message for r in caplog.records)
def test_style_registry_routes_columns():
assert {"subtask", "plan", "memory", "motion", "task_aug"} == PERSISTENT_STYLES