From 949a0505a10475b4db986e38692a3ab83fd9cbfd Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 18 May 2026 11:04:55 +0200 Subject: [PATCH] review: address CarolinePascal feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - language timestamps: float64 -> float32 to match LeRobotDataset frame timestamps (Arrow struct + HF feature) - dataset_metadata: hoist `.language` imports to module top — language.py has no lerobot imports, so there is no circular-import risk - dataset_metadata: add a `meta.tools` setter that persists the catalog to info.json and reloads `meta.info` - feature_utils: validate the `language` dtype instead of returning "" — warn (non-fatal) when a non-empty value is written at record time - centralize the scalar-unwrap helper as `lerobot.utils.utils.unwrap_scalar`, shared by render_messages_processor and language_render - docs: move `## Layer 2 — recipe anatomy` ahead of the resolver sections, which describe recipe bindings rather than dataset layout - language_render: note in EMITTED_AT_TOLERANCE_S that persistent rows change on a human-action timescale, not the camera frame rate Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/source/language_and_recipes.mdx | 36 ++++++------ src/lerobot/datasets/dataset_metadata.py | 19 +++++-- src/lerobot/datasets/feature_utils.py | 27 ++++++++- src/lerobot/datasets/language.py | 6 +- src/lerobot/datasets/language_render.py | 18 +++--- .../processor/render_messages_processor.py | 16 +----- src/lerobot/utils/utils.py | 19 +++++++ tests/datasets/test_dataset_metadata.py | 56 +++++++++++++++++++ tests/datasets/test_language.py | 17 ++++++ 9 files changed, 168 insertions(+), 46 deletions(-) diff --git a/docs/source/language_and_recipes.mdx b/docs/source/language_and_recipes.mdx index 7eb588ce8..4181dbe34 100644 --- a/docs/source/language_and_recipes.mdx +++ b/docs/source/language_and_recipes.mdx @@ -40,7 +40,7 @@ frame the row sits on already provides it): role: string content: string | null style: string | null -timestamp: float64 # persistent rows only +timestamp: float32 # persistent rows only camera: string | null # observation.images.* feature key, view-dependent rows only tool_calls: list[Json] | null ``` @@ -64,6 +64,23 @@ The language stack itself has three internal modules backing layer 1: `LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior. +## Layer 2 — recipe anatomy + +Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They +declare which annotation rows to pull (via `bindings`) and how to compose them +into chat turns (`messages`). + +```yaml +messages: + - { role: user, content: "${task}", stream: high_level } + - { role: assistant, content: "${subtask}", stream: low_level, target: true } +``` + +A recipe can also branch into a weighted **blend** of sub-recipes. At sample +time, exactly one branch is selected deterministically from the sample index, +so different frames train different objectives (e.g. memory updates vs. +low-level execution vs. VQA) without any Python wiring. + ### Temporal semantics Persistent styles are active after emission until replaced: @@ -112,23 +129,6 @@ ask_vqa_top: Add one such sub-recipe per camera the dataset records. -## Layer 2 — recipe anatomy - -Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They -declare which annotation rows to pull (via `bindings`) and how to compose them -into chat turns (`messages`). - -```yaml -messages: - - { role: user, content: "${task}", stream: high_level } - - { role: assistant, content: "${subtask}", stream: low_level, target: true } -``` - -A recipe can also branch into a weighted **blend** of sub-recipes. At sample -time, exactly one branch is selected deterministically from the sample index, -so different frames train different objectives (e.g. memory updates vs. -low-level execution vs. VQA) without any Python wiring. - ## Layer 3 — training format Rendered samples use HF-style chat messages plus LeRobot sidecars: diff --git a/src/lerobot/datasets/dataset_metadata.py b/src/lerobot/datasets/dataset_metadata.py index 81c12c090..5ffebaf35 100644 --- a/src/lerobot/datasets/dataset_metadata.py +++ b/src/lerobot/datasets/dataset_metadata.py @@ -39,6 +39,7 @@ from .io_utils import ( write_stats, write_tasks, ) +from .language import DEFAULT_TOOLS, LANGUAGE_COLUMNS from .utils import ( DEFAULT_EPISODES_PATH, check_version_compatibility, @@ -323,8 +324,6 @@ class LeRobotDatasetMetadata: Used to gate language-aware code paths (collate, render step) so unannotated datasets keep PyTorch's default collate behavior. """ - from .language import LANGUAGE_COLUMNS # noqa: PLC0415 (avoid circular import) - return any(col in self.features for col in LANGUAGE_COLUMNS) @property @@ -342,13 +341,25 @@ class LeRobotDatasetMetadata: Implementations live under :mod:`lerobot.tools` (one file per tool); see ``docs/source/tools.mdx`` for the authoring guide. """ - from .language import DEFAULT_TOOLS # noqa: PLC0415 (avoid circular import) - declared = self.info.tools if declared: return [dict(t) for t in declared] return [dict(t) for t in DEFAULT_TOOLS] + @tools.setter + def tools(self, value: list[dict] | None) -> None: + """Persist a tool catalog to ``meta/info.json`` and reload metadata. + + Writes ``value`` into the on-disk ``info.json`` (or clears the + ``tools`` key when ``value`` is ``None`` or empty), then reloads + ``self.info`` so the in-memory metadata matches what's on disk. + Saves callers from hand-editing ``info.json`` and re-instantiating + the metadata object. + """ + self.info.tools = [dict(t) for t in value] if value else None + write_info(self.info, self.root) + self.info = load_info(self.root) + @property def names(self) -> dict[str, list | dict]: """Names of the various dimensions of vector modalities.""" diff --git a/src/lerobot/datasets/feature_utils.py b/src/lerobot/datasets/feature_utils.py index a1993bc8c..0be315257 100644 --- a/src/lerobot/datasets/feature_utils.py +++ b/src/lerobot/datasets/feature_utils.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from pprint import pformat import datasets @@ -255,7 +256,7 @@ def validate_feature_dtype_and_shape( elif expected_dtype == "string": return validate_feature_string(name, value) elif expected_dtype == "language": - return "" + return validate_feature_language(name, value) else: raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") @@ -335,6 +336,30 @@ def validate_feature_string(name: str, value: str) -> str: return "" +def validate_feature_language(name: str, value) -> str: + """Validate a feature that is expected to hold language annotations. + + Language columns (``language_persistent`` / ``language_events``) are + populated after recording by the annotation pipeline, not at record time. + Any value supplied here is dropped before the frame is written, so a + non-empty value almost certainly signals a mistake. We warn rather than + fail to keep recording resilient. + + Args: + name (str): The name of the feature. + value: The value to validate. + + Returns: + str: Always an empty string — language values are non-fatal. + """ + if value is not None: + logging.warning( + f"The feature '{name}' is a 'language' column populated by the annotation pipeline, " + f"not at record time. The provided value will be dropped." + ) + return "" + + def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None: """Validate the episode buffer before it's written to disk. diff --git a/src/lerobot/datasets/language.py b/src/lerobot/datasets/language.py index 27c26ac06..124c25221 100644 --- a/src/lerobot/datasets/language.py +++ b/src/lerobot/datasets/language.py @@ -79,13 +79,15 @@ def language_persistent_row_arrow_type() -> pa.StructType: Persistent rows carry their own ``timestamp`` because they represent a state that became active at a specific moment and remains active until superseded. + ``timestamp`` is ``float32`` to match the timestamp dtype LeRobotDataset + uses for frame data. """ return pa.struct( [ pa.field("role", pa.string(), nullable=False), pa.field("content", pa.string(), nullable=True), pa.field("style", pa.string(), nullable=True), - pa.field("timestamp", pa.float64(), nullable=False), + pa.field("timestamp", pa.float32(), nullable=False), pa.field("camera", pa.string(), nullable=True), pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True), ] @@ -125,7 +127,7 @@ def language_persistent_row_feature() -> dict[str, object]: "role": datasets.Value("string"), "content": datasets.Value("string"), "style": datasets.Value("string"), - "timestamp": datasets.Value("float64"), + "timestamp": datasets.Value("float32"), "camera": datasets.Value("string"), "tool_calls": datasets.List(_json_feature()), } diff --git a/src/lerobot/datasets/language_render.py b/src/lerobot/datasets/language_render.py index bc99ef0fd..999fa19ad 100644 --- a/src/lerobot/datasets/language_render.py +++ b/src/lerobot/datasets/language_render.py @@ -23,6 +23,7 @@ from collections.abc import Sequence from typing import Any from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe +from lerobot.utils.utils import unwrap_scalar from .language import LANGUAGE_PERSISTENT, column_for_style @@ -67,12 +68,16 @@ def active_at( EMITTED_AT_TOLERANCE_S = 0.1 """Half-window for matching persistent rows to a frame timestamp in -``emitted_at``. Persistent timestamps come from parquet (float64) and ``t`` -is also a float64 from parquet, so in the ideal hot path an exact match +``emitted_at``. Persistent timestamps come from parquet (float32) and ``t`` +is also a float32 from parquet, so in the ideal hot path an exact match would suffice — but any caller that derives ``t`` arithmetically (e.g. ``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers common arithmetic drift without admitting frames that are visibly far -apart at typical control rates (30–100 Hz).""" +apart at typical control rates (30–100 Hz). This does mean two persistent +rows of the same selector emitted within 0.1 s of each other cannot be +told apart by ``emitted_at`` — acceptable because persistent annotations +(subtask / plan / memory transitions) change on a human-action timescale, +not at the camera frame rate.""" def emitted_at( @@ -506,16 +511,13 @@ def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]: bucket and are tiebroken by ``(style, role)``. """ timestamp = row.get("timestamp") - ts = ( - float(timestamp.item() if hasattr(timestamp, "item") else timestamp) if timestamp is not None else 0.0 - ) + ts = float(unwrap_scalar(timestamp)) if timestamp is not None else 0.0 return (ts, row.get("style") or "", row.get("role") or "") def _timestamp(row: LanguageRow) -> float: """Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars).""" - value = row["timestamp"] - return float(value.item() if hasattr(value, "item") else value) + return float(unwrap_scalar(row["timestamp"])) def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool: diff --git a/src/lerobot/processor/render_messages_processor.py b/src/lerobot/processor/render_messages_processor.py index b5bb81f9d..140592f0e 100644 --- a/src/lerobot/processor/render_messages_processor.py +++ b/src/lerobot/processor/render_messages_processor.py @@ -24,6 +24,7 @@ from lerobot.configs.recipe import TrainingRecipe from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT from lerobot.datasets.language_render import render_sample from lerobot.types import EnvTransition, TransitionKey +from lerobot.utils.utils import unwrap_scalar from .pipeline import ProcessorStep, ProcessorStepRegistry @@ -60,8 +61,8 @@ class RenderMessagesStep(ProcessorStep): recipe=self.recipe, persistent=persistent, events=events, - t=_scalar(timestamp), - sample_idx=int(_scalar(sample_idx)), + t=unwrap_scalar(timestamp), + sample_idx=int(unwrap_scalar(sample_idx)), task=complementary_data.get("task"), dataset_ctx=self.dataset_ctx, ) @@ -81,14 +82,3 @@ class RenderMessagesStep(ProcessorStep): ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: """Pass features through unchanged; rendering only touches complementary data.""" return features - - -def _scalar(value: Any) -> float | int: - """Unwrap a tensor/array/single-element list into a Python scalar.""" - if hasattr(value, "item"): - return value.item() - if isinstance(value, list): - if len(value) != 1: - raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}") - return _scalar(value[0]) - return value diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 2574f1fa3..6aad0c503 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -160,6 +160,25 @@ def has_method(cls: object, method_name: str) -> bool: return hasattr(cls, method_name) and callable(getattr(cls, method_name)) +def unwrap_scalar(value: Any) -> Any: + """Unwrap a tensor / numpy scalar / single-element list into a Python scalar. + + Tensors and numpy scalars expose ``.item()``; single-element lists are + unwrapped recursively. Anything else is returned unchanged. Centralized + here so the language renderer and processor steps share one definition. + + Raises: + ValueError: If ``value`` is a list with zero or multiple elements. + """ + if hasattr(value, "item"): + return value.item() + if isinstance(value, list): + if len(value) != 1: + raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}") + return unwrap_scalar(value[0]) + return value + + def is_valid_numpy_dtype_string(dtype_str: str) -> bool: """ Return True if a given string can be converted to a numpy dtype. diff --git a/tests/datasets/test_dataset_metadata.py b/tests/datasets/test_dataset_metadata.py index 7235db4a3..171d8af8b 100644 --- a/tests/datasets/test_dataset_metadata.py +++ b/tests/datasets/test_dataset_metadata.py @@ -466,3 +466,59 @@ def test_tools_round_trip_through_dataset_info(tmp_path): info = DatasetInfo.from_dict(raw) assert info.tools == raw["tools"] assert info.to_dict()["tools"] == raw["tools"] + + +def test_tools_setter_persists_to_info_json_and_reloads(tmp_path): + """Assigning meta.tools writes info.json and reloads meta.info.""" + from lerobot.datasets.io_utils import load_info + + root = tmp_path / "set_tools" + meta = LeRobotDatasetMetadata.create( + repo_id="test/set_tools", + fps=DEFAULT_FPS, + features=SIMPLE_FEATURES, + root=root, + use_videos=False, + ) + + custom_tool = { + "type": "function", + "function": { + "name": "record_observation", + "description": "Capture a still image.", + "parameters": { + "type": "object", + "properties": {"label": {"type": "string"}}, + "required": ["label"], + }, + }, + } + meta.tools = [custom_tool] + + # In-memory metadata reflects the new catalog ... + assert meta.tools == [custom_tool] + assert meta.info.tools == [custom_tool] + # ... and a fresh read from disk agrees. + assert load_info(root).tools == [custom_tool] + + +def test_tools_setter_clears_key_when_set_to_none(tmp_path): + """Setting meta.tools back to None drops the key and restores the default.""" + from lerobot.datasets.language import DEFAULT_TOOLS + + root = tmp_path / "clear_tools" + meta = LeRobotDatasetMetadata.create( + repo_id="test/clear_tools", + fps=DEFAULT_FPS, + features=SIMPLE_FEATURES, + root=root, + use_videos=False, + ) + + meta.tools = [{"type": "function", "function": {"name": "say"}}] + meta.tools = None + + assert meta.tools == DEFAULT_TOOLS + with open(root / INFO_PATH) as f: + info_on_disk = json.load(f) + assert "tools" not in info_on_disk diff --git a/tests/datasets/test_language.py b/tests/datasets/test_language.py index f108f86c9..52c7b3708 100644 --- a/tests/datasets/test_language.py +++ b/tests/datasets/test_language.py @@ -45,6 +45,23 @@ def test_language_arrow_schema_has_expected_fields(): assert isinstance(event_row_type, pa.StructType) assert event_row_type.names == ["role", "content", "style", "camera", "tool_calls"] + # Persistent-row timestamps use float32, matching LeRobotDataset frame timestamps. + assert persistent_row_type.field("timestamp").type == pa.float32() + + +def test_validate_feature_language_warns_only_on_non_empty_value(caplog): + from lerobot.datasets.feature_utils import validate_feature_language + + # None (the expected record-time value) is silent and non-fatal. + with caplog.at_level("WARNING"): + assert validate_feature_language("language_persistent", None) == "" + assert caplog.records == [] + + # A stray non-empty value is dropped later, so we warn rather than fail. + with caplog.at_level("WARNING"): + assert validate_feature_language("language_persistent", [{"role": "user"}]) == "" + assert any("language_persistent" in r.message for r in caplog.records) + def test_style_registry_routes_columns(): assert {"subtask", "plan", "memory", "motion", "task_aug"} == PERSISTENT_STYLES