refactor(dataset): split LeRobotDataset into DatasetReader & DatasetWriter (+ API cleanup) (#3180)

* refactor(dataset): split reader and writer

* chore(dataset): remove proxys

* refactor(dataset): better reader & writer encapsulation

* refactor(datasets): clean API + reduce leaky implementations

* refactor(dataset): API cleaning for writer, reader and meta

* refactor(dataset): expose writer & reader + other minor improvements

* refactor(dataset): improve teardown routine

* refactor(dataset): add hf_dataset property at the facade level

* chore(dataset): add init for datasset module

* docs(dataset): add docstrings for public API of the dataset classes

* tests(dataset): add tests for new classes

* fix(dataset): remove circular dependecy
This commit is contained in:
Steven Palma
2026-03-26 19:09:25 +01:00
committed by GitHub
parent 017ff73fbf
commit 123495250b
28 changed files with 2742 additions and 1158 deletions
+52 -118
View File
@@ -32,10 +32,7 @@ from lerobot.datasets.factory import make_dataset
from lerobot.datasets.feature_utils import get_hf_features_from_features, hw_to_dataset_features
from lerobot.datasets.image_writer import image_array_to_pil_image
from lerobot.datasets.io_utils import hf_transform_to_torch
from lerobot.datasets.lerobot_dataset import (
LeRobotDataset,
_encode_video_worker,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.multi_dataset import MultiLeRobotDataset
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
@@ -72,7 +69,7 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
"""
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
objects have the same sets of attributes defined.
objects have the same sets of facade-level attributes defined.
"""
# Instantiate both ways
robot = make_robot_from_config(MockRobotConfig())
@@ -87,6 +84,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
# Facade-level attributes should match between __init__ and create()
init_attr = set(vars(dataset_init).keys())
create_attr = set(vars(dataset_create).keys())
@@ -214,6 +212,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert len(dataset) == 1
assert dataset[0]["task"] == "Dummy task"
@@ -226,6 +225,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["state"].shape == torch.Size([2])
@@ -235,6 +235,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["state"].shape == torch.Size([2, 4])
@@ -244,6 +245,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
@@ -253,6 +255,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
@@ -262,6 +265,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
@@ -271,6 +275,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["state"].ndim == 0
@@ -280,6 +285,7 @@ def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["caption"] == "Dummy caption"
@@ -315,6 +321,7 @@ def test_add_frame_image(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -323,6 +330,7 @@ def test_add_frame_image_h_w_c(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -332,6 +340,7 @@ def test_add_frame_image_uint8(image_dataset):
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": image, "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -341,6 +350,7 @@ def test_add_frame_image_pil(image_dataset):
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
dataset.save_episode()
dataset.finalize()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -361,7 +371,7 @@ def test_tmp_image_deletion(tmp_path, empty_lerobot_dataset_factory):
ds_img = empty_lerobot_dataset_factory(root=tmp_path / "img", features=features_image)
ds_img.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
ds_img.save_episode()
img_dir = ds_img._get_image_file_dir(0, image_key)
img_dir = ds_img.writer._get_image_file_dir(0, image_key)
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
@@ -374,10 +384,10 @@ def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory):
}
ds_vid = empty_lerobot_dataset_factory(root=tmp_path / "vid", features=features_video)
ds_vid.batch_encoding_size = 1
ds_vid.writer._batch_encoding_size = 1
ds_vid.add_frame({vid_key: np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
ds_vid.save_episode()
vid_img_dir = ds_vid._get_image_file_dir(0, vid_key)
vid_img_dir = ds_vid.writer._get_image_file_dir(0, vid_key)
assert not vid_img_dir.exists(), (
"Temporary image directory should be removed when batch_encoding_size == 1"
)
@@ -402,8 +412,8 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
}
)
ds_mixed.save_episode()
img_dir = ds_mixed._get_image_file_dir(0, image_key)
vid_img_dir = ds_mixed._get_image_file_dir(0, vid_key)
img_dir = ds_mixed.writer._get_image_file_dir(0, image_key)
vid_img_dir = ds_mixed.writer._get_image_file_dir(0, vid_key)
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
assert vid_img_dir.exists(), (
"Temporary image directory should not be removed for video features when batch_encoding_size == 2"
@@ -631,29 +641,29 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
)
# Test hf_dataset is None
dataset.hf_dataset = None
assert dataset._check_cached_episodes_sufficient() is False
dataset.reader.hf_dataset = None
assert dataset.reader._check_cached_episodes_sufficient() is False
# Test hf_dataset is empty
import datasets
empty_features = get_hf_features_from_features(dataset.features)
dataset.hf_dataset = datasets.Dataset.from_dict(
dataset.reader.hf_dataset = datasets.Dataset.from_dict(
{key: [] for key in empty_features}, features=empty_features
)
dataset.hf_dataset.set_transform(hf_transform_to_torch)
assert dataset._check_cached_episodes_sufficient() is False
dataset.reader.hf_dataset.set_transform(hf_transform_to_torch)
assert dataset.reader._check_cached_episodes_sufficient() is False
# Restore the original dataset for remaining tests
dataset.hf_dataset = dataset.load_hf_dataset()
dataset.reader.hf_dataset = dataset.reader._load_hf_dataset()
# Test all episodes requested (self.episodes = None) and all are available
dataset.episodes = None
assert dataset._check_cached_episodes_sufficient() is True
dataset.reader.episodes = None
assert dataset.reader._check_cached_episodes_sufficient() is True
# Test specific episodes requested that are all available
dataset.episodes = [0, 2, 4]
assert dataset._check_cached_episodes_sufficient() is True
dataset.reader.episodes = [0, 2, 4]
assert dataset.reader._check_cached_episodes_sufficient() is True
# Test request episodes that don't exist in the cached dataset
# Create a dataset with only episodes 0, 1, 2
@@ -665,8 +675,8 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
)
# Request episodes that include non-existent ones
limited_dataset.episodes = [0, 1, 2, 3, 4]
assert limited_dataset._check_cached_episodes_sufficient() is False
limited_dataset.reader.episodes = [0, 1, 2, 3, 4]
assert limited_dataset.reader._check_cached_episodes_sufficient() is False
# Test create a dataset with sparse episodes (e.g., only episodes 0, 2, 4)
# First create the full dataset structure
@@ -702,22 +712,22 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
filtered_data[key] = filtered_values
sparse_dataset.hf_dataset = datasets.Dataset.from_dict(
sparse_dataset.reader.hf_dataset = datasets.Dataset.from_dict(
filtered_data, features=get_hf_features_from_features(sparse_dataset.features)
)
sparse_dataset.hf_dataset.set_transform(hf_transform_to_torch)
sparse_dataset.reader.hf_dataset.set_transform(hf_transform_to_torch)
# Test requesting all episodes when only some are cached
sparse_dataset.episodes = None
assert sparse_dataset._check_cached_episodes_sufficient() is False
sparse_dataset.reader.episodes = None
assert sparse_dataset.reader._check_cached_episodes_sufficient() is False
# Test requesting only the available episodes
sparse_dataset.episodes = [0, 2, 4]
assert sparse_dataset._check_cached_episodes_sufficient() is True
sparse_dataset.reader.episodes = [0, 2, 4]
assert sparse_dataset.reader._check_cached_episodes_sufficient() is True
# Test requesting a mix of available and unavailable episodes
sparse_dataset.episodes = [0, 1, 2]
assert sparse_dataset._check_cached_episodes_sufficient() is False
sparse_dataset.reader.episodes = [0, 1, 2]
assert sparse_dataset.reader._check_cached_episodes_sufficient() is False
def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory):
@@ -1189,13 +1199,13 @@ def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory):
del dataset_verify
# Phase 3: Resume recording - add more episodes
dataset_resumed = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0")
dataset_resumed = LeRobotDataset.resume(initial_repo_id, root=initial_root, revision="v3.0")
assert dataset_resumed.meta.total_episodes == initial_episodes
assert dataset_resumed.meta.total_frames == initial_episodes * frames_per_episode
assert dataset_resumed.latest_episode is None # Not recording yet
assert dataset_resumed.writer is None
assert dataset_resumed.meta.writer is None
assert dataset_resumed.writer._latest_episode is None # Not recording yet
assert dataset_resumed.writer._pq_writer is None
assert dataset_resumed.meta._pq_writer is None
additional_episodes = 2
for ep_idx in range(initial_episodes, initial_episodes + additional_episodes):
@@ -1271,7 +1281,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
dataset.meta.update_chunk_settings(data_files_size_in_mb=100)
assert dataset._current_file_start_frame is None
assert dataset.writer._current_file_start_frame is None
frames_per_episode = 10
for _ in range(frames_per_episode):
@@ -1284,7 +1294,7 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
)
dataset.save_episode()
assert dataset._current_file_start_frame == 0
assert dataset.writer._current_file_start_frame == 0
assert dataset.meta.total_episodes == 1
assert dataset.meta.total_frames == frames_per_episode
@@ -1298,12 +1308,12 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
)
dataset.save_episode()
assert dataset._current_file_start_frame == 0
assert dataset.writer._current_file_start_frame == 0
assert dataset.meta.total_episodes == 2
assert dataset.meta.total_frames == 2 * frames_per_episode
ep1_chunk = dataset.latest_episode["data/chunk_index"]
ep1_file = dataset.latest_episode["data/file_index"]
ep1_chunk = dataset.writer._latest_episode["data/chunk_index"]
ep1_file = dataset.writer._latest_episode["data/file_index"]
assert ep1_chunk == 0
assert ep1_file == 0
@@ -1317,12 +1327,12 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
)
dataset.save_episode()
assert dataset._current_file_start_frame == 0
assert dataset.writer._current_file_start_frame == 0
assert dataset.meta.total_episodes == 3
assert dataset.meta.total_frames == 3 * frames_per_episode
ep2_chunk = dataset.latest_episode["data/chunk_index"]
ep2_file = dataset.latest_episode["data/file_index"]
ep2_chunk = dataset.writer._latest_episode["data/chunk_index"]
ep2_file = dataset.writer._latest_episode["data/file_index"]
assert ep2_chunk == 0
assert ep2_file == 0
@@ -1354,82 +1364,6 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
assert frame["episode_index"].item() == expected_ep
def test_encode_video_worker_forwards_vcodec(tmp_path):
"""Test that _encode_video_worker correctly forwards the vcodec parameter to encode_video_frames."""
from unittest.mock import patch
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
# Create the expected directory structure
video_key = "observation.images.laptop"
episode_index = 0
frame_index = 0
fpath = DEFAULT_IMAGE_PATH.format(
image_key=video_key, episode_index=episode_index, frame_index=frame_index
)
img_dir = tmp_path / Path(fpath).parent
img_dir.mkdir(parents=True, exist_ok=True)
# Create a dummy image file
dummy_img = Image.new("RGB", (64, 64), color="red")
dummy_img.save(img_dir / "frame-000000.png")
# Track what vcodec was passed to encode_video_frames
captured_kwargs = {}
def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs):
captured_kwargs.update(kwargs)
# Create a dummy output file so the worker doesn't fail
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
Path(video_path).touch()
with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames):
# Test with h264 codec
_encode_video_worker(video_key, episode_index, tmp_path, fps=30, vcodec="h264")
assert "vcodec" in captured_kwargs
assert captured_kwargs["vcodec"] == "h264"
def test_encode_video_worker_default_vcodec(tmp_path):
"""Test that _encode_video_worker uses libsvtav1 as the default codec."""
from unittest.mock import patch
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
# Create the expected directory structure
video_key = "observation.images.laptop"
episode_index = 0
frame_index = 0
fpath = DEFAULT_IMAGE_PATH.format(
image_key=video_key, episode_index=episode_index, frame_index=frame_index
)
img_dir = tmp_path / Path(fpath).parent
img_dir.mkdir(parents=True, exist_ok=True)
# Create a dummy image file
dummy_img = Image.new("RGB", (64, 64), color="red")
dummy_img.save(img_dir / "frame-000000.png")
# Track what vcodec was passed to encode_video_frames
captured_kwargs = {}
def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs):
captured_kwargs.update(kwargs)
# Create a dummy output file so the worker doesn't fail
Path(video_path).parent.mkdir(parents=True, exist_ok=True)
Path(video_path).touch()
with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames):
# Test with default codec (no vcodec specified)
_encode_video_worker(video_key, episode_index, tmp_path, fps=30)
assert "vcodec" in captured_kwargs
assert captured_kwargs["vcodec"] == "libsvtav1"
def test_lerobot_dataset_vcodec_validation():
"""Test that LeRobotDataset validates the vcodec parameter."""
# Test that invalid vcodec raises ValueError