mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 08:09:45 +00:00
debug frames
This commit is contained in:
@@ -71,15 +71,15 @@ class RLearNConfig(PreTrainedConfig):
|
||||
rewind_last3_prob: float = 0.3
|
||||
mismatch_prob: float = 0.2
|
||||
|
||||
# Logit regression (only supported mode)
|
||||
logit_eps: float = 1e-6
|
||||
# Logit regression (only supported mode) - FIXED: Larger eps to prevent extreme targets
|
||||
logit_eps: float = 0.02 # Was 1e-6 → logit(±13.8), now 0.02 → logit(±3.9)
|
||||
head_lr_multiplier: float = 2.0
|
||||
head_weight_init_std: float = 0.05
|
||||
|
||||
# Reward head architecture
|
||||
head_hidden_dim: int = 1024 # Hidden dimension for reward head
|
||||
head_num_layers: int = 4 # Number of layers in reward head
|
||||
head_dropout: float = 0.1 # Dropout in reward head
|
||||
# Reward head architecture - FIXED: Simpler architecture to prevent flat basins
|
||||
head_hidden_dim: int = 1024 # Hidden dimension for reward head
|
||||
head_num_layers: int = 2 # REDUCED: 2 layers instead of 4 to prevent over-regularization
|
||||
head_dropout: float = 0.05 # REDUCED: Less dropout to prevent conservatism
|
||||
|
||||
# Normalization presets
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
|
||||
@@ -184,38 +184,23 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
depth=config.mlp_predictor_depth
|
||||
)
|
||||
|
||||
# Layer normalization before reward head to stabilize MLP outputs
|
||||
self.pre_reward_norm = nn.LayerNorm(config.dim_model)
|
||||
# FIXED: Simpler head architecture to prevent constant output pathology
|
||||
# Remove LayerNorm (causes flat basin), reduce depth, larger init, less dropout
|
||||
|
||||
# Temporal-aware regression head with increased capacity
|
||||
# Build a deeper MLP for better visual-progress learning
|
||||
head_layers = []
|
||||
|
||||
# Input layer: embedding + temporal position -> hidden
|
||||
head_layers.extend([
|
||||
nn.Linear(config.dim_model + 1, config.head_hidden_dim), # +1 for temporal position
|
||||
# Simple 2-layer MLP with larger initialization to encourage exploration
|
||||
self.reward_head = nn.Sequential(
|
||||
nn.Linear(config.dim_model + 1, config.head_hidden_dim), # +1 for temporal position
|
||||
nn.ReLU(),
|
||||
nn.Dropout(config.head_dropout)
|
||||
])
|
||||
nn.Dropout(0.05), # Reduced dropout to prevent noise-induced conservatism
|
||||
nn.Linear(config.head_hidden_dim, 1)
|
||||
)
|
||||
|
||||
# Hidden layers: multiple layers for complex visual-progress mapping
|
||||
for _ in range(config.head_num_layers - 2): # -2 for input and output layers
|
||||
head_layers.extend([
|
||||
nn.Linear(config.head_hidden_dim, config.head_hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(config.head_dropout)
|
||||
])
|
||||
|
||||
# Output layer: hidden -> logit
|
||||
head_layers.append(nn.Linear(config.head_hidden_dim, 1))
|
||||
|
||||
self.reward_head = nn.Sequential(*head_layers)
|
||||
|
||||
# Initialize the deeper temporal-aware head for logit regression
|
||||
# FIXED: Larger weight initialization to escape flat basins
|
||||
with torch.no_grad():
|
||||
for module in self.reward_head:
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, 0.0, config.head_weight_init_std)
|
||||
# Use Xavier/Glorot initialization for better gradient flow
|
||||
nn.init.xavier_uniform_(module.weight, gain=1.0)
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
# Simple frame dropout probability
|
||||
@@ -335,16 +320,15 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# MLP predictor
|
||||
video_frame_embeds = self.mlp_predictor(frame_tokens)
|
||||
|
||||
# Get rewards via temporal-aware logit regression head
|
||||
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
|
||||
# Get rewards via temporal-aware logit regression head (no pre-normalization)
|
||||
|
||||
# Add temporal position information
|
||||
B, T_pred = normalized_embeds.shape[:2]
|
||||
temporal_pos = torch.linspace(0, 1, T_pred, device=normalized_embeds.device)
|
||||
B, T_pred = video_frame_embeds.shape[:2]
|
||||
temporal_pos = torch.linspace(0, 1, T_pred, device=video_frame_embeds.device)
|
||||
temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_pred, 1) # (B, T, 1)
|
||||
|
||||
# Concatenate embeddings with temporal position
|
||||
temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1) # (B, T, D+1)
|
||||
temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1) # (B, T, D+1)
|
||||
|
||||
# Forward through temporal-aware head
|
||||
raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T)
|
||||
@@ -414,10 +398,9 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Check frame-to-frame differences in raw input
|
||||
if T > 1:
|
||||
raw_frame_diffs = torch.norm(
|
||||
frames[:, 1:, :, :, :] - frames[:, :-1, :, :, :],
|
||||
dim=(2, 3, 4) # Across C, H, W
|
||||
).mean()
|
||||
# FIXED: Use proper tensor operations for difference calculation
|
||||
frame_diffs = (frames[:, 1:, :, :, :] - frames[:, :-1, :, :, :]).pow(2).sum(dim=(2, 3, 4)).sqrt()
|
||||
raw_frame_diffs = frame_diffs.mean()
|
||||
print(f"Raw input frame differences: {raw_frame_diffs:.6f}")
|
||||
|
||||
if raw_frame_diffs < 0.001:
|
||||
@@ -428,10 +411,9 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# Check processed pixel values
|
||||
first_sample_pixels = inputs['pixel_values'][:T] # First sample's pixels
|
||||
if T > 1:
|
||||
pixel_diffs = torch.norm(
|
||||
first_sample_pixels[1:] - first_sample_pixels[:-1],
|
||||
dim=(1, 2, 3) # Across C, H, W
|
||||
).mean()
|
||||
# FIXED: Use proper tensor operations
|
||||
pixel_frame_diffs = (first_sample_pixels[1:] - first_sample_pixels[:-1]).pow(2).sum(dim=(1, 2, 3)).sqrt()
|
||||
pixel_diffs = pixel_frame_diffs.mean()
|
||||
print(f"Processed pixel_values differences: {pixel_diffs:.6f}")
|
||||
|
||||
if pixel_diffs < 0.001:
|
||||
@@ -441,16 +423,17 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Check if all samples in batch have same first frame
|
||||
if B > 1:
|
||||
batch_first_frame_diff = torch.norm(
|
||||
inputs['pixel_values'][::T] - inputs['pixel_values'][0].unsqueeze(0),
|
||||
dim=(1, 2, 3)
|
||||
).mean()
|
||||
print(f"Batch first-frame differences: {batch_first_frame_diff:.6f}")
|
||||
|
||||
if batch_first_frame_diff < 0.001:
|
||||
print(f" ⚠️ ALL BATCH SAMPLES HAVE SAME FIRST FRAME! Diff: {batch_first_frame_diff:.8f}")
|
||||
else:
|
||||
print(f" ✓ Batch samples have different first frames. Diff: {batch_first_frame_diff:.6f}")
|
||||
# FIXED: Use proper tensor operations
|
||||
batch_first_frames = inputs['pixel_values'][::T] # Every T-th frame (first frame of each sample)
|
||||
if len(batch_first_frames) > 1:
|
||||
first_frame_diffs = (batch_first_frames[1:] - batch_first_frames[0].unsqueeze(0)).pow(2).sum(dim=(1, 2, 3)).sqrt()
|
||||
batch_first_frame_diff = first_frame_diffs.mean()
|
||||
print(f"Batch first-frame differences: {batch_first_frame_diff:.6f}")
|
||||
|
||||
if batch_first_frame_diff < 0.001:
|
||||
print(f" ⚠️ ALL BATCH SAMPLES HAVE SAME FIRST FRAME! Diff: {batch_first_frame_diff:.8f}")
|
||||
else:
|
||||
print(f" ✓ Batch samples have different first frames. Diff: {batch_first_frame_diff:.6f}")
|
||||
|
||||
# Check feature statistics
|
||||
feature_mean = vision_features.mean().item()
|
||||
@@ -622,15 +605,14 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# During inference, we might not want to compute loss
|
||||
if not self.training and target is None:
|
||||
# Return predictions without loss using temporal-aware head
|
||||
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
|
||||
|
||||
# Add temporal position information
|
||||
B_inf, T_inf = normalized_embeds.shape[:2]
|
||||
temporal_pos = torch.linspace(0, 1, T_inf, device=normalized_embeds.device)
|
||||
B_inf, T_inf = video_frame_embeds.shape[:2]
|
||||
temporal_pos = torch.linspace(0, 1, T_inf, device=video_frame_embeds.device)
|
||||
temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B_inf, T_inf, 1)
|
||||
|
||||
# Concatenate and forward through temporal-aware head
|
||||
temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1)
|
||||
temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1)
|
||||
raw_logits = self.reward_head(temporal_input).squeeze(-1)
|
||||
rewards = torch.sigmoid(raw_logits)
|
||||
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
|
||||
@@ -639,24 +621,30 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
loss_start = time.perf_counter()
|
||||
|
||||
# Get model outputs with temporal-aware head
|
||||
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
|
||||
|
||||
# Add temporal position information
|
||||
temporal_pos = torch.linspace(0, 1, T_eff, device=normalized_embeds.device)
|
||||
temporal_pos = torch.linspace(0, 1, T_eff, device=video_frame_embeds.device)
|
||||
temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_eff, 1) # (B, T_eff, 1)
|
||||
|
||||
# Concatenate embeddings with temporal position
|
||||
temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1) # (B, T_eff, D+1)
|
||||
temporal_input = torch.cat([video_frame_embeds, temporal_pos], dim=-1) # (B, T_eff, D+1)
|
||||
|
||||
# Forward through temporal-aware head
|
||||
raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T_eff)
|
||||
|
||||
# Logit regression: transform targets to logit space and compute MSE on logits
|
||||
# FIXED: More robust logit regression with gradient protection
|
||||
eps = self.config.logit_eps
|
||||
target_expanded = target.expand(B, -1)[:, :T_eff] # Expand and trim to T_eff
|
||||
target_clamped = torch.clamp(target_expanded, eps, 1 - eps)
|
||||
target_logits = torch.logit(target_clamped)
|
||||
loss = F.mse_loss(raw_logits, target_logits, reduction='mean')
|
||||
|
||||
# Use Smooth L1 loss instead of MSE for better gradient stability
|
||||
loss = F.smooth_l1_loss(raw_logits, target_logits, reduction='mean', beta=1.0)
|
||||
|
||||
# Clip gradients specifically for the reward head during backward pass
|
||||
# This prevents extreme gradients from corrupting AdamW momentum
|
||||
if self.training:
|
||||
raw_logits.register_hook(lambda grad: torch.clamp(grad, -10.0, 10.0))
|
||||
|
||||
# For logging, compute sigmoid predictions
|
||||
predicted_rewards = torch.sigmoid(raw_logits)
|
||||
@@ -698,15 +686,14 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
mismatch_embeds = self.mlp_predictor(mismatch_tokens)
|
||||
|
||||
# Predict near-zero progress for mismatched pairs with temporal awareness
|
||||
normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds)
|
||||
|
||||
# Add temporal position information for mismatch computation
|
||||
T_mismatch = normalized_mismatch_embeds.shape[1]
|
||||
temporal_pos_mm = torch.linspace(0, 1, T_mismatch, device=normalized_mismatch_embeds.device)
|
||||
T_mismatch = mismatch_embeds.shape[1]
|
||||
temporal_pos_mm = torch.linspace(0, 1, T_mismatch, device=mismatch_embeds.device)
|
||||
temporal_pos_mm = temporal_pos_mm.unsqueeze(0).unsqueeze(-1).expand(B, T_mismatch, 1)
|
||||
|
||||
# Concatenate mismatch embeddings with temporal position
|
||||
temporal_input_mm = torch.cat([normalized_mismatch_embeds, temporal_pos_mm], dim=-1)
|
||||
temporal_input_mm = torch.cat([mismatch_embeds, temporal_pos_mm], dim=-1)
|
||||
|
||||
# Forward through temporal-aware head
|
||||
mismatch_raw_logits = self.reward_head(temporal_input_mm).squeeze(-1)
|
||||
@@ -977,7 +964,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
B = len(episode_indices)
|
||||
T = self.config.max_seq_len
|
||||
|
||||
# Get raw image data
|
||||
# Get raw image data - this contains the window of frames provided by the dataset
|
||||
raw_frames = extract_visual_sequence(batch, target_seq_len=None)
|
||||
available_T = raw_frames.shape[1]
|
||||
|
||||
@@ -1044,13 +1031,19 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
print(f"Unique window indices: {unique_indices} out of {len(window_indices)}")
|
||||
if unique_indices == 1:
|
||||
print(f" ⚠️ ALL WINDOW INDICES ARE THE SAME! Index: {window_indices[0]}")
|
||||
elif unique_indices < T // 2:
|
||||
print(f" ⚠️ TOO FEW UNIQUE INDICES! Only {unique_indices} unique frames")
|
||||
else:
|
||||
print(f" ✓ Good frame diversity: {unique_indices} unique frames")
|
||||
|
||||
# Check frame tensor differences
|
||||
if len(frame_tensors) > 1:
|
||||
frame_diff = torch.norm(frame_tensors[1] - frame_tensors[0]).item()
|
||||
print(f"First frame difference: {frame_diff:.6f}")
|
||||
frame_diff = (frame_tensors[1] - frame_tensors[0]).pow(2).sum().sqrt().item()
|
||||
print(f"First vs second frame difference: {frame_diff:.6f}")
|
||||
if frame_diff < 0.001:
|
||||
print(f" ⚠️ CONSECUTIVE SAMPLED FRAMES ARE NEARLY IDENTICAL!")
|
||||
else:
|
||||
print(f" ✓ Frames are different")
|
||||
print("-" * 50)
|
||||
|
||||
frames = torch.stack(sampled_frames, dim=0)
|
||||
|
||||
Reference in New Issue
Block a user