mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-11 13:49:43 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7a62235bac | |||
| 81f0ca9ce4 |
@@ -24,6 +24,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_MASK_64 = (1 << 64) - 1
|
||||
_FEISTEL_ROUNDS = 4
|
||||
# Cycle-walking converges in <4 expected steps on the chosen domain; this bound is a generous
|
||||
# safety net that should never be hit in practice.
|
||||
_MAX_CYCLE_WALK_STEPS = 100
|
||||
|
||||
|
||||
def _mix64(x: int) -> int:
|
||||
@@ -165,13 +168,17 @@ class EpisodeAwareSampler:
|
||||
def _permute(self, index: int, keys: list[int]) -> int:
|
||||
# Feistel network with cycle-walking: a bijection on [0, num_frames).
|
||||
half_bits, half_mask = self._half_bits, self._half_mask
|
||||
while True:
|
||||
for _ in range(_MAX_CYCLE_WALK_STEPS):
|
||||
left, right = index >> half_bits, index & half_mask
|
||||
for key in keys:
|
||||
left, right = right, left ^ (_mix64(right ^ key) & half_mask)
|
||||
index = (left << half_bits) | right
|
||||
if index < self._num_frames:
|
||||
return index
|
||||
raise RuntimeError(
|
||||
f"Feistel cycle-walking did not converge within {_MAX_CYCLE_WALK_STEPS} steps; "
|
||||
"this should never happen for a valid domain."
|
||||
)
|
||||
|
||||
def _frame_index(self, position: int) -> int:
|
||||
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
|
||||
|
||||
@@ -246,7 +246,10 @@ def test_deterministic_sampler_constant_memory():
|
||||
sampler = deterministic_sampler([0], [num_frames], shuffle=True, seed=0)
|
||||
assert len(sampler) == num_frames
|
||||
sampler.load_state_dict({"epoch": 3, "start_index": num_frames - 3})
|
||||
tail = list(sampler)
|
||||
# Collect via the iterator: list(sampler) would call PyObject_LengthHint -> sampler.__len__
|
||||
# (the full epoch length, here 10**12) and pre-allocate that many slots before iterating. The
|
||||
# iterator itself exposes no length hint, so this stays O(1) like the resumed epoch it drains.
|
||||
tail = list(iter(sampler))
|
||||
assert len(tail) == 3
|
||||
assert all(0 <= idx < num_frames for idx in tail)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user