diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index c73da7d0a..5af24b740 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -24,6 +24,9 @@ logger = logging.getLogger(__name__) _MASK_64 = (1 << 64) - 1 _FEISTEL_ROUNDS = 4 +# Cycle-walking converges in <4 expected steps on the chosen domain; this bound is a generous +# safety net that should never be hit in practice. +_MAX_CYCLE_WALK_STEPS = 100 def _mix64(x: int) -> int: @@ -165,13 +168,17 @@ class EpisodeAwareSampler: def _permute(self, index: int, keys: list[int]) -> int: # Feistel network with cycle-walking: a bijection on [0, num_frames). half_bits, half_mask = self._half_bits, self._half_mask - while True: + for _ in range(_MAX_CYCLE_WALK_STEPS): left, right = index >> half_bits, index & half_mask for key in keys: left, right = right, left ^ (_mix64(right ^ key) & half_mask) index = (left << half_bits) | right if index < self._num_frames: return index + raise RuntimeError( + f"Feistel cycle-walking did not converge within {_MAX_CYCLE_WALK_STEPS} steps; " + "this should never happen for a valid domain." + ) def _frame_index(self, position: int) -> int: episode = int(np.searchsorted(self._cum_lengths, position, side="right"))