mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-27 05:07:15 +00:00
change sampling
This commit is contained in:
@@ -69,12 +69,23 @@ class RLearNConfig(PreTrainedConfig):
|
||||
|
||||
# ReWiND-specific parameters
|
||||
use_video_rewind: bool = True # Enable video rewinding augmentation
|
||||
rewind_prob: float = 0.8 # Probability of applying rewind to each sample (paper: ~80%)
|
||||
rewind_last3_prob: float = 0.1 # Of the rewinds, 10% only rewind the last 3 frames
|
||||
rewind_prob: float = 0.5 # Reduced from 0.8 to avoid too many artifacts
|
||||
rewind_last3_prob: float = 0.3 # Increased to favor smaller rewinds
|
||||
use_mismatch_loss: bool = False # Enable mismatched language-video loss
|
||||
mismatch_prob: float = (
|
||||
0.2 # Probability to include a mismatched video-language forward pass (paper: ~20%)
|
||||
)
|
||||
|
||||
# NEW: Loss and head improvements
|
||||
use_logit_regression: bool = True # Use logit space regression instead of sigmoid+MSE
|
||||
logit_eps: float = 1e-6 # Clipping epsilon for logit transform: logit(clamp(target, eps, 1-eps))
|
||||
head_lr_multiplier: float = 2.0 # Increase head learning rate relative to base
|
||||
head_weight_init_std: float = 0.05 # Larger head weight initialization for faster positive logits
|
||||
remove_head_bias_wd: bool = True # Remove weight decay from head bias
|
||||
|
||||
# Window sampling improvements
|
||||
use_random_anchor_sampling: bool = True # Use explicit random anchor sampling during training
|
||||
use_window_relative_progress: bool = True # Use window-relative progress (0-1 across window) instead of episode-relative
|
||||
|
||||
# Loss hyperparameters (simplified for ReWiND)
|
||||
# The main loss is just MSE between predicted and target progress
|
||||
@@ -91,7 +102,7 @@ class RLearNConfig(PreTrainedConfig):
|
||||
num_register_tokens: int = 4 # register / memory tokens, can't hurt
|
||||
mlp_predictor_depth: int = 3 # depth of the per-frame MLP head
|
||||
|
||||
# Simple MSE regression loss (no binning)
|
||||
# Loss configuration - supports both sigmoid+MSE and logit regression
|
||||
|
||||
# Evaluation visualization parameters
|
||||
enable_eval_visualizations: bool = False # Enable reward evaluation visualizations during training
|
||||
|
||||
@@ -190,16 +190,21 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# Layer normalization before reward head to stabilize MLP outputs
|
||||
self.pre_reward_norm = nn.LayerNorm(config.dim_model)
|
||||
|
||||
# MSE regression head with sigmoid activation to bound outputs to [0,1]
|
||||
# Regression head - supports both logit and sigmoid modes
|
||||
self.reward_head = nn.Linear(config.dim_model, 1)
|
||||
|
||||
# Initialize with small weights to prevent sigmoid saturation
|
||||
# Target: sigmoid(0) = 0.5, so we want raw logits around [-2, 2] range
|
||||
# Initialize head with improved settings
|
||||
with torch.no_grad():
|
||||
self.reward_head.weight.normal_(0.0, 0.02) # Small but not tiny
|
||||
self.reward_head.bias.fill_(0.0) # Start at sigmoid(0) = 0.5
|
||||
if config.use_logit_regression:
|
||||
# Logit regression: can use larger weights since no saturation issues
|
||||
self.reward_head.weight.normal_(0.0, config.head_weight_init_std)
|
||||
self.reward_head.bias.fill_(0.0) # Neutral start in logit space
|
||||
else:
|
||||
# Sigmoid mode: moderate initialization
|
||||
self.reward_head.weight.normal_(0.0, 0.02)
|
||||
self.reward_head.bias.fill_(0.0)
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.sigmoid = nn.Sigmoid() if not config.use_logit_regression else None
|
||||
|
||||
# Simple frame dropout probability
|
||||
self.frame_dropout_p = config.frame_dropout_p
|
||||
@@ -224,9 +229,55 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
print(f"⚠️ torch.compile failed: {e}")
|
||||
# Continue without compilation
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
# Train only projections, temporal module and head by default if backbones are frozen
|
||||
return [p for p in self.parameters() if p.requires_grad]
|
||||
def get_optim_params(self) -> list:
|
||||
"""Return parameter groups with custom LR and weight decay settings."""
|
||||
# Collect trainable parameters
|
||||
base_params = []
|
||||
head_weight_params = []
|
||||
head_bias_params = []
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
if "reward_head" in name:
|
||||
if "bias" in name:
|
||||
head_bias_params.append(param)
|
||||
else:
|
||||
head_weight_params.append(param)
|
||||
else:
|
||||
base_params.append(param)
|
||||
|
||||
# Create parameter groups with different settings
|
||||
param_groups = []
|
||||
|
||||
# Base parameters (everything except head)
|
||||
if base_params:
|
||||
param_groups.append({
|
||||
"params": base_params,
|
||||
"name": "base"
|
||||
})
|
||||
|
||||
# Head weight parameters (higher LR)
|
||||
if head_weight_params:
|
||||
param_groups.append({
|
||||
"params": head_weight_params,
|
||||
"lr": self.config.learning_rate * self.config.head_lr_multiplier,
|
||||
"name": "head_weights"
|
||||
})
|
||||
|
||||
# Head bias parameters (higher LR, optionally no weight decay)
|
||||
if head_bias_params:
|
||||
head_bias_group = {
|
||||
"params": head_bias_params,
|
||||
"lr": self.config.learning_rate * self.config.head_lr_multiplier,
|
||||
"name": "head_bias"
|
||||
}
|
||||
if self.config.remove_head_bias_wd:
|
||||
head_bias_group["weight_decay"] = 0.0
|
||||
param_groups.append(head_bias_group)
|
||||
|
||||
return param_groups
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
@@ -311,9 +362,16 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# MLP predictor
|
||||
video_frame_embeds = self.mlp_predictor(frame_specific_tokens)
|
||||
|
||||
# Get rewards via linear head with sigmoid activation
|
||||
# Get rewards via linear head
|
||||
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
|
||||
return self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1) # (B, T)
|
||||
raw_logits = self.reward_head(normalized_embeds).squeeze(-1) # (B, T)
|
||||
|
||||
if self.config.use_logit_regression:
|
||||
# In logit mode, apply sigmoid at inference
|
||||
return torch.sigmoid(raw_logits)
|
||||
else:
|
||||
# In sigmoid mode, apply sigmoid as usual
|
||||
return self.sigmoid(raw_logits)
|
||||
|
||||
def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# Initial version: no-op; rely on upstream processors if any
|
||||
@@ -386,16 +444,22 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract frames and form (B, T, C, H, W)
|
||||
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
|
||||
# NEW: Explicit random anchor window sampling for training
|
||||
if self.training:
|
||||
frames, anchor_stats = self._sample_random_anchor_windows(batch)
|
||||
else:
|
||||
# During inference, use the generic extractor
|
||||
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
|
||||
anchor_stats = None
|
||||
|
||||
B, T, C, H, W = frames.shape
|
||||
device = next(self.parameters()).device
|
||||
frames = frames.to(device)
|
||||
|
||||
# Apply video rewinding augmentation during training
|
||||
# Apply video rewinding augmentation during training (FIXED: no constant padding)
|
||||
augmented_target = None
|
||||
if self.training and self.config.use_video_rewind:
|
||||
frames, augmented_target = apply_video_rewind(
|
||||
frames, augmented_target = apply_video_rewind_fixed(
|
||||
frames,
|
||||
rewind_prob=self.config.rewind_prob,
|
||||
last3_prob=getattr(self.config, "rewind_last3_prob", None),
|
||||
@@ -472,122 +536,17 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
loss_dict: dict[str, float] = {}
|
||||
|
||||
# Check if video rewinding already set the target
|
||||
if self.training and self.config.use_video_rewind and "augmented_target" in locals():
|
||||
if self.training and self.config.use_video_rewind and augmented_target is not None:
|
||||
# Use the augmented target from video rewinding and align with temporal subsampling
|
||||
target = augmented_target[:, idx]
|
||||
elif self.training and anchor_stats is not None and not anchor_stats.get("fallback_used", False):
|
||||
# NEW: Calculate progress using the known random anchors
|
||||
target = self._calculate_anchor_based_progress(batch, anchor_stats, T_eff)
|
||||
else:
|
||||
# Calculate true episode progress using episode_index and frame_index from batch
|
||||
# Fallback: Calculate episode progress the old way
|
||||
episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch)
|
||||
if episode_indices is not None and frame_indices is not None and self.episode_data_index is not None:
|
||||
|
||||
# Calculate progress for the current frame in each sample
|
||||
progress_values = []
|
||||
|
||||
for b_idx in range(B):
|
||||
ep_idx = episode_indices[b_idx].item()
|
||||
frame_idx = frame_indices[b_idx].item()
|
||||
|
||||
# Get episode boundaries
|
||||
ep_start = self.episode_data_index["from"][ep_idx].item()
|
||||
ep_end = self.episode_data_index["to"][ep_idx].item()
|
||||
ep_length = ep_end - ep_start
|
||||
|
||||
# Progress from 0 to 1 within the episode
|
||||
# frame_index is relative to the episode (0-based within episode)
|
||||
progress = frame_idx / max(1, ep_length - 1)
|
||||
progress_values.append(progress)
|
||||
|
||||
# Create progress tensor for the current frame (last in temporal sequence)
|
||||
current_progress = torch.tensor(progress_values, device=video_frame_embeds.device, dtype=video_frame_embeds.dtype)
|
||||
|
||||
# Now calculate progress for ALL frames in the temporal window
|
||||
# The observation_delta_indices tell us which frames we're looking at
|
||||
delta_indices = self.config.observation_delta_indices # e.g., [-15, -14, ..., 0]
|
||||
|
||||
# Calculate progress for each frame in the temporal window
|
||||
all_progress = []
|
||||
|
||||
# DEBUG: Log indexing details for first sample occasionally
|
||||
debug_indexing = torch.rand(1).item() < 0.10 # 10% chance - increased for debugging
|
||||
if debug_indexing:
|
||||
print(f"\n=== INDEXING DEBUG ===")
|
||||
print(f"Delta indices: {delta_indices}")
|
||||
print(f"Batch size: {B}")
|
||||
|
||||
# Check if batch samples have diverse frame indices (red flag if all identical)
|
||||
unique_frames = torch.unique(frame_indices).tolist()
|
||||
unique_episodes = torch.unique(episode_indices).tolist()
|
||||
print(f"Unique frame indices in batch: {unique_frames[:10]}{'...' if len(unique_frames) > 10 else ''}")
|
||||
print(f"Unique episode indices in batch: {unique_episodes[:10]}{'...' if len(unique_episodes) > 10 else ''}")
|
||||
|
||||
if len(unique_frames) == 1:
|
||||
print("🚨 RED FLAG: All samples have IDENTICAL frame index! This causes identical targets.")
|
||||
|
||||
# First sample details
|
||||
ep_idx_0 = episode_indices[0].item()
|
||||
frame_idx_0 = frame_indices[0].item()
|
||||
ep_start_0 = self.episode_data_index["from"][ep_idx_0].item()
|
||||
ep_end_0 = self.episode_data_index["to"][ep_idx_0].item()
|
||||
ep_length_0 = ep_end_0 - ep_start_0
|
||||
print(f"First sample - Episode: {ep_idx_0}, Frame: {frame_idx_0}/{ep_length_0}, Episode length: {ep_length_0}")
|
||||
|
||||
# Check boundary proximity
|
||||
frames_from_start = frame_idx_0
|
||||
frames_from_end = ep_length_0 - frame_idx_0 - 1
|
||||
print(f"First sample proximity - Start: {frames_from_start}, End: {frames_from_end}")
|
||||
|
||||
if frames_from_start < 15:
|
||||
print(f"⚠️ Close to episode START: many deltas will go negative")
|
||||
if frames_from_end < 15:
|
||||
print(f"⚠️ Close to episode END: many deltas will exceed episode")
|
||||
|
||||
for i, delta in enumerate(delta_indices):
|
||||
# For each sample, calculate the progress of the frame at delta offset
|
||||
frame_progress = []
|
||||
for b_idx in range(B):
|
||||
ep_idx = episode_indices[b_idx].item()
|
||||
frame_idx = frame_indices[b_idx].item()
|
||||
|
||||
# Calculate the actual frame index with delta
|
||||
target_frame_idx = frame_idx + delta
|
||||
|
||||
# Get episode boundaries
|
||||
ep_start = self.episode_data_index["from"][ep_idx].item()
|
||||
ep_end = self.episode_data_index["to"][ep_idx].item()
|
||||
ep_length = ep_end - ep_start
|
||||
|
||||
# Calculate progress with proper boundary handling
|
||||
if target_frame_idx < 0:
|
||||
# Before episode start: extrapolate negative progress
|
||||
prog = target_frame_idx / max(1, ep_length - 1)
|
||||
elif target_frame_idx >= ep_length:
|
||||
# After episode end: extrapolate progress beyond 1.0
|
||||
prog = target_frame_idx / max(1, ep_length - 1)
|
||||
else:
|
||||
# Within episode: normal progress calculation
|
||||
prog = target_frame_idx / max(1, ep_length - 1)
|
||||
|
||||
# Clip to reasonable bounds to prevent extreme values
|
||||
prog = max(-1.0, min(2.0, prog)) # Allow some extrapolation
|
||||
frame_progress.append(prog)
|
||||
|
||||
# DEBUG: Log first sample's calculation
|
||||
if debug_indexing and b_idx == 0:
|
||||
boundary_status = "BEFORE" if target_frame_idx < 0 else "AFTER" if target_frame_idx >= ep_length else "WITHIN"
|
||||
print(f" Frame {i:2d} (δ={delta:3d}): target_idx={target_frame_idx:3d} [{boundary_status}] → progress={prog:.6f}")
|
||||
|
||||
all_progress.append(
|
||||
torch.tensor(frame_progress, device=video_frame_embeds.device, dtype=video_frame_embeds.dtype)
|
||||
)
|
||||
|
||||
if debug_indexing:
|
||||
print("=" * 22)
|
||||
|
||||
# Stack to get (B, T) tensor where T is the temporal sequence length
|
||||
target = torch.stack(all_progress, dim=1) # (B, max_seq_len)
|
||||
|
||||
# Apply stride/dropout indexing to match the processed frames
|
||||
target = target[:, idx]
|
||||
target = self._calculate_episode_progress(batch, episode_indices, frame_indices, T_eff, idx)
|
||||
else:
|
||||
raise ValueError(
|
||||
"No episode information found to build full-episode progress. "
|
||||
@@ -603,16 +562,26 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
rewards = self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1)
|
||||
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
|
||||
|
||||
# Calculate loss using MSE
|
||||
# Calculate loss using the configured mode (logit regression or sigmoid+MSE)
|
||||
loss_start = time.perf_counter()
|
||||
assert target.dtype == torch.float, "Continuous rewards require float targets"
|
||||
|
||||
# Get reward predictions with sigmoid activation
|
||||
# Get model outputs
|
||||
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
|
||||
predicted_rewards = self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1) # (B, T_eff)
|
||||
raw_logits = self.reward_head(normalized_embeds).squeeze(-1) # (B, T_eff)
|
||||
|
||||
# MSE loss with masking for variable length sequences
|
||||
loss = F.mse_loss(predicted_rewards, target[:, :T_eff], reduction='mean')
|
||||
if self.config.use_logit_regression:
|
||||
# Logit regression: transform targets to logit space and compute MSE on logits
|
||||
eps = self.config.logit_eps
|
||||
target_clamped = torch.clamp(target[:, :T_eff], eps, 1 - eps)
|
||||
target_logits = torch.logit(target_clamped)
|
||||
loss = F.mse_loss(raw_logits, target_logits, reduction='mean')
|
||||
# For logging/debug, also compute sigmoid predictions
|
||||
predicted_rewards = torch.sigmoid(raw_logits)
|
||||
else:
|
||||
# Sigmoid mode: apply sigmoid and compute MSE on probabilities
|
||||
predicted_rewards = self.sigmoid(raw_logits)
|
||||
loss = F.mse_loss(predicted_rewards, target[:, :T_eff], reduction='mean')
|
||||
|
||||
# Optional: Mismatched video-language pairs loss
|
||||
L_mismatch = torch.zeros((), device=device)
|
||||
@@ -644,9 +613,18 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Mismatched pairs should predict zero progress
|
||||
normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds)
|
||||
mismatch_predictions = self.sigmoid(self.reward_head(normalized_mismatch_embeds)).squeeze(-1)
|
||||
zeros_target = torch.zeros_like(target[:, :T_eff])
|
||||
L_mismatch = F.mse_loss(mismatch_predictions, zeros_target, reduction='mean')
|
||||
mismatch_raw_logits = self.reward_head(normalized_mismatch_embeds).squeeze(-1)
|
||||
|
||||
if self.config.use_logit_regression:
|
||||
# In logit mode, target logit of ~0 corresponds to sigmoid(x)≈0
|
||||
eps = self.config.logit_eps
|
||||
zeros_target_logits = torch.logit(torch.full_like(target[:, :T_eff], eps))
|
||||
L_mismatch = F.mse_loss(mismatch_raw_logits, zeros_target_logits, reduction='mean')
|
||||
else:
|
||||
# In sigmoid mode, target sigmoid output of 0
|
||||
mismatch_predictions = self.sigmoid(mismatch_raw_logits)
|
||||
zeros_target = torch.zeros_like(target[:, :T_eff])
|
||||
L_mismatch = F.mse_loss(mismatch_predictions, zeros_target, reduction='mean')
|
||||
|
||||
# Total loss
|
||||
total_loss = loss + L_mismatch
|
||||
@@ -713,8 +691,23 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
"target_min": float(target.min().item()),
|
||||
"target_max": float(target.max().item()),
|
||||
"target_mean": float(target.mean().item()),
|
||||
"target_std": float(target.std().item()),
|
||||
# Prediction statistics
|
||||
"pred_mean": float(predicted_rewards.mean().item()),
|
||||
"pred_std": float(predicted_rewards.std().item()),
|
||||
# Raw logits statistics (useful for monitoring head behavior)
|
||||
"raw_logits_mean": float(raw_logits.mean().item()),
|
||||
"raw_logits_std": float(raw_logits.std().item()),
|
||||
# NEW: Anchor sampling statistics if available
|
||||
**({
|
||||
"anchor_mean": float(anchor_stats['anchor_mean']) if anchor_stats and not anchor_stats.get('fallback_used', False) else 0.0,
|
||||
"anchor_std": float(anchor_stats['anchor_std']) if anchor_stats and not anchor_stats.get('fallback_used', False) else 0.0,
|
||||
"oob_fraction": float(anchor_stats['oob_fraction']) if anchor_stats and not anchor_stats.get('fallback_used', False) else 0.0,
|
||||
"padded_fraction": float(anchor_stats['padded_fraction']) if anchor_stats and not anchor_stats.get('fallback_used', False) else 0.0,
|
||||
"use_random_anchors": not (anchor_stats and anchor_stats.get('fallback_used', False)) if anchor_stats else False,
|
||||
}),
|
||||
# Loss mode indicator
|
||||
"logit_regression": bool(self.config.use_logit_regression),
|
||||
# Timing information
|
||||
"timing_vision_ms": float(vision_time * 1000),
|
||||
"timing_language_ms": float(lang_time * 1000),
|
||||
@@ -874,6 +867,211 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
return ep, fr
|
||||
|
||||
def _sample_random_anchor_windows(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Sample random anchor windows for training to avoid sampling bias.
|
||||
|
||||
Returns:
|
||||
frames: (B, T, C, H, W) tensor with T = max_seq_len
|
||||
anchor_stats: dict with sampling statistics for logging
|
||||
"""
|
||||
# Extract episode and frame indices
|
||||
episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch)
|
||||
|
||||
if episode_indices is None or frame_indices is None or self.episode_data_index is None:
|
||||
# Fallback to generic extractor if we don't have episode info
|
||||
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
|
||||
return frames, {"fallback_used": True}
|
||||
|
||||
device = next(self.parameters()).device
|
||||
B = len(episode_indices)
|
||||
T = self.config.max_seq_len
|
||||
delta_indices = self.config.observation_delta_indices # [-15, -14, ..., 0]
|
||||
|
||||
# Get raw image data - assume it's already a temporal sequence from dataset
|
||||
raw_frames = extract_visual_sequence(batch, target_seq_len=None) # Don't force padding
|
||||
available_T = raw_frames.shape[1]
|
||||
|
||||
# For each sample, choose a random anchor and build the window
|
||||
sampled_frames = []
|
||||
anchor_positions = []
|
||||
oob_count = 0
|
||||
padded_count = 0
|
||||
resampled_count = 0
|
||||
|
||||
for b_idx in range(B):
|
||||
ep_idx = episode_indices[b_idx].item()
|
||||
frame_idx = frame_indices[b_idx].item() # Current frame position in episode
|
||||
|
||||
# Get episode boundaries
|
||||
ep_start = self.episode_data_index["from"][ep_idx].item()
|
||||
ep_end = self.episode_data_index["to"][ep_idx].item()
|
||||
ep_length = ep_end - ep_start
|
||||
|
||||
# Choose random anchor within episode bounds such that we can get a full window
|
||||
# The anchor is the "current" frame (delta=0), so we need at least T-1 frames before it
|
||||
min_anchor = T - 1 # Need 15 frames before for [-15..0] window
|
||||
max_anchor = ep_length - 1 # Episode frame indices are 0-based
|
||||
|
||||
if min_anchor > max_anchor:
|
||||
# Episode too short for full window - use available frames with padding
|
||||
anchor = max_anchor
|
||||
padded_count += 1
|
||||
else:
|
||||
# Sample uniformly from valid range
|
||||
anchor = torch.randint(min_anchor, max_anchor + 1, (1,)).item()
|
||||
|
||||
anchor_positions.append(anchor)
|
||||
|
||||
# Build window indices relative to episode start
|
||||
window_indices = [anchor + delta for delta in delta_indices]
|
||||
|
||||
# Handle out-of-bounds with reflection or clamping
|
||||
valid_indices = []
|
||||
had_oob = False
|
||||
for idx in window_indices:
|
||||
if idx < 0:
|
||||
# Reflect at episode boundary
|
||||
valid_indices.append(-idx)
|
||||
had_oob = True
|
||||
elif idx >= ep_length:
|
||||
# Reflect at episode end
|
||||
valid_indices.append(2 * (ep_length - 1) - idx)
|
||||
had_oob = True
|
||||
else:
|
||||
valid_indices.append(idx)
|
||||
|
||||
if had_oob:
|
||||
oob_count += 1
|
||||
|
||||
# Extract frames at these indices from the raw temporal sequence
|
||||
# Map episode-relative indices to sequence indices
|
||||
frame_tensors = []
|
||||
for ep_rel_idx in valid_indices:
|
||||
if ep_rel_idx < available_T:
|
||||
frame_tensors.append(raw_frames[b_idx, ep_rel_idx])
|
||||
else:
|
||||
# Fallback: repeat last available frame
|
||||
frame_tensors.append(raw_frames[b_idx, -1])
|
||||
padded_count += 1
|
||||
|
||||
sampled_frames.append(torch.stack(frame_tensors)) # (T, C, H, W)
|
||||
|
||||
frames = torch.stack(sampled_frames, dim=0) # (B, T, C, H, W)
|
||||
|
||||
anchor_stats = {
|
||||
"anchor_mean": float(torch.tensor(anchor_positions).float().mean()),
|
||||
"anchor_std": float(torch.tensor(anchor_positions).float().std()),
|
||||
"oob_fraction": float(oob_count) / B,
|
||||
"padded_fraction": float(padded_count) / B,
|
||||
"resampled_count": resampled_count,
|
||||
"fallback_used": False
|
||||
}
|
||||
|
||||
return frames, anchor_stats
|
||||
|
||||
def _calculate_anchor_based_progress(self, batch: dict[str, Tensor], anchor_stats: dict, T_eff: int) -> Tensor:
|
||||
"""Calculate progress labels based on known random anchors (more efficient)."""
|
||||
episode_indices, _ = self._extract_episode_and_frame_indices(batch)
|
||||
if episode_indices is None:
|
||||
raise ValueError("Need episode_indices for anchor-based progress calculation")
|
||||
|
||||
device = next(self.parameters()).device
|
||||
B = len(episode_indices)
|
||||
delta_indices = self.config.observation_delta_indices
|
||||
|
||||
# Build progress for each anchor position in the batch
|
||||
all_progress = []
|
||||
|
||||
for i, delta in enumerate(delta_indices[:T_eff]): # Only compute for frames we'll actually use
|
||||
frame_progress = []
|
||||
for b_idx in range(B):
|
||||
ep_idx = episode_indices[b_idx].item()
|
||||
|
||||
# Get episode length
|
||||
ep_start = self.episode_data_index["from"][ep_idx].item()
|
||||
ep_end = self.episode_data_index["to"][ep_idx].item()
|
||||
ep_length = ep_end - ep_start
|
||||
|
||||
# The anchor was chosen during window sampling
|
||||
# For anchor-based progress, we use window-relative progress to center around 0.5
|
||||
# This is more stable and matches ReWiND's simple approach
|
||||
window_position = i # Position in window [0, T_eff-1]
|
||||
progress = window_position / max(1, T_eff - 1) # 0 to 1 across window
|
||||
|
||||
frame_progress.append(progress)
|
||||
|
||||
all_progress.append(
|
||||
torch.tensor(frame_progress, device=device, dtype=torch.float32)
|
||||
)
|
||||
|
||||
return torch.stack(all_progress, dim=1) # (B, T_eff)
|
||||
|
||||
def _calculate_episode_progress(self, batch: dict[str, Tensor], episode_indices: Tensor,
|
||||
frame_indices: Tensor, T_eff: int, idx: Tensor) -> Tensor:
|
||||
"""Calculate progress labels using episode-relative positions (legacy fallback)."""
|
||||
device = next(self.parameters()).device
|
||||
B = len(episode_indices)
|
||||
delta_indices = self.config.observation_delta_indices
|
||||
|
||||
# Calculate progress for each frame in the temporal window
|
||||
all_progress = []
|
||||
|
||||
# DEBUG: Log indexing details for first sample occasionally
|
||||
debug_indexing = torch.rand(1).item() < 0.05 # 5% chance
|
||||
if debug_indexing:
|
||||
print(f"\n=== EPISODE PROGRESS DEBUG ===")
|
||||
print(f"Delta indices: {delta_indices}")
|
||||
print(f"Batch size: {B}, T_eff: {T_eff}")
|
||||
|
||||
# Check if batch samples have diverse frame indices
|
||||
unique_frames = torch.unique(frame_indices).tolist()
|
||||
unique_episodes = torch.unique(episode_indices).tolist()
|
||||
print(f"Unique frame indices in batch: {len(unique_frames)} values")
|
||||
print(f"Unique episode indices in batch: {len(unique_episodes)} values")
|
||||
|
||||
if len(unique_frames) == 1:
|
||||
print("🚨 RED FLAG: All samples have IDENTICAL frame index!")
|
||||
|
||||
for i, delta in enumerate(delta_indices):
|
||||
# For each sample, calculate the progress of the frame at delta offset
|
||||
frame_progress = []
|
||||
for b_idx in range(B):
|
||||
ep_idx = episode_indices[b_idx].item()
|
||||
frame_idx = frame_indices[b_idx].item()
|
||||
|
||||
# Calculate the actual frame index with delta
|
||||
target_frame_idx = frame_idx + delta
|
||||
|
||||
# Get episode boundaries
|
||||
ep_start = self.episode_data_index["from"][ep_idx].item()
|
||||
ep_end = self.episode_data_index["to"][ep_idx].item()
|
||||
ep_length = ep_end - ep_start
|
||||
|
||||
# Calculate progress with proper boundary handling
|
||||
if target_frame_idx < 0:
|
||||
prog = target_frame_idx / max(1, ep_length - 1)
|
||||
elif target_frame_idx >= ep_length:
|
||||
prog = target_frame_idx / max(1, ep_length - 1)
|
||||
else:
|
||||
prog = target_frame_idx / max(1, ep_length - 1)
|
||||
|
||||
# Clip to reasonable bounds and clamp to [0,1] as recommended
|
||||
prog = max(0.0, min(1.0, prog))
|
||||
frame_progress.append(prog)
|
||||
|
||||
all_progress.append(
|
||||
torch.tensor(frame_progress, device=device, dtype=torch.float32)
|
||||
)
|
||||
|
||||
if debug_indexing:
|
||||
print("=" * 30)
|
||||
|
||||
# Stack to get (B, T) tensor where T is the temporal sequence length
|
||||
target = torch.stack(all_progress, dim=1) # (B, max_seq_len)
|
||||
|
||||
# Apply stride/dropout indexing to match the processed frames
|
||||
return target[:, idx]
|
||||
|
||||
def _load_episode_index_from_jsonl(self, path: str) -> dict[str, Tensor]:
|
||||
import json
|
||||
lengths: list[int] = []
|
||||
@@ -984,17 +1182,16 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None
|
||||
return frames
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None) -> tuple[Tensor, Tensor]:
|
||||
"""Apply video rewinding augmentation as described in ReWiND paper.
|
||||
|
||||
Each video in the batch has an independent chance of being rewound.
|
||||
def apply_video_rewind_fixed(frames: Tensor, rewind_prob: float = 0.5, last3_prob: float | None = None) -> tuple[Tensor, Tensor]:
|
||||
"""Apply video rewinding augmentation WITHOUT constant-value padding (FIXED version).
|
||||
|
||||
This version ensures the rewound sequence is exactly T frames without flat plateaus
|
||||
that drag down the target mean.
|
||||
|
||||
Args:
|
||||
frames: Tensor of shape (B, T, C, H, W)
|
||||
rewind_prob: Probability of applying rewind augmentation to each video
|
||||
last3_prob: Probability of limiting rewind to last 3 frames
|
||||
|
||||
Returns:
|
||||
Augmented frames and corresponding progress labels
|
||||
@@ -1002,8 +1199,8 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
|
||||
B, T, C, H, W = frames.shape
|
||||
device = frames.device
|
||||
|
||||
# Create default progress labels (linearly increasing from 0 to 1 with denominator T-1)
|
||||
# torch.linspace(0, 1, T) already yields j/(T-1) at step j
|
||||
# Create default progress labels using window-relative progress (0 to 1)
|
||||
# This centers the mean around 0.5 and removes episode-length dependence
|
||||
default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1)
|
||||
|
||||
# Apply rewind augmentation to each sample in batch independently
|
||||
@@ -1020,50 +1217,68 @@ def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5, last3_prob: flo
|
||||
augmented_progress.append(default_progress[b])
|
||||
continue
|
||||
|
||||
# Apply rewinding to this video
|
||||
# Split point i: between frame 2 and T-1 (upper bound exclusive in torch.randint)
|
||||
i = torch.randint(2, T, (1,)).item()
|
||||
# Apply rewinding - but ensure we get exactly T frames
|
||||
max_attempts = 10 # Limit resampling attempts
|
||||
success = False
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
# Split point i: between frame 2 and T-1
|
||||
i = torch.randint(2, T, (1,)).item()
|
||||
|
||||
# Rewind length k: between 1 and i-1 frames
|
||||
if last3_prob is not None and torch.rand(1).item() < last3_prob and i >= 3:
|
||||
k = min(3, i - 1)
|
||||
# Rewind length k: between 1 and i-1 frames
|
||||
if last3_prob is not None and torch.rand(1).item() < last3_prob and i >= 3:
|
||||
k = min(3, i - 1)
|
||||
else:
|
||||
k = torch.randint(1, i, (1,)).item()
|
||||
k = min(k, i - 1)
|
||||
|
||||
# Create rewound sequence: frames[0:i] + reversed frames[i-k:i]
|
||||
forward_length = i
|
||||
reverse_length = k
|
||||
total_length = forward_length + reverse_length
|
||||
|
||||
# Check if we can make exactly T frames
|
||||
if total_length == T:
|
||||
# Perfect fit!
|
||||
forward_frames = frames[b, :i]
|
||||
reverse_frames = frames[b, max(0, i - k):i].flip(dims=[0])
|
||||
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
|
||||
|
||||
# Create corresponding progress labels without constant padding
|
||||
denom = max(T - 1, 1)
|
||||
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
|
||||
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device)
|
||||
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||
|
||||
success = True
|
||||
break
|
||||
elif total_length < T:
|
||||
# Too short - try to extend by adjusting k
|
||||
needed = T - total_length
|
||||
if i + needed <= T: # Can we extend k?
|
||||
k_extended = k + needed
|
||||
if i - k_extended >= 0:
|
||||
forward_frames = frames[b, :i]
|
||||
reverse_frames = frames[b, max(0, i - k_extended):i].flip(dims=[0])
|
||||
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
|
||||
|
||||
if rewound_seq.shape[0] == T:
|
||||
# Create progress labels
|
||||
denom = max(T - 1, 1)
|
||||
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
|
||||
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k_extended) / denom), k_extended, device=device)
|
||||
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||
|
||||
success = True
|
||||
break
|
||||
# If too long or can't fix, try again with different i,k
|
||||
|
||||
if success:
|
||||
augmented_frames.append(rewound_seq)
|
||||
augmented_progress.append(rewound_progress)
|
||||
else:
|
||||
k = torch.randint(1, i, (1,)).item()
|
||||
k = min(k, i - 1)
|
||||
|
||||
# Create rewound sequence: o1...oi, oi-1, ..., oi-k
|
||||
forward_frames = frames[b, :i] # Frames up to split point
|
||||
reverse_frames = frames[b, max(0, i - k) : i].flip(dims=[0]) # Reversed frames
|
||||
|
||||
# Concatenate forward and reverse parts
|
||||
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
|
||||
|
||||
# Pad by repeating the last real frame if needed to maintain fixed length T
|
||||
if rewound_seq.shape[0] < T:
|
||||
last_frame = rewound_seq[-1:]
|
||||
pad_frames = last_frame.expand(T - rewound_seq.shape[0], C, H, W)
|
||||
rewound_seq = torch.cat([rewound_seq, pad_frames], dim=0)
|
||||
elif rewound_seq.shape[0] > T:
|
||||
rewound_seq = rewound_seq[:T]
|
||||
|
||||
# Create corresponding progress labels
|
||||
denom = max(T - 1, 1)
|
||||
# Forward part: increasing progress using denominator T-1
|
||||
forward_progress = torch.linspace(0, (i - 1) / denom, i, device=device)
|
||||
# Reverse part: decreasing progress starting from (i-1)/(T-1)
|
||||
reverse_progress = torch.linspace((i - 1) / denom, max(0.0, (i - k) / denom), k, device=device)
|
||||
|
||||
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||
|
||||
# Pad progress by repeating the last real progress if needed
|
||||
if rewound_progress.shape[0] < T:
|
||||
last_val = rewound_progress[-1]
|
||||
pad_vals = last_val.expand(T - rewound_progress.shape[0])
|
||||
rewound_progress = torch.cat([rewound_progress, pad_vals])
|
||||
elif rewound_progress.shape[0] > T:
|
||||
rewound_progress = rewound_progress[:T]
|
||||
|
||||
augmented_frames.append(rewound_seq)
|
||||
augmented_progress.append(rewound_progress)
|
||||
# Fallback: use original sequence if we can't create a good rewind
|
||||
augmented_frames.append(frames[b])
|
||||
augmented_progress.append(default_progress[b])
|
||||
|
||||
return torch.stack(augmented_frames), torch.stack(augmented_progress)
|
||||
@@ -248,6 +248,15 @@ def train(cfg: TrainPipelineConfig):
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
elif cfg.policy.type == "rlearn":
|
||||
# For RLearN, drop first 15 frames to avoid padding issues with temporal windows
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.episode_data_index,
|
||||
drop_n_first_frames=15, # Skip frames that would need padding
|
||||
drop_n_last_frames=0,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
|
||||
Reference in New Issue
Block a user