debug frames

This commit is contained in:
Pepijn
2025-08-31 19:18:52 +02:00
parent 10e36f2453
commit e3306951c0
2 changed files with 64 additions and 71 deletions
@@ -71,15 +71,15 @@ class RLearNConfig(PreTrainedConfig):
rewind_last3_prob: float = 0.3
mismatch_prob: float = 0.2
# Logit regression (only supported mode)
logit_eps: float = 1e-6
# Logit regression (only supported mode) - FIXED: Larger eps to prevent extreme targets
logit_eps: float = 0.02 # Was 1e-6 → logit(±13.8), now 0.02 → logit(±3.9)
head_lr_multiplier: float = 2.0
head_weight_init_std: float = 0.05
# Reward head architecture
head_hidden_dim: int = 1024 # Hidden dimension for reward head
head_num_layers: int = 4 # Number of layers in reward head
head_dropout: float = 0.1 # Dropout in reward head
# Reward head architecture - FIXED: Simpler architecture to prevent flat basins
head_hidden_dim: int = 1024 # Hidden dimension for reward head
head_num_layers: int = 2 # REDUCED: 2 layers instead of 4 to prevent over-regularization
head_dropout: float = 0.05 # REDUCED: Less dropout to prevent conservatism
# Normalization presets
normalization_mapping: dict[str, NormalizationMode] = field(
+58 -65
View File
@@ -184,38 +184,23 @@ class RLearNPolicy(PreTrainedPolicy):
depth=config.mlp_predictor_depth
)
# Layer normalization before reward head to stabilize MLP outputs
self.pre_reward_norm = nn.LayerNorm(config.dim_model)
# FIXED: Simpler head architecture to prevent constant output pathology
# Remove LayerNorm (causes flat basin), reduce depth, larger init, less dropout
# Temporal-aware regression head with increased capacity
# Build a deeper MLP for better visual-progress learning
head_layers = []
# Input layer: embedding + temporal position -> hidden
head_layers.extend([
nn.Linear(config.dim_model + 1, config.head_hidden_dim), # +1 for temporal position
# Simple 2-layer MLP with larger initialization to encourage exploration
self.reward_head = nn.Sequential(
nn.Linear(config.dim_model + 1, config.head_hidden_dim), # +1 for temporal position
nn.ReLU(),
nn.Dropout(config.head_dropout)
])
nn.Dropout(0.05), # Reduced dropout to prevent noise-induced conservatism
nn.Linear(config.head_hidden_dim, 1)
)
# Hidden layers: multiple layers for complex visual-progress mapping
for _ in range(config.head_num_layers - 2): # -2 for input and output layers
head_layers.extend([
nn.Linear(config.head_hidden_dim, config.head_hidden_dim),
nn.ReLU(),
nn.Dropout(config.head_dropout)
])
# Output layer: hidden -> logit
head_layers.append(nn.Linear(config.head_hidden_dim, 1))
self.reward_head = nn.Sequential(*head_layers)
# Initialize the deeper temporal-aware head for logit regression
# FIXED: Larger weight initialization to escape flat basins
with torch.no_grad():
for module in self.reward_head:
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, 0.0, config.head_weight_init_std)
# Use Xavier/Glorot initialization for better gradient flow
nn.init.xavier_uniform_(module.weight, gain=1.0)
nn.init.zeros_(module.bias)
# Simple frame dropout probability
@@ -335,16 +320,15 @@ class RLearNPolicy(PreTrainedPolicy):
# MLP predictor
video_frame_embeds = self.mlp_predictor(frame_tokens)
# Get rewards via temporal-aware logit regression head
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
# Get rewards via temporal-aware logit regression head (no pre-normalization)
# Add temporal position information
B, T_pred = normalized_embeds.shape[:2]
temporal_pos = torch.linspace(0, 1, T_pred, device=normalized_embeds.device)
B, T_pred = video_frame_embeds.shape[:2]
temporal_pos = torch.linspace(0, 1, T_pred, device=video_frame_embeds.device)
temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_pred, 1) # (B, T, 1)
# Concatenate embeddings with temporal position
temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1) # (B, T, D+1)
temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1) # (B, T, D+1)
# Forward through temporal-aware head
raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T)
@@ -414,10 +398,9 @@ class RLearNPolicy(PreTrainedPolicy):
# Check frame-to-frame differences in raw input
if T > 1:
raw_frame_diffs = torch.norm(
frames[:, 1:, :, :, :] - frames[:, :-1, :, :, :],
dim=(2, 3, 4) # Across C, H, W
).mean()
# FIXED: Use proper tensor operations for difference calculation
frame_diffs = (frames[:, 1:, :, :, :] - frames[:, :-1, :, :, :]).pow(2).sum(dim=(2, 3, 4)).sqrt()
raw_frame_diffs = frame_diffs.mean()
print(f"Raw input frame differences: {raw_frame_diffs:.6f}")
if raw_frame_diffs < 0.001:
@@ -428,10 +411,9 @@ class RLearNPolicy(PreTrainedPolicy):
# Check processed pixel values
first_sample_pixels = inputs['pixel_values'][:T] # First sample's pixels
if T > 1:
pixel_diffs = torch.norm(
first_sample_pixels[1:] - first_sample_pixels[:-1],
dim=(1, 2, 3) # Across C, H, W
).mean()
# FIXED: Use proper tensor operations
pixel_frame_diffs = (first_sample_pixels[1:] - first_sample_pixels[:-1]).pow(2).sum(dim=(1, 2, 3)).sqrt()
pixel_diffs = pixel_frame_diffs.mean()
print(f"Processed pixel_values differences: {pixel_diffs:.6f}")
if pixel_diffs < 0.001:
@@ -441,16 +423,17 @@ class RLearNPolicy(PreTrainedPolicy):
# Check if all samples in batch have same first frame
if B > 1:
batch_first_frame_diff = torch.norm(
inputs['pixel_values'][::T] - inputs['pixel_values'][0].unsqueeze(0),
dim=(1, 2, 3)
).mean()
print(f"Batch first-frame differences: {batch_first_frame_diff:.6f}")
if batch_first_frame_diff < 0.001:
print(f" ⚠️ ALL BATCH SAMPLES HAVE SAME FIRST FRAME! Diff: {batch_first_frame_diff:.8f}")
else:
print(f" ✓ Batch samples have different first frames. Diff: {batch_first_frame_diff:.6f}")
# FIXED: Use proper tensor operations
batch_first_frames = inputs['pixel_values'][::T] # Every T-th frame (first frame of each sample)
if len(batch_first_frames) > 1:
first_frame_diffs = (batch_first_frames[1:] - batch_first_frames[0].unsqueeze(0)).pow(2).sum(dim=(1, 2, 3)).sqrt()
batch_first_frame_diff = first_frame_diffs.mean()
print(f"Batch first-frame differences: {batch_first_frame_diff:.6f}")
if batch_first_frame_diff < 0.001:
print(f" ⚠️ ALL BATCH SAMPLES HAVE SAME FIRST FRAME! Diff: {batch_first_frame_diff:.8f}")
else:
print(f" ✓ Batch samples have different first frames. Diff: {batch_first_frame_diff:.6f}")
# Check feature statistics
feature_mean = vision_features.mean().item()
@@ -622,15 +605,14 @@ class RLearNPolicy(PreTrainedPolicy):
# During inference, we might not want to compute loss
if not self.training and target is None:
# Return predictions without loss using temporal-aware head
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
# Add temporal position information
B_inf, T_inf = normalized_embeds.shape[:2]
temporal_pos = torch.linspace(0, 1, T_inf, device=normalized_embeds.device)
B_inf, T_inf = video_frame_embeds.shape[:2]
temporal_pos = torch.linspace(0, 1, T_inf, device=video_frame_embeds.device)
temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B_inf, T_inf, 1)
# Concatenate and forward through temporal-aware head
temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1)
temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1)
raw_logits = self.reward_head(temporal_input).squeeze(-1)
rewards = torch.sigmoid(raw_logits)
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
@@ -639,24 +621,30 @@ class RLearNPolicy(PreTrainedPolicy):
loss_start = time.perf_counter()
# Get model outputs with temporal-aware head
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
# Add temporal position information
temporal_pos = torch.linspace(0, 1, T_eff, device=normalized_embeds.device)
temporal_pos = torch.linspace(0, 1, T_eff, device=video_frame_embeds.device)
temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_eff, 1) # (B, T_eff, 1)
# Concatenate embeddings with temporal position
temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1) # (B, T_eff, D+1)
temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1) # (B, T_eff, D+1)
# Forward through temporal-aware head
raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T_eff)
# Logit regression: transform targets to logit space and compute MSE on logits
# FIXED: More robust logit regression with gradient protection
eps = self.config.logit_eps
target_expanded = target.expand(B, -1)[:, :T_eff] # Expand and trim to T_eff
target_clamped = torch.clamp(target_expanded, eps, 1 - eps)
target_logits = torch.logit(target_clamped)
loss = F.mse_loss(raw_logits, target_logits, reduction='mean')
# Use Smooth L1 loss instead of MSE for better gradient stability
loss = F.smooth_l1_loss(raw_logits, target_logits, reduction='mean', beta=1.0)
# Clip gradients specifically for the reward head during backward pass
# This prevents extreme gradients from corrupting AdamW momentum
if self.training:
raw_logits.register_hook(lambda grad: torch.clamp(grad, -10.0, 10.0))
# For logging, compute sigmoid predictions
predicted_rewards = torch.sigmoid(raw_logits)
@@ -698,15 +686,14 @@ class RLearNPolicy(PreTrainedPolicy):
mismatch_embeds = self.mlp_predictor(mismatch_tokens)
# Predict near-zero progress for mismatched pairs with temporal awareness
normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds)
# Add temporal position information for mismatch computation
T_mismatch = normalized_mismatch_embeds.shape[1]
temporal_pos_mm = torch.linspace(0, 1, T_mismatch, device=normalized_mismatch_embeds.device)
T_mismatch = mismatch_embeds.shape[1]
temporal_pos_mm = torch.linspace(0, 1, T_mismatch, device=mismatch_embeds.device)
temporal_pos_mm = temporal_pos_mm.unsqueeze(0).unsqueeze(-1).expand(B, T_mismatch, 1)
# Concatenate mismatch embeddings with temporal position
temporal_input_mm = torch.cat([normalized_mismatch_embeds, temporal_pos_mm], dim=-1)
temporal_input_mm = torch.cat([mismatch_embeds, temporal_pos_mm], dim=-1)
# Forward through temporal-aware head
mismatch_raw_logits = self.reward_head(temporal_input_mm).squeeze(-1)
@@ -977,7 +964,7 @@ class RLearNPolicy(PreTrainedPolicy):
B = len(episode_indices)
T = self.config.max_seq_len
# Get raw image data
# Get raw image data - this contains the window of frames provided by the dataset
raw_frames = extract_visual_sequence(batch, target_seq_len=None)
available_T = raw_frames.shape[1]
@@ -1044,13 +1031,19 @@ class RLearNPolicy(PreTrainedPolicy):
print(f"Unique window indices: {unique_indices} out of {len(window_indices)}")
if unique_indices == 1:
print(f" ⚠️ ALL WINDOW INDICES ARE THE SAME! Index: {window_indices[0]}")
elif unique_indices < T // 2:
print(f" ⚠️ TOO FEW UNIQUE INDICES! Only {unique_indices} unique frames")
else:
print(f" ✓ Good frame diversity: {unique_indices} unique frames")
# Check frame tensor differences
if len(frame_tensors) > 1:
frame_diff = torch.norm(frame_tensors[1] - frame_tensors[0]).item()
print(f"First frame difference: {frame_diff:.6f}")
frame_diff = (frame_tensors[1] - frame_tensors[0]).pow(2).sum().sqrt().item()
print(f"First vs second frame difference: {frame_diff:.6f}")
if frame_diff < 0.001:
print(f" ⚠️ CONSECUTIVE SAMPLED FRAMES ARE NEARLY IDENTICAL!")
else:
print(f" ✓ Frames are different")
print("-" * 50)
frames = torch.stack(sampled_frames, dim=0)