mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
chore(dataset): basic house-keeping (#3170)
This commit is contained in:
@@ -13,6 +13,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
@@ -106,3 +109,28 @@ def test_shuffle():
|
||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||
assert len(sampler) == 6
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_negative_drop_first_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
|
||||
|
||||
|
||||
def test_negative_drop_last_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_last_frames=-1)
|
||||
|
||||
|
||||
def test_all_episodes_dropped_raises():
|
||||
# All episodes have 1 frame, drop_n_first_frames=1 removes all
|
||||
with pytest.raises(ValueError, match="No valid frames remain"):
|
||||
EpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
|
||||
|
||||
|
||||
def test_partial_episode_drop_warns(caplog):
|
||||
# Episode 0: 1 frame (dropped), Episode 1: 5 frames (kept)
|
||||
with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"):
|
||||
sampler = EpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1)
|
||||
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
|
||||
assert sampler.indices == [2, 3, 4, 5]
|
||||
assert "Episode 0" in caplog.text
|
||||
|
||||
Reference in New Issue
Block a user