mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
refactor(dataset): split LeRobotDataset into DatasetReader & DatasetWriter (+ API cleanup) (#3180)
* refactor(dataset): split reader and writer * chore(dataset): remove proxys * refactor(dataset): better reader & writer encapsulation * refactor(datasets): clean API + reduce leaky implementations * refactor(dataset): API cleaning for writer, reader and meta * refactor(dataset): expose writer & reader + other minor improvements * refactor(dataset): improve teardown routine * refactor(dataset): add hf_dataset property at the facade level * chore(dataset): add init for datasset module * docs(dataset): add docstrings for public API of the dataset classes * tests(dataset): add tests for new classes * fix(dataset): remove circular dependecy
This commit is contained in:
@@ -0,0 +1,385 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contract tests for LeRobotDatasetMetadata."""
|
||||
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import INFO_PATH
|
||||
from tests.fixtures.constants import DEFAULT_FPS, DUMMY_ROBOT_TYPE
|
||||
|
||||
# ── helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
SIMPLE_FEATURES = {
|
||||
"state": {"dtype": "float32", "shape": (6,), "names": None},
|
||||
"action": {"dtype": "float32", "shape": (6,), "names": None},
|
||||
}
|
||||
|
||||
VIDEO_FEATURES = {
|
||||
**SIMPLE_FEATURES,
|
||||
"observation.images.laptop": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
},
|
||||
}
|
||||
|
||||
IMAGE_FEATURES = {
|
||||
**SIMPLE_FEATURES,
|
||||
"observation.images.laptop": {
|
||||
"dtype": "image",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _make_dummy_stats(features: dict) -> dict:
|
||||
"""Create minimal episode stats matching the given features."""
|
||||
stats = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] in ("image", "video"):
|
||||
stats[key] = {
|
||||
"max": np.ones((3, 1, 1), dtype=np.float32),
|
||||
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32),
|
||||
"min": np.zeros((3, 1, 1), dtype=np.float32),
|
||||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32),
|
||||
"count": np.array([5]),
|
||||
}
|
||||
elif ft["dtype"] in ("float32", "float64", "int64"):
|
||||
stats[key] = {
|
||||
"max": np.ones(ft["shape"], dtype=np.float32),
|
||||
"mean": np.full(ft["shape"], 0.5, dtype=np.float32),
|
||||
"min": np.zeros(ft["shape"], dtype=np.float32),
|
||||
"std": np.full(ft["shape"], 0.25, dtype=np.float32),
|
||||
"count": np.array([5]),
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
# ── Construction contracts ───────────────────────────────────────────
|
||||
|
||||
|
||||
def test_create_produces_valid_info_on_disk(tmp_path):
|
||||
"""create() writes info.json and the returned object reflects the provided settings."""
|
||||
root = tmp_path / "new_ds"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/meta",
|
||||
fps=DEFAULT_FPS,
|
||||
features=SIMPLE_FEATURES,
|
||||
robot_type=DUMMY_ROBOT_TYPE,
|
||||
root=root,
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
# info.json was written to disk
|
||||
assert (root / INFO_PATH).exists()
|
||||
with open(root / INFO_PATH) as f:
|
||||
info_on_disk = json.load(f)
|
||||
|
||||
assert meta.fps == DEFAULT_FPS
|
||||
assert meta.robot_type == DUMMY_ROBOT_TYPE
|
||||
assert "state" in meta.features
|
||||
assert "action" in meta.features
|
||||
assert info_on_disk["fps"] == DEFAULT_FPS
|
||||
|
||||
|
||||
def test_create_starts_with_zero_counts(tmp_path):
|
||||
"""A freshly created metadata has zero episode/frame/task counts."""
|
||||
root = tmp_path / "empty_ds"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/empty", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
|
||||
)
|
||||
|
||||
assert meta.total_episodes == 0
|
||||
assert meta.total_frames == 0
|
||||
assert meta.total_tasks == 0
|
||||
assert meta.tasks is None
|
||||
assert meta.episodes is None
|
||||
assert meta.stats is None
|
||||
|
||||
|
||||
def test_create_with_videos_sets_video_path(tmp_path):
|
||||
"""When features include video-dtype keys, create() produces a non-None video_path."""
|
||||
root = tmp_path / "video_ds"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/video", fps=DEFAULT_FPS, features=VIDEO_FEATURES, root=root, use_videos=True
|
||||
)
|
||||
|
||||
assert meta.video_path is not None
|
||||
assert len(meta.video_keys) == 1
|
||||
assert "observation.images.laptop" in meta.video_keys
|
||||
|
||||
|
||||
def test_create_without_videos_has_no_video_path(tmp_path):
|
||||
"""When use_videos=False and no video features, video_path is None."""
|
||||
root = tmp_path / "no_video"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/novid", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
|
||||
)
|
||||
|
||||
assert meta.video_path is None
|
||||
assert meta.video_keys == []
|
||||
|
||||
|
||||
def test_create_raises_on_existing_directory(tmp_path):
|
||||
"""create() raises if root directory already exists."""
|
||||
root = tmp_path / "existing"
|
||||
root.mkdir()
|
||||
|
||||
with pytest.raises(FileExistsError):
|
||||
LeRobotDatasetMetadata.create(
|
||||
repo_id="test/exists", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
|
||||
)
|
||||
|
||||
|
||||
def test_init_loads_existing_metadata(tmp_path, lerobot_dataset_metadata_factory, info_factory):
|
||||
"""When metadata files exist on disk, __init__ loads them correctly."""
|
||||
root = tmp_path / "load_test"
|
||||
info = info_factory(total_episodes=3, total_frames=150, total_tasks=1, use_videos=False)
|
||||
meta = lerobot_dataset_metadata_factory(root=root, info=info)
|
||||
|
||||
assert meta.total_episodes == 3
|
||||
assert meta.total_frames == 150
|
||||
assert meta.fps == info["fps"]
|
||||
|
||||
|
||||
# ── Property accessors ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_property_accessors_reflect_info(tmp_path):
|
||||
"""Properties return values consistent with the info dict."""
|
||||
root = tmp_path / "props_ds"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/props",
|
||||
fps=DEFAULT_FPS,
|
||||
features=IMAGE_FEATURES,
|
||||
robot_type=DUMMY_ROBOT_TYPE,
|
||||
root=root,
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
assert meta.fps == DEFAULT_FPS
|
||||
assert meta.robot_type == DUMMY_ROBOT_TYPE
|
||||
# shapes should be tuples
|
||||
for _key, shape in meta.shapes.items():
|
||||
assert isinstance(shape, tuple)
|
||||
# image_keys should contain the image feature
|
||||
assert "observation.images.laptop" in meta.image_keys
|
||||
# camera_keys is a superset of image_keys and video_keys
|
||||
assert set(meta.image_keys + meta.video_keys) == set(meta.camera_keys)
|
||||
|
||||
|
||||
def test_data_path_is_formattable(tmp_path):
|
||||
"""data_path contains format placeholders that can be .format()-ed."""
|
||||
root = tmp_path / "fmt_ds"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/fmt", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
|
||||
)
|
||||
|
||||
formatted = meta.data_path.format(chunk_index=0, file_index=0)
|
||||
assert "chunk" in formatted.lower() or "0" in formatted
|
||||
|
||||
|
||||
# ── Task management ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_save_episode_tasks_creates_tasks_dataframe(tmp_path):
|
||||
"""On a fresh metadata, save_episode_tasks() creates the tasks DataFrame."""
|
||||
root = tmp_path / "task_ds"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/task", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
|
||||
)
|
||||
assert meta.tasks is None
|
||||
|
||||
meta.save_episode_tasks(["Pick up the cube"])
|
||||
|
||||
assert meta.tasks is not None
|
||||
assert len(meta.tasks) == 1
|
||||
assert "Pick up the cube" in meta.tasks.index
|
||||
|
||||
|
||||
def test_save_episode_tasks_is_additive(tmp_path):
|
||||
"""New tasks are added; existing tasks keep their original index."""
|
||||
root = tmp_path / "additive_ds"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/add", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
|
||||
)
|
||||
|
||||
meta.save_episode_tasks(["Task A"])
|
||||
idx_a = meta.get_task_index("Task A")
|
||||
|
||||
meta.save_episode_tasks(["Task A", "Task B"])
|
||||
assert meta.get_task_index("Task A") == idx_a # unchanged
|
||||
assert meta.get_task_index("Task B") is not None
|
||||
assert len(meta.tasks) == 2
|
||||
|
||||
|
||||
def test_get_task_index_returns_none_for_unknown(tmp_path):
|
||||
"""get_task_index() returns None for an unknown task."""
|
||||
root = tmp_path / "unknown_ds"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/unknown", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
|
||||
)
|
||||
meta.save_episode_tasks(["Known task"])
|
||||
|
||||
assert meta.get_task_index("Known task") == 0
|
||||
assert meta.get_task_index("Unknown task") is None
|
||||
|
||||
|
||||
def test_save_episode_tasks_rejects_duplicates(tmp_path):
|
||||
"""save_episode_tasks() raises ValueError on duplicate task strings."""
|
||||
root = tmp_path / "dup_ds"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/dup", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
meta.save_episode_tasks(["Same task", "Same task"])
|
||||
|
||||
|
||||
# ── Episode saving ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_save_episode_increments_counters(tmp_path):
|
||||
"""After save_episode(), total_episodes and total_frames increase."""
|
||||
root = tmp_path / "ep_ds"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/ep", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
|
||||
)
|
||||
meta.save_episode_tasks(["Task 1"])
|
||||
stats = _make_dummy_stats(meta.features)
|
||||
|
||||
meta.save_episode(
|
||||
episode_index=0,
|
||||
episode_length=10,
|
||||
episode_tasks=["Task 1"],
|
||||
episode_stats=stats,
|
||||
episode_metadata={},
|
||||
)
|
||||
|
||||
assert meta.total_episodes == 1
|
||||
assert meta.total_frames == 10
|
||||
|
||||
|
||||
def test_save_episode_updates_stats(tmp_path):
|
||||
"""After save_episode(), .stats is non-None and has feature keys."""
|
||||
root = tmp_path / "stats_ds"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/stats", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
|
||||
)
|
||||
meta.save_episode_tasks(["Task 1"])
|
||||
stats = _make_dummy_stats(meta.features)
|
||||
|
||||
meta.save_episode(
|
||||
episode_index=0,
|
||||
episode_length=5,
|
||||
episode_tasks=["Task 1"],
|
||||
episode_stats=stats,
|
||||
episode_metadata={},
|
||||
)
|
||||
|
||||
assert meta.stats is not None
|
||||
# Stats should contain at least the user-defined feature keys
|
||||
for key in SIMPLE_FEATURES:
|
||||
assert key in meta.stats
|
||||
|
||||
|
||||
# ── Chunk settings ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_update_chunk_settings_persists(tmp_path):
|
||||
"""update_chunk_settings() changes values and writes info.json."""
|
||||
root = tmp_path / "chunk_ds"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/chunk", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
|
||||
)
|
||||
original = meta.get_chunk_settings()
|
||||
|
||||
meta.update_chunk_settings(chunks_size=500)
|
||||
assert meta.chunks_size == 500
|
||||
assert meta.chunks_size != original["chunks_size"] or original["chunks_size"] == 500
|
||||
|
||||
# Verify persisted
|
||||
with open(root / INFO_PATH) as f:
|
||||
info_on_disk = json.load(f)
|
||||
assert info_on_disk["chunks_size"] == 500
|
||||
|
||||
|
||||
def test_update_chunk_settings_rejects_non_positive(tmp_path):
|
||||
"""update_chunk_settings() raises ValueError for <= 0 values."""
|
||||
root = tmp_path / "bad_chunk"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/bad", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
meta.update_chunk_settings(chunks_size=0)
|
||||
with pytest.raises(ValueError):
|
||||
meta.update_chunk_settings(data_files_size_in_mb=-1)
|
||||
|
||||
|
||||
# ── Finalization ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_finalize_is_idempotent(tmp_path):
|
||||
"""Calling finalize() multiple times does not raise."""
|
||||
root = tmp_path / "fin_ds"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/fin", fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root, use_videos=False
|
||||
)
|
||||
|
||||
meta.finalize()
|
||||
meta.finalize() # second call should not raise
|
||||
|
||||
|
||||
def test_finalize_flushes_buffered_metadata(tmp_path):
|
||||
"""Episodes saved before finalize() are written to parquet."""
|
||||
root = tmp_path / "flush_ds"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/flush",
|
||||
fps=DEFAULT_FPS,
|
||||
features=SIMPLE_FEATURES,
|
||||
root=root,
|
||||
use_videos=False,
|
||||
metadata_buffer_size=100, # large buffer so nothing auto-flushes
|
||||
)
|
||||
meta.save_episode_tasks(["Task 1"])
|
||||
stats = _make_dummy_stats(meta.features)
|
||||
|
||||
# Save a few episodes (won't auto-flush since buffer_size=100)
|
||||
for i in range(3):
|
||||
meta.save_episode(
|
||||
episode_index=i,
|
||||
episode_length=5,
|
||||
episode_tasks=["Task 1"],
|
||||
episode_stats=stats,
|
||||
episode_metadata={},
|
||||
)
|
||||
|
||||
# Before finalize, the parquet might not exist yet
|
||||
meta.finalize()
|
||||
|
||||
# After finalize, episodes parquet should exist
|
||||
episodes_dir = root / "meta" / "episodes"
|
||||
assert episodes_dir.exists()
|
||||
parquet_files = list(episodes_dir.rglob("*.parquet"))
|
||||
assert len(parquet_files) > 0
|
||||
@@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contract tests for DatasetReader."""
|
||||
|
||||
from lerobot.datasets.dataset_reader import DatasetReader
|
||||
from lerobot.datasets.video_utils import get_safe_default_codec
|
||||
|
||||
# ── Loading ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_try_load_returns_true_when_data_exists(tmp_path, lerobot_dataset_factory):
|
||||
"""Given a fully written dataset, try_load() returns True."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False
|
||||
)
|
||||
reader = DatasetReader(
|
||||
meta=dataset.meta,
|
||||
root=dataset.root,
|
||||
episodes=None,
|
||||
tolerance_s=1e-4,
|
||||
video_backend=get_safe_default_codec(),
|
||||
delta_timestamps=None,
|
||||
image_transforms=None,
|
||||
)
|
||||
assert reader.try_load() is True
|
||||
assert reader.hf_dataset is not None
|
||||
|
||||
|
||||
def test_try_load_returns_false_when_no_data(tmp_path):
|
||||
"""When only metadata exists (no data/ parquets), try_load() returns False."""
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
root = tmp_path / "meta_only"
|
||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/meta_only", fps=30, features=features, root=root, use_videos=False
|
||||
)
|
||||
|
||||
reader = DatasetReader(
|
||||
meta=meta,
|
||||
root=meta.root,
|
||||
episodes=None,
|
||||
tolerance_s=1e-4,
|
||||
video_backend=get_safe_default_codec(),
|
||||
delta_timestamps=None,
|
||||
image_transforms=None,
|
||||
)
|
||||
assert reader.try_load() is False
|
||||
assert reader.hf_dataset is None
|
||||
|
||||
|
||||
# ── Counts ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_num_frames_without_filter(tmp_path, lerobot_dataset_factory):
|
||||
"""With episodes=None, num_frames equals total_frames."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=3, total_frames=60, use_videos=False
|
||||
)
|
||||
assert dataset.reader.num_frames == dataset.meta.total_frames
|
||||
|
||||
|
||||
def test_num_episodes_without_filter(tmp_path, lerobot_dataset_factory):
|
||||
"""With episodes=None, num_episodes equals total_episodes."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=3, total_frames=60, use_videos=False
|
||||
)
|
||||
assert dataset.reader.num_episodes == dataset.meta.total_episodes
|
||||
|
||||
|
||||
def test_num_frames_with_episode_filter(tmp_path, lerobot_dataset_factory):
|
||||
"""When filtering to a subset, only those episodes' frames are counted."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=5, total_frames=100, episodes=[0, 2], use_videos=False
|
||||
)
|
||||
# Filtered frames should be less than total
|
||||
assert dataset.reader.num_frames <= dataset.meta.total_frames
|
||||
assert dataset.reader.num_episodes == 2
|
||||
|
||||
|
||||
# ── get_item ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_item_returns_expected_keys(tmp_path, lerobot_dataset_factory):
|
||||
"""get_item(0) returns a dict with expected keys."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False
|
||||
)
|
||||
item = dataset.reader.get_item(0)
|
||||
|
||||
# Standard keys that must always be present
|
||||
for key in ["index", "episode_index", "frame_index", "timestamp", "task_index", "task"]:
|
||||
assert key in item, f"Missing key: {key}"
|
||||
|
||||
|
||||
def test_get_item_values_are_correct(tmp_path, lerobot_dataset_factory):
|
||||
"""get_item() returns correct index and episode_index."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False
|
||||
)
|
||||
item_0 = dataset.reader.get_item(0)
|
||||
|
||||
assert item_0["index"].item() == 0
|
||||
assert item_0["episode_index"].item() == 0
|
||||
|
||||
|
||||
# ── Transforms ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_image_transforms_are_applied(tmp_path, lerobot_dataset_factory):
|
||||
"""When image_transforms is provided, get_item() applies it to camera keys."""
|
||||
transform_called = {"count": 0}
|
||||
|
||||
def sentinel_transform(img):
|
||||
transform_called["count"] += 1
|
||||
return img
|
||||
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds",
|
||||
total_episodes=1,
|
||||
total_frames=5,
|
||||
use_videos=False,
|
||||
image_transforms=sentinel_transform,
|
||||
)
|
||||
item = dataset[0] # noqa: F841
|
||||
|
||||
# Should have been called once per camera key per frame
|
||||
num_cameras = len(dataset.meta.camera_keys)
|
||||
if num_cameras > 0:
|
||||
assert transform_called["count"] >= 1
|
||||
|
||||
|
||||
# ── File paths ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_episodes_file_paths_returns_data_paths(tmp_path, lerobot_dataset_factory):
|
||||
"""get_episodes_file_paths() returns paths including data/ paths."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=False
|
||||
)
|
||||
paths = dataset.reader.get_episodes_file_paths()
|
||||
|
||||
assert len(paths) > 0
|
||||
assert any("data/" in str(p) for p in paths)
|
||||
|
||||
|
||||
def test_get_episodes_file_paths_includes_video_paths(tmp_path, lerobot_dataset_factory):
|
||||
"""When dataset has video keys, file paths include video/ paths."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=2, total_frames=20, use_videos=True
|
||||
)
|
||||
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
paths = dataset.reader.get_episodes_file_paths()
|
||||
assert any("video" in str(p).lower() for p in paths)
|
||||
@@ -0,0 +1,226 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contract tests for DatasetWriter."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.datasets.dataset_writer import _encode_video_worker
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
|
||||
from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID
|
||||
|
||||
SIMPLE_FEATURES = {
|
||||
"state": {"dtype": "float32", "shape": (6,), "names": None},
|
||||
"action": {"dtype": "float32", "shape": (6,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def _make_frame(features: dict, task: str = "Dummy task") -> dict:
|
||||
"""Build a valid frame dict for the given features."""
|
||||
frame = {"task": task}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] in ("image", "video"):
|
||||
frame[key] = np.random.randint(0, 256, size=ft["shape"], dtype=np.uint8)
|
||||
elif ft["dtype"] in ("float32", "float64"):
|
||||
frame[key] = torch.randn(ft["shape"])
|
||||
elif ft["dtype"] == "int64":
|
||||
frame[key] = torch.zeros(ft["shape"], dtype=torch.int64)
|
||||
return frame
|
||||
|
||||
|
||||
# ── Existing encode_video_worker tests ───────────────────────────────
|
||||
|
||||
|
||||
def test_encode_video_worker_forwards_vcodec(tmp_path):
|
||||
"""_encode_video_worker correctly forwards the vcodec parameter."""
|
||||
video_key = "observation.images.laptop"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
Image.new("RGB", (64, 64), color="red").save(img_dir / "frame-000000.png")
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
def mock_encode(imgs_dir, video_path, fps, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(video_path).touch()
|
||||
|
||||
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
|
||||
_encode_video_worker(video_key, 0, tmp_path, fps=30, vcodec="h264")
|
||||
|
||||
assert captured_kwargs["vcodec"] == "h264"
|
||||
|
||||
|
||||
def test_encode_video_worker_default_vcodec(tmp_path):
|
||||
"""_encode_video_worker uses libsvtav1 as the default codec."""
|
||||
video_key = "observation.images.laptop"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
Image.new("RGB", (64, 64), color="red").save(img_dir / "frame-000000.png")
|
||||
|
||||
captured_kwargs = {}
|
||||
|
||||
def mock_encode(imgs_dir, video_path, fps, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(video_path).touch()
|
||||
|
||||
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
|
||||
_encode_video_worker(video_key, 0, tmp_path, fps=30)
|
||||
|
||||
assert captured_kwargs["vcodec"] == "libsvtav1"
|
||||
|
||||
|
||||
# ── add_frame contracts ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_add_frame_increments_buffer_size(tmp_path):
|
||||
"""Each add_frame() call increases episode_buffer['size'] by 1."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
|
||||
)
|
||||
assert dataset.writer.episode_buffer["size"] == 0
|
||||
|
||||
dataset.add_frame(_make_frame(SIMPLE_FEATURES))
|
||||
assert dataset.writer.episode_buffer["size"] == 1
|
||||
|
||||
dataset.add_frame(_make_frame(SIMPLE_FEATURES))
|
||||
assert dataset.writer.episode_buffer["size"] == 2
|
||||
|
||||
|
||||
def test_add_frame_rejects_missing_feature(tmp_path):
|
||||
"""add_frame() raises ValueError when a required feature is missing."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
|
||||
)
|
||||
with pytest.raises(ValueError, match="Missing features"):
|
||||
dataset.add_frame({"task": "Dummy task", "state": torch.randn(6)})
|
||||
# missing 'action'
|
||||
|
||||
|
||||
# ── save_episode contracts ───────────────────────────────────────────
|
||||
|
||||
|
||||
def test_save_episode_writes_parquet(tmp_path):
|
||||
"""After save_episode(), at least one .parquet file exists under data/."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
|
||||
)
|
||||
for _ in range(3):
|
||||
dataset.add_frame(_make_frame(SIMPLE_FEATURES))
|
||||
dataset.save_episode()
|
||||
|
||||
parquet_files = list((tmp_path / "ds" / "data").rglob("*.parquet"))
|
||||
assert len(parquet_files) > 0
|
||||
|
||||
|
||||
def test_save_episode_updates_counters(tmp_path):
|
||||
"""After save_episode(), metadata counters are updated."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
|
||||
)
|
||||
for _ in range(5):
|
||||
dataset.add_frame(_make_frame(SIMPLE_FEATURES))
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset.meta.total_episodes == 1
|
||||
assert dataset.meta.total_frames == 5
|
||||
|
||||
|
||||
def test_save_episode_resets_buffer(tmp_path):
|
||||
"""After save_episode(), the episode buffer is reset."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
|
||||
)
|
||||
for _ in range(3):
|
||||
dataset.add_frame(_make_frame(SIMPLE_FEATURES))
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset.writer.episode_buffer["size"] == 0
|
||||
|
||||
|
||||
def test_save_multiple_episodes(tmp_path):
|
||||
"""Recording 3 episodes results in correct total counts."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
|
||||
)
|
||||
total_frames = 0
|
||||
for ep in range(3):
|
||||
n_frames = ep + 2 # 2, 3, 4
|
||||
for _ in range(n_frames):
|
||||
dataset.add_frame(_make_frame(SIMPLE_FEATURES))
|
||||
dataset.save_episode()
|
||||
total_frames += n_frames
|
||||
|
||||
assert dataset.meta.total_episodes == 3
|
||||
assert dataset.meta.total_frames == total_frames
|
||||
|
||||
|
||||
# ── clear / lifecycle ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_clear_resets_buffer(tmp_path):
|
||||
"""clear_episode_buffer() resets the buffer size to 0."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
|
||||
)
|
||||
dataset.add_frame(_make_frame(SIMPLE_FEATURES))
|
||||
assert dataset.writer.episode_buffer["size"] == 1
|
||||
|
||||
dataset.clear_episode_buffer()
|
||||
assert dataset.writer.episode_buffer["size"] == 0
|
||||
|
||||
|
||||
def test_finalize_is_idempotent(tmp_path):
|
||||
"""Calling finalize() twice does not raise."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
|
||||
)
|
||||
for _ in range(3):
|
||||
dataset.add_frame(_make_frame(SIMPLE_FEATURES))
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.finalize() # second call should not raise
|
||||
|
||||
|
||||
def test_finalize_then_read_roundtrip(tmp_path):
|
||||
"""Write data, finalize, re-open, and verify data matches."""
|
||||
root = tmp_path / "roundtrip"
|
||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=features, root=root)
|
||||
|
||||
# Record known values
|
||||
known_states = []
|
||||
for i in range(5):
|
||||
state = torch.tensor([float(i), float(i * 10)])
|
||||
known_states.append(state)
|
||||
dataset.add_frame({"task": "Test task", "state": state})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
# Read back
|
||||
for i in range(5):
|
||||
item = dataset[i]
|
||||
assert torch.allclose(item["state"], known_states[i], atol=1e-5)
|
||||
+52
-118
@@ -32,10 +32,7 @@ from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features, hw_to_dataset_features
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.datasets.io_utils import hf_transform_to_torch
|
||||
from lerobot.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
_encode_video_worker,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
@@ -72,7 +69,7 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
"""
|
||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
||||
objects have the same sets of attributes defined.
|
||||
objects have the same sets of facade-level attributes defined.
|
||||
"""
|
||||
# Instantiate both ways
|
||||
robot = make_robot_from_config(MockRobotConfig())
|
||||
@@ -87,6 +84,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
|
||||
|
||||
# Facade-level attributes should match between __init__ and create()
|
||||
init_attr = set(vars(dataset_init).keys())
|
||||
create_attr = set(vars(dataset_create).keys())
|
||||
|
||||
@@ -214,6 +212,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert dataset[0]["task"] == "Dummy task"
|
||||
@@ -226,6 +225,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2])
|
||||
|
||||
@@ -235,6 +235,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4])
|
||||
|
||||
@@ -244,6 +245,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
|
||||
|
||||
@@ -253,6 +255,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
|
||||
|
||||
@@ -262,6 +265,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
|
||||
|
||||
@@ -271,6 +275,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["state"].ndim == 0
|
||||
|
||||
@@ -280,6 +285,7 @@ def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["caption"] == "Dummy caption"
|
||||
|
||||
@@ -315,6 +321,7 @@ def test_add_frame_image(image_dataset):
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
@@ -323,6 +330,7 @@ def test_add_frame_image_h_w_c(image_dataset):
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
@@ -332,6 +340,7 @@ def test_add_frame_image_uint8(image_dataset):
|
||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||
dataset.add_frame({"image": image, "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
@@ -341,6 +350,7 @@ def test_add_frame_image_pil(image_dataset):
|
||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
@@ -361,7 +371,7 @@ def test_tmp_image_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
ds_img = empty_lerobot_dataset_factory(root=tmp_path / "img", features=features_image)
|
||||
ds_img.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
ds_img.save_episode()
|
||||
img_dir = ds_img._get_image_file_dir(0, image_key)
|
||||
img_dir = ds_img.writer._get_image_file_dir(0, image_key)
|
||||
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
|
||||
|
||||
|
||||
@@ -374,10 +384,10 @@ def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
}
|
||||
|
||||
ds_vid = empty_lerobot_dataset_factory(root=tmp_path / "vid", features=features_video)
|
||||
ds_vid.batch_encoding_size = 1
|
||||
ds_vid.writer._batch_encoding_size = 1
|
||||
ds_vid.add_frame({vid_key: np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
ds_vid.save_episode()
|
||||
vid_img_dir = ds_vid._get_image_file_dir(0, vid_key)
|
||||
vid_img_dir = ds_vid.writer._get_image_file_dir(0, vid_key)
|
||||
assert not vid_img_dir.exists(), (
|
||||
"Temporary image directory should be removed when batch_encoding_size == 1"
|
||||
)
|
||||
@@ -402,8 +412,8 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
}
|
||||
)
|
||||
ds_mixed.save_episode()
|
||||
img_dir = ds_mixed._get_image_file_dir(0, image_key)
|
||||
vid_img_dir = ds_mixed._get_image_file_dir(0, vid_key)
|
||||
img_dir = ds_mixed.writer._get_image_file_dir(0, image_key)
|
||||
vid_img_dir = ds_mixed.writer._get_image_file_dir(0, vid_key)
|
||||
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
|
||||
assert vid_img_dir.exists(), (
|
||||
"Temporary image directory should not be removed for video features when batch_encoding_size == 2"
|
||||
@@ -631,29 +641,29 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
|
||||
)
|
||||
|
||||
# Test hf_dataset is None
|
||||
dataset.hf_dataset = None
|
||||
assert dataset._check_cached_episodes_sufficient() is False
|
||||
dataset.reader.hf_dataset = None
|
||||
assert dataset.reader._check_cached_episodes_sufficient() is False
|
||||
|
||||
# Test hf_dataset is empty
|
||||
import datasets
|
||||
|
||||
empty_features = get_hf_features_from_features(dataset.features)
|
||||
dataset.hf_dataset = datasets.Dataset.from_dict(
|
||||
dataset.reader.hf_dataset = datasets.Dataset.from_dict(
|
||||
{key: [] for key in empty_features}, features=empty_features
|
||||
)
|
||||
dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
assert dataset._check_cached_episodes_sufficient() is False
|
||||
dataset.reader.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
assert dataset.reader._check_cached_episodes_sufficient() is False
|
||||
|
||||
# Restore the original dataset for remaining tests
|
||||
dataset.hf_dataset = dataset.load_hf_dataset()
|
||||
dataset.reader.hf_dataset = dataset.reader._load_hf_dataset()
|
||||
|
||||
# Test all episodes requested (self.episodes = None) and all are available
|
||||
dataset.episodes = None
|
||||
assert dataset._check_cached_episodes_sufficient() is True
|
||||
dataset.reader.episodes = None
|
||||
assert dataset.reader._check_cached_episodes_sufficient() is True
|
||||
|
||||
# Test specific episodes requested that are all available
|
||||
dataset.episodes = [0, 2, 4]
|
||||
assert dataset._check_cached_episodes_sufficient() is True
|
||||
dataset.reader.episodes = [0, 2, 4]
|
||||
assert dataset.reader._check_cached_episodes_sufficient() is True
|
||||
|
||||
# Test request episodes that don't exist in the cached dataset
|
||||
# Create a dataset with only episodes 0, 1, 2
|
||||
@@ -665,8 +675,8 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
|
||||
)
|
||||
|
||||
# Request episodes that include non-existent ones
|
||||
limited_dataset.episodes = [0, 1, 2, 3, 4]
|
||||
assert limited_dataset._check_cached_episodes_sufficient() is False
|
||||
limited_dataset.reader.episodes = [0, 1, 2, 3, 4]
|
||||
assert limited_dataset.reader._check_cached_episodes_sufficient() is False
|
||||
|
||||
# Test create a dataset with sparse episodes (e.g., only episodes 0, 2, 4)
|
||||
# First create the full dataset structure
|
||||
@@ -702,22 +712,22 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
|
||||
|
||||
filtered_data[key] = filtered_values
|
||||
|
||||
sparse_dataset.hf_dataset = datasets.Dataset.from_dict(
|
||||
sparse_dataset.reader.hf_dataset = datasets.Dataset.from_dict(
|
||||
filtered_data, features=get_hf_features_from_features(sparse_dataset.features)
|
||||
)
|
||||
sparse_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
sparse_dataset.reader.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
# Test requesting all episodes when only some are cached
|
||||
sparse_dataset.episodes = None
|
||||
assert sparse_dataset._check_cached_episodes_sufficient() is False
|
||||
sparse_dataset.reader.episodes = None
|
||||
assert sparse_dataset.reader._check_cached_episodes_sufficient() is False
|
||||
|
||||
# Test requesting only the available episodes
|
||||
sparse_dataset.episodes = [0, 2, 4]
|
||||
assert sparse_dataset._check_cached_episodes_sufficient() is True
|
||||
sparse_dataset.reader.episodes = [0, 2, 4]
|
||||
assert sparse_dataset.reader._check_cached_episodes_sufficient() is True
|
||||
|
||||
# Test requesting a mix of available and unavailable episodes
|
||||
sparse_dataset.episodes = [0, 1, 2]
|
||||
assert sparse_dataset._check_cached_episodes_sufficient() is False
|
||||
sparse_dataset.reader.episodes = [0, 1, 2]
|
||||
assert sparse_dataset.reader._check_cached_episodes_sufficient() is False
|
||||
|
||||
|
||||
def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory):
|
||||
@@ -1189,13 +1199,13 @@ def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory):
|
||||
del dataset_verify
|
||||
|
||||
# Phase 3: Resume recording - add more episodes
|
||||
dataset_resumed = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0")
|
||||
dataset_resumed = LeRobotDataset.resume(initial_repo_id, root=initial_root, revision="v3.0")
|
||||
|
||||
assert dataset_resumed.meta.total_episodes == initial_episodes
|
||||
assert dataset_resumed.meta.total_frames == initial_episodes * frames_per_episode
|
||||
assert dataset_resumed.latest_episode is None # Not recording yet
|
||||
assert dataset_resumed.writer is None
|
||||
assert dataset_resumed.meta.writer is None
|
||||
assert dataset_resumed.writer._latest_episode is None # Not recording yet
|
||||
assert dataset_resumed.writer._pq_writer is None
|
||||
assert dataset_resumed.meta._pq_writer is None
|
||||
|
||||
additional_episodes = 2
|
||||
for ep_idx in range(initial_episodes, initial_episodes + additional_episodes):
|
||||
@@ -1271,7 +1281,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
dataset.meta.update_chunk_settings(data_files_size_in_mb=100)
|
||||
|
||||
assert dataset._current_file_start_frame is None
|
||||
assert dataset.writer._current_file_start_frame is None
|
||||
|
||||
frames_per_episode = 10
|
||||
for _ in range(frames_per_episode):
|
||||
@@ -1284,7 +1294,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset._current_file_start_frame == 0
|
||||
assert dataset.writer._current_file_start_frame == 0
|
||||
assert dataset.meta.total_episodes == 1
|
||||
assert dataset.meta.total_frames == frames_per_episode
|
||||
|
||||
@@ -1298,12 +1308,12 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset._current_file_start_frame == 0
|
||||
assert dataset.writer._current_file_start_frame == 0
|
||||
assert dataset.meta.total_episodes == 2
|
||||
assert dataset.meta.total_frames == 2 * frames_per_episode
|
||||
|
||||
ep1_chunk = dataset.latest_episode["data/chunk_index"]
|
||||
ep1_file = dataset.latest_episode["data/file_index"]
|
||||
ep1_chunk = dataset.writer._latest_episode["data/chunk_index"]
|
||||
ep1_file = dataset.writer._latest_episode["data/file_index"]
|
||||
assert ep1_chunk == 0
|
||||
assert ep1_file == 0
|
||||
|
||||
@@ -1317,12 +1327,12 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset._current_file_start_frame == 0
|
||||
assert dataset.writer._current_file_start_frame == 0
|
||||
assert dataset.meta.total_episodes == 3
|
||||
assert dataset.meta.total_frames == 3 * frames_per_episode
|
||||
|
||||
ep2_chunk = dataset.latest_episode["data/chunk_index"]
|
||||
ep2_file = dataset.latest_episode["data/file_index"]
|
||||
ep2_chunk = dataset.writer._latest_episode["data/chunk_index"]
|
||||
ep2_file = dataset.writer._latest_episode["data/file_index"]
|
||||
assert ep2_chunk == 0
|
||||
assert ep2_file == 0
|
||||
|
||||
@@ -1354,82 +1364,6 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
||||
assert frame["episode_index"].item() == expected_ep
|
||||
|
||||
|
||||
def test_encode_video_worker_forwards_vcodec(tmp_path):
|
||||
"""Test that _encode_video_worker correctly forwards the vcodec parameter to encode_video_frames."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
|
||||
|
||||
# Create the expected directory structure
|
||||
video_key = "observation.images.laptop"
|
||||
episode_index = 0
|
||||
frame_index = 0
|
||||
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=video_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a dummy image file
|
||||
dummy_img = Image.new("RGB", (64, 64), color="red")
|
||||
dummy_img.save(img_dir / "frame-000000.png")
|
||||
|
||||
# Track what vcodec was passed to encode_video_frames
|
||||
captured_kwargs = {}
|
||||
|
||||
def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
# Create a dummy output file so the worker doesn't fail
|
||||
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(video_path).touch()
|
||||
|
||||
with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames):
|
||||
# Test with h264 codec
|
||||
_encode_video_worker(video_key, episode_index, tmp_path, fps=30, vcodec="h264")
|
||||
|
||||
assert "vcodec" in captured_kwargs
|
||||
assert captured_kwargs["vcodec"] == "h264"
|
||||
|
||||
|
||||
def test_encode_video_worker_default_vcodec(tmp_path):
|
||||
"""Test that _encode_video_worker uses libsvtav1 as the default codec."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
|
||||
|
||||
# Create the expected directory structure
|
||||
video_key = "observation.images.laptop"
|
||||
episode_index = 0
|
||||
frame_index = 0
|
||||
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=video_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a dummy image file
|
||||
dummy_img = Image.new("RGB", (64, 64), color="red")
|
||||
dummy_img.save(img_dir / "frame-000000.png")
|
||||
|
||||
# Track what vcodec was passed to encode_video_frames
|
||||
captured_kwargs = {}
|
||||
|
||||
def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
# Create a dummy output file so the worker doesn't fail
|
||||
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(video_path).touch()
|
||||
|
||||
with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames):
|
||||
# Test with default codec (no vcodec specified)
|
||||
_encode_video_worker(video_key, episode_index, tmp_path, fps=30)
|
||||
|
||||
assert "vcodec" in captured_kwargs
|
||||
assert captured_kwargs["vcodec"] == "libsvtav1"
|
||||
|
||||
|
||||
def test_lerobot_dataset_vcodec_validation():
|
||||
"""Test that LeRobotDataset validates the vcodec parameter."""
|
||||
# Test that invalid vcodec raises ValueError
|
||||
|
||||
@@ -352,10 +352,14 @@ def test_with_different_image_formats(tmp_path, img_array_factory):
|
||||
|
||||
|
||||
def test_safe_stop_image_writer_decorator():
|
||||
class MockDataset:
|
||||
class MockWriter:
|
||||
def __init__(self):
|
||||
self.image_writer = MagicMock(spec=AsyncImageWriter)
|
||||
|
||||
class MockDataset:
|
||||
def __init__(self):
|
||||
self.writer = MockWriter()
|
||||
|
||||
@safe_stop_image_writer
|
||||
def function_that_raises_exception(dataset=None):
|
||||
raise Exception("Test exception")
|
||||
@@ -366,7 +370,7 @@ def test_safe_stop_image_writer_decorator():
|
||||
function_that_raises_exception(dataset=dataset)
|
||||
|
||||
assert str(exc_info.value) == "Test exception"
|
||||
dataset.image_writer.stop.assert_called_once()
|
||||
dataset.writer.image_writer.stop.assert_called_once()
|
||||
|
||||
|
||||
def test_main_process_time(tmp_path, img_tensor_factory):
|
||||
|
||||
@@ -0,0 +1,314 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contract tests for the LeRobotDataset facade.
|
||||
|
||||
Tests focus on mode contracts (read-only, write-only, resume), guards,
|
||||
property delegation, and the full create-record-finalize-read lifecycle.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.dataset_reader import DatasetReader
|
||||
from lerobot.datasets.dataset_writer import DatasetWriter
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID
|
||||
|
||||
SIMPLE_FEATURES = {
|
||||
"state": {"dtype": "float32", "shape": (2,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def _make_frame(task: str = "Dummy task") -> dict:
|
||||
return {"task": task, "state": torch.randn(2)}
|
||||
|
||||
|
||||
# ── Read-only mode (via __init__) ────────────────────────────────────
|
||||
|
||||
|
||||
def test_init_creates_reader_no_writer(tmp_path, lerobot_dataset_factory):
|
||||
"""__init__() sets reader to a DatasetReader and writer to None."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False
|
||||
)
|
||||
assert isinstance(dataset.reader, DatasetReader)
|
||||
assert dataset.writer is None
|
||||
|
||||
|
||||
def test_init_loads_data(tmp_path, lerobot_dataset_factory):
|
||||
"""After __init__(), the dataset has data and len > 0."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False
|
||||
)
|
||||
assert len(dataset) > 0
|
||||
|
||||
|
||||
def test_getitem_works_in_read_mode(tmp_path, lerobot_dataset_factory):
|
||||
"""dataset[0] returns a dict with expected keys in read-only mode."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=1, total_frames=10, use_videos=False
|
||||
)
|
||||
item = dataset[0]
|
||||
assert isinstance(item, dict)
|
||||
assert "index" in item
|
||||
assert "task" in item
|
||||
|
||||
|
||||
def test_len_matches_num_frames(tmp_path, lerobot_dataset_factory):
|
||||
"""len(dataset) equals dataset.num_frames."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=2, total_frames=30, use_videos=False
|
||||
)
|
||||
assert len(dataset) == dataset.num_frames
|
||||
|
||||
|
||||
# ── Write-only mode (via create()) ──────────────────────────────────
|
||||
|
||||
|
||||
def test_create_sets_writer_no_reader(tmp_path):
|
||||
"""create() sets writer to a DatasetWriter and reader to None."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
|
||||
)
|
||||
assert isinstance(dataset.writer, DatasetWriter)
|
||||
assert dataset.reader is None
|
||||
|
||||
|
||||
def test_create_initial_counts_zero(tmp_path):
|
||||
"""After create(), num_episodes == 0 and num_frames == 0."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
|
||||
)
|
||||
assert dataset.num_episodes == 0
|
||||
assert dataset.num_frames == 0
|
||||
|
||||
|
||||
def test_add_frame_works_in_write_mode(tmp_path):
|
||||
"""add_frame() succeeds on a dataset created via create()."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
|
||||
)
|
||||
dataset.add_frame(_make_frame()) # should not raise
|
||||
|
||||
|
||||
# ── Resume mode ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_resume_creates_writer(tmp_path):
|
||||
"""After resume(), writer is a DatasetWriter."""
|
||||
root = tmp_path / "resume_ds"
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root
|
||||
)
|
||||
for _ in range(3):
|
||||
dataset.add_frame(_make_frame())
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
resumed = LeRobotDataset.resume(repo_id=DUMMY_REPO_ID, root=root)
|
||||
assert isinstance(resumed.writer, DatasetWriter)
|
||||
|
||||
|
||||
def test_resume_preserves_episode_count(tmp_path):
|
||||
"""After resume(), existing episodes are counted."""
|
||||
root = tmp_path / "resume_ds"
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root
|
||||
)
|
||||
for _ in range(3):
|
||||
dataset.add_frame(_make_frame())
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
resumed = LeRobotDataset.resume(repo_id=DUMMY_REPO_ID, root=root)
|
||||
assert resumed.meta.total_episodes == 1
|
||||
|
||||
|
||||
def test_resume_can_add_more_episodes(tmp_path):
|
||||
"""After resume(), new episodes can be added."""
|
||||
root = tmp_path / "resume_ds"
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root
|
||||
)
|
||||
for _ in range(3):
|
||||
dataset.add_frame(_make_frame())
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
resumed = LeRobotDataset.resume(repo_id=DUMMY_REPO_ID, root=root)
|
||||
for _ in range(2):
|
||||
resumed.add_frame(_make_frame())
|
||||
resumed.save_episode()
|
||||
|
||||
assert resumed.meta.total_episodes == 2
|
||||
|
||||
|
||||
# ── Writer guard ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_add_frame_raises_without_writer(tmp_path, lerobot_dataset_factory):
|
||||
"""add_frame() raises RuntimeError on a read-only dataset."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="read-only"):
|
||||
dataset.add_frame(_make_frame())
|
||||
|
||||
|
||||
def test_save_episode_raises_without_writer(tmp_path, lerobot_dataset_factory):
|
||||
"""save_episode() raises RuntimeError on a read-only dataset."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="read-only"):
|
||||
dataset.save_episode()
|
||||
|
||||
|
||||
def test_clear_episode_buffer_raises_without_writer(tmp_path, lerobot_dataset_factory):
|
||||
"""clear_episode_buffer() raises RuntimeError on a read-only dataset."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="read-only"):
|
||||
dataset.clear_episode_buffer()
|
||||
|
||||
|
||||
# ── Reader guard ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_getitem_raises_before_finalize(tmp_path):
|
||||
"""dataset[0] raises RuntimeError while recording (before finalize)."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
|
||||
)
|
||||
for _ in range(3):
|
||||
dataset.add_frame(_make_frame())
|
||||
dataset.save_episode()
|
||||
|
||||
with pytest.raises(RuntimeError, match="finalize"):
|
||||
dataset[0]
|
||||
|
||||
|
||||
def test_getitem_works_after_finalize(tmp_path):
|
||||
"""After finalize(), dataset[0] returns data."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
|
||||
)
|
||||
for _ in range(3):
|
||||
dataset.add_frame(_make_frame())
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
item = dataset[0]
|
||||
assert "state" in item
|
||||
assert "task" in item
|
||||
|
||||
|
||||
# ── Property delegation ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_fps_delegates_to_meta(tmp_path, lerobot_dataset_factory):
|
||||
"""dataset.fps == dataset.meta.fps."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False
|
||||
)
|
||||
assert dataset.fps == dataset.meta.fps
|
||||
|
||||
|
||||
def test_features_delegates_to_meta(tmp_path, lerobot_dataset_factory):
|
||||
"""dataset.features is dataset.meta.features."""
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", total_episodes=1, total_frames=5, use_videos=False
|
||||
)
|
||||
assert dataset.features is dataset.meta.features
|
||||
|
||||
|
||||
def test_num_frames_uses_meta_in_write_mode(tmp_path):
|
||||
"""In write-only mode (reader=None), num_frames comes from metadata."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
|
||||
)
|
||||
assert dataset.reader is None
|
||||
assert dataset.num_frames == dataset.meta.total_frames
|
||||
|
||||
|
||||
# ── Lifecycle ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_finalize_is_idempotent(tmp_path):
|
||||
"""Calling finalize() twice does not raise."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
|
||||
)
|
||||
dataset.finalize()
|
||||
dataset.finalize()
|
||||
|
||||
|
||||
def test_has_pending_frames_lifecycle(tmp_path):
|
||||
"""has_pending_frames: False -> True (add_frame) -> False (save_episode)."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=tmp_path / "ds"
|
||||
)
|
||||
assert dataset.has_pending_frames() is False
|
||||
|
||||
dataset.add_frame(_make_frame())
|
||||
assert dataset.has_pending_frames() is True
|
||||
|
||||
dataset.save_episode()
|
||||
assert dataset.has_pending_frames() is False
|
||||
|
||||
|
||||
def test_create_record_finalize_read_roundtrip(tmp_path):
|
||||
"""End-to-end: create, record 2 episodes, finalize, re-open, verify data."""
|
||||
root = tmp_path / "roundtrip"
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS, features=SIMPLE_FEATURES, root=root
|
||||
)
|
||||
|
||||
# Episode 0: 3 frames with known values
|
||||
ep0_states = []
|
||||
for i in range(3):
|
||||
state = torch.tensor([float(i), float(i * 2)])
|
||||
ep0_states.append(state)
|
||||
dataset.add_frame({"task": "Task A", "state": state})
|
||||
dataset.save_episode()
|
||||
|
||||
# Episode 1: 2 frames
|
||||
ep1_states = []
|
||||
for i in range(2):
|
||||
state = torch.tensor([float(i + 100), float(i + 200)])
|
||||
ep1_states.append(state)
|
||||
dataset.add_frame({"task": "Task B", "state": state})
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
# Re-open as read-only
|
||||
reopened = LeRobotDataset(repo_id=DUMMY_REPO_ID, root=root)
|
||||
assert len(reopened) == 5
|
||||
assert reopened.num_episodes == 2
|
||||
|
||||
# Verify episode 0
|
||||
for i in range(3):
|
||||
item = reopened[i]
|
||||
assert torch.allclose(item["state"], ep0_states[i], atol=1e-5)
|
||||
assert item["episode_index"].item() == 0
|
||||
|
||||
# Verify episode 1
|
||||
for i in range(2):
|
||||
item = reopened[3 + i]
|
||||
assert torch.allclose(item["state"], ep1_states[i], atol=1e-5)
|
||||
assert item["episode_index"].item() == 1
|
||||
@@ -534,7 +534,7 @@ class TestStreamingEncoderIntegration:
|
||||
streaming_encoding=True,
|
||||
)
|
||||
|
||||
assert dataset._streaming_encoder is not None
|
||||
assert dataset.writer._streaming_encoder is not None
|
||||
|
||||
num_frames = 20
|
||||
for _ in range(num_frames):
|
||||
@@ -580,7 +580,7 @@ class TestStreamingEncoderIntegration:
|
||||
streaming_encoding=False,
|
||||
)
|
||||
|
||||
assert dataset._streaming_encoder is None
|
||||
assert dataset.writer._streaming_encoder is None
|
||||
|
||||
num_frames = 5
|
||||
for _ in range(num_frames):
|
||||
|
||||
Reference in New Issue
Block a user