diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index 7614c7dd8..cfe2c5eaf 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -246,7 +246,10 @@ def test_deterministic_sampler_constant_memory(): sampler = deterministic_sampler([0], [num_frames], shuffle=True, seed=0) assert len(sampler) == num_frames sampler.load_state_dict({"epoch": 3, "start_index": num_frames - 3}) - tail = list(sampler) + # Collect via the iterator: list(sampler) would call PyObject_LengthHint -> sampler.__len__ + # (the full epoch length, here 10**12) and pre-allocate that many slots before iterating. The + # iterator itself exposes no length hint, so this stays O(1) like the resumed epoch it drains. + tail = list(iter(sampler)) assert len(tail) == 3 assert all(0 <= idx < num_frames for idx in tail)