chore(dataset): basic house-keeping (#3170)

This commit is contained in:
Steven Palma
2026-03-15 22:12:09 -07:00
committed by GitHub
parent 7c2ec31793
commit 9d3b62aa61
9 changed files with 153 additions and 41 deletions
+38
View File
@@ -0,0 +1,38 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from lerobot.configs.default import DatasetConfig
def test_dataset_config_valid():
DatasetConfig(repo_id="user/repo", episodes=[0, 1, 2])
def test_dataset_config_negative_episodes():
with pytest.raises(ValueError, match="non-negative"):
DatasetConfig(repo_id="user/repo", episodes=[0, -1, 2])
def test_dataset_config_duplicate_episodes():
with pytest.raises(ValueError, match="duplicates"):
DatasetConfig(repo_id="user/repo", episodes=[0, 1, 1, 2])
def test_dataset_config_none_episodes_ok():
DatasetConfig(repo_id="user/repo", episodes=None)
def test_dataset_config_empty_episodes_ok():
DatasetConfig(repo_id="user/repo", episodes=[])
+4 -4
View File
@@ -142,9 +142,9 @@ def test_write_image_image(tmp_path, img_factory):
def test_write_image_exception(tmp_path):
image_array = "invalid data"
fpath = tmp_path / DUMMY_IMAGE
with patch("builtins.print") as mock_print:
with patch("lerobot.datasets.image_writer.logger") as mock_logger:
write_image(image_array, fpath)
mock_print.assert_called()
mock_logger.error.assert_called()
assert not fpath.exists()
@@ -243,10 +243,10 @@ def test_save_image_invalid_data(tmp_path):
image_array = "invalid data"
fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True)
with patch("builtins.print") as mock_print:
with patch("lerobot.datasets.image_writer.logger") as mock_logger:
writer.save_image(image_array, fpath)
writer.wait_until_done()
mock_print.assert_called()
mock_logger.error.assert_called()
assert not fpath.exists()
finally:
writer.stop()
+28
View File
@@ -13,6 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pytest
import torch
from datasets import Dataset
@@ -106,3 +109,28 @@ def test_shuffle():
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_negative_drop_first_frames_raises():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
def test_negative_drop_last_frames_raises():
with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_last_frames=-1)
def test_all_episodes_dropped_raises():
# All episodes have 1 frame, drop_n_first_frames=1 removes all
with pytest.raises(ValueError, match="No valid frames remain"):
EpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
def test_partial_episode_drop_warns(caplog):
# Episode 0: 1 frame (dropped), Episode 1: 5 frames (kept)
with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"):
sampler = EpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1)
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
assert sampler.indices == [2, 3, 4, 5]
assert "Episode 0" in caplog.text