Merge remote-tracking branch 'origin/main' into feat/language-columns

This commit is contained in:
Pepijn
2026-05-06 12:09:13 +02:00
146 changed files with 9361 additions and 4180 deletions
+3 -3
View File
@@ -113,7 +113,7 @@ def assert_metadata_consistency(aggr_ds, ds_0, ds_1):
"""Test that metadata is correctly aggregated."""
# Test basic info
assert aggr_ds.fps == ds_0.fps == ds_1.fps, "FPS should be the same across all datasets"
assert aggr_ds.meta.info["robot_type"] == ds_0.meta.info["robot_type"] == ds_1.meta.info["robot_type"], (
assert aggr_ds.meta.info.robot_type == ds_0.meta.info.robot_type == ds_1.meta.info.robot_type, (
"Robot type should be the same"
)
@@ -153,8 +153,8 @@ def assert_video_frames_integrity(aggr_ds, ds_0, ds_1):
video_keys = list(
filter(
lambda key: aggr_ds.meta.info["features"][key]["dtype"] == "video",
aggr_ds.meta.info["features"].keys(),
lambda key: aggr_ds.meta.info.features[key]["dtype"] == "video",
aggr_ds.meta.info.features.keys(),
)
)
+1 -1
View File
@@ -161,7 +161,7 @@ def test_init_loads_existing_metadata(tmp_path, lerobot_dataset_metadata_factory
assert meta.total_episodes == 3
assert meta.total_frames == 150
assert meta.fps == info["fps"]
assert meta.fps == info.fps
# ── Property accessors ───────────────────────────────────────────────
+21 -5
View File
@@ -1,7 +1,5 @@
#!/usr/bin/env python
from pathlib import Path
import pytest
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
@@ -370,9 +368,27 @@ def test_resolve_task_explicit_override_beats_rephrasings():
assert rendered["messages"][0]["content"] == "explicit override wins"
def test_canonical_recipe_can_render_low_level_branch():
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
def test_low_level_branch_renders_active_subtask():
low_level = TrainingRecipe(
blend={
"low": TrainingRecipe(
weight=1.0,
messages=[
MessageTurn(
role="user",
content="${task}\nPlan: ${plan}\nMemory: ${memory}",
stream="high_level",
),
MessageTurn(
role="assistant",
content="${subtask}",
stream="low_level",
target=True,
),
],
)
}
)
rendered = render_sample(
recipe=low_level,
+17 -5
View File
@@ -80,18 +80,18 @@ def _write_dataset_tree(
)
tasks = tasks_factory(total_tasks=1)
episodes = episodes_factory(
features=info["features"],
fps=info["fps"],
features=info.features,
fps=info.fps,
total_episodes=1,
total_frames=3,
tasks=tasks,
)
stats = stats_factory(features=info["features"])
stats = stats_factory(features=info.features)
hf_dataset = hf_dataset_factory(
features=info["features"],
features=info.features,
tasks=tasks,
episodes=episodes,
fps=info["fps"],
fps=info.fps,
)
create_info(root, info)
@@ -416,6 +416,18 @@ def test_create_initial_counts_zero(tmp_path):
assert dataset.num_frames == 0
def test_create_propagates_video_files_size_in_mb(tmp_path):
"""video_files_size_in_mb passed to create() is reflected in the dataset metadata."""
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID,
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=tmp_path / "ds",
video_files_size_in_mb=42.0,
)
assert dataset.meta.video_files_size_in_mb == 42.0
def test_add_frame_works_in_write_mode(tmp_path):
"""add_frame() succeeds on a dataset created via create()."""
dataset = LeRobotDataset.create(