mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
chore(dataset): basic house-keeping (#3170)
This commit is contained in:
@@ -0,0 +1,38 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
|
||||
|
||||
def test_dataset_config_valid():
|
||||
DatasetConfig(repo_id="user/repo", episodes=[0, 1, 2])
|
||||
|
||||
|
||||
def test_dataset_config_negative_episodes():
|
||||
with pytest.raises(ValueError, match="non-negative"):
|
||||
DatasetConfig(repo_id="user/repo", episodes=[0, -1, 2])
|
||||
|
||||
|
||||
def test_dataset_config_duplicate_episodes():
|
||||
with pytest.raises(ValueError, match="duplicates"):
|
||||
DatasetConfig(repo_id="user/repo", episodes=[0, 1, 1, 2])
|
||||
|
||||
|
||||
def test_dataset_config_none_episodes_ok():
|
||||
DatasetConfig(repo_id="user/repo", episodes=None)
|
||||
|
||||
|
||||
def test_dataset_config_empty_episodes_ok():
|
||||
DatasetConfig(repo_id="user/repo", episodes=[])
|
||||
@@ -142,9 +142,9 @@ def test_write_image_image(tmp_path, img_factory):
|
||||
def test_write_image_exception(tmp_path):
|
||||
image_array = "invalid data"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
with patch("builtins.print") as mock_print:
|
||||
with patch("lerobot.datasets.image_writer.logger") as mock_logger:
|
||||
write_image(image_array, fpath)
|
||||
mock_print.assert_called()
|
||||
mock_logger.error.assert_called()
|
||||
assert not fpath.exists()
|
||||
|
||||
|
||||
@@ -243,10 +243,10 @@ def test_save_image_invalid_data(tmp_path):
|
||||
image_array = "invalid data"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with patch("builtins.print") as mock_print:
|
||||
with patch("lerobot.datasets.image_writer.logger") as mock_logger:
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
mock_print.assert_called()
|
||||
mock_logger.error.assert_called()
|
||||
assert not fpath.exists()
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
@@ -13,6 +13,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
@@ -106,3 +109,28 @@ def test_shuffle():
|
||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||
assert len(sampler) == 6
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_negative_drop_first_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
|
||||
|
||||
|
||||
def test_negative_drop_last_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_last_frames=-1)
|
||||
|
||||
|
||||
def test_all_episodes_dropped_raises():
|
||||
# All episodes have 1 frame, drop_n_first_frames=1 removes all
|
||||
with pytest.raises(ValueError, match="No valid frames remain"):
|
||||
EpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
|
||||
|
||||
|
||||
def test_partial_episode_drop_warns(caplog):
|
||||
# Episode 0: 1 frame (dropped), Episode 1: 5 frames (kept)
|
||||
with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"):
|
||||
sampler = EpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1)
|
||||
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
|
||||
assert sampler.indices == [2, 3, 4, 5]
|
||||
assert "Episode 0" in caplog.text
|
||||
|
||||
Reference in New Issue
Block a user