Fix pool sampling camera timestamps

This commit is contained in:
Pepijn
2026-06-22 16:44:37 +02:00
parent ef47c35178
commit 9202fcea96
+9 -21
View File
@@ -165,25 +165,12 @@ def _bytes_for(manifest: EpisodeVideoManifest, episodes: Sequence[int]) -> int:
return total
def _episode_timestamp_window(manifest: EpisodeVideoManifest, episode_index: int) -> tuple[float, float]:
spans = [manifest.lookup(episode_index, camera_key) for camera_key in manifest.video_keys]
lo = max(span.first_pts for span in spans)
hi = min(span.last_pts for span in spans)
if hi >= lo:
return lo, hi
first = spans[0]
return first.first_pts, max(first.last_pts, first.first_pts)
def _random_training_samples(
manifest: EpisodeVideoManifest, episodes: Sequence[int], count: int, seed: int
) -> list[tuple[int, float]]:
def _random_training_samples(episodes: Sequence[int], count: int, seed: int) -> list[tuple[int, float]]:
rng = random.Random(seed)
out = []
for _ in range(count):
ep = rng.choice(episodes)
lo, hi = _episode_timestamp_window(manifest, ep)
out.append((ep, rng.uniform(lo, max(hi, lo))))
out.append((ep, rng.random()))
return out
@@ -258,10 +245,12 @@ def _open_resident_decoders(
def _decode_training_sample(
cache: EpisodeByteCache,
episode_index: int,
timestamp: float,
relative_t: float,
locks: dict[tuple[int, str], threading.Lock],
) -> None:
for camera_key in cache.manifest.video_keys:
span = cache.manifest.lookup(episode_index, camera_key)
timestamp = span.first_pts + relative_t * max(span.last_pts - span.first_pts, 0.0)
with locks[(episode_index, camera_key)]:
cache.get_frames(episode_index, camera_key, [timestamp])
@@ -275,7 +264,7 @@ def run_pool_random_decode(
decode_workers: int,
seed: int,
) -> dict[str, float]:
samples = _random_training_samples(cache.manifest, episodes, sample_count, seed)
samples = _random_training_samples(episodes, sample_count, seed)
touched_episodes = sorted({ep for ep, _ts in samples})
decoder_open_s, decoder_count = _open_resident_decoders(
cache, touched_episodes, decode_workers=decode_workers
@@ -371,10 +360,9 @@ def run_pool_stream_simulation(
schedule_one()
ep = rng.choice(resident)
lo, hi = _episode_timestamp_window(cache.manifest, ep)
ts = rng.uniform(lo, max(hi, lo))
_decode_training_sample(cache, ep, ts, locks)
decoded_samples.append((ep, ts))
relative_t = rng.random()
_decode_training_sample(cache, ep, relative_t, locks)
decoded_samples.append((ep, relative_t))
if sample_period > 0:
now = time.perf_counter()