From 1e1b01025721b32b806bee3656fa337b885808f8 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 31 Aug 2025 15:40:00 +0200 Subject: [PATCH] cleanup --- .../policies/rlearn/configuration_rlearn.py | 55 +- .../policies/rlearn/modeling_rlearn.py | 512 +++++------------- 2 files changed, 140 insertions(+), 427 deletions(-) diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index d35262201..a3eff5049 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -60,59 +60,32 @@ class RLearNConfig(PreTrainedConfig): learning_rate: float = 1e-3 weight_decay: float = 0.01 - # Backward compatibility field (not used in current implementation) - use_tanh_head: bool = False - # Performance optimizations - use_amp: bool = True # Mixed precision training for speed boost - compile_model: bool = True # torch.compile for additional speedup + use_amp: bool = True + compile_model: bool = True - # ReWiND-specific parameters - use_video_rewind: bool = True # Enable video rewinding augmentation - rewind_prob: float = 0.5 # Reduced from 0.8 to avoid too many artifacts - rewind_last3_prob: float = 0.3 # Increased to favor smaller rewinds - use_mismatch_loss: bool = False # Enable mismatched language-video loss - mismatch_prob: float = ( - 0.2 # Probability to include a mismatched video-language forward pass (paper: ~20%) - ) + # ReWiND augmentation + rewind_prob: float = 0.5 + rewind_last3_prob: float = 0.3 + mismatch_prob: float = 0.2 - # NEW: Loss and head improvements - use_logit_regression: bool = True # Use logit space regression instead of sigmoid+MSE - logit_eps: float = 1e-6 # Clipping epsilon for logit transform: logit(clamp(target, eps, 1-eps)) - head_lr_multiplier: float = 2.0 # Increase head learning rate relative to base - head_weight_init_std: float = 0.05 # Larger head weight initialization for faster positive logits - remove_head_bias_wd: bool = True # Remove weight decay from head bias - - # Window sampling improvements - use_random_anchor_sampling: bool = True # Use explicit random anchor sampling during training - use_window_relative_progress: bool = True # Use window-relative progress (0-1 across window) instead of episode-relative - - # Loss hyperparameters (simplified for ReWiND) - # The main loss is just MSE between predicted and target progress + # Logit regression (only supported mode) + logit_eps: float = 1e-6 + head_lr_multiplier: float = 2.0 + head_weight_init_std: float = 0.05 # Normalization presets normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { "VISUAL": NormalizationMode.MEAN_STD, - # Language is tokenized at the encoder level; no numeric normalization here. } ) - # Architectural knobs to better mirror ReWiND - num_register_tokens: int = 4 # register / memory tokens, can't hurt - mlp_predictor_depth: int = 3 # depth of the per-frame MLP head - - # Loss configuration - supports both sigmoid+MSE and logit regression + # Architecture + num_register_tokens: int = 4 + mlp_predictor_depth: int = 3 - # Evaluation visualization parameters - enable_eval_visualizations: bool = False # Enable reward evaluation visualizations during training - eval_visualization_freq: int = 1000 # Steps between evaluation visualizations - eval_holdout_episodes: int = 9 # Number of episodes to hold out for evaluation - eval_max_frames: int = 128 # Maximum frames per episode for evaluation - eval_visualization_seed: int = 42 # Seed for reproducible episode selection - - # Optional: path to episodes.jsonl to build full-episode indices automatically - # Default to common dataset layout: /meta/episodes.jsonl + # Required path to episodes.jsonl for episode boundaries episodes_jsonl_path: str | None = "meta/episodes.jsonl" def validate_features(self) -> None: diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index fb9c1ea6b..16d543ae7 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -165,15 +165,8 @@ class RLearNPolicy(PreTrainedPolicy): # Stronger temporal positional encoding self.temporal_pos_embedding = nn.Parameter(torch.randn(config.max_seq_len, config.dim_model) * 0.1) - # CRITICAL: Frame-specific MLPs prevent temporal over-smoothing - # Problem: Transformer attention was making all 16 predictions identical (e.g. all 0.34) - # Solution: Each temporal position gets its own specialized MLP processing - # Frame 0 → MLP[0], Frame 1 → MLP[1], ..., Frame 15 → MLP[15] - # This creates distinct pathways for each frame while preserving attention context - self.frame_specific_mlp = nn.ModuleList([ - nn.Linear(config.dim_model, config.dim_model) - for _ in range(config.max_seq_len) # 16 separate MLPs for 16 frame positions - ]) + # Single MLP processes all frames + self.frame_mlp = nn.Linear(config.dim_model, config.dim_model) # Register / memory / attention sink tokens self.num_register_tokens = config.num_register_tokens @@ -190,21 +183,13 @@ class RLearNPolicy(PreTrainedPolicy): # Layer normalization before reward head to stabilize MLP outputs self.pre_reward_norm = nn.LayerNorm(config.dim_model) - # Regression head - supports both logit and sigmoid modes + # Regression head (logit mode only) self.reward_head = nn.Linear(config.dim_model, 1) - # Initialize head with improved settings + # Initialize head for logit regression with torch.no_grad(): - if config.use_logit_regression: - # Logit regression: can use larger weights since no saturation issues - self.reward_head.weight.normal_(0.0, config.head_weight_init_std) - self.reward_head.bias.fill_(0.0) # Neutral start in logit space - else: - # Sigmoid mode: moderate initialization - self.reward_head.weight.normal_(0.0, 0.02) - self.reward_head.bias.fill_(0.0) - - self.sigmoid = nn.Sigmoid() if not config.use_logit_regression else None + self.reward_head.weight.normal_(0.0, config.head_weight_init_std) + self.reward_head.bias.fill_(0.0) # Simple frame dropout probability self.frame_dropout_p = config.frame_dropout_p @@ -230,54 +215,21 @@ class RLearNPolicy(PreTrainedPolicy): # Continue without compilation def get_optim_params(self) -> list: - """Return parameter groups with custom LR and weight decay settings.""" - # Collect trainable parameters + """Return parameter groups with head LR boost.""" base_params = [] - head_weight_params = [] - head_bias_params = [] + head_params = [] for name, param in self.named_parameters(): - if not param.requires_grad: - continue - - if "reward_head" in name: - if "bias" in name: - head_bias_params.append(param) + if param.requires_grad: + if "reward_head" in name: + head_params.append(param) else: - head_weight_params.append(param) - else: - base_params.append(param) + base_params.append(param) - # Create parameter groups with different settings - param_groups = [] - - # Base parameters (everything except head) - if base_params: - param_groups.append({ - "params": base_params, - "name": "base" - }) - - # Head weight parameters (higher LR) - if head_weight_params: - param_groups.append({ - "params": head_weight_params, - "lr": self.config.learning_rate * self.config.head_lr_multiplier, - "name": "head_weights" - }) - - # Head bias parameters (higher LR, optionally no weight decay) - if head_bias_params: - head_bias_group = { - "params": head_bias_params, - "lr": self.config.learning_rate * self.config.head_lr_multiplier, - "name": "head_bias" - } - if self.config.remove_head_bias_wd: - head_bias_group["weight_decay"] = 0.0 - param_groups.append(head_bias_group) - - return param_groups + return [ + {"params": base_params}, + {"params": head_params, "lr": self.config.learning_rate * self.config.head_lr_multiplier} + ] def reset(self): pass @@ -350,28 +302,16 @@ class RLearNPolicy(PreTrainedPolicy): # Unpack and get video token features _, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d') - # Apply frame-specific processing to prevent over-smoothing - frame_specific_embeds = [] - T_video = attended_video_tokens.shape[1] - for t in range(T_video): - # Apply frame-specific MLP to each temporal position - frame_embed = self.frame_specific_mlp[t](attended_video_tokens[:, t]) - frame_specific_embeds.append(frame_embed) - frame_specific_tokens = torch.stack(frame_specific_embeds, dim=1) # (B, T, D) + # Process all frames with single MLP + frame_tokens = self.frame_mlp(attended_video_tokens) # (B, T, D) # MLP predictor - video_frame_embeds = self.mlp_predictor(frame_specific_tokens) + video_frame_embeds = self.mlp_predictor(frame_tokens) - # Get rewards via linear head + # Get rewards via logit regression head normalized_embeds = self.pre_reward_norm(video_frame_embeds) raw_logits = self.reward_head(normalized_embeds).squeeze(-1) # (B, T) - - if self.config.use_logit_regression: - # In logit mode, apply sigmoid at inference - return torch.sigmoid(raw_logits) - else: - # In sigmoid mode, apply sigmoid as usual - return self.sigmoid(raw_logits) + return torch.sigmoid(raw_logits) # Apply sigmoid for final predictions def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: # Initial version: no-op; rely on upstream processors if any @@ -444,25 +384,20 @@ class RLearNPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) - # NEW: Explicit random anchor window sampling for training - if self.training: - frames, anchor_stats = self._sample_random_anchor_windows(batch) - else: - # During inference, use the generic extractor - frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len) - anchor_stats = None + # Always use random anchor window sampling + frames, anchor_stats = self._sample_random_anchor_windows(batch) B, T, C, H, W = frames.shape device = next(self.parameters()).device frames = frames.to(device) - # Apply video rewinding augmentation during training (FIXED: no constant padding) + # Apply video rewinding augmentation (always enabled during training) augmented_target = None - if self.training and self.config.use_video_rewind: - frames, augmented_target = apply_video_rewind_fixed( + if self.training: + frames, augmented_target = apply_video_rewind( frames, rewind_prob=self.config.rewind_prob, - last3_prob=getattr(self.config, "rewind_last3_prob", None), + last3_prob=self.config.rewind_last3_prob, ) # Apply stride and frame dropout @@ -518,42 +453,29 @@ class RLearNPolicy(PreTrainedPolicy): # Unpack and get video token features _, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d') - # Apply frame-specific processing to prevent over-smoothing - frame_specific_embeds = [] - T_video = attended_video_tokens.shape[1] - for t in range(T_video): - # Apply frame-specific MLP to each temporal position - frame_embed = self.frame_specific_mlp[t](attended_video_tokens[:, t]) - frame_specific_embeds.append(frame_embed) - frame_specific_tokens = torch.stack(frame_specific_embeds, dim=1) # (B, T, D) + # Process all frames with single MLP + frame_tokens = self.frame_mlp(attended_video_tokens) # (B, T, D) # MLP predictor - video_frame_embeds = self.mlp_predictor(frame_specific_tokens) + video_frame_embeds = self.mlp_predictor(frame_tokens) transformer_time = time.perf_counter() - transformer_start # Generate progress labels on-the-fly (ReWiND approach) # IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window loss_dict: dict[str, float] = {} - # Check if video rewinding already set the target - if self.training and self.config.use_video_rewind and augmented_target is not None: - # Use the augmented target from video rewinding and align with temporal subsampling + # Generate progress targets + if self.training and augmented_target is not None: + # Video rewind already generated targets target = augmented_target[:, idx] - elif self.training and anchor_stats is not None and not anchor_stats.get("fallback_used", False): - # NEW: Calculate progress using the known random anchors - target = self._calculate_anchor_based_progress(batch, anchor_stats, T_eff) else: - # Fallback: Calculate episode progress the old way - episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch) - if episode_indices is not None and frame_indices is not None and self.episode_data_index is not None: - target = self._calculate_episode_progress(batch, episode_indices, frame_indices, T_eff, idx) - else: + # Use anchor-based window-relative progress + if anchor_stats.get("fallback_used", False): raise ValueError( - "No episode information found to build full-episode progress. " - "Expected 'episode_index' and 'frame_index' in batch and a valid 'episode_data_index' on the policy. " - "Please pass RLearNPolicy(episode_data_index=...) built from episodes.jsonl (per-episode lengths), " - "and ensure the dataset exposes 'episode_index' and 'frame_index' (shape (B,) or (B,1))." + "Anchor-based sampling failed. Ensure 'episode_index', 'frame_index' are in batch " + "and 'episode_data_index' is loaded from episodes.jsonl" ) + target = self._calculate_anchor_based_progress(T_eff) # During inference, we might not want to compute loss if not self.training and target is None: @@ -562,121 +484,72 @@ class RLearNPolicy(PreTrainedPolicy): rewards = self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1) return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()} - # Calculate loss using the configured mode (logit regression or sigmoid+MSE) + # Calculate loss using logit regression loss_start = time.perf_counter() - assert target.dtype == torch.float, "Continuous rewards require float targets" # Get model outputs normalized_embeds = self.pre_reward_norm(video_frame_embeds) raw_logits = self.reward_head(normalized_embeds).squeeze(-1) # (B, T_eff) - if self.config.use_logit_regression: - # Logit regression: transform targets to logit space and compute MSE on logits - eps = self.config.logit_eps - target_clamped = torch.clamp(target[:, :T_eff], eps, 1 - eps) - target_logits = torch.logit(target_clamped) - loss = F.mse_loss(raw_logits, target_logits, reduction='mean') - # For logging/debug, also compute sigmoid predictions - predicted_rewards = torch.sigmoid(raw_logits) - else: - # Sigmoid mode: apply sigmoid and compute MSE on probabilities - predicted_rewards = self.sigmoid(raw_logits) - loss = F.mse_loss(predicted_rewards, target[:, :T_eff], reduction='mean') + # Logit regression: transform targets to logit space and compute MSE on logits + eps = self.config.logit_eps + target_expanded = target.expand(B, -1)[:, :T_eff] # Expand and trim to T_eff + target_clamped = torch.clamp(target_expanded, eps, 1 - eps) + target_logits = torch.logit(target_clamped) + loss = F.mse_loss(raw_logits, target_logits, reduction='mean') + + # For logging, compute sigmoid predictions + predicted_rewards = torch.sigmoid(raw_logits) - # Optional: Mismatched video-language pairs loss + # Mismatched video-language pairs loss (always logit regression) L_mismatch = torch.zeros((), device=device) - if self.training and self.config.use_mismatch_loss and B > 1: - if torch.rand(1, device=device).item() < getattr(self.config, "mismatch_prob", 0.2): - # Shuffle language within batch - shuffled_indices = torch.randperm(B, device=device) - shuffled_commands = [commands[i] for i in shuffled_indices] - - # Re-encode with mismatched language - lang_embeds_mm, mask_mm = self._encode_language_tokens(shuffled_commands, device) - lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm) - - # Pack and forward - tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d') - mask_mm = F.pad(mask_mm, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True) - attended_mm = self.decoder(tokens_mm, mask=mask_mm) - _, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d') - - # Apply frame-specific processing to mismatch embeddings - mismatch_specific_embeds = [] - T_video_mm = attended_video_mm.shape[1] - for t in range(T_video_mm): - frame_embed = self.frame_specific_mlp[t](attended_video_mm[:, t]) - mismatch_specific_embeds.append(frame_embed) - mismatch_specific_tokens = torch.stack(mismatch_specific_embeds, dim=1) - - mismatch_embeds = self.mlp_predictor(mismatch_specific_tokens) - - # Mismatched pairs should predict zero progress - normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds) - mismatch_raw_logits = self.reward_head(normalized_mismatch_embeds).squeeze(-1) - - if self.config.use_logit_regression: - # In logit mode, target logit of ~0 corresponds to sigmoid(x)≈0 - eps = self.config.logit_eps - zeros_target_logits = torch.logit(torch.full_like(target[:, :T_eff], eps)) - L_mismatch = F.mse_loss(mismatch_raw_logits, zeros_target_logits, reduction='mean') - else: - # In sigmoid mode, target sigmoid output of 0 - mismatch_predictions = self.sigmoid(mismatch_raw_logits) - zeros_target = torch.zeros_like(target[:, :T_eff]) - L_mismatch = F.mse_loss(mismatch_predictions, zeros_target, reduction='mean') + if self.training and B > 1 and torch.rand(1, device=device).item() < self.config.mismatch_prob: + # Shuffle language within batch + shuffled_indices = torch.randperm(B, device=device) + shuffled_commands = [commands[i] for i in shuffled_indices] + + # Re-encode with mismatched language + lang_embeds_mm, mask_mm = self._encode_language_tokens(shuffled_commands, device) + lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm) + + # Pack and forward + tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d') + mask_mm = F.pad(mask_mm, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True) + attended_mm = self.decoder(tokens_mm, mask=mask_mm) + _, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d') + + # Process mismatch frames with single MLP + mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D) + + mismatch_embeds = self.mlp_predictor(mismatch_tokens) + + # Mismatched pairs should predict near-zero progress (logit mode) + normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds) + mismatch_raw_logits = self.reward_head(normalized_mismatch_embeds).squeeze(-1) + + # Target logit corresponding to sigmoid ≈ 0 + eps = self.config.logit_eps + zeros_target_logits = torch.logit(torch.full_like(target_expanded[:, :T_eff], eps)) + L_mismatch = F.mse_loss(mismatch_raw_logits, zeros_target_logits, reduction='mean') # Total loss total_loss = loss + L_mismatch loss_time = time.perf_counter() - loss_start - # DEBUG: Print targets and predictions occasionally during training - if self.training and torch.rand(1).item() < 0.02: # ~2% chance to debug print + # DEBUG: Clean logit regression monitoring + if self.training and torch.rand(1).item() < 0.03: with torch.no_grad(): - # Get raw MLP outputs, normalized outputs, and predictions - raw_outputs = video_frame_embeds - normalized_embeds = self.pre_reward_norm(video_frame_embeds) - raw_logits = self.reward_head(normalized_embeds).squeeze(-1) - preds = self.sigmoid(raw_logits) - - # Randomly sample a sequence from the batch for detailed analysis sample_idx = torch.randint(0, B, (1,)).item() + sample_targets = target_expanded[sample_idx, :T_eff].cpu().numpy() + sample_preds = predicted_rewards[sample_idx].cpu().numpy() - print(f"\n=== DEBUG TRAINING ===") - # Target statistics - print(f"Target min: {target.min():.6f}") - print(f"Target max: {target.max():.6f}") - print(f"Target mean: {target.mean():.6f}") - print(f"Target range: [{target.min():.3f}, {target.max():.3f}]") - # Model output statistics - print(f"Raw MLP range: [{raw_outputs.min():.3f}, {raw_outputs.max():.3f}]") - print(f"Normalized MLP range: [{normalized_embeds.min():.6f}, {normalized_embeds.max():.6f}]") - print(f"Raw logits range: [{raw_logits.min():.6f}, {raw_logits.max():.6f}]") - print(f"Raw logits mean: {raw_logits.mean():.6f}") - print(f"Sigmoid pred range: [{preds.min():.3f}, {preds.max():.3f}]") - print(f"Sigmoid pred mean: {preds.mean():.3f}") - print(f"Loss: {loss:.4f}") - # Show randomly sampled sequence for comparison - print(f"Sample {sample_idx} targets (all 16):", target[sample_idx].cpu().numpy()) - print(f"Sample {sample_idx} preds (all 16): ", preds[sample_idx].cpu().numpy()) - - # TARGET FIX VERIFICATION: Check if we still have flat/stuck patterns - sample_targets = target[sample_idx].cpu().numpy() - # Count consecutive identical values (should be minimal after fix) - consecutive_same = 0 - max_consecutive = 0 - for i in range(1, len(sample_targets)): - if abs(sample_targets[i] - sample_targets[i-1]) < 1e-6: - consecutive_same += 1 - max_consecutive = max(max_consecutive, consecutive_same + 1) - else: - consecutive_same = 0 - - if max_consecutive >= 3: - print(f"⚠️ STILL STUCK: {max_consecutive} consecutive identical targets!") - else: - print(f"✅ TARGET FIXED: Max consecutive identical = {max_consecutive}") - print("="*25) + print(f"\n=== LOGIT REGRESSION DEBUG ===") + print(f"Target: min={target_expanded.min():.3f}, max={target_expanded.max():.3f}, mean={target_expanded.mean():.3f}") + print(f"Logits: min={raw_logits.min():.3f}, max={raw_logits.max():.3f}, mean={raw_logits.mean():.3f}") + print(f"Preds: min={predicted_rewards.min():.3f}, max={predicted_rewards.max():.3f}, mean={predicted_rewards.mean():.3f}") + print(f"Sample {sample_idx}: targets={sample_targets[:8]} preds={sample_preds[:8]}") + print(f"Loss: {loss:.6f}") + print("=" * 40) total_forward_time = time.perf_counter() - forward_start @@ -698,16 +571,11 @@ class RLearNPolicy(PreTrainedPolicy): # Raw logits statistics (useful for monitoring head behavior) "raw_logits_mean": float(raw_logits.mean().item()), "raw_logits_std": float(raw_logits.std().item()), - # NEW: Anchor sampling statistics if available - **({ - "anchor_mean": float(anchor_stats['anchor_mean']) if anchor_stats and not anchor_stats.get('fallback_used', False) else 0.0, - "anchor_std": float(anchor_stats['anchor_std']) if anchor_stats and not anchor_stats.get('fallback_used', False) else 0.0, - "oob_fraction": float(anchor_stats['oob_fraction']) if anchor_stats and not anchor_stats.get('fallback_used', False) else 0.0, - "padded_fraction": float(anchor_stats['padded_fraction']) if anchor_stats and not anchor_stats.get('fallback_used', False) else 0.0, - "use_random_anchors": not (anchor_stats and anchor_stats.get('fallback_used', False)) if anchor_stats else False, - }), - # Loss mode indicator - "logit_regression": bool(self.config.use_logit_regression), + # Anchor sampling statistics + "anchor_mean": float(anchor_stats.get('anchor_mean', 0.0)), + "anchor_std": float(anchor_stats.get('anchor_std', 0.0)), + "oob_fraction": float(anchor_stats.get('oob_fraction', 0.0)), + "padded_fraction": float(anchor_stats.get('padded_fraction', 0.0)), # Timing information "timing_vision_ms": float(vision_time * 1000), "timing_language_ms": float(lang_time * 1000), @@ -868,209 +736,84 @@ class RLearNPolicy(PreTrainedPolicy): return ep, fr def _sample_random_anchor_windows(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: - """Sample random anchor windows for training to avoid sampling bias. - - Returns: - frames: (B, T, C, H, W) tensor with T = max_seq_len - anchor_stats: dict with sampling statistics for logging - """ - # Extract episode and frame indices + """Sample random anchor windows for training.""" + # Extract episode and frame indices - required for proper anchor sampling episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch) if episode_indices is None or frame_indices is None or self.episode_data_index is None: - # Fallback to generic extractor if we don't have episode info - frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len) - return frames, {"fallback_used": True} + raise ValueError( + "Random anchor sampling requires 'episode_index', 'frame_index' in batch " + "and loaded 'episode_data_index'. Ensure episodes.jsonl is available." + ) device = next(self.parameters()).device B = len(episode_indices) T = self.config.max_seq_len - delta_indices = self.config.observation_delta_indices # [-15, -14, ..., 0] - # Get raw image data - assume it's already a temporal sequence from dataset - raw_frames = extract_visual_sequence(batch, target_seq_len=None) # Don't force padding + # Get raw image data + raw_frames = extract_visual_sequence(batch, target_seq_len=None) available_T = raw_frames.shape[1] - # For each sample, choose a random anchor and build the window + # Sample random anchors and build windows sampled_frames = [] anchor_positions = [] oob_count = 0 - padded_count = 0 - resampled_count = 0 for b_idx in range(B): ep_idx = episode_indices[b_idx].item() - frame_idx = frame_indices[b_idx].item() # Current frame position in episode # Get episode boundaries ep_start = self.episode_data_index["from"][ep_idx].item() ep_end = self.episode_data_index["to"][ep_idx].item() ep_length = ep_end - ep_start - # Choose random anchor within episode bounds such that we can get a full window - # The anchor is the "current" frame (delta=0), so we need at least T-1 frames before it - min_anchor = T - 1 # Need 15 frames before for [-15..0] window - max_anchor = ep_length - 1 # Episode frame indices are 0-based - - if min_anchor > max_anchor: - # Episode too short for full window - use available frames with padding - anchor = max_anchor - padded_count += 1 - else: - # Sample uniformly from valid range - anchor = torch.randint(min_anchor, max_anchor + 1, (1,)).item() - + # Choose random anchor - need at least T-1 frames before for [-15..0] window + min_anchor = T - 1 + max_anchor = max(min_anchor, ep_length - 1) + anchor = torch.randint(min_anchor, max_anchor + 1, (1,)).item() anchor_positions.append(anchor) - # Build window indices relative to episode start - window_indices = [anchor + delta for delta in delta_indices] - - # Handle out-of-bounds with reflection or clamping - valid_indices = [] + # Build window indices with reflection padding + window_indices = [] had_oob = False - for idx in window_indices: + for delta in range(-(T-1), 1): # [-15, -14, ..., 0] for T=16 + idx = anchor + delta if idx < 0: - # Reflect at episode boundary - valid_indices.append(-idx) + idx = -idx # Reflect at start had_oob = True elif idx >= ep_length: - # Reflect at episode end - valid_indices.append(2 * (ep_length - 1) - idx) + idx = 2 * (ep_length - 1) - idx # Reflect at end had_oob = True - else: - valid_indices.append(idx) + window_indices.append(min(idx, available_T - 1)) if had_oob: oob_count += 1 - # Extract frames at these indices from the raw temporal sequence - # Map episode-relative indices to sequence indices - frame_tensors = [] - for ep_rel_idx in valid_indices: - if ep_rel_idx < available_T: - frame_tensors.append(raw_frames[b_idx, ep_rel_idx]) - else: - # Fallback: repeat last available frame - frame_tensors.append(raw_frames[b_idx, -1]) - padded_count += 1 - - sampled_frames.append(torch.stack(frame_tensors)) # (T, C, H, W) + # Extract frames + frame_tensors = [raw_frames[b_idx, idx] for idx in window_indices] + sampled_frames.append(torch.stack(frame_tensors)) - frames = torch.stack(sampled_frames, dim=0) # (B, T, C, H, W) + frames = torch.stack(sampled_frames, dim=0) anchor_stats = { "anchor_mean": float(torch.tensor(anchor_positions).float().mean()), "anchor_std": float(torch.tensor(anchor_positions).float().std()), "oob_fraction": float(oob_count) / B, - "padded_fraction": float(padded_count) / B, - "resampled_count": resampled_count, + "padded_fraction": 0.0, # No padding with reflection approach "fallback_used": False } return frames, anchor_stats - def _calculate_anchor_based_progress(self, batch: dict[str, Tensor], anchor_stats: dict, T_eff: int) -> Tensor: - """Calculate progress labels based on known random anchors (more efficient).""" - episode_indices, _ = self._extract_episode_and_frame_indices(batch) - if episode_indices is None: - raise ValueError("Need episode_indices for anchor-based progress calculation") - + def _calculate_anchor_based_progress(self, T_eff: int) -> Tensor: + """Generate window-relative progress (0 to 1 across window).""" device = next(self.parameters()).device - B = len(episode_indices) - delta_indices = self.config.observation_delta_indices - - # Build progress for each anchor position in the batch - all_progress = [] - - for i, delta in enumerate(delta_indices[:T_eff]): # Only compute for frames we'll actually use - frame_progress = [] - for b_idx in range(B): - ep_idx = episode_indices[b_idx].item() - - # Get episode length - ep_start = self.episode_data_index["from"][ep_idx].item() - ep_end = self.episode_data_index["to"][ep_idx].item() - ep_length = ep_end - ep_start - - # The anchor was chosen during window sampling - # For anchor-based progress, we use window-relative progress to center around 0.5 - # This is more stable and matches ReWiND's simple approach - window_position = i # Position in window [0, T_eff-1] - progress = window_position / max(1, T_eff - 1) # 0 to 1 across window - - frame_progress.append(progress) - - all_progress.append( - torch.tensor(frame_progress, device=device, dtype=torch.float32) - ) - - return torch.stack(all_progress, dim=1) # (B, T_eff) + # Simple window-relative progress: 0 to 1 across the temporal window + # This centers the mean around 0.5 and is stable regardless of episode length + progress = torch.linspace(0, 1, T_eff, device=device) + return progress.unsqueeze(0) # (1, T_eff) - will broadcast to (B, T_eff) - def _calculate_episode_progress(self, batch: dict[str, Tensor], episode_indices: Tensor, - frame_indices: Tensor, T_eff: int, idx: Tensor) -> Tensor: - """Calculate progress labels using episode-relative positions (legacy fallback).""" - device = next(self.parameters()).device - B = len(episode_indices) - delta_indices = self.config.observation_delta_indices - - # Calculate progress for each frame in the temporal window - all_progress = [] - - # DEBUG: Log indexing details for first sample occasionally - debug_indexing = torch.rand(1).item() < 0.05 # 5% chance - if debug_indexing: - print(f"\n=== EPISODE PROGRESS DEBUG ===") - print(f"Delta indices: {delta_indices}") - print(f"Batch size: {B}, T_eff: {T_eff}") - - # Check if batch samples have diverse frame indices - unique_frames = torch.unique(frame_indices).tolist() - unique_episodes = torch.unique(episode_indices).tolist() - print(f"Unique frame indices in batch: {len(unique_frames)} values") - print(f"Unique episode indices in batch: {len(unique_episodes)} values") - - if len(unique_frames) == 1: - print("🚨 RED FLAG: All samples have IDENTICAL frame index!") - - for i, delta in enumerate(delta_indices): - # For each sample, calculate the progress of the frame at delta offset - frame_progress = [] - for b_idx in range(B): - ep_idx = episode_indices[b_idx].item() - frame_idx = frame_indices[b_idx].item() - # Calculate the actual frame index with delta - target_frame_idx = frame_idx + delta - - # Get episode boundaries - ep_start = self.episode_data_index["from"][ep_idx].item() - ep_end = self.episode_data_index["to"][ep_idx].item() - ep_length = ep_end - ep_start - - # Calculate progress with proper boundary handling - if target_frame_idx < 0: - prog = target_frame_idx / max(1, ep_length - 1) - elif target_frame_idx >= ep_length: - prog = target_frame_idx / max(1, ep_length - 1) - else: - prog = target_frame_idx / max(1, ep_length - 1) - - # Clip to reasonable bounds and clamp to [0,1] as recommended - prog = max(0.0, min(1.0, prog)) - frame_progress.append(prog) - - all_progress.append( - torch.tensor(frame_progress, device=device, dtype=torch.float32) - ) - - if debug_indexing: - print("=" * 30) - - # Stack to get (B, T) tensor where T is the temporal sequence length - target = torch.stack(all_progress, dim=1) # (B, max_seq_len) - - # Apply stride/dropout indexing to match the processed frames - return target[:, idx] def _load_episode_index_from_jsonl(self, path: str) -> dict[str, Tensor]: import json @@ -1097,10 +840,7 @@ class RLearNPolicy(PreTrainedPolicy): "to": torch.tensor(ends, device=device, dtype=torch.long), } - # Helper functions for ReWiND architecture - - def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None) -> Tensor: """Extract visual sequence from batch and ensure it has the expected temporal length. @@ -1182,8 +922,8 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None return frames -def apply_video_rewind_fixed(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None) -> tuple[Tensor, Tensor]: - """Apply video rewinding augmentation WITHOUT constant-value padding (FIXED version). +def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None) -> tuple[Tensor, Tensor]: + """Apply video rewinding augmentation without constant-value padding. This version ensures the rewound sequence is exactly T frames without flat plateaus that drag down the target mean.