From 81f0ca9ce44e4d2424f4aac9528fa2d3ac713194 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 11 Jun 2026 10:39:13 +0000 Subject: [PATCH] test(sampler): drain resumed trillion-frame sampler via iter() to avoid list() prealloc list(sampler) calls PyObject_LengthHint -> __len__ (the full 10**12 epoch length) and preallocates that many slots before iterating, OOMing even though the resumed epoch only yields 3 frames. Collect through the iterator (no length hint) so the test exercises the real O(1) seek/drain instead of CPython's list growth heuristic. --- tests/datasets/test_sampler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index 7614c7dd8..cfe2c5eaf 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -246,7 +246,10 @@ def test_deterministic_sampler_constant_memory(): sampler = deterministic_sampler([0], [num_frames], shuffle=True, seed=0) assert len(sampler) == num_frames sampler.load_state_dict({"epoch": 3, "start_index": num_frames - 3}) - tail = list(sampler) + # Collect via the iterator: list(sampler) would call PyObject_LengthHint -> sampler.__len__ + # (the full epoch length, here 10**12) and pre-allocate that many slots before iterating. The + # iterator itself exposes no length hint, so this stays O(1) like the resumed epoch it drains. + tail = list(iter(sampler)) assert len(tail) == 3 assert all(0 <= idx < num_frames for idx in tail)