From e3306951c011de6ddede8158668361781b5f2eb9 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 31 Aug 2025 19:18:52 +0200 Subject: [PATCH] debug frames --- .../policies/rlearn/configuration_rlearn.py | 12 +- .../policies/rlearn/modeling_rlearn.py | 123 +++++++++--------- 2 files changed, 64 insertions(+), 71 deletions(-) diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index 32239bb40..c60eb5b4a 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -71,15 +71,15 @@ class RLearNConfig(PreTrainedConfig): rewind_last3_prob: float = 0.3 mismatch_prob: float = 0.2 - # Logit regression (only supported mode) - logit_eps: float = 1e-6 + # Logit regression (only supported mode) - FIXED: Larger eps to prevent extreme targets + logit_eps: float = 0.02 # Was 1e-6 → logit(±13.8), now 0.02 → logit(±3.9) head_lr_multiplier: float = 2.0 head_weight_init_std: float = 0.05 - # Reward head architecture - head_hidden_dim: int = 1024 # Hidden dimension for reward head - head_num_layers: int = 4 # Number of layers in reward head - head_dropout: float = 0.1 # Dropout in reward head + # Reward head architecture - FIXED: Simpler architecture to prevent flat basins + head_hidden_dim: int = 1024 # Hidden dimension for reward head + head_num_layers: int = 2 # REDUCED: 2 layers instead of 4 to prevent over-regularization + head_dropout: float = 0.05 # REDUCED: Less dropout to prevent conservatism # Normalization presets normalization_mapping: dict[str, NormalizationMode] = field( diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 1c1bdfb66..dfd0b961f 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -184,38 +184,23 @@ class RLearNPolicy(PreTrainedPolicy): depth=config.mlp_predictor_depth ) - # Layer normalization before reward head to stabilize MLP outputs - self.pre_reward_norm = nn.LayerNorm(config.dim_model) + # FIXED: Simpler head architecture to prevent constant output pathology + # Remove LayerNorm (causes flat basin), reduce depth, larger init, less dropout - # Temporal-aware regression head with increased capacity - # Build a deeper MLP for better visual-progress learning - head_layers = [] - - # Input layer: embedding + temporal position -> hidden - head_layers.extend([ - nn.Linear(config.dim_model + 1, config.head_hidden_dim), # +1 for temporal position + # Simple 2-layer MLP with larger initialization to encourage exploration + self.reward_head = nn.Sequential( + nn.Linear(config.dim_model + 1, config.head_hidden_dim), # +1 for temporal position nn.ReLU(), - nn.Dropout(config.head_dropout) - ]) + nn.Dropout(0.05), # Reduced dropout to prevent noise-induced conservatism + nn.Linear(config.head_hidden_dim, 1) + ) - # Hidden layers: multiple layers for complex visual-progress mapping - for _ in range(config.head_num_layers - 2): # -2 for input and output layers - head_layers.extend([ - nn.Linear(config.head_hidden_dim, config.head_hidden_dim), - nn.ReLU(), - nn.Dropout(config.head_dropout) - ]) - - # Output layer: hidden -> logit - head_layers.append(nn.Linear(config.head_hidden_dim, 1)) - - self.reward_head = nn.Sequential(*head_layers) - - # Initialize the deeper temporal-aware head for logit regression + # FIXED: Larger weight initialization to escape flat basins with torch.no_grad(): for module in self.reward_head: if isinstance(module, nn.Linear): - nn.init.normal_(module.weight, 0.0, config.head_weight_init_std) + # Use Xavier/Glorot initialization for better gradient flow + nn.init.xavier_uniform_(module.weight, gain=1.0) nn.init.zeros_(module.bias) # Simple frame dropout probability @@ -335,16 +320,15 @@ class RLearNPolicy(PreTrainedPolicy): # MLP predictor video_frame_embeds = self.mlp_predictor(frame_tokens) - # Get rewards via temporal-aware logit regression head - normalized_embeds = self.pre_reward_norm(video_frame_embeds) + # Get rewards via temporal-aware logit regression head (no pre-normalization) # Add temporal position information - B, T_pred = normalized_embeds.shape[:2] - temporal_pos = torch.linspace(0, 1, T_pred, device=normalized_embeds.device) + B, T_pred = video_frame_embeds.shape[:2] + temporal_pos = torch.linspace(0, 1, T_pred, device=video_frame_embeds.device) temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_pred, 1) # (B, T, 1) # Concatenate embeddings with temporal position - temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1) # (B, T, D+1) + temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1) # (B, T, D+1) # Forward through temporal-aware head raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T) @@ -414,10 +398,9 @@ class RLearNPolicy(PreTrainedPolicy): # Check frame-to-frame differences in raw input if T > 1: - raw_frame_diffs = torch.norm( - frames[:, 1:, :, :, :] - frames[:, :-1, :, :, :], - dim=(2, 3, 4) # Across C, H, W - ).mean() + # FIXED: Use proper tensor operations for difference calculation + frame_diffs = (frames[:, 1:, :, :, :] - frames[:, :-1, :, :, :]).pow(2).sum(dim=(2, 3, 4)).sqrt() + raw_frame_diffs = frame_diffs.mean() print(f"Raw input frame differences: {raw_frame_diffs:.6f}") if raw_frame_diffs < 0.001: @@ -428,10 +411,9 @@ class RLearNPolicy(PreTrainedPolicy): # Check processed pixel values first_sample_pixels = inputs['pixel_values'][:T] # First sample's pixels if T > 1: - pixel_diffs = torch.norm( - first_sample_pixels[1:] - first_sample_pixels[:-1], - dim=(1, 2, 3) # Across C, H, W - ).mean() + # FIXED: Use proper tensor operations + pixel_frame_diffs = (first_sample_pixels[1:] - first_sample_pixels[:-1]).pow(2).sum(dim=(1, 2, 3)).sqrt() + pixel_diffs = pixel_frame_diffs.mean() print(f"Processed pixel_values differences: {pixel_diffs:.6f}") if pixel_diffs < 0.001: @@ -441,16 +423,17 @@ class RLearNPolicy(PreTrainedPolicy): # Check if all samples in batch have same first frame if B > 1: - batch_first_frame_diff = torch.norm( - inputs['pixel_values'][::T] - inputs['pixel_values'][0].unsqueeze(0), - dim=(1, 2, 3) - ).mean() - print(f"Batch first-frame differences: {batch_first_frame_diff:.6f}") - - if batch_first_frame_diff < 0.001: - print(f" ⚠️ ALL BATCH SAMPLES HAVE SAME FIRST FRAME! Diff: {batch_first_frame_diff:.8f}") - else: - print(f" ✓ Batch samples have different first frames. Diff: {batch_first_frame_diff:.6f}") + # FIXED: Use proper tensor operations + batch_first_frames = inputs['pixel_values'][::T] # Every T-th frame (first frame of each sample) + if len(batch_first_frames) > 1: + first_frame_diffs = (batch_first_frames[1:] - batch_first_frames[0].unsqueeze(0)).pow(2).sum(dim=(1, 2, 3)).sqrt() + batch_first_frame_diff = first_frame_diffs.mean() + print(f"Batch first-frame differences: {batch_first_frame_diff:.6f}") + + if batch_first_frame_diff < 0.001: + print(f" ⚠️ ALL BATCH SAMPLES HAVE SAME FIRST FRAME! Diff: {batch_first_frame_diff:.8f}") + else: + print(f" ✓ Batch samples have different first frames. Diff: {batch_first_frame_diff:.6f}") # Check feature statistics feature_mean = vision_features.mean().item() @@ -622,15 +605,14 @@ class RLearNPolicy(PreTrainedPolicy): # During inference, we might not want to compute loss if not self.training and target is None: # Return predictions without loss using temporal-aware head - normalized_embeds = self.pre_reward_norm(video_frame_embeds) # Add temporal position information - B_inf, T_inf = normalized_embeds.shape[:2] - temporal_pos = torch.linspace(0, 1, T_inf, device=normalized_embeds.device) + B_inf, T_inf = video_frame_embeds.shape[:2] + temporal_pos = torch.linspace(0, 1, T_inf, device=video_frame_embeds.device) temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B_inf, T_inf, 1) # Concatenate and forward through temporal-aware head - temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1) + temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1) raw_logits = self.reward_head(temporal_input).squeeze(-1) rewards = torch.sigmoid(raw_logits) return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()} @@ -639,24 +621,30 @@ class RLearNPolicy(PreTrainedPolicy): loss_start = time.perf_counter() # Get model outputs with temporal-aware head - normalized_embeds = self.pre_reward_norm(video_frame_embeds) # Add temporal position information - temporal_pos = torch.linspace(0, 1, T_eff, device=normalized_embeds.device) + temporal_pos = torch.linspace(0, 1, T_eff, device=video_frame_embeds.device) temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_eff, 1) # (B, T_eff, 1) # Concatenate embeddings with temporal position - temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1) # (B, T_eff, D+1) + temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1) # (B, T_eff, D+1) # Forward through temporal-aware head raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T_eff) - # Logit regression: transform targets to logit space and compute MSE on logits + # FIXED: More robust logit regression with gradient protection eps = self.config.logit_eps target_expanded = target.expand(B, -1)[:, :T_eff] # Expand and trim to T_eff target_clamped = torch.clamp(target_expanded, eps, 1 - eps) target_logits = torch.logit(target_clamped) - loss = F.mse_loss(raw_logits, target_logits, reduction='mean') + + # Use Smooth L1 loss instead of MSE for better gradient stability + loss = F.smooth_l1_loss(raw_logits, target_logits, reduction='mean', beta=1.0) + + # Clip gradients specifically for the reward head during backward pass + # This prevents extreme gradients from corrupting AdamW momentum + if self.training: + raw_logits.register_hook(lambda grad: torch.clamp(grad, -10.0, 10.0)) # For logging, compute sigmoid predictions predicted_rewards = torch.sigmoid(raw_logits) @@ -698,15 +686,14 @@ class RLearNPolicy(PreTrainedPolicy): mismatch_embeds = self.mlp_predictor(mismatch_tokens) # Predict near-zero progress for mismatched pairs with temporal awareness - normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds) # Add temporal position information for mismatch computation - T_mismatch = normalized_mismatch_embeds.shape[1] - temporal_pos_mm = torch.linspace(0, 1, T_mismatch, device=normalized_mismatch_embeds.device) + T_mismatch = mismatch_embeds.shape[1] + temporal_pos_mm = torch.linspace(0, 1, T_mismatch, device=mismatch_embeds.device) temporal_pos_mm = temporal_pos_mm.unsqueeze(0).unsqueeze(-1).expand(B, T_mismatch, 1) # Concatenate mismatch embeddings with temporal position - temporal_input_mm = torch.cat([normalized_mismatch_embeds, temporal_pos_mm], dim=-1) + temporal_input_mm = torch.cat([mismatch_embeds, temporal_pos_mm], dim=-1) # Forward through temporal-aware head mismatch_raw_logits = self.reward_head(temporal_input_mm).squeeze(-1) @@ -977,7 +964,7 @@ class RLearNPolicy(PreTrainedPolicy): B = len(episode_indices) T = self.config.max_seq_len - # Get raw image data + # Get raw image data - this contains the window of frames provided by the dataset raw_frames = extract_visual_sequence(batch, target_seq_len=None) available_T = raw_frames.shape[1] @@ -1044,13 +1031,19 @@ class RLearNPolicy(PreTrainedPolicy): print(f"Unique window indices: {unique_indices} out of {len(window_indices)}") if unique_indices == 1: print(f" ⚠️ ALL WINDOW INDICES ARE THE SAME! Index: {window_indices[0]}") + elif unique_indices < T // 2: + print(f" ⚠️ TOO FEW UNIQUE INDICES! Only {unique_indices} unique frames") + else: + print(f" ✓ Good frame diversity: {unique_indices} unique frames") # Check frame tensor differences if len(frame_tensors) > 1: - frame_diff = torch.norm(frame_tensors[1] - frame_tensors[0]).item() - print(f"First frame difference: {frame_diff:.6f}") + frame_diff = (frame_tensors[1] - frame_tensors[0]).pow(2).sum().sqrt().item() + print(f"First vs second frame difference: {frame_diff:.6f}") if frame_diff < 0.001: print(f" ⚠️ CONSECUTIVE SAMPLED FRAMES ARE NEARLY IDENTICAL!") + else: + print(f" ✓ Frames are different") print("-" * 50) frames = torch.stack(sampled_frames, dim=0)