mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
fix(datasets): guard Feistel cycle-walking loop against non-convergence
Replace the unbounded while True in EpisodeAwareSampler._permute with a bounded for loop capped at _MAX_CYCLE_WALK_STEPS (100) and raise RuntimeError if the cycle-walk fails to land in [0, num_frames). The loop is expected to converge in <4 steps on the chosen power-of-two domain, so the bound is a safety net that should never trip in practice but prevents a pathological infinite loop. https://claude.ai/code/session_01HQ15tFrBsHYScjGWosEv22
This commit is contained in:
@@ -24,6 +24,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_MASK_64 = (1 << 64) - 1
|
||||
_FEISTEL_ROUNDS = 4
|
||||
# Cycle-walking converges in <4 expected steps on the chosen domain; this bound is a generous
|
||||
# safety net that should never be hit in practice.
|
||||
_MAX_CYCLE_WALK_STEPS = 100
|
||||
|
||||
|
||||
def _mix64(x: int) -> int:
|
||||
@@ -165,13 +168,17 @@ class EpisodeAwareSampler:
|
||||
def _permute(self, index: int, keys: list[int]) -> int:
|
||||
# Feistel network with cycle-walking: a bijection on [0, num_frames).
|
||||
half_bits, half_mask = self._half_bits, self._half_mask
|
||||
while True:
|
||||
for _ in range(_MAX_CYCLE_WALK_STEPS):
|
||||
left, right = index >> half_bits, index & half_mask
|
||||
for key in keys:
|
||||
left, right = right, left ^ (_mix64(right ^ key) & half_mask)
|
||||
index = (left << half_bits) | right
|
||||
if index < self._num_frames:
|
||||
return index
|
||||
raise RuntimeError(
|
||||
f"Feistel cycle-walking did not converge within {_MAX_CYCLE_WALK_STEPS} steps; "
|
||||
"this should never happen for a valid domain."
|
||||
)
|
||||
|
||||
def _frame_index(self, position: int) -> int:
|
||||
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
|
||||
|
||||
Reference in New Issue
Block a user