diff --git a/scripts/bench_episode_byte_cache.py b/scripts/bench_episode_byte_cache.py index d07ec7722..453f38139 100644 --- a/scripts/bench_episode_byte_cache.py +++ b/scripts/bench_episode_byte_cache.py @@ -165,25 +165,12 @@ def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int: return total -def _episode_timestamp_window(manifest: EpisodeVideoManifest, episode_index: int) -> tuple[float, float]: - spans = [manifest.lookup(episode_index, camera_key) for camera_key in manifest.video_keys] - lo = max(span.first_pts for span in spans) - hi = min(span.last_pts for span in spans) - if hi >= lo: - return lo, hi - first = spans[0] - return first.first_pts, max(first.last_pts, first.first_pts) - - -def _random_training_samples( - manifest: EpisodeVideoManifest, episodes: Sequence[int], count: int, seed: int -) -> list[tuple[int, float]]: +def _random_training_samples(episodes: Sequence[int], count: int, seed: int) -> list[tuple[int, float]]: rng = random.Random(seed) out = [] for _ in range(count): ep = rng.choice(episodes) - lo, hi = _episode_timestamp_window(manifest, ep) - out.append((ep, rng.uniform(lo, max(hi, lo)))) + out.append((ep, rng.random())) return out @@ -258,10 +245,12 @@ def _open_resident_decoders( def _decode_training_sample( cache: EpisodeByteCache, episode_index: int, - timestamp: float, + relative_t: float, locks: dict[tuple[int, str], threading.Lock], ) -> None: for camera_key in cache.manifest.video_keys: + span = cache.manifest.lookup(episode_index, camera_key) + timestamp = span.first_pts + relative_t * max(span.last_pts - span.first_pts, 0.0) with locks[(episode_index, camera_key)]: cache.get_frames(episode_index, camera_key, [timestamp]) @@ -275,7 +264,7 @@ def run_pool_random_decode( decode_workers: int, seed: int, ) -> dict[str, float]: - samples = _random_training_samples(cache.manifest, episodes, sample_count, seed) + samples = _random_training_samples(episodes, sample_count, seed) touched_episodes = sorted({ep for ep, _ts in samples}) decoder_open_s, decoder_count = _open_resident_decoders( cache, touched_episodes, decode_workers=decode_workers @@ -371,10 +360,9 @@ def run_pool_stream_simulation( schedule_one() ep = rng.choice(resident) - lo, hi = _episode_timestamp_window(cache.manifest, ep) - ts = rng.uniform(lo, max(hi, lo)) - _decode_training_sample(cache, ep, ts, locks) - decoded_samples.append((ep, ts)) + relative_t = rng.random() + _decode_training_sample(cache, ep, relative_t, locks) + decoded_samples.append((ep, relative_t)) if sample_period > 0: now = time.perf_counter()