mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
review: address CarolinePascal feedback
- language timestamps: float64 -> float32 to match LeRobotDataset frame timestamps (Arrow struct + HF feature) - dataset_metadata: hoist `.language` imports to module top — language.py has no lerobot imports, so there is no circular-import risk - dataset_metadata: add a `meta.tools` setter that persists the catalog to info.json and reloads `meta.info` - feature_utils: validate the `language` dtype instead of returning "" — warn (non-fatal) when a non-empty value is written at record time - centralize the scalar-unwrap helper as `lerobot.utils.utils.unwrap_scalar`, shared by render_messages_processor and language_render - docs: move `## Layer 2 — recipe anatomy` ahead of the resolver sections, which describe recipe bindings rather than dataset layout - language_render: note in EMITTED_AT_TOLERANCE_S that persistent rows change on a human-action timescale, not the camera frame rate Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -40,7 +40,7 @@ frame the row sits on already provides it):
|
||||
role: string
|
||||
content: string | null
|
||||
style: string | null
|
||||
timestamp: float64 # persistent rows only
|
||||
timestamp: float32 # persistent rows only
|
||||
camera: string | null # observation.images.* feature key, view-dependent rows only
|
||||
tool_calls: list[Json] | null
|
||||
```
|
||||
@@ -64,6 +64,23 @@ The language stack itself has three internal modules backing layer 1:
|
||||
|
||||
`LeRobotDataset` stays recipe-agnostic. It passes `language_persistent` and `language_events` through when present, and unannotated datasets keep their existing behavior.
|
||||
|
||||
## Layer 2 — recipe anatomy
|
||||
|
||||
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They
|
||||
declare which annotation rows to pull (via `bindings`) and how to compose them
|
||||
into chat turns (`messages`).
|
||||
|
||||
```yaml
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
|
||||
```
|
||||
|
||||
A recipe can also branch into a weighted **blend** of sub-recipes. At sample
|
||||
time, exactly one branch is selected deterministically from the sample index,
|
||||
so different frames train different objectives (e.g. memory updates vs.
|
||||
low-level execution vs. VQA) without any Python wiring.
|
||||
|
||||
### Temporal semantics
|
||||
|
||||
Persistent styles are active after emission until replaced:
|
||||
@@ -112,23 +129,6 @@ ask_vqa_top:
|
||||
|
||||
Add one such sub-recipe per camera the dataset records.
|
||||
|
||||
## Layer 2 — recipe anatomy
|
||||
|
||||
Recipes are YAML files backed by `TrainingRecipe` and `MessageTurn`. They
|
||||
declare which annotation rows to pull (via `bindings`) and how to compose them
|
||||
into chat turns (`messages`).
|
||||
|
||||
```yaml
|
||||
messages:
|
||||
- { role: user, content: "${task}", stream: high_level }
|
||||
- { role: assistant, content: "${subtask}", stream: low_level, target: true }
|
||||
```
|
||||
|
||||
A recipe can also branch into a weighted **blend** of sub-recipes. At sample
|
||||
time, exactly one branch is selected deterministically from the sample index,
|
||||
so different frames train different objectives (e.g. memory updates vs.
|
||||
low-level execution vs. VQA) without any Python wiring.
|
||||
|
||||
## Layer 3 — training format
|
||||
|
||||
Rendered samples use HF-style chat messages plus LeRobot sidecars:
|
||||
|
||||
@@ -39,6 +39,7 @@ from .io_utils import (
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from .language import DEFAULT_TOOLS, LANGUAGE_COLUMNS
|
||||
from .utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
check_version_compatibility,
|
||||
@@ -323,8 +324,6 @@ class LeRobotDatasetMetadata:
|
||||
Used to gate language-aware code paths (collate, render step) so
|
||||
unannotated datasets keep PyTorch's default collate behavior.
|
||||
"""
|
||||
from .language import LANGUAGE_COLUMNS # noqa: PLC0415 (avoid circular import)
|
||||
|
||||
return any(col in self.features for col in LANGUAGE_COLUMNS)
|
||||
|
||||
@property
|
||||
@@ -342,13 +341,25 @@ class LeRobotDatasetMetadata:
|
||||
Implementations live under :mod:`lerobot.tools` (one file per
|
||||
tool); see ``docs/source/tools.mdx`` for the authoring guide.
|
||||
"""
|
||||
from .language import DEFAULT_TOOLS # noqa: PLC0415 (avoid circular import)
|
||||
|
||||
declared = self.info.tools
|
||||
if declared:
|
||||
return [dict(t) for t in declared]
|
||||
return [dict(t) for t in DEFAULT_TOOLS]
|
||||
|
||||
@tools.setter
|
||||
def tools(self, value: list[dict] | None) -> None:
|
||||
"""Persist a tool catalog to ``meta/info.json`` and reload metadata.
|
||||
|
||||
Writes ``value`` into the on-disk ``info.json`` (or clears the
|
||||
``tools`` key when ``value`` is ``None`` or empty), then reloads
|
||||
``self.info`` so the in-memory metadata matches what's on disk.
|
||||
Saves callers from hand-editing ``info.json`` and re-instantiating
|
||||
the metadata object.
|
||||
"""
|
||||
self.info.tools = [dict(t) for t in value] if value else None
|
||||
write_info(self.info, self.root)
|
||||
self.info = load_info(self.root)
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pprint import pformat
|
||||
|
||||
import datasets
|
||||
@@ -255,7 +256,7 @@ def validate_feature_dtype_and_shape(
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
elif expected_dtype == "language":
|
||||
return ""
|
||||
return validate_feature_language(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
@@ -335,6 +336,30 @@ def validate_feature_string(name: str, value: str) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def validate_feature_language(name: str, value) -> str:
|
||||
"""Validate a feature that is expected to hold language annotations.
|
||||
|
||||
Language columns (``language_persistent`` / ``language_events``) are
|
||||
populated after recording by the annotation pipeline, not at record time.
|
||||
Any value supplied here is dropped before the frame is written, so a
|
||||
non-empty value almost certainly signals a mistake. We warn rather than
|
||||
fail to keep recording resilient.
|
||||
|
||||
Args:
|
||||
name (str): The name of the feature.
|
||||
value: The value to validate.
|
||||
|
||||
Returns:
|
||||
str: Always an empty string — language values are non-fatal.
|
||||
"""
|
||||
if value is not None:
|
||||
logging.warning(
|
||||
f"The feature '{name}' is a 'language' column populated by the annotation pipeline, "
|
||||
f"not at record time. The provided value will be dropped."
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None:
|
||||
"""Validate the episode buffer before it's written to disk.
|
||||
|
||||
|
||||
@@ -79,13 +79,15 @@ def language_persistent_row_arrow_type() -> pa.StructType:
|
||||
|
||||
Persistent rows carry their own ``timestamp`` because they represent a state
|
||||
that became active at a specific moment and remains active until superseded.
|
||||
``timestamp`` is ``float32`` to match the timestamp dtype LeRobotDataset
|
||||
uses for frame data.
|
||||
"""
|
||||
return pa.struct(
|
||||
[
|
||||
pa.field("role", pa.string(), nullable=False),
|
||||
pa.field("content", pa.string(), nullable=True),
|
||||
pa.field("style", pa.string(), nullable=True),
|
||||
pa.field("timestamp", pa.float64(), nullable=False),
|
||||
pa.field("timestamp", pa.float32(), nullable=False),
|
||||
pa.field("camera", pa.string(), nullable=True),
|
||||
pa.field("tool_calls", pa.list_(_json_arrow_type()), nullable=True),
|
||||
]
|
||||
@@ -125,7 +127,7 @@ def language_persistent_row_feature() -> dict[str, object]:
|
||||
"role": datasets.Value("string"),
|
||||
"content": datasets.Value("string"),
|
||||
"style": datasets.Value("string"),
|
||||
"timestamp": datasets.Value("float64"),
|
||||
"timestamp": datasets.Value("float32"),
|
||||
"camera": datasets.Value("string"),
|
||||
"tool_calls": datasets.List(_json_feature()),
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.recipe import DEFAULT_BINDINGS, PLACEHOLDER_RE, TrainingRecipe
|
||||
from lerobot.utils.utils import unwrap_scalar
|
||||
|
||||
from .language import LANGUAGE_PERSISTENT, column_for_style
|
||||
|
||||
@@ -67,12 +68,16 @@ def active_at(
|
||||
|
||||
EMITTED_AT_TOLERANCE_S = 0.1
|
||||
"""Half-window for matching persistent rows to a frame timestamp in
|
||||
``emitted_at``. Persistent timestamps come from parquet (float64) and ``t``
|
||||
is also a float64 from parquet, so in the ideal hot path an exact match
|
||||
``emitted_at``. Persistent timestamps come from parquet (float32) and ``t``
|
||||
is also a float32 from parquet, so in the ideal hot path an exact match
|
||||
would suffice — but any caller that derives ``t`` arithmetically (e.g.
|
||||
``frame_idx / fps``) breaks bit-equality. A 0.1 s tolerance covers
|
||||
common arithmetic drift without admitting frames that are visibly far
|
||||
apart at typical control rates (30–100 Hz)."""
|
||||
apart at typical control rates (30–100 Hz). This does mean two persistent
|
||||
rows of the same selector emitted within 0.1 s of each other cannot be
|
||||
told apart by ``emitted_at`` — acceptable because persistent annotations
|
||||
(subtask / plan / memory transitions) change on a human-action timescale,
|
||||
not at the camera frame rate."""
|
||||
|
||||
|
||||
def emitted_at(
|
||||
@@ -506,16 +511,13 @@ def _row_sort_key(row: LanguageRow) -> tuple[float, str, str]:
|
||||
bucket and are tiebroken by ``(style, role)``.
|
||||
"""
|
||||
timestamp = row.get("timestamp")
|
||||
ts = (
|
||||
float(timestamp.item() if hasattr(timestamp, "item") else timestamp) if timestamp is not None else 0.0
|
||||
)
|
||||
ts = float(unwrap_scalar(timestamp)) if timestamp is not None else 0.0
|
||||
return (ts, row.get("style") or "", row.get("role") or "")
|
||||
|
||||
|
||||
def _timestamp(row: LanguageRow) -> float:
|
||||
"""Extract a row's ``timestamp`` as a Python float (unwrapping numpy scalars)."""
|
||||
value = row["timestamp"]
|
||||
return float(value.item() if hasattr(value, "item") else value)
|
||||
return float(unwrap_scalar(row["timestamp"]))
|
||||
|
||||
|
||||
def _row_has_tool_name(row: LanguageRow, tool_name: str) -> bool:
|
||||
|
||||
@@ -24,6 +24,7 @@ from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.datasets.language import LANGUAGE_EVENTS, LANGUAGE_PERSISTENT
|
||||
from lerobot.datasets.language_render import render_sample
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.utils import unwrap_scalar
|
||||
|
||||
from .pipeline import ProcessorStep, ProcessorStepRegistry
|
||||
|
||||
@@ -60,8 +61,8 @@ class RenderMessagesStep(ProcessorStep):
|
||||
recipe=self.recipe,
|
||||
persistent=persistent,
|
||||
events=events,
|
||||
t=_scalar(timestamp),
|
||||
sample_idx=int(_scalar(sample_idx)),
|
||||
t=unwrap_scalar(timestamp),
|
||||
sample_idx=int(unwrap_scalar(sample_idx)),
|
||||
task=complementary_data.get("task"),
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
@@ -81,14 +82,3 @@ class RenderMessagesStep(ProcessorStep):
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Pass features through unchanged; rendering only touches complementary data."""
|
||||
return features
|
||||
|
||||
|
||||
def _scalar(value: Any) -> float | int:
|
||||
"""Unwrap a tensor/array/single-element list into a Python scalar."""
|
||||
if hasattr(value, "item"):
|
||||
return value.item()
|
||||
if isinstance(value, list):
|
||||
if len(value) != 1:
|
||||
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
|
||||
return _scalar(value[0])
|
||||
return value
|
||||
|
||||
@@ -160,6 +160,25 @@ def has_method(cls: object, method_name: str) -> bool:
|
||||
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
|
||||
|
||||
|
||||
def unwrap_scalar(value: Any) -> Any:
|
||||
"""Unwrap a tensor / numpy scalar / single-element list into a Python scalar.
|
||||
|
||||
Tensors and numpy scalars expose ``.item()``; single-element lists are
|
||||
unwrapped recursively. Anything else is returned unchanged. Centralized
|
||||
here so the language renderer and processor steps share one definition.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``value`` is a list with zero or multiple elements.
|
||||
"""
|
||||
if hasattr(value, "item"):
|
||||
return value.item()
|
||||
if isinstance(value, list):
|
||||
if len(value) != 1:
|
||||
raise ValueError(f"Expected a scalar, got list of length {len(value)}: {value!r}")
|
||||
return unwrap_scalar(value[0])
|
||||
return value
|
||||
|
||||
|
||||
def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
|
||||
"""
|
||||
Return True if a given string can be converted to a numpy dtype.
|
||||
|
||||
@@ -466,3 +466,59 @@ def test_tools_round_trip_through_dataset_info(tmp_path):
|
||||
info = DatasetInfo.from_dict(raw)
|
||||
assert info.tools == raw["tools"]
|
||||
assert info.to_dict()["tools"] == raw["tools"]
|
||||
|
||||
|
||||
def test_tools_setter_persists_to_info_json_and_reloads(tmp_path):
|
||||
"""Assigning meta.tools writes info.json and reloads meta.info."""
|
||||
from lerobot.datasets.io_utils import load_info
|
||||
|
||||
root = tmp_path / "set_tools"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/set_tools",
|
||||
fps=DEFAULT_FPS,
|
||||
features=SIMPLE_FEATURES,
|
||||
root=root,
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
custom_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "record_observation",
|
||||
"description": "Capture a still image.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"label": {"type": "string"}},
|
||||
"required": ["label"],
|
||||
},
|
||||
},
|
||||
}
|
||||
meta.tools = [custom_tool]
|
||||
|
||||
# In-memory metadata reflects the new catalog ...
|
||||
assert meta.tools == [custom_tool]
|
||||
assert meta.info.tools == [custom_tool]
|
||||
# ... and a fresh read from disk agrees.
|
||||
assert load_info(root).tools == [custom_tool]
|
||||
|
||||
|
||||
def test_tools_setter_clears_key_when_set_to_none(tmp_path):
|
||||
"""Setting meta.tools back to None drops the key and restores the default."""
|
||||
from lerobot.datasets.language import DEFAULT_TOOLS
|
||||
|
||||
root = tmp_path / "clear_tools"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/clear_tools",
|
||||
fps=DEFAULT_FPS,
|
||||
features=SIMPLE_FEATURES,
|
||||
root=root,
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
meta.tools = [{"type": "function", "function": {"name": "say"}}]
|
||||
meta.tools = None
|
||||
|
||||
assert meta.tools == DEFAULT_TOOLS
|
||||
with open(root / INFO_PATH) as f:
|
||||
info_on_disk = json.load(f)
|
||||
assert "tools" not in info_on_disk
|
||||
|
||||
@@ -45,6 +45,23 @@ def test_language_arrow_schema_has_expected_fields():
|
||||
assert isinstance(event_row_type, pa.StructType)
|
||||
assert event_row_type.names == ["role", "content", "style", "camera", "tool_calls"]
|
||||
|
||||
# Persistent-row timestamps use float32, matching LeRobotDataset frame timestamps.
|
||||
assert persistent_row_type.field("timestamp").type == pa.float32()
|
||||
|
||||
|
||||
def test_validate_feature_language_warns_only_on_non_empty_value(caplog):
|
||||
from lerobot.datasets.feature_utils import validate_feature_language
|
||||
|
||||
# None (the expected record-time value) is silent and non-fatal.
|
||||
with caplog.at_level("WARNING"):
|
||||
assert validate_feature_language("language_persistent", None) == ""
|
||||
assert caplog.records == []
|
||||
|
||||
# A stray non-empty value is dropped later, so we warn rather than fail.
|
||||
with caplog.at_level("WARNING"):
|
||||
assert validate_feature_language("language_persistent", [{"role": "user"}]) == ""
|
||||
assert any("language_persistent" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_style_registry_routes_columns():
|
||||
assert {"subtask", "plan", "memory", "motion", "task_aug"} == PERSISTENT_STYLES
|
||||
|
||||
Reference in New Issue
Block a user