debug frames

This commit is contained in:
Pepijn
2025-08-31 19:18:52 +02:00
parent 10e36f2453
commit e3306951c0
2 changed files with 64 additions and 71 deletions
@@ -71,15 +71,15 @@ class RLearNConfig(PreTrainedConfig):
rewind_last3_prob: float = 0.3 rewind_last3_prob: float = 0.3
mismatch_prob: float = 0.2 mismatch_prob: float = 0.2
# Logit regression (only supported mode) # Logit regression (only supported mode) - FIXED: Larger eps to prevent extreme targets
logit_eps: float = 1e-6 logit_eps: float = 0.02 # Was 1e-6 → logit(±13.8), now 0.02 → logit(±3.9)
head_lr_multiplier: float = 2.0 head_lr_multiplier: float = 2.0
head_weight_init_std: float = 0.05 head_weight_init_std: float = 0.05
# Reward head architecture # Reward head architecture - FIXED: Simpler architecture to prevent flat basins
head_hidden_dim: int = 1024 # Hidden dimension for reward head head_hidden_dim: int = 1024 # Hidden dimension for reward head
head_num_layers: int = 4 # Number of layers in reward head head_num_layers: int = 2 # REDUCED: 2 layers instead of 4 to prevent over-regularization
head_dropout: float = 0.1 # Dropout in reward head head_dropout: float = 0.05 # REDUCED: Less dropout to prevent conservatism
# Normalization presets # Normalization presets
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
+58 -65
View File
@@ -184,38 +184,23 @@ class RLearNPolicy(PreTrainedPolicy):
depth=config.mlp_predictor_depth depth=config.mlp_predictor_depth
) )
# Layer normalization before reward head to stabilize MLP outputs # FIXED: Simpler head architecture to prevent constant output pathology
self.pre_reward_norm = nn.LayerNorm(config.dim_model) # Remove LayerNorm (causes flat basin), reduce depth, larger init, less dropout
# Temporal-aware regression head with increased capacity # Simple 2-layer MLP with larger initialization to encourage exploration
# Build a deeper MLP for better visual-progress learning self.reward_head = nn.Sequential(
head_layers = [] nn.Linear(config.dim_model + 1, config.head_hidden_dim), # +1 for temporal position
# Input layer: embedding + temporal position -> hidden
head_layers.extend([
nn.Linear(config.dim_model + 1, config.head_hidden_dim), # +1 for temporal position
nn.ReLU(), nn.ReLU(),
nn.Dropout(config.head_dropout) nn.Dropout(0.05), # Reduced dropout to prevent noise-induced conservatism
]) nn.Linear(config.head_hidden_dim, 1)
)
# Hidden layers: multiple layers for complex visual-progress mapping # FIXED: Larger weight initialization to escape flat basins
for _ in range(config.head_num_layers - 2): # -2 for input and output layers
head_layers.extend([
nn.Linear(config.head_hidden_dim, config.head_hidden_dim),
nn.ReLU(),
nn.Dropout(config.head_dropout)
])
# Output layer: hidden -> logit
head_layers.append(nn.Linear(config.head_hidden_dim, 1))
self.reward_head = nn.Sequential(*head_layers)
# Initialize the deeper temporal-aware head for logit regression
with torch.no_grad(): with torch.no_grad():
for module in self.reward_head: for module in self.reward_head:
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, 0.0, config.head_weight_init_std) # Use Xavier/Glorot initialization for better gradient flow
nn.init.xavier_uniform_(module.weight, gain=1.0)
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
# Simple frame dropout probability # Simple frame dropout probability
@@ -335,16 +320,15 @@ class RLearNPolicy(PreTrainedPolicy):
# MLP predictor # MLP predictor
video_frame_embeds = self.mlp_predictor(frame_tokens) video_frame_embeds = self.mlp_predictor(frame_tokens)
# Get rewards via temporal-aware logit regression head # Get rewards via temporal-aware logit regression head (no pre-normalization)
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
# Add temporal position information # Add temporal position information
B, T_pred = normalized_embeds.shape[:2] B, T_pred = video_frame_embeds.shape[:2]
temporal_pos = torch.linspace(0, 1, T_pred, device=normalized_embeds.device) temporal_pos = torch.linspace(0, 1, T_pred, device=video_frame_embeds.device)
temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_pred, 1) # (B, T, 1) temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_pred, 1) # (B, T, 1)
# Concatenate embeddings with temporal position # Concatenate embeddings with temporal position
temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1) # (B, T, D+1) temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1) # (B, T, D+1)
# Forward through temporal-aware head # Forward through temporal-aware head
raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T) raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T)
@@ -414,10 +398,9 @@ class RLearNPolicy(PreTrainedPolicy):
# Check frame-to-frame differences in raw input # Check frame-to-frame differences in raw input
if T > 1: if T > 1:
raw_frame_diffs = torch.norm( # FIXED: Use proper tensor operations for difference calculation
frames[:, 1:, :, :, :] - frames[:, :-1, :, :, :], frame_diffs = (frames[:, 1:, :, :, :] - frames[:, :-1, :, :, :]).pow(2).sum(dim=(2, 3, 4)).sqrt()
dim=(2, 3, 4) # Across C, H, W raw_frame_diffs = frame_diffs.mean()
).mean()
print(f"Raw input frame differences: {raw_frame_diffs:.6f}") print(f"Raw input frame differences: {raw_frame_diffs:.6f}")
if raw_frame_diffs < 0.001: if raw_frame_diffs < 0.001:
@@ -428,10 +411,9 @@ class RLearNPolicy(PreTrainedPolicy):
# Check processed pixel values # Check processed pixel values
first_sample_pixels = inputs['pixel_values'][:T] # First sample's pixels first_sample_pixels = inputs['pixel_values'][:T] # First sample's pixels
if T > 1: if T > 1:
pixel_diffs = torch.norm( # FIXED: Use proper tensor operations
first_sample_pixels[1:] - first_sample_pixels[:-1], pixel_frame_diffs = (first_sample_pixels[1:] - first_sample_pixels[:-1]).pow(2).sum(dim=(1, 2, 3)).sqrt()
dim=(1, 2, 3) # Across C, H, W pixel_diffs = pixel_frame_diffs.mean()
).mean()
print(f"Processed pixel_values differences: {pixel_diffs:.6f}") print(f"Processed pixel_values differences: {pixel_diffs:.6f}")
if pixel_diffs < 0.001: if pixel_diffs < 0.001:
@@ -441,16 +423,17 @@ class RLearNPolicy(PreTrainedPolicy):
# Check if all samples in batch have same first frame # Check if all samples in batch have same first frame
if B > 1: if B > 1:
batch_first_frame_diff = torch.norm( # FIXED: Use proper tensor operations
inputs['pixel_values'][::T] - inputs['pixel_values'][0].unsqueeze(0), batch_first_frames = inputs['pixel_values'][::T] # Every T-th frame (first frame of each sample)
dim=(1, 2, 3) if len(batch_first_frames) > 1:
).mean() first_frame_diffs = (batch_first_frames[1:] - batch_first_frames[0].unsqueeze(0)).pow(2).sum(dim=(1, 2, 3)).sqrt()
print(f"Batch first-frame differences: {batch_first_frame_diff:.6f}") batch_first_frame_diff = first_frame_diffs.mean()
print(f"Batch first-frame differences: {batch_first_frame_diff:.6f}")
if batch_first_frame_diff < 0.001:
print(f" ⚠️ ALL BATCH SAMPLES HAVE SAME FIRST FRAME! Diff: {batch_first_frame_diff:.8f}") if batch_first_frame_diff < 0.001:
else: print(f" ⚠️ ALL BATCH SAMPLES HAVE SAME FIRST FRAME! Diff: {batch_first_frame_diff:.8f}")
print(f" ✓ Batch samples have different first frames. Diff: {batch_first_frame_diff:.6f}") else:
print(f" ✓ Batch samples have different first frames. Diff: {batch_first_frame_diff:.6f}")
# Check feature statistics # Check feature statistics
feature_mean = vision_features.mean().item() feature_mean = vision_features.mean().item()
@@ -622,15 +605,14 @@ class RLearNPolicy(PreTrainedPolicy):
# During inference, we might not want to compute loss # During inference, we might not want to compute loss
if not self.training and target is None: if not self.training and target is None:
# Return predictions without loss using temporal-aware head # Return predictions without loss using temporal-aware head
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
# Add temporal position information # Add temporal position information
B_inf, T_inf = normalized_embeds.shape[:2] B_inf, T_inf = video_frame_embeds.shape[:2]
temporal_pos = torch.linspace(0, 1, T_inf, device=normalized_embeds.device) temporal_pos = torch.linspace(0, 1, T_inf, device=video_frame_embeds.device)
temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B_inf, T_inf, 1) temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B_inf, T_inf, 1)
# Concatenate and forward through temporal-aware head # Concatenate and forward through temporal-aware head
temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1) temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1)
raw_logits = self.reward_head(temporal_input).squeeze(-1) raw_logits = self.reward_head(temporal_input).squeeze(-1)
rewards = torch.sigmoid(raw_logits) rewards = torch.sigmoid(raw_logits)
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()} return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
@@ -639,24 +621,30 @@ class RLearNPolicy(PreTrainedPolicy):
loss_start = time.perf_counter() loss_start = time.perf_counter()
# Get model outputs with temporal-aware head # Get model outputs with temporal-aware head
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
# Add temporal position information # Add temporal position information
temporal_pos = torch.linspace(0, 1, T_eff, device=normalized_embeds.device) temporal_pos = torch.linspace(0, 1, T_eff, device=video_frame_embeds.device)
temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_eff, 1) # (B, T_eff, 1) temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_eff, 1) # (B, T_eff, 1)
# Concatenate embeddings with temporal position # Concatenate embeddings with temporal position
temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1) # (B, T_eff, D+1) temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1) # (B, T_eff, D+1)
# Forward through temporal-aware head # Forward through temporal-aware head
raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T_eff) raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T_eff)
# Logit regression: transform targets to logit space and compute MSE on logits # FIXED: More robust logit regression with gradient protection
eps = self.config.logit_eps eps = self.config.logit_eps
target_expanded = target.expand(B, -1)[:, :T_eff] # Expand and trim to T_eff target_expanded = target.expand(B, -1)[:, :T_eff] # Expand and trim to T_eff
target_clamped = torch.clamp(target_expanded, eps, 1 - eps) target_clamped = torch.clamp(target_expanded, eps, 1 - eps)
target_logits = torch.logit(target_clamped) target_logits = torch.logit(target_clamped)
loss = F.mse_loss(raw_logits, target_logits, reduction='mean')
# Use Smooth L1 loss instead of MSE for better gradient stability
loss = F.smooth_l1_loss(raw_logits, target_logits, reduction='mean', beta=1.0)
# Clip gradients specifically for the reward head during backward pass
# This prevents extreme gradients from corrupting AdamW momentum
if self.training:
raw_logits.register_hook(lambda grad: torch.clamp(grad, -10.0, 10.0))
# For logging, compute sigmoid predictions # For logging, compute sigmoid predictions
predicted_rewards = torch.sigmoid(raw_logits) predicted_rewards = torch.sigmoid(raw_logits)
@@ -698,15 +686,14 @@ class RLearNPolicy(PreTrainedPolicy):
mismatch_embeds = self.mlp_predictor(mismatch_tokens) mismatch_embeds = self.mlp_predictor(mismatch_tokens)
# Predict near-zero progress for mismatched pairs with temporal awareness # Predict near-zero progress for mismatched pairs with temporal awareness
normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds)
# Add temporal position information for mismatch computation # Add temporal position information for mismatch computation
T_mismatch = normalized_mismatch_embeds.shape[1] T_mismatch = mismatch_embeds.shape[1]
temporal_pos_mm = torch.linspace(0, 1, T_mismatch, device=normalized_mismatch_embeds.device) temporal_pos_mm = torch.linspace(0, 1, T_mismatch, device=mismatch_embeds.device)
temporal_pos_mm = temporal_pos_mm.unsqueeze(0).unsqueeze(-1).expand(B, T_mismatch, 1) temporal_pos_mm = temporal_pos_mm.unsqueeze(0).unsqueeze(-1).expand(B, T_mismatch, 1)
# Concatenate mismatch embeddings with temporal position # Concatenate mismatch embeddings with temporal position
temporal_input_mm = torch.cat([normalized_mismatch_embeds, temporal_pos_mm], dim=-1) temporal_input_mm = torch.cat([mismatch_embeds, temporal_pos_mm], dim=-1)
# Forward through temporal-aware head # Forward through temporal-aware head
mismatch_raw_logits = self.reward_head(temporal_input_mm).squeeze(-1) mismatch_raw_logits = self.reward_head(temporal_input_mm).squeeze(-1)
@@ -977,7 +964,7 @@ class RLearNPolicy(PreTrainedPolicy):
B = len(episode_indices) B = len(episode_indices)
T = self.config.max_seq_len T = self.config.max_seq_len
# Get raw image data # Get raw image data - this contains the window of frames provided by the dataset
raw_frames = extract_visual_sequence(batch, target_seq_len=None) raw_frames = extract_visual_sequence(batch, target_seq_len=None)
available_T = raw_frames.shape[1] available_T = raw_frames.shape[1]
@@ -1044,13 +1031,19 @@ class RLearNPolicy(PreTrainedPolicy):
print(f"Unique window indices: {unique_indices} out of {len(window_indices)}") print(f"Unique window indices: {unique_indices} out of {len(window_indices)}")
if unique_indices == 1: if unique_indices == 1:
print(f" ⚠️ ALL WINDOW INDICES ARE THE SAME! Index: {window_indices[0]}") print(f" ⚠️ ALL WINDOW INDICES ARE THE SAME! Index: {window_indices[0]}")
elif unique_indices < T // 2:
print(f" ⚠️ TOO FEW UNIQUE INDICES! Only {unique_indices} unique frames")
else:
print(f" ✓ Good frame diversity: {unique_indices} unique frames")
# Check frame tensor differences # Check frame tensor differences
if len(frame_tensors) > 1: if len(frame_tensors) > 1:
frame_diff = torch.norm(frame_tensors[1] - frame_tensors[0]).item() frame_diff = (frame_tensors[1] - frame_tensors[0]).pow(2).sum().sqrt().item()
print(f"First frame difference: {frame_diff:.6f}") print(f"First vs second frame difference: {frame_diff:.6f}")
if frame_diff < 0.001: if frame_diff < 0.001:
print(f" ⚠️ CONSECUTIVE SAMPLED FRAMES ARE NEARLY IDENTICAL!") print(f" ⚠️ CONSECUTIVE SAMPLED FRAMES ARE NEARLY IDENTICAL!")
else:
print(f" ✓ Frames are different")
print("-" * 50) print("-" * 50)
frames = torch.stack(sampled_frames, dim=0) frames = torch.stack(sampled_frames, dim=0)