mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-24 03:37:12 +00:00
Fix pool sampling camera timestamps
This commit is contained in:
@@ -165,25 +165,12 @@ def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int:
|
||||
return total
|
||||
|
||||
|
||||
def _episode_timestamp_window(manifest: EpisodeVideoManifest, episode_index: int) -> tuple[float, float]:
|
||||
spans = [manifest.lookup(episode_index, camera_key) for camera_key in manifest.video_keys]
|
||||
lo = max(span.first_pts for span in spans)
|
||||
hi = min(span.last_pts for span in spans)
|
||||
if hi >= lo:
|
||||
return lo, hi
|
||||
first = spans[0]
|
||||
return first.first_pts, max(first.last_pts, first.first_pts)
|
||||
|
||||
|
||||
def _random_training_samples(
|
||||
manifest: EpisodeVideoManifest, episodes: Sequence[int], count: int, seed: int
|
||||
) -> list[tuple[int, float]]:
|
||||
def _random_training_samples(episodes: Sequence[int], count: int, seed: int) -> list[tuple[int, float]]:
|
||||
rng = random.Random(seed)
|
||||
out = []
|
||||
for _ in range(count):
|
||||
ep = rng.choice(episodes)
|
||||
lo, hi = _episode_timestamp_window(manifest, ep)
|
||||
out.append((ep, rng.uniform(lo, max(hi, lo))))
|
||||
out.append((ep, rng.random()))
|
||||
return out
|
||||
|
||||
|
||||
@@ -258,10 +245,12 @@ def _open_resident_decoders(
|
||||
def _decode_training_sample(
|
||||
cache: EpisodeByteCache,
|
||||
episode_index: int,
|
||||
timestamp: float,
|
||||
relative_t: float,
|
||||
locks: dict[tuple[int, str], threading.Lock],
|
||||
) -> None:
|
||||
for camera_key in cache.manifest.video_keys:
|
||||
span = cache.manifest.lookup(episode_index, camera_key)
|
||||
timestamp = span.first_pts + relative_t * max(span.last_pts - span.first_pts, 0.0)
|
||||
with locks[(episode_index, camera_key)]:
|
||||
cache.get_frames(episode_index, camera_key, [timestamp])
|
||||
|
||||
@@ -275,7 +264,7 @@ def run_pool_random_decode(
|
||||
decode_workers: int,
|
||||
seed: int,
|
||||
) -> dict[str, float]:
|
||||
samples = _random_training_samples(cache.manifest, episodes, sample_count, seed)
|
||||
samples = _random_training_samples(episodes, sample_count, seed)
|
||||
touched_episodes = sorted({ep for ep, _ts in samples})
|
||||
decoder_open_s, decoder_count = _open_resident_decoders(
|
||||
cache, touched_episodes, decode_workers=decode_workers
|
||||
@@ -371,10 +360,9 @@ def run_pool_stream_simulation(
|
||||
schedule_one()
|
||||
|
||||
ep = rng.choice(resident)
|
||||
lo, hi = _episode_timestamp_window(cache.manifest, ep)
|
||||
ts = rng.uniform(lo, max(hi, lo))
|
||||
_decode_training_sample(cache, ep, ts, locks)
|
||||
decoded_samples.append((ep, ts))
|
||||
relative_t = rng.random()
|
||||
_decode_training_sample(cache, ep, relative_t, locks)
|
||||
decoded_samples.append((ep, relative_t))
|
||||
|
||||
if sample_period > 0:
|
||||
now = time.perf_counter()
|
||||
|
||||
Reference in New Issue
Block a user