This commit is contained in:
Pepijn
2025-08-31 15:40:00 +02:00
parent def71cc439
commit 1e1b010257
2 changed files with 140 additions and 427 deletions
@@ -60,59 +60,32 @@ class RLearNConfig(PreTrainedConfig):
learning_rate: float = 1e-3 learning_rate: float = 1e-3
weight_decay: float = 0.01 weight_decay: float = 0.01
# Backward compatibility field (not used in current implementation)
use_tanh_head: bool = False
# Performance optimizations # Performance optimizations
use_amp: bool = True # Mixed precision training for speed boost use_amp: bool = True
compile_model: bool = True # torch.compile for additional speedup compile_model: bool = True
# ReWiND-specific parameters # ReWiND augmentation
use_video_rewind: bool = True # Enable video rewinding augmentation rewind_prob: float = 0.5
rewind_prob: float = 0.5 # Reduced from 0.8 to avoid too many artifacts rewind_last3_prob: float = 0.3
rewind_last3_prob: float = 0.3 # Increased to favor smaller rewinds mismatch_prob: float = 0.2
use_mismatch_loss: bool = False # Enable mismatched language-video loss
mismatch_prob: float = (
0.2 # Probability to include a mismatched video-language forward pass (paper: ~20%)
)
# NEW: Loss and head improvements # Logit regression (only supported mode)
use_logit_regression: bool = True # Use logit space regression instead of sigmoid+MSE logit_eps: float = 1e-6
logit_eps: float = 1e-6 # Clipping epsilon for logit transform: logit(clamp(target, eps, 1-eps)) head_lr_multiplier: float = 2.0
head_lr_multiplier: float = 2.0 # Increase head learning rate relative to base head_weight_init_std: float = 0.05
head_weight_init_std: float = 0.05 # Larger head weight initialization for faster positive logits
remove_head_bias_wd: bool = True # Remove weight decay from head bias
# Window sampling improvements
use_random_anchor_sampling: bool = True # Use explicit random anchor sampling during training
use_window_relative_progress: bool = True # Use window-relative progress (0-1 across window) instead of episode-relative
# Loss hyperparameters (simplified for ReWiND)
# The main loss is just MSE between predicted and target progress
# Normalization presets # Normalization presets
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: { default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD, "VISUAL": NormalizationMode.MEAN_STD,
# Language is tokenized at the encoder level; no numeric normalization here.
} }
) )
# Architectural knobs to better mirror ReWiND # Architecture
num_register_tokens: int = 4 # register / memory tokens, can't hurt num_register_tokens: int = 4
mlp_predictor_depth: int = 3 # depth of the per-frame MLP head mlp_predictor_depth: int = 3
# Loss configuration - supports both sigmoid+MSE and logit regression
# Evaluation visualization parameters # Required path to episodes.jsonl for episode boundaries
enable_eval_visualizations: bool = False # Enable reward evaluation visualizations during training
eval_visualization_freq: int = 1000 # Steps between evaluation visualizations
eval_holdout_episodes: int = 9 # Number of episodes to hold out for evaluation
eval_max_frames: int = 128 # Maximum frames per episode for evaluation
eval_visualization_seed: int = 42 # Seed for reproducible episode selection
# Optional: path to episodes.jsonl to build full-episode indices automatically
# Default to common dataset layout: <dataset_root>/meta/episodes.jsonl
episodes_jsonl_path: str | None = "meta/episodes.jsonl" episodes_jsonl_path: str | None = "meta/episodes.jsonl"
def validate_features(self) -> None: def validate_features(self) -> None:
+126 -386
View File
@@ -165,15 +165,8 @@ class RLearNPolicy(PreTrainedPolicy):
# Stronger temporal positional encoding # Stronger temporal positional encoding
self.temporal_pos_embedding = nn.Parameter(torch.randn(config.max_seq_len, config.dim_model) * 0.1) self.temporal_pos_embedding = nn.Parameter(torch.randn(config.max_seq_len, config.dim_model) * 0.1)
# CRITICAL: Frame-specific MLPs prevent temporal over-smoothing # Single MLP processes all frames
# Problem: Transformer attention was making all 16 predictions identical (e.g. all 0.34) self.frame_mlp = nn.Linear(config.dim_model, config.dim_model)
# Solution: Each temporal position gets its own specialized MLP processing
# Frame 0 → MLP[0], Frame 1 → MLP[1], ..., Frame 15 → MLP[15]
# This creates distinct pathways for each frame while preserving attention context
self.frame_specific_mlp = nn.ModuleList([
nn.Linear(config.dim_model, config.dim_model)
for _ in range(config.max_seq_len) # 16 separate MLPs for 16 frame positions
])
# Register / memory / attention sink tokens # Register / memory / attention sink tokens
self.num_register_tokens = config.num_register_tokens self.num_register_tokens = config.num_register_tokens
@@ -190,21 +183,13 @@ class RLearNPolicy(PreTrainedPolicy):
# Layer normalization before reward head to stabilize MLP outputs # Layer normalization before reward head to stabilize MLP outputs
self.pre_reward_norm = nn.LayerNorm(config.dim_model) self.pre_reward_norm = nn.LayerNorm(config.dim_model)
# Regression head - supports both logit and sigmoid modes # Regression head (logit mode only)
self.reward_head = nn.Linear(config.dim_model, 1) self.reward_head = nn.Linear(config.dim_model, 1)
# Initialize head with improved settings # Initialize head for logit regression
with torch.no_grad(): with torch.no_grad():
if config.use_logit_regression: self.reward_head.weight.normal_(0.0, config.head_weight_init_std)
# Logit regression: can use larger weights since no saturation issues self.reward_head.bias.fill_(0.0)
self.reward_head.weight.normal_(0.0, config.head_weight_init_std)
self.reward_head.bias.fill_(0.0) # Neutral start in logit space
else:
# Sigmoid mode: moderate initialization
self.reward_head.weight.normal_(0.0, 0.02)
self.reward_head.bias.fill_(0.0)
self.sigmoid = nn.Sigmoid() if not config.use_logit_regression else None
# Simple frame dropout probability # Simple frame dropout probability
self.frame_dropout_p = config.frame_dropout_p self.frame_dropout_p = config.frame_dropout_p
@@ -230,54 +215,21 @@ class RLearNPolicy(PreTrainedPolicy):
# Continue without compilation # Continue without compilation
def get_optim_params(self) -> list: def get_optim_params(self) -> list:
"""Return parameter groups with custom LR and weight decay settings.""" """Return parameter groups with head LR boost."""
# Collect trainable parameters
base_params = [] base_params = []
head_weight_params = [] head_params = []
head_bias_params = []
for name, param in self.named_parameters(): for name, param in self.named_parameters():
if not param.requires_grad: if param.requires_grad:
continue if "reward_head" in name:
head_params.append(param)
if "reward_head" in name:
if "bias" in name:
head_bias_params.append(param)
else: else:
head_weight_params.append(param) base_params.append(param)
else:
base_params.append(param)
# Create parameter groups with different settings return [
param_groups = [] {"params": base_params},
{"params": head_params, "lr": self.config.learning_rate * self.config.head_lr_multiplier}
# Base parameters (everything except head) ]
if base_params:
param_groups.append({
"params": base_params,
"name": "base"
})
# Head weight parameters (higher LR)
if head_weight_params:
param_groups.append({
"params": head_weight_params,
"lr": self.config.learning_rate * self.config.head_lr_multiplier,
"name": "head_weights"
})
# Head bias parameters (higher LR, optionally no weight decay)
if head_bias_params:
head_bias_group = {
"params": head_bias_params,
"lr": self.config.learning_rate * self.config.head_lr_multiplier,
"name": "head_bias"
}
if self.config.remove_head_bias_wd:
head_bias_group["weight_decay"] = 0.0
param_groups.append(head_bias_group)
return param_groups
def reset(self): def reset(self):
pass pass
@@ -350,28 +302,16 @@ class RLearNPolicy(PreTrainedPolicy):
# Unpack and get video token features # Unpack and get video token features
_, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d') _, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d')
# Apply frame-specific processing to prevent over-smoothing # Process all frames with single MLP
frame_specific_embeds = [] frame_tokens = self.frame_mlp(attended_video_tokens) # (B, T, D)
T_video = attended_video_tokens.shape[1]
for t in range(T_video):
# Apply frame-specific MLP to each temporal position
frame_embed = self.frame_specific_mlp[t](attended_video_tokens[:, t])
frame_specific_embeds.append(frame_embed)
frame_specific_tokens = torch.stack(frame_specific_embeds, dim=1) # (B, T, D)
# MLP predictor # MLP predictor
video_frame_embeds = self.mlp_predictor(frame_specific_tokens) video_frame_embeds = self.mlp_predictor(frame_tokens)
# Get rewards via linear head # Get rewards via logit regression head
normalized_embeds = self.pre_reward_norm(video_frame_embeds) normalized_embeds = self.pre_reward_norm(video_frame_embeds)
raw_logits = self.reward_head(normalized_embeds).squeeze(-1) # (B, T) raw_logits = self.reward_head(normalized_embeds).squeeze(-1) # (B, T)
return torch.sigmoid(raw_logits) # Apply sigmoid for final predictions
if self.config.use_logit_regression:
# In logit mode, apply sigmoid at inference
return torch.sigmoid(raw_logits)
else:
# In sigmoid mode, apply sigmoid as usual
return self.sigmoid(raw_logits)
def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# Initial version: no-op; rely on upstream processors if any # Initial version: no-op; rely on upstream processors if any
@@ -444,25 +384,20 @@ class RLearNPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
# NEW: Explicit random anchor window sampling for training # Always use random anchor window sampling
if self.training: frames, anchor_stats = self._sample_random_anchor_windows(batch)
frames, anchor_stats = self._sample_random_anchor_windows(batch)
else:
# During inference, use the generic extractor
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
anchor_stats = None
B, T, C, H, W = frames.shape B, T, C, H, W = frames.shape
device = next(self.parameters()).device device = next(self.parameters()).device
frames = frames.to(device) frames = frames.to(device)
# Apply video rewinding augmentation during training (FIXED: no constant padding) # Apply video rewinding augmentation (always enabled during training)
augmented_target = None augmented_target = None
if self.training and self.config.use_video_rewind: if self.training:
frames, augmented_target = apply_video_rewind_fixed( frames, augmented_target = apply_video_rewind(
frames, frames,
rewind_prob=self.config.rewind_prob, rewind_prob=self.config.rewind_prob,
last3_prob=getattr(self.config, "rewind_last3_prob", None), last3_prob=self.config.rewind_last3_prob,
) )
# Apply stride and frame dropout # Apply stride and frame dropout
@@ -518,42 +453,29 @@ class RLearNPolicy(PreTrainedPolicy):
# Unpack and get video token features # Unpack and get video token features
_, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d') _, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d')
# Apply frame-specific processing to prevent over-smoothing # Process all frames with single MLP
frame_specific_embeds = [] frame_tokens = self.frame_mlp(attended_video_tokens) # (B, T, D)
T_video = attended_video_tokens.shape[1]
for t in range(T_video):
# Apply frame-specific MLP to each temporal position
frame_embed = self.frame_specific_mlp[t](attended_video_tokens[:, t])
frame_specific_embeds.append(frame_embed)
frame_specific_tokens = torch.stack(frame_specific_embeds, dim=1) # (B, T, D)
# MLP predictor # MLP predictor
video_frame_embeds = self.mlp_predictor(frame_specific_tokens) video_frame_embeds = self.mlp_predictor(frame_tokens)
transformer_time = time.perf_counter() - transformer_start transformer_time = time.perf_counter() - transformer_start
# Generate progress labels on-the-fly (ReWiND approach) # Generate progress labels on-the-fly (ReWiND approach)
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window # IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
loss_dict: dict[str, float] = {} loss_dict: dict[str, float] = {}
# Check if video rewinding already set the target # Generate progress targets
if self.training and self.config.use_video_rewind and augmented_target is not None: if self.training and augmented_target is not None:
# Use the augmented target from video rewinding and align with temporal subsampling # Video rewind already generated targets
target = augmented_target[:, idx] target = augmented_target[:, idx]
elif self.training and anchor_stats is not None and not anchor_stats.get("fallback_used", False):
# NEW: Calculate progress using the known random anchors
target = self._calculate_anchor_based_progress(batch, anchor_stats, T_eff)
else: else:
# Fallback: Calculate episode progress the old way # Use anchor-based window-relative progress
episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch) if anchor_stats.get("fallback_used", False):
if episode_indices is not None and frame_indices is not None and self.episode_data_index is not None:
target = self._calculate_episode_progress(batch, episode_indices, frame_indices, T_eff, idx)
else:
raise ValueError( raise ValueError(
"No episode information found to build full-episode progress. " "Anchor-based sampling failed. Ensure 'episode_index', 'frame_index' are in batch "
"Expected 'episode_index' and 'frame_index' in batch and a valid 'episode_data_index' on the policy. " "and 'episode_data_index' is loaded from episodes.jsonl"
"Please pass RLearNPolicy(episode_data_index=...) built from episodes.jsonl (per-episode lengths), "
"and ensure the dataset exposes 'episode_index' and 'frame_index' (shape (B,) or (B,1))."
) )
target = self._calculate_anchor_based_progress(T_eff)
# During inference, we might not want to compute loss # During inference, we might not want to compute loss
if not self.training and target is None: if not self.training and target is None:
@@ -562,121 +484,72 @@ class RLearNPolicy(PreTrainedPolicy):
rewards = self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1) rewards = self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1)
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()} return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
# Calculate loss using the configured mode (logit regression or sigmoid+MSE) # Calculate loss using logit regression
loss_start = time.perf_counter() loss_start = time.perf_counter()
assert target.dtype == torch.float, "Continuous rewards require float targets"
# Get model outputs # Get model outputs
normalized_embeds = self.pre_reward_norm(video_frame_embeds) normalized_embeds = self.pre_reward_norm(video_frame_embeds)
raw_logits = self.reward_head(normalized_embeds).squeeze(-1) # (B, T_eff) raw_logits = self.reward_head(normalized_embeds).squeeze(-1) # (B, T_eff)
if self.config.use_logit_regression: # Logit regression: transform targets to logit space and compute MSE on logits
# Logit regression: transform targets to logit space and compute MSE on logits eps = self.config.logit_eps
eps = self.config.logit_eps target_expanded = target.expand(B, -1)[:, :T_eff] # Expand and trim to T_eff
target_clamped = torch.clamp(target[:, :T_eff], eps, 1 - eps) target_clamped = torch.clamp(target_expanded, eps, 1 - eps)
target_logits = torch.logit(target_clamped) target_logits = torch.logit(target_clamped)
loss = F.mse_loss(raw_logits, target_logits, reduction='mean') loss = F.mse_loss(raw_logits, target_logits, reduction='mean')
# For logging/debug, also compute sigmoid predictions
predicted_rewards = torch.sigmoid(raw_logits) # For logging, compute sigmoid predictions
else: predicted_rewards = torch.sigmoid(raw_logits)
# Sigmoid mode: apply sigmoid and compute MSE on probabilities
predicted_rewards = self.sigmoid(raw_logits)
loss = F.mse_loss(predicted_rewards, target[:, :T_eff], reduction='mean')
# Optional: Mismatched video-language pairs loss # Mismatched video-language pairs loss (always logit regression)
L_mismatch = torch.zeros((), device=device) L_mismatch = torch.zeros((), device=device)
if self.training and self.config.use_mismatch_loss and B > 1: if self.training and B > 1 and torch.rand(1, device=device).item() < self.config.mismatch_prob:
if torch.rand(1, device=device).item() < getattr(self.config, "mismatch_prob", 0.2): # Shuffle language within batch
# Shuffle language within batch shuffled_indices = torch.randperm(B, device=device)
shuffled_indices = torch.randperm(B, device=device) shuffled_commands = [commands[i] for i in shuffled_indices]
shuffled_commands = [commands[i] for i in shuffled_indices]
# Re-encode with mismatched language
# Re-encode with mismatched language lang_embeds_mm, mask_mm = self._encode_language_tokens(shuffled_commands, device)
lang_embeds_mm, mask_mm = self._encode_language_tokens(shuffled_commands, device) lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
# Pack and forward
# Pack and forward tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d')
tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d') mask_mm = F.pad(mask_mm, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
mask_mm = F.pad(mask_mm, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True) attended_mm = self.decoder(tokens_mm, mask=mask_mm)
attended_mm = self.decoder(tokens_mm, mask=mask_mm) _, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d')
_, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d')
# Process mismatch frames with single MLP
# Apply frame-specific processing to mismatch embeddings mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D)
mismatch_specific_embeds = []
T_video_mm = attended_video_mm.shape[1] mismatch_embeds = self.mlp_predictor(mismatch_tokens)
for t in range(T_video_mm):
frame_embed = self.frame_specific_mlp[t](attended_video_mm[:, t]) # Mismatched pairs should predict near-zero progress (logit mode)
mismatch_specific_embeds.append(frame_embed) normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds)
mismatch_specific_tokens = torch.stack(mismatch_specific_embeds, dim=1) mismatch_raw_logits = self.reward_head(normalized_mismatch_embeds).squeeze(-1)
mismatch_embeds = self.mlp_predictor(mismatch_specific_tokens) # Target logit corresponding to sigmoid ≈ 0
eps = self.config.logit_eps
# Mismatched pairs should predict zero progress zeros_target_logits = torch.logit(torch.full_like(target_expanded[:, :T_eff], eps))
normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds) L_mismatch = F.mse_loss(mismatch_raw_logits, zeros_target_logits, reduction='mean')
mismatch_raw_logits = self.reward_head(normalized_mismatch_embeds).squeeze(-1)
if self.config.use_logit_regression:
# In logit mode, target logit of ~0 corresponds to sigmoid(x)≈0
eps = self.config.logit_eps
zeros_target_logits = torch.logit(torch.full_like(target[:, :T_eff], eps))
L_mismatch = F.mse_loss(mismatch_raw_logits, zeros_target_logits, reduction='mean')
else:
# In sigmoid mode, target sigmoid output of 0
mismatch_predictions = self.sigmoid(mismatch_raw_logits)
zeros_target = torch.zeros_like(target[:, :T_eff])
L_mismatch = F.mse_loss(mismatch_predictions, zeros_target, reduction='mean')
# Total loss # Total loss
total_loss = loss + L_mismatch total_loss = loss + L_mismatch
loss_time = time.perf_counter() - loss_start loss_time = time.perf_counter() - loss_start
# DEBUG: Print targets and predictions occasionally during training # DEBUG: Clean logit regression monitoring
if self.training and torch.rand(1).item() < 0.02: # ~2% chance to debug print if self.training and torch.rand(1).item() < 0.03:
with torch.no_grad(): with torch.no_grad():
# Get raw MLP outputs, normalized outputs, and predictions
raw_outputs = video_frame_embeds
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
raw_logits = self.reward_head(normalized_embeds).squeeze(-1)
preds = self.sigmoid(raw_logits)
# Randomly sample a sequence from the batch for detailed analysis
sample_idx = torch.randint(0, B, (1,)).item() sample_idx = torch.randint(0, B, (1,)).item()
sample_targets = target_expanded[sample_idx, :T_eff].cpu().numpy()
sample_preds = predicted_rewards[sample_idx].cpu().numpy()
print(f"\n=== DEBUG TRAINING ===") print(f"\n=== LOGIT REGRESSION DEBUG ===")
# Target statistics print(f"Target: min={target_expanded.min():.3f}, max={target_expanded.max():.3f}, mean={target_expanded.mean():.3f}")
print(f"Target min: {target.min():.6f}") print(f"Logits: min={raw_logits.min():.3f}, max={raw_logits.max():.3f}, mean={raw_logits.mean():.3f}")
print(f"Target max: {target.max():.6f}") print(f"Preds: min={predicted_rewards.min():.3f}, max={predicted_rewards.max():.3f}, mean={predicted_rewards.mean():.3f}")
print(f"Target mean: {target.mean():.6f}") print(f"Sample {sample_idx}: targets={sample_targets[:8]} preds={sample_preds[:8]}")
print(f"Target range: [{target.min():.3f}, {target.max():.3f}]") print(f"Loss: {loss:.6f}")
# Model output statistics print("=" * 40)
print(f"Raw MLP range: [{raw_outputs.min():.3f}, {raw_outputs.max():.3f}]")
print(f"Normalized MLP range: [{normalized_embeds.min():.6f}, {normalized_embeds.max():.6f}]")
print(f"Raw logits range: [{raw_logits.min():.6f}, {raw_logits.max():.6f}]")
print(f"Raw logits mean: {raw_logits.mean():.6f}")
print(f"Sigmoid pred range: [{preds.min():.3f}, {preds.max():.3f}]")
print(f"Sigmoid pred mean: {preds.mean():.3f}")
print(f"Loss: {loss:.4f}")
# Show randomly sampled sequence for comparison
print(f"Sample {sample_idx} targets (all 16):", target[sample_idx].cpu().numpy())
print(f"Sample {sample_idx} preds (all 16): ", preds[sample_idx].cpu().numpy())
# TARGET FIX VERIFICATION: Check if we still have flat/stuck patterns
sample_targets = target[sample_idx].cpu().numpy()
# Count consecutive identical values (should be minimal after fix)
consecutive_same = 0
max_consecutive = 0
for i in range(1, len(sample_targets)):
if abs(sample_targets[i] - sample_targets[i-1]) < 1e-6:
consecutive_same += 1
max_consecutive = max(max_consecutive, consecutive_same + 1)
else:
consecutive_same = 0
if max_consecutive >= 3:
print(f"⚠️ STILL STUCK: {max_consecutive} consecutive identical targets!")
else:
print(f"✅ TARGET FIXED: Max consecutive identical = {max_consecutive}")
print("="*25)
total_forward_time = time.perf_counter() - forward_start total_forward_time = time.perf_counter() - forward_start
@@ -698,16 +571,11 @@ class RLearNPolicy(PreTrainedPolicy):
# Raw logits statistics (useful for monitoring head behavior) # Raw logits statistics (useful for monitoring head behavior)
"raw_logits_mean": float(raw_logits.mean().item()), "raw_logits_mean": float(raw_logits.mean().item()),
"raw_logits_std": float(raw_logits.std().item()), "raw_logits_std": float(raw_logits.std().item()),
# NEW: Anchor sampling statistics if available # Anchor sampling statistics
**({ "anchor_mean": float(anchor_stats.get('anchor_mean', 0.0)),
"anchor_mean": float(anchor_stats['anchor_mean']) if anchor_stats and not anchor_stats.get('fallback_used', False) else 0.0, "anchor_std": float(anchor_stats.get('anchor_std', 0.0)),
"anchor_std": float(anchor_stats['anchor_std']) if anchor_stats and not anchor_stats.get('fallback_used', False) else 0.0, "oob_fraction": float(anchor_stats.get('oob_fraction', 0.0)),
"oob_fraction": float(anchor_stats['oob_fraction']) if anchor_stats and not anchor_stats.get('fallback_used', False) else 0.0, "padded_fraction": float(anchor_stats.get('padded_fraction', 0.0)),
"padded_fraction": float(anchor_stats['padded_fraction']) if anchor_stats and not anchor_stats.get('fallback_used', False) else 0.0,
"use_random_anchors": not (anchor_stats and anchor_stats.get('fallback_used', False)) if anchor_stats else False,
}),
# Loss mode indicator
"logit_regression": bool(self.config.use_logit_regression),
# Timing information # Timing information
"timing_vision_ms": float(vision_time * 1000), "timing_vision_ms": float(vision_time * 1000),
"timing_language_ms": float(lang_time * 1000), "timing_language_ms": float(lang_time * 1000),
@@ -868,209 +736,84 @@ class RLearNPolicy(PreTrainedPolicy):
return ep, fr return ep, fr
def _sample_random_anchor_windows(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: def _sample_random_anchor_windows(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Sample random anchor windows for training to avoid sampling bias. """Sample random anchor windows for training."""
# Extract episode and frame indices - required for proper anchor sampling
Returns:
frames: (B, T, C, H, W) tensor with T = max_seq_len
anchor_stats: dict with sampling statistics for logging
"""
# Extract episode and frame indices
episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch) episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch)
if episode_indices is None or frame_indices is None or self.episode_data_index is None: if episode_indices is None or frame_indices is None or self.episode_data_index is None:
# Fallback to generic extractor if we don't have episode info raise ValueError(
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len) "Random anchor sampling requires 'episode_index', 'frame_index' in batch "
return frames, {"fallback_used": True} "and loaded 'episode_data_index'. Ensure episodes.jsonl is available."
)
device = next(self.parameters()).device device = next(self.parameters()).device
B = len(episode_indices) B = len(episode_indices)
T = self.config.max_seq_len T = self.config.max_seq_len
delta_indices = self.config.observation_delta_indices # [-15, -14, ..., 0]
# Get raw image data - assume it's already a temporal sequence from dataset # Get raw image data
raw_frames = extract_visual_sequence(batch, target_seq_len=None) # Don't force padding raw_frames = extract_visual_sequence(batch, target_seq_len=None)
available_T = raw_frames.shape[1] available_T = raw_frames.shape[1]
# For each sample, choose a random anchor and build the window # Sample random anchors and build windows
sampled_frames = [] sampled_frames = []
anchor_positions = [] anchor_positions = []
oob_count = 0 oob_count = 0
padded_count = 0
resampled_count = 0
for b_idx in range(B): for b_idx in range(B):
ep_idx = episode_indices[b_idx].item() ep_idx = episode_indices[b_idx].item()
frame_idx = frame_indices[b_idx].item() # Current frame position in episode
# Get episode boundaries # Get episode boundaries
ep_start = self.episode_data_index["from"][ep_idx].item() ep_start = self.episode_data_index["from"][ep_idx].item()
ep_end = self.episode_data_index["to"][ep_idx].item() ep_end = self.episode_data_index["to"][ep_idx].item()
ep_length = ep_end - ep_start ep_length = ep_end - ep_start
# Choose random anchor within episode bounds such that we can get a full window # Choose random anchor - need at least T-1 frames before for [-15..0] window
# The anchor is the "current" frame (delta=0), so we need at least T-1 frames before it min_anchor = T - 1
min_anchor = T - 1 # Need 15 frames before for [-15..0] window max_anchor = max(min_anchor, ep_length - 1)
max_anchor = ep_length - 1 # Episode frame indices are 0-based anchor = torch.randint(min_anchor, max_anchor + 1, (1,)).item()
if min_anchor > max_anchor:
# Episode too short for full window - use available frames with padding
anchor = max_anchor
padded_count += 1
else:
# Sample uniformly from valid range
anchor = torch.randint(min_anchor, max_anchor + 1, (1,)).item()
anchor_positions.append(anchor) anchor_positions.append(anchor)
# Build window indices relative to episode start # Build window indices with reflection padding
window_indices = [anchor + delta for delta in delta_indices] window_indices = []
# Handle out-of-bounds with reflection or clamping
valid_indices = []
had_oob = False had_oob = False
for idx in window_indices: for delta in range(-(T-1), 1): # [-15, -14, ..., 0] for T=16
idx = anchor + delta
if idx < 0: if idx < 0:
# Reflect at episode boundary idx = -idx # Reflect at start
valid_indices.append(-idx)
had_oob = True had_oob = True
elif idx >= ep_length: elif idx >= ep_length:
# Reflect at episode end idx = 2 * (ep_length - 1) - idx # Reflect at end
valid_indices.append(2 * (ep_length - 1) - idx)
had_oob = True had_oob = True
else: window_indices.append(min(idx, available_T - 1))
valid_indices.append(idx)
if had_oob: if had_oob:
oob_count += 1 oob_count += 1
# Extract frames at these indices from the raw temporal sequence # Extract frames
# Map episode-relative indices to sequence indices frame_tensors = [raw_frames[b_idx, idx] for idx in window_indices]
frame_tensors = [] sampled_frames.append(torch.stack(frame_tensors))
for ep_rel_idx in valid_indices:
if ep_rel_idx < available_T:
frame_tensors.append(raw_frames[b_idx, ep_rel_idx])
else:
# Fallback: repeat last available frame
frame_tensors.append(raw_frames[b_idx, -1])
padded_count += 1
sampled_frames.append(torch.stack(frame_tensors)) # (T, C, H, W)
frames = torch.stack(sampled_frames, dim=0) # (B, T, C, H, W) frames = torch.stack(sampled_frames, dim=0)
anchor_stats = { anchor_stats = {
"anchor_mean": float(torch.tensor(anchor_positions).float().mean()), "anchor_mean": float(torch.tensor(anchor_positions).float().mean()),
"anchor_std": float(torch.tensor(anchor_positions).float().std()), "anchor_std": float(torch.tensor(anchor_positions).float().std()),
"oob_fraction": float(oob_count) / B, "oob_fraction": float(oob_count) / B,
"padded_fraction": float(padded_count) / B, "padded_fraction": 0.0, # No padding with reflection approach
"resampled_count": resampled_count,
"fallback_used": False "fallback_used": False
} }
return frames, anchor_stats return frames, anchor_stats
def _calculate_anchor_based_progress(self, batch: dict[str, Tensor], anchor_stats: dict, T_eff: int) -> Tensor: def _calculate_anchor_based_progress(self, T_eff: int) -> Tensor:
"""Calculate progress labels based on known random anchors (more efficient).""" """Generate window-relative progress (0 to 1 across window)."""
episode_indices, _ = self._extract_episode_and_frame_indices(batch)
if episode_indices is None:
raise ValueError("Need episode_indices for anchor-based progress calculation")
device = next(self.parameters()).device device = next(self.parameters()).device
B = len(episode_indices) # Simple window-relative progress: 0 to 1 across the temporal window
delta_indices = self.config.observation_delta_indices # This centers the mean around 0.5 and is stable regardless of episode length
progress = torch.linspace(0, 1, T_eff, device=device)
# Build progress for each anchor position in the batch return progress.unsqueeze(0) # (1, T_eff) - will broadcast to (B, T_eff)
all_progress = []
for i, delta in enumerate(delta_indices[:T_eff]): # Only compute for frames we'll actually use
frame_progress = []
for b_idx in range(B):
ep_idx = episode_indices[b_idx].item()
# Get episode length
ep_start = self.episode_data_index["from"][ep_idx].item()
ep_end = self.episode_data_index["to"][ep_idx].item()
ep_length = ep_end - ep_start
# The anchor was chosen during window sampling
# For anchor-based progress, we use window-relative progress to center around 0.5
# This is more stable and matches ReWiND's simple approach
window_position = i # Position in window [0, T_eff-1]
progress = window_position / max(1, T_eff - 1) # 0 to 1 across window
frame_progress.append(progress)
all_progress.append(
torch.tensor(frame_progress, device=device, dtype=torch.float32)
)
return torch.stack(all_progress, dim=1) # (B, T_eff)
def _calculate_episode_progress(self, batch: dict[str, Tensor], episode_indices: Tensor,
frame_indices: Tensor, T_eff: int, idx: Tensor) -> Tensor:
"""Calculate progress labels using episode-relative positions (legacy fallback)."""
device = next(self.parameters()).device
B = len(episode_indices)
delta_indices = self.config.observation_delta_indices
# Calculate progress for each frame in the temporal window
all_progress = []
# DEBUG: Log indexing details for first sample occasionally
debug_indexing = torch.rand(1).item() < 0.05 # 5% chance
if debug_indexing:
print(f"\n=== EPISODE PROGRESS DEBUG ===")
print(f"Delta indices: {delta_indices}")
print(f"Batch size: {B}, T_eff: {T_eff}")
# Check if batch samples have diverse frame indices
unique_frames = torch.unique(frame_indices).tolist()
unique_episodes = torch.unique(episode_indices).tolist()
print(f"Unique frame indices in batch: {len(unique_frames)} values")
print(f"Unique episode indices in batch: {len(unique_episodes)} values")
if len(unique_frames) == 1:
print("🚨 RED FLAG: All samples have IDENTICAL frame index!")
for i, delta in enumerate(delta_indices):
# For each sample, calculate the progress of the frame at delta offset
frame_progress = []
for b_idx in range(B):
ep_idx = episode_indices[b_idx].item()
frame_idx = frame_indices[b_idx].item()
# Calculate the actual frame index with delta
target_frame_idx = frame_idx + delta
# Get episode boundaries
ep_start = self.episode_data_index["from"][ep_idx].item()
ep_end = self.episode_data_index["to"][ep_idx].item()
ep_length = ep_end - ep_start
# Calculate progress with proper boundary handling
if target_frame_idx < 0:
prog = target_frame_idx / max(1, ep_length - 1)
elif target_frame_idx >= ep_length:
prog = target_frame_idx / max(1, ep_length - 1)
else:
prog = target_frame_idx / max(1, ep_length - 1)
# Clip to reasonable bounds and clamp to [0,1] as recommended
prog = max(0.0, min(1.0, prog))
frame_progress.append(prog)
all_progress.append(
torch.tensor(frame_progress, device=device, dtype=torch.float32)
)
if debug_indexing:
print("=" * 30)
# Stack to get (B, T) tensor where T is the temporal sequence length
target = torch.stack(all_progress, dim=1) # (B, max_seq_len)
# Apply stride/dropout indexing to match the processed frames
return target[:, idx]
def _load_episode_index_from_jsonl(self, path: str) -> dict[str, Tensor]: def _load_episode_index_from_jsonl(self, path: str) -> dict[str, Tensor]:
import json import json
@@ -1097,10 +840,7 @@ class RLearNPolicy(PreTrainedPolicy):
"to": torch.tensor(ends, device=device, dtype=torch.long), "to": torch.tensor(ends, device=device, dtype=torch.long),
} }
# Helper functions for ReWiND architecture # Helper functions for ReWiND architecture
def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None) -> Tensor: def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None) -> Tensor:
"""Extract visual sequence from batch and ensure it has the expected temporal length. """Extract visual sequence from batch and ensure it has the expected temporal length.
@@ -1182,8 +922,8 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None
return frames return frames
def apply_video_rewind_fixed(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None) -> tuple[Tensor, Tensor]: def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None) -> tuple[Tensor, Tensor]:
"""Apply video rewinding augmentation WITHOUT constant-value padding (FIXED version). """Apply video rewinding augmentation without constant-value padding.
This version ensures the rewound sequence is exactly T frames without flat plateaus This version ensures the rewound sequence is exactly T frames without flat plateaus
that drag down the target mean. that drag down the target mean.