chore(dataset): basic house-keeping

This commit is contained in:
Steven Palma
2026-03-15 21:26:58 -07:00
parent 7c2ec31793
commit c7458c67cd
9 changed files with 153 additions and 41 deletions
+28
View File
@@ -13,6 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pytest
import torch
from datasets import Dataset
@@ -106,3 +109,28 @@ def test_shuffle():
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_negative_drop_first_frames_raises():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
def test_negative_drop_last_frames_raises():
with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_last_frames=-1)
def test_all_episodes_dropped_raises():
# All episodes have 1 frame, drop_n_first_frames=1 removes all
with pytest.raises(ValueError, match="No valid frames remain"):
EpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
def test_partial_episode_drop_warns(caplog):
# Episode 0: 1 frame (dropped), Episode 1: 5 frames (kept)
with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"):
sampler = EpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1)
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
assert sampler.indices == [2, 3, 4, 5]
assert "Episode 0" in caplog.text