mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 05:59:52 +00:00
refactor(dataset): split LeRobotDataset into DatasetReader & DatasetWriter (+ API cleanup) (#3180)
* refactor(dataset): split reader and writer * chore(dataset): remove proxys * refactor(dataset): better reader & writer encapsulation * refactor(datasets): clean API + reduce leaky implementations * refactor(dataset): API cleaning for writer, reader and meta * refactor(dataset): expose writer & reader + other minor improvements * refactor(dataset): improve teardown routine * refactor(dataset): add hf_dataset property at the facade level * chore(dataset): add init for datasset module * docs(dataset): add docstrings for public API of the dataset classes * tests(dataset): add tests for new classes * fix(dataset): remove circular dependecy
This commit is contained in:
+52
-118
@@ -32,10 +32,7 @@ from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features, hw_to_dataset_features
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.datasets.io_utils import hf_transform_to_torch
|
||||
from lerobot.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
_encode_video_worker,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
@@ -72,7 +69,7 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
"""
|
||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
||||
objects have the same sets of attributes defined.
|
||||
objects have the same sets of facade-level attributes defined.
|
||||
"""
|
||||
# Instantiate both ways
|
||||
robot = make_robot_from_config(MockRobotConfig())
|
||||
@@ -87,6 +84,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
|
||||
|
||||
# Facade-level attributes should match between __init__ and create()
|
||||
init_attr = set(vars(dataset_init).keys())
|
||||
create_attr = set(vars(dataset_create).keys())
|
||||
|
||||
@@ -214,6 +212,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert dataset[0]["task"] == "Dummy task"
|
||||
@@ -226,6 +225,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2])
|
||||
|
||||
@@ -235,6 +235,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4])
|
||||
|
||||
@@ -244,6 +245,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
|
||||
|
||||
@@ -253,6 +255,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
|
||||
|
||||
@@ -262,6 +265,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
|
||||
|
||||
@@ -271,6 +275,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["state"].ndim == 0
|
||||
|
||||
@@ -280,6 +285,7 @@ def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["caption"] == "Dummy caption"
|
||||
|
||||
@@ -315,6 +321,7 @@ def test_add_frame_image(image_dataset):
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
@@ -323,6 +330,7 @@ def test_add_frame_image_h_w_c(image_dataset):
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
@@ -332,6 +340,7 @@ def test_add_frame_image_uint8(image_dataset):
|
||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||
dataset.add_frame({"image": image, "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
@@ -341,6 +350,7 @@ def test_add_frame_image_pil(image_dataset):
|
||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
@@ -361,7 +371,7 @@ def test_tmp_image_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
ds_img = empty_lerobot_dataset_factory(root=tmp_path / "img", features=features_image)
|
||||
ds_img.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
ds_img.save_episode()
|
||||
img_dir = ds_img._get_image_file_dir(0, image_key)
|
||||
img_dir = ds_img.writer._get_image_file_dir(0, image_key)
|
||||
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
|
||||
|
||||
|
||||
@@ -374,10 +384,10 @@ def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
}
|
||||
|
||||
ds_vid = empty_lerobot_dataset_factory(root=tmp_path / "vid", features=features_video)
|
||||
ds_vid.batch_encoding_size = 1
|
||||
ds_vid.writer._batch_encoding_size = 1
|
||||
ds_vid.add_frame({vid_key: np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
ds_vid.save_episode()
|
||||
vid_img_dir = ds_vid._get_image_file_dir(0, vid_key)
|
||||
vid_img_dir = ds_vid.writer._get_image_file_dir(0, vid_key)
|
||||
assert not vid_img_dir.exists(), (
|
||||
"Temporary image directory should be removed when batch_encoding_size == 1"
|
||||
)
|
||||
@@ -402,8 +412,8 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
}
|
||||
)
|
||||
ds_mixed.save_episode()
|
||||
img_dir = ds_mixed._get_image_file_dir(0, image_key)
|
||||
vid_img_dir = ds_mixed._get_image_file_dir(0, vid_key)
|
||||
img_dir = ds_mixed.writer._get_image_file_dir(0, image_key)
|
||||
vid_img_dir = ds_mixed.writer._get_image_file_dir(0, vid_key)
|
||||
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
|
||||
assert vid_img_dir.exists(), (
|
||||
"Temporary image directory should not be removed for video features when batch_encoding_size == 2"
|
||||
@@ -631,29 +641,29 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
|
||||
)
|
||||
|
||||
# Test hf_dataset is None
|
||||
dataset.hf_dataset = None
|
||||
assert dataset._check_cached_episodes_sufficient() is False
|
||||
dataset.reader.hf_dataset = None
|
||||
assert dataset.reader._check_cached_episodes_sufficient() is False
|
||||
|
||||
# Test hf_dataset is empty
|
||||
import datasets
|
||||
|
||||
empty_features = get_hf_features_from_features(dataset.features)
|
||||
dataset.hf_dataset = datasets.Dataset.from_dict(
|
||||
dataset.reader.hf_dataset = datasets.Dataset.from_dict(
|
||||
{key: [] for key in empty_features}, features=empty_features
|
||||
)
|
||||
dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
assert dataset._check_cached_episodes_sufficient() is False
|
||||
dataset.reader.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
assert dataset.reader._check_cached_episodes_sufficient() is False
|
||||
|
||||
# Restore the original dataset for remaining tests
|
||||
dataset.hf_dataset = dataset.load_hf_dataset()
|
||||
dataset.reader.hf_dataset = dataset.reader._load_hf_dataset()
|
||||
|
||||
# Test all episodes requested (self.episodes = None) and all are available
|
||||
dataset.episodes = None
|
||||
assert dataset._check_cached_episodes_sufficient() is True
|
||||
dataset.reader.episodes = None
|
||||
assert dataset.reader._check_cached_episodes_sufficient() is True
|
||||
|
||||
# Test specific episodes requested that are all available
|
||||
dataset.episodes = [0, 2, 4]
|
||||
assert dataset._check_cached_episodes_sufficient() is True
|
||||
dataset.reader.episodes = [0, 2, 4]
|
||||
assert dataset.reader._check_cached_episodes_sufficient() is True
|
||||
|
||||
# Test request episodes that don't exist in the cached dataset
|
||||
# Create a dataset with only episodes 0, 1, 2
|
||||
@@ -665,8 +675,8 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
|
||||
)
|
||||
|
||||
# Request episodes that include non-existent ones
|
||||
limited_dataset.episodes = [0, 1, 2, 3, 4]
|
||||
assert limited_dataset._check_cached_episodes_sufficient() is False
|
||||
limited_dataset.reader.episodes = [0, 1, 2, 3, 4]
|
||||
assert limited_dataset.reader._check_cached_episodes_sufficient() is False
|
||||
|
||||
# Test create a dataset with sparse episodes (e.g., only episodes 0, 2, 4)
|
||||
# First create the full dataset structure
|
||||
@@ -702,22 +712,22 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
|
||||
|
||||
filtered_data[key] = filtered_values
|
||||
|
||||
sparse_dataset.hf_dataset = datasets.Dataset.from_dict(
|
||||
sparse_dataset.reader.hf_dataset = datasets.Dataset.from_dict(
|
||||
filtered_data, features=get_hf_features_from_features(sparse_dataset.features)
|
||||
)
|
||||
sparse_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
sparse_dataset.reader.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
# Test requesting all episodes when only some are cached
|
||||
sparse_dataset.episodes = None
|
||||
assert sparse_dataset._check_cached_episodes_sufficient() is False
|
||||
sparse_dataset.reader.episodes = None
|
||||
assert sparse_dataset.reader._check_cached_episodes_sufficient() is False
|
||||
|
||||
# Test requesting only the available episodes
|
||||
sparse_dataset.episodes = [0, 2, 4]
|
||||
assert sparse_dataset._check_cached_episodes_sufficient() is True
|
||||
sparse_dataset.reader.episodes = [0, 2, 4]
|
||||
assert sparse_dataset.reader._check_cached_episodes_sufficient() is True
|
||||
|
||||
# Test requesting a mix of available and unavailable episodes
|
||||
sparse_dataset.episodes = [0, 1, 2]
|
||||
assert sparse_dataset._check_cached_episodes_sufficient() is False
|
||||
sparse_dataset.reader.episodes = [0, 1, 2]
|
||||
assert sparse_dataset.reader._check_cached_episodes_sufficient() is False
|
||||
|
||||
|
||||
def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory):
|
||||
@@ -1189,13 +1199,13 @@ def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory):
|
||||
del dataset_verify
|
||||
|
||||
# Phase 3: Resume recording - add more episodes
|
||||
dataset_resumed = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0")
|
||||
dataset_resumed = LeRobotDataset.resume(initial_repo_id, root=initial_root, revision="v3.0")
|
||||
|
||||
assert dataset_resumed.meta.total_episodes == initial_episodes
|
||||
assert dataset_resumed.meta.total_frames == initial_episodes * frames_per_episode
|
||||
assert dataset_resumed.latest_episode is None # Not recording yet
|
||||
assert dataset_resumed.writer is None
|
||||
assert dataset_resumed.meta.writer is None
|
||||
assert dataset_resumed.writer._latest_episode is None # Not recording yet
|
||||
assert dataset_resumed.writer._pq_writer is None
|
||||
assert dataset_resumed.meta._pq_writer is None
|
||||
|
||||
additional_episodes = 2
|
||||
for ep_idx in range(initial_episodes, initial_episodes + additional_episodes):
|
||||
@@ -1271,7 +1281,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
dataset.meta.update_chunk_settings(data_files_size_in_mb=100)
|
||||
|
||||
assert dataset._current_file_start_frame is None
|
||||
assert dataset.writer._current_file_start_frame is None
|
||||
|
||||
frames_per_episode = 10
|
||||
for _ in range(frames_per_episode):
|
||||
@@ -1284,7 +1294,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset._current_file_start_frame == 0
|
||||
assert dataset.writer._current_file_start_frame == 0
|
||||
assert dataset.meta.total_episodes == 1
|
||||
assert dataset.meta.total_frames == frames_per_episode
|
||||
|
||||
@@ -1298,12 +1308,12 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset._current_file_start_frame == 0
|
||||
assert dataset.writer._current_file_start_frame == 0
|
||||
assert dataset.meta.total_episodes == 2
|
||||
assert dataset.meta.total_frames == 2 * frames_per_episode
|
||||
|
||||
ep1_chunk = dataset.latest_episode["data/chunk_index"]
|
||||
ep1_file = dataset.latest_episode["data/file_index"]
|
||||
ep1_chunk = dataset.writer._latest_episode["data/chunk_index"]
|
||||
ep1_file = dataset.writer._latest_episode["data/file_index"]
|
||||
assert ep1_chunk == 0
|
||||
assert ep1_file == 0
|
||||
|
||||
@@ -1317,12 +1327,12 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset._current_file_start_frame == 0
|
||||
assert dataset.writer._current_file_start_frame == 0
|
||||
assert dataset.meta.total_episodes == 3
|
||||
assert dataset.meta.total_frames == 3 * frames_per_episode
|
||||
|
||||
ep2_chunk = dataset.latest_episode["data/chunk_index"]
|
||||
ep2_file = dataset.latest_episode["data/file_index"]
|
||||
ep2_chunk = dataset.writer._latest_episode["data/chunk_index"]
|
||||
ep2_file = dataset.writer._latest_episode["data/file_index"]
|
||||
assert ep2_chunk == 0
|
||||
assert ep2_file == 0
|
||||
|
||||
@@ -1354,82 +1364,6 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
||||
assert frame["episode_index"].item() == expected_ep
|
||||
|
||||
|
||||
def test_encode_video_worker_forwards_vcodec(tmp_path):
|
||||
"""Test that _encode_video_worker correctly forwards the vcodec parameter to encode_video_frames."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
|
||||
|
||||
# Create the expected directory structure
|
||||
video_key = "observation.images.laptop"
|
||||
episode_index = 0
|
||||
frame_index = 0
|
||||
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=video_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a dummy image file
|
||||
dummy_img = Image.new("RGB", (64, 64), color="red")
|
||||
dummy_img.save(img_dir / "frame-000000.png")
|
||||
|
||||
# Track what vcodec was passed to encode_video_frames
|
||||
captured_kwargs = {}
|
||||
|
||||
def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
# Create a dummy output file so the worker doesn't fail
|
||||
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(video_path).touch()
|
||||
|
||||
with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames):
|
||||
# Test with h264 codec
|
||||
_encode_video_worker(video_key, episode_index, tmp_path, fps=30, vcodec="h264")
|
||||
|
||||
assert "vcodec" in captured_kwargs
|
||||
assert captured_kwargs["vcodec"] == "h264"
|
||||
|
||||
|
||||
def test_encode_video_worker_default_vcodec(tmp_path):
|
||||
"""Test that _encode_video_worker uses libsvtav1 as the default codec."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
|
||||
|
||||
# Create the expected directory structure
|
||||
video_key = "observation.images.laptop"
|
||||
episode_index = 0
|
||||
frame_index = 0
|
||||
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=video_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a dummy image file
|
||||
dummy_img = Image.new("RGB", (64, 64), color="red")
|
||||
dummy_img.save(img_dir / "frame-000000.png")
|
||||
|
||||
# Track what vcodec was passed to encode_video_frames
|
||||
captured_kwargs = {}
|
||||
|
||||
def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
# Create a dummy output file so the worker doesn't fail
|
||||
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(video_path).touch()
|
||||
|
||||
with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames):
|
||||
# Test with default codec (no vcodec specified)
|
||||
_encode_video_worker(video_key, episode_index, tmp_path, fps=30)
|
||||
|
||||
assert "vcodec" in captured_kwargs
|
||||
assert captured_kwargs["vcodec"] == "libsvtav1"
|
||||
|
||||
|
||||
def test_lerobot_dataset_vcodec_validation():
|
||||
"""Test that LeRobotDataset validates the vcodec parameter."""
|
||||
# Test that invalid vcodec raises ValueError
|
||||
|
||||
Reference in New Issue
Block a user