Compare commits

...

2 Commits

Author SHA1 Message Date
Claude 7a62235bac fix(datasets): guard Feistel cycle-walking loop against non-convergence
Replace the unbounded while True in EpisodeAwareSampler._permute with a
bounded for loop capped at _MAX_CYCLE_WALK_STEPS (100) and raise
RuntimeError if the cycle-walk fails to land in [0, num_frames). The
loop is expected to converge in <4 steps on the chosen power-of-two
domain, so the bound is a safety net that should never trip in practice
but prevents a pathological infinite loop.

https://claude.ai/code/session_01HQ15tFrBsHYScjGWosEv22
2026-06-11 13:20:31 +00:00
Pepijn 81f0ca9ce4 test(sampler): drain resumed trillion-frame sampler via iter() to avoid list() prealloc
list(sampler) calls PyObject_LengthHint -> __len__ (the full 10**12 epoch length) and
preallocates that many slots before iterating, OOMing even though the resumed epoch only
yields 3 frames. Collect through the iterator (no length hint) so the test exercises the
real O(1) seek/drain instead of CPython's list growth heuristic.
2026-06-11 10:39:13 +00:00
2 changed files with 12 additions and 2 deletions
+8 -1
View File
@@ -24,6 +24,9 @@ logger = logging.getLogger(__name__)
_MASK_64 = (1 << 64) - 1
_FEISTEL_ROUNDS = 4
# Cycle-walking converges in <4 expected steps on the chosen domain; this bound is a generous
# safety net that should never be hit in practice.
_MAX_CYCLE_WALK_STEPS = 100
def _mix64(x: int) -> int:
@@ -165,13 +168,17 @@ class EpisodeAwareSampler:
def _permute(self, index: int, keys: list[int]) -> int:
# Feistel network with cycle-walking: a bijection on [0, num_frames).
half_bits, half_mask = self._half_bits, self._half_mask
while True:
for _ in range(_MAX_CYCLE_WALK_STEPS):
left, right = index >> half_bits, index & half_mask
for key in keys:
left, right = right, left ^ (_mix64(right ^ key) & half_mask)
index = (left << half_bits) | right
if index < self._num_frames:
return index
raise RuntimeError(
f"Feistel cycle-walking did not converge within {_MAX_CYCLE_WALK_STEPS} steps; "
"this should never happen for a valid domain."
)
def _frame_index(self, position: int) -> int:
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
+4 -1
View File
@@ -246,7 +246,10 @@ def test_deterministic_sampler_constant_memory():
sampler = deterministic_sampler([0], [num_frames], shuffle=True, seed=0)
assert len(sampler) == num_frames
sampler.load_state_dict({"epoch": 3, "start_index": num_frames - 3})
tail = list(sampler)
# Collect via the iterator: list(sampler) would call PyObject_LengthHint -> sampler.__len__
# (the full epoch length, here 10**12) and pre-allocate that many slots before iterating. The
# iterator itself exposes no length hint, so this stays O(1) like the resumed epoch it drains.
tail = list(iter(sampler))
assert len(tail) == 3
assert all(0 <= idx < num_frames for idx in tail)