debug sampling

This commit is contained in:
Pepijn
2025-08-31 01:48:35 +02:00
parent 852713dc84
commit 9767120eb4
2 changed files with 58 additions and 30 deletions
File diff suppressed because one or more lines are too long
+23 -3
View File
@@ -506,7 +506,20 @@ class RLearNPolicy(PreTrainedPolicy):
# Calculate progress for each frame in the temporal window
all_progress = []
for delta in delta_indices:
# DEBUG: Log first sample's target calculation
debug_first_sample = True
if debug_first_sample and torch.rand(1).item() < 0.05: # 5% chance
ep_idx_debug = episode_indices[0].item()
frame_idx_debug = frame_indices[0].item()
ep_start_debug = self.episode_data_index["from"][ep_idx_debug].item()
ep_end_debug = self.episode_data_index["to"][ep_idx_debug].item()
ep_length_debug = ep_end_debug - ep_start_debug
print(f"\n=== TARGET DEBUG ===")
print(f"Episode {ep_idx_debug}: length={ep_length_debug}, current_frame={frame_idx_debug}")
print(f"Delta indices: {delta_indices}")
for i, delta in enumerate(delta_indices):
# For each sample, calculate the progress of the frame at delta offset
frame_progress = []
for b_idx in range(B):
@@ -522,16 +535,23 @@ class RLearNPolicy(PreTrainedPolicy):
ep_length = ep_end - ep_start
# Clamp to episode boundaries (frame_index is relative to episode)
target_frame_idx = max(0, min(ep_length - 1, target_frame_idx))
target_frame_idx_clamped = max(0, min(ep_length - 1, target_frame_idx))
# Calculate progress for this frame
prog = target_frame_idx / max(1, ep_length - 1)
prog = target_frame_idx_clamped / max(1, ep_length - 1)
frame_progress.append(prog)
# DEBUG: Log first sample calculation
if debug_first_sample and b_idx == 0 and torch.rand(1).item() < 0.05:
print(f"Frame {i:2d} (delta={delta:3d}): target_idx={target_frame_idx:3d} → clamped={target_frame_idx_clamped:3d} → progress={prog:.6f}")
all_progress.append(
torch.tensor(frame_progress, device=video_frame_embeds.device, dtype=video_frame_embeds.dtype)
)
if debug_first_sample and torch.rand(1).item() < 0.05:
print("=" * 20)
# Stack to get (B, T) tensor where T is the temporal sequence length
target = torch.stack(all_progress, dim=1) # (B, max_seq_len)