chore(dataset): basic house-keeping (#3170)

This commit is contained in:
Steven Palma
2026-03-15 22:12:09 -07:00
committed by GitHub
parent 7c2ec31793
commit 9d3b62aa61
9 changed files with 153 additions and 41 deletions
+4 -4
View File
@@ -142,9 +142,9 @@ def test_write_image_image(tmp_path, img_factory):
def test_write_image_exception(tmp_path):
image_array = "invalid data"
fpath = tmp_path / DUMMY_IMAGE
with patch("builtins.print") as mock_print:
with patch("lerobot.datasets.image_writer.logger") as mock_logger:
write_image(image_array, fpath)
mock_print.assert_called()
mock_logger.error.assert_called()
assert not fpath.exists()
@@ -243,10 +243,10 @@ def test_save_image_invalid_data(tmp_path):
image_array = "invalid data"
fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True)
with patch("builtins.print") as mock_print:
with patch("lerobot.datasets.image_writer.logger") as mock_logger:
writer.save_image(image_array, fpath)
writer.wait_until_done()
mock_print.assert_called()
mock_logger.error.assert_called()
assert not fpath.exists()
finally:
writer.stop()
+28
View File
@@ -13,6 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pytest
import torch
from datasets import Dataset
@@ -106,3 +109,28 @@ def test_shuffle():
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_negative_drop_first_frames_raises():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
def test_negative_drop_last_frames_raises():
with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_last_frames=-1)
def test_all_episodes_dropped_raises():
# All episodes have 1 frame, drop_n_first_frames=1 removes all
with pytest.raises(ValueError, match="No valid frames remain"):
EpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
def test_partial_episode_drop_warns(caplog):
# Episode 0: 1 frame (dropped), Episode 1: 5 frames (kept)
with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"):
sampler = EpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1)
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
assert sampler.indices == [2, 3, 4, 5]
assert "Episode 0" in caplog.text