mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-29 06:07:40 +00:00
debug sampling
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -506,7 +506,20 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Calculate progress for each frame in the temporal window
|
||||
all_progress = []
|
||||
for delta in delta_indices:
|
||||
|
||||
# DEBUG: Log first sample's target calculation
|
||||
debug_first_sample = True
|
||||
if debug_first_sample and torch.rand(1).item() < 0.05: # 5% chance
|
||||
ep_idx_debug = episode_indices[0].item()
|
||||
frame_idx_debug = frame_indices[0].item()
|
||||
ep_start_debug = self.episode_data_index["from"][ep_idx_debug].item()
|
||||
ep_end_debug = self.episode_data_index["to"][ep_idx_debug].item()
|
||||
ep_length_debug = ep_end_debug - ep_start_debug
|
||||
print(f"\n=== TARGET DEBUG ===")
|
||||
print(f"Episode {ep_idx_debug}: length={ep_length_debug}, current_frame={frame_idx_debug}")
|
||||
print(f"Delta indices: {delta_indices}")
|
||||
|
||||
for i, delta in enumerate(delta_indices):
|
||||
# For each sample, calculate the progress of the frame at delta offset
|
||||
frame_progress = []
|
||||
for b_idx in range(B):
|
||||
@@ -522,16 +535,23 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
ep_length = ep_end - ep_start
|
||||
|
||||
# Clamp to episode boundaries (frame_index is relative to episode)
|
||||
target_frame_idx = max(0, min(ep_length - 1, target_frame_idx))
|
||||
target_frame_idx_clamped = max(0, min(ep_length - 1, target_frame_idx))
|
||||
|
||||
# Calculate progress for this frame
|
||||
prog = target_frame_idx / max(1, ep_length - 1)
|
||||
prog = target_frame_idx_clamped / max(1, ep_length - 1)
|
||||
frame_progress.append(prog)
|
||||
|
||||
# DEBUG: Log first sample calculation
|
||||
if debug_first_sample and b_idx == 0 and torch.rand(1).item() < 0.05:
|
||||
print(f"Frame {i:2d} (delta={delta:3d}): target_idx={target_frame_idx:3d} → clamped={target_frame_idx_clamped:3d} → progress={prog:.6f}")
|
||||
|
||||
all_progress.append(
|
||||
torch.tensor(frame_progress, device=video_frame_embeds.device, dtype=video_frame_embeds.dtype)
|
||||
)
|
||||
|
||||
if debug_first_sample and torch.rand(1).item() < 0.05:
|
||||
print("=" * 20)
|
||||
|
||||
# Stack to get (B, T) tensor where T is the temporal sequence length
|
||||
target = torch.stack(all_progress, dim=1) # (B, max_seq_len)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user