Add uniform sampling and transition smoothing

This commit is contained in:
Pepijn
2025-11-28 17:15:57 +01:00
parent 6e3b972534
commit 112eb70a65
4 changed files with 237 additions and 90 deletions
+16 -6
View File
@@ -286,10 +286,17 @@ def run_inference(
state_slices = []
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
# Compute frame indices: [initial_frame (0), t-(7*gap), t-(6*gap), ..., t-gap, t]
# The first delta is -100000 which clamps to episode start
# Compute frame indices using symmetric bidirectional pattern:
# [initial (0), t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
# Boundary handling: clamp to [0, last_valid]
deltas = model.config.observation_delta_indices
frame_indices = [max(0, min(current_frame + delta, len(video_embeddings) - 1)) for delta in deltas]
last_valid = len(video_embeddings) - 1
frame_indices = []
for delta in deltas:
idx = current_frame + delta
idx = max(0, min(idx, last_valid)) # Clamp to valid range
frame_indices.append(idx)
video_slice = video_embeddings[frame_indices]
video_slices.append(video_slice)
@@ -324,9 +331,12 @@ def run_inference(
batch_video, batch_text, batch_states
)
# Extract last frame predictions (the "current" frame)
batch_progress = progress_preds[:, -1, 0].cpu().numpy()
batch_stages = stage_probs[:, -1, :].cpu().numpy()
# Extract predictions at the "current frame" position
# With symmetric pattern [initial, t-4g, t-3g, t-2g, t-g, t, t+g, t+2g, t+3g],
# the current frame is at position 5 (0-indexed)
current_frame_idx = 5
batch_progress = progress_preds[:, current_frame_idx, 0].cpu().numpy()
batch_stages = stage_probs[:, current_frame_idx, :].cpu().numpy()
all_progress.extend(batch_progress)
all_stages.extend(batch_stages)