use only rewind loss

This commit is contained in:
Pepijn
2025-08-28 14:22:57 +02:00
parent a4c88d6340
commit c877e98658
11 changed files with 445 additions and 520 deletions
File diff suppressed because one or more lines are too long
+2
View File
@@ -33,6 +33,8 @@ class DatasetConfig:
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | None = None
episodes: list[int] | None = None
# Percentage of dataset to use (0-100). If set, overrides episodes parameter.
percentage: float | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
revision: str | None = None
use_imagenet_stats: bool = True
+12 -1
View File
@@ -87,10 +87,21 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
# Handle percentage parameter
episodes = cfg.dataset.episodes
if cfg.dataset.percentage is not None:
# Calculate episodes based on percentage
total_episodes = ds_meta.total_episodes
num_episodes_to_use = max(1, int(total_episodes * cfg.dataset.percentage / 100))
episodes = list(range(num_episodes_to_use))
import logging
logging.info(f"Using {cfg.dataset.percentage}% of dataset: {num_episodes_to_use}/{total_episodes} episodes")
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=cfg.dataset.episodes,
episodes=episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
+2
View File
@@ -20,6 +20,7 @@ from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
from .rlearn.configuration_rlearn import RLearNConfig as RLearNConfig
__all__ = [
"ACTConfig",
@@ -28,4 +29,5 @@ __all__ = [
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
"RLearNConfig",
]
+5
View File
@@ -244,6 +244,7 @@ def make_policy(
cfg: PreTrainedConfig,
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
episode_data_index: dict | None = None,
) -> PreTrainedPolicy:
"""Make an instance of a policy class.
@@ -301,6 +302,10 @@ def make_policy(
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
# Pass episode_data_index for RLearN policy to calculate proper progress
if cfg.type == "rlearn" and episode_data_index is not None:
kwargs["episode_data_index"] = episode_data_index
if cfg.pretrained_path:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
@@ -61,31 +61,14 @@ class RLearNConfig(PreTrainedConfig):
# Training
learning_rate: float = 1e-4
weight_decay: float = 0.01
loss_type: str = "composite" # Always use composite loss with spatial awareness
ranking_margin: float = 0.1
# Composite loss weights (with spatial awareness and ReWiND reversibility)
lambda_prog: float = 1.0 # Progress regression weight
lambda_spatial_nce: float = 0.5 # Spatial-aware InfoNCE weight
lambda_rewind: float = 0.4 # ReWiND reversible ranking weight
# ReWiND-specific parameters
use_video_rewind: bool = True # Enable video rewinding augmentation
rewind_prob: float = 0.5 # Probability of applying rewind to each batch
use_mismatch_loss: bool = True # Enable mismatched language-video loss
# Loss hyperparameters
nce_temperature: float = 0.07 # Temperature for InfoNCE
zscore_eps: float = 1e-5 # Epsilon for z-score normalization
min_rank_gap: int = 1 # Minimum gap for temporal ranking pairs
num_ranking_pairs: int = 64 # Number of (far, near) pairs to sample for ReWiND
last_k_for_nce: int = 3 # Use last k frames for InfoNCE
mismatch_lang_prob: float = 0.2 # Probability of language mismatch augmentation
# Value-based pairwise loss hyperparameters (for value_pairwise mode)
lambda_dir: float = 1.0 # Intra-trajectory directional ranking
lambda_text: float = 0.5 # Inter-instruction contrastive ranking
lambda_flat: float = 0.25 # Flatness under mismatch
dir_margin: float = 0.2 # Margin for directional ranking
text_margin: float = 0.2 # Margin for text contrastive ranking
flat_epsilon: float = 0.05 # Epsilon band for flatness loss
num_pairs_per_loss: int = 64 # Number of pairs to sample per loss term
use_hard_negatives: bool = True # Whether to generate hard negative instructions
# Loss hyperparameters (simplified for ReWiND)
# The main loss is just MSE between predicted and target progress
# Normalization presets
normalization_mapping: dict[str, NormalizationMode] = field(
@@ -116,7 +99,7 @@ class RLearNConfig(PreTrainedConfig):
@property
def reward_delta_indices(self) -> list | None:
# By default we supervise every provided timestep equally.
# ReWiND generates progress labels on-the-fly, doesn't need reward data
return None
def get_optimizer_preset(self): # type: ignore[override]
+45 -31
View File
@@ -241,56 +241,70 @@ class RLearnEvaluator:
@torch.no_grad()
def predict_episode_rewards(self, frames: Tensor, language: str, batch_size: int = 16) -> np.ndarray:
"""
Predict rewards for a single episode.
Predict rewards for a single episode using proper temporal sequences.
Args:
frames: Video frames tensor of shape (T, C, H, W)
language: Language instruction string
batch_size: Maximum sequence length to process at once
batch_size: Maximum number of temporal sequences to process at once
Returns:
Predicted rewards array of shape (T,)
"""
T = frames.shape[0]
max_seq_len = self.model.config.max_seq_len
# Preprocess frames to match model expectations
processed_frames = self._preprocess_frames(frames)
# Process in chunks if episode is very long
if T <= batch_size:
# Single batch
# Create temporal sequences for each frame
# For frame i, we want frames [i-max_seq_len+1, ..., i-1, i]
temporal_sequences = []
for i in range(T):
# Create sequence ending at frame i
seq_frames = []
for j in range(max(0, i - max_seq_len + 1), i + 1):
# Use frame j if available, otherwise repeat the first available frame
frame_idx = max(0, min(j, T - 1))
seq_frames.append(processed_frames[frame_idx])
# Pad sequence to max_seq_len by repeating the first frame if needed
while len(seq_frames) < max_seq_len:
seq_frames.insert(0, seq_frames[0]) # Prepend first frame
# Take only the last max_seq_len frames if we have too many
seq_frames = seq_frames[-max_seq_len:]
temporal_sequences.append(torch.stack(seq_frames)) # (max_seq_len, C, H, W)
# Stack all temporal sequences: (T, max_seq_len, C, H, W)
all_sequences = torch.stack(temporal_sequences)
# Process in batches
rewards = []
for i in range(0, T, batch_size):
end_idx = min(i + batch_size, T)
batch_sequences = all_sequences[i:end_idx].to(self.device) # (B, max_seq_len, C, H, W)
# Create batch for model
batch = {
OBS_IMAGES: processed_frames.unsqueeze(0).to(self.device), # (1, T, C, H, W)
OBS_LANGUAGE: [language],
OBS_IMAGES: batch_sequences, # (B, T, C, H, W) format expected by model
OBS_LANGUAGE: [language] * batch_sequences.shape[0],
}
# Use the new predict_rewards method
values = self.model.predict_rewards(batch) # (1, T')
rewards = values.squeeze(0).cpu().numpy() # (T',)
# Predict rewards - model returns (B, T') but we want the last timestep for each sequence
values = self.model.predict_rewards(batch) # (B, T')
else:
# Process in overlapping chunks to handle very long episodes
rewards = []
stride = batch_size // 2 # 50% overlap
# Take the last timestep prediction for each sequence (represents current frame reward)
if values.dim() == 2:
batch_rewards = values[:, -1].cpu().numpy() # (B,) - last timestep
else:
batch_rewards = values.cpu().numpy() # (B,) - already single timestep
for i in range(0, T, stride):
end_idx = min(i + batch_size, T)
chunk_frames = processed_frames[i:end_idx]
rewards.extend(batch_rewards)
batch = {OBS_IMAGES: chunk_frames.unsqueeze(0).to(self.device), OBS_LANGUAGE: [language]}
chunk_values = self.model.predict_rewards(batch)
chunk_rewards = chunk_values.squeeze(0).cpu().numpy()
# For overlapping chunks, only take the first half (except for the last chunk)
if i + batch_size < T:
rewards.extend(chunk_rewards[:stride])
else:
rewards.extend(chunk_rewards)
rewards = np.array(rewards[:T]) # Ensure exact length
return rewards
return np.array(rewards[:T]) # Ensure exact length
def _preprocess_frames(self, frames: Tensor) -> Tensor:
"""
+210 -375
View File
@@ -15,7 +15,12 @@
# limitations under the License.
"""
RLearN: Video-Language Conditioned Reward Model
RLearN: Video-Language Conditioned Reward Model (ReWiND Implementation)
This implementation follows the ReWiND paper approach:
- Automatically generates linear progress labels (0 to 1) for each episode
- No need for pre-annotated rewards in the dataset
- Applies video rewinding augmentation to create synthetic failure trajectories
Inputs
- images: (B, T, C, H, W) sequence of frames (or single frame with T=1)
@@ -95,9 +100,10 @@ class RLearNPolicy(PreTrainedPolicy):
config_class = RLearNConfig
name = "rlearn"
def __init__(self, config: RLearNConfig):
def __init__(self, config: RLearNConfig, episode_data_index: dict = None):
super().__init__(config)
self.config = config
self.episode_data_index = episode_data_index # Store episode boundaries for progress calculation
# Encoders
from transformers import AutoModel, AutoProcessor
@@ -161,15 +167,6 @@ class RLearNPolicy(PreTrainedPolicy):
if config.use_tanh_head:
head_layers.append(nn.Tanh())
self.head = nn.Sequential(*head_layers)
# Projection from scalar value summary to text embedding dim for InfoNCE
self.value_to_text_proj = nn.Linear(1, config.dim_model)
# Spatial attention for InfoNCE
self.spatial_cross_attn = nn.MultiheadAttention(
embed_dim=config.dim_model, num_heads=config.n_heads, batch_first=True
)
self.spatial_norm = nn.LayerNorm(config.dim_model)
# Simple frame dropout probability
self.frame_dropout_p = config.frame_dropout_p
self.stride = max(1, config.stride)
@@ -274,12 +271,14 @@ class RLearNPolicy(PreTrainedPolicy):
return batch
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Compute training loss and logs.
"""Compute ReWiND training loss with on-the-fly progress label generation.
Expected batch keys:
- OBS_IMAGES: list[Tensor] of shape [(B, C, H, W), ...] per time step or stacked (B, T, C, H, W)
- OBS_LANGUAGE: optional string tokens already tokenized externally or raw strings
- REWARD: (B, T) or (B,) target rewards
Note: Progress labels (0 to 1) are generated automatically for each episode.
No REWARD key is needed in the batch.
"""
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
@@ -288,6 +287,13 @@ class RLearNPolicy(PreTrainedPolicy):
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
B, T, C, H, W = frames.shape
# Apply video rewinding augmentation during training
if self.training and self.config.use_video_rewind:
frames, augmented_target = apply_video_rewind(frames, rewind_prob=self.config.rewind_prob)
# Use augmented progress labels if rewinding was applied
if REWARD in batch:
target = augmented_target
# Apply stride and frame dropout during training
idx = torch.arange(0, T, self.stride, device=frames.device)
if self.training and self.frame_dropout_p > 0.0 and T > 1:
@@ -328,24 +334,15 @@ class RLearNPolicy(PreTrainedPolicy):
# Encode through vision model
vision_outputs = self.vision_encoder(pixel_values=pixel_values)
# Extract BOTH CLS token and spatial patches
# Extract CLS token for temporal modeling
if hasattr(vision_outputs, "last_hidden_state"):
all_tokens = vision_outputs.last_hidden_state # (BT, num_tokens, D)
cls_tokens = all_tokens[:, 0] # (BT, D) - CLS token for temporal modeling
spatial_tokens = all_tokens[:, 1:] # (BT, num_patches, D) - spatial patches
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D) - CLS token
else:
raise RuntimeError("Vision encoder must output last_hidden_state with spatial features")
raise RuntimeError("Vision encoder must output last_hidden_state")
# Project CLS tokens for temporal sequence
visual_seq = self.visual_proj(cls_tokens).reshape(B, -1, self.config.dim_model) # (B, T', D)
# Keep spatial features for spatial-aware losses (project them too)
# Assuming 16x16 patches for 256x256 image with patch_size=16
num_patches = spatial_tokens.shape[1]
spatial_features = self.visual_proj(spatial_tokens).reshape(
B, -1, num_patches, self.config.dim_model
) # (B, T', num_patches, D)
# Add temporal positional encodings and optional first-frame bias
pe = (
self.positional_encoding[: visual_seq.shape[1]]
@@ -362,150 +359,156 @@ class RLearNPolicy(PreTrainedPolicy):
temporal_features = self.temporal(visual_seq, lang_emb, return_features=True) # (B, T', D)
values = self.head(temporal_features).squeeze(-1) # (B, T')
# Targets
target = batch.get(REWARD, None)
# Generate progress labels on-the-fly (ReWiND approach)
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
loss_dict: dict[str, float] = {}
if target is None:
# If no labels, return zeros loss and logits for inference
# Check if video rewinding already set the target
if self.training and self.config.use_video_rewind and 'augmented_target' in locals():
# Use the augmented target from video rewinding
target = augmented_target
else:
# Calculate true episode progress using episode_index and frame_index from batch
if "episode_index" in batch and "frame_index" in batch and hasattr(self, 'episode_data_index'):
# Get episode indices and frame indices from batch
episode_indices = batch["episode_index"] # Shape: (B,)
frame_indices = batch["frame_index"] # Shape: (B,)
# Calculate progress for the current frame in each sample
progress_values = []
for b_idx in range(B):
ep_idx = episode_indices[b_idx].item()
frame_idx = frame_indices[b_idx].item()
# Get episode boundaries
ep_start = self.episode_data_index["from"][ep_idx].item()
ep_end = self.episode_data_index["to"][ep_idx].item()
ep_length = ep_end - ep_start
# Progress from 0 to 1 within the episode
# frame_index is relative to the episode (0-based within episode)
progress = frame_idx / max(1, ep_length - 1)
progress_values.append(progress)
# Create progress tensor for the current frame (last in temporal sequence)
current_progress = torch.tensor(progress_values, device=values.device, dtype=values.dtype)
# Now calculate progress for ALL frames in the temporal window
# The observation_delta_indices tell us which frames we're looking at
delta_indices = self.config.observation_delta_indices # e.g., [-15, -14, ..., 0]
# Calculate progress for each frame in the temporal window
all_progress = []
for delta in delta_indices:
# For each sample, calculate the progress of the frame at delta offset
frame_progress = []
for b_idx in range(B):
ep_idx = episode_indices[b_idx].item()
frame_idx = frame_indices[b_idx].item()
# Calculate the actual frame index with delta
target_frame_idx = frame_idx + delta
# Get episode boundaries
ep_start = self.episode_data_index["from"][ep_idx].item()
ep_end = self.episode_data_index["to"][ep_idx].item()
ep_length = ep_end - ep_start
# Clamp to episode boundaries (frame_index is relative to episode)
target_frame_idx = max(0, min(ep_length - 1, target_frame_idx))
# Calculate progress for this frame
prog = target_frame_idx / max(1, ep_length - 1)
frame_progress.append(prog)
all_progress.append(torch.tensor(frame_progress, device=values.device, dtype=values.dtype))
# Stack to get (B, T) tensor where T is the temporal sequence length
target = torch.stack(all_progress, dim=1) # (B, max_seq_len)
# Apply stride/dropout indexing to match the processed frames
target = target[:, idx]
elif "index" in batch and hasattr(self, 'episode_data_index'):
# Fallback: Use global index if available
global_indices = batch["index"] # Shape: (B,)
# For each index, find which episode it belongs to and its position
progress_values = []
for global_idx in global_indices:
# Find which episode this index belongs to
episode_starts = self.episode_data_index["from"]
episode_ends = self.episode_data_index["to"]
# Find the episode by checking which range the index falls into
episode_idx = None
frame_in_episode = None
for ep_idx in range(len(episode_starts)):
if episode_starts[ep_idx] <= global_idx < episode_ends[ep_idx]:
episode_idx = ep_idx
frame_in_episode = global_idx.item() - episode_starts[ep_idx].item()
break
if episode_idx is not None:
# Calculate position within episode
ep_start = episode_starts[episode_idx].item()
ep_end = episode_ends[episode_idx].item()
ep_length = ep_end - ep_start
# Progress from 0 to 1 within the episode
progress = frame_in_episode / max(1, ep_length - 1)
else:
# Fallback if we can't find the episode (shouldn't happen)
progress = 0.5
progress_values.append(progress)
# For temporal window, use simplified linear progress
# (proper calculation would need all frame indices in the window)
T_effective = len(idx)
target = torch.tensor(progress_values, device=values.device, dtype=values.dtype)
target = target.unsqueeze(1).expand(B, T_effective) # Simple expansion
else:
raise ValueError("No episode information found in batch. Please ensure 'episode_index' and 'frame_index' keys are present.")
# During inference, we might not want to compute loss
if not self.training and target is None:
loss = values.mean() * 0.0
loss_dict["has_labels"] = 0.0
return loss, {**loss_dict, "values_mean": values.mean().item()}
# Align target with sampled timesteps
if target.dim() == 1:
target = target.unsqueeze(1) # (B, 1)
# ReWiND Loss (following the paper exactly)
# The core loss is progress regression with video rewinding augmentation
# Handle target padding to match frame sequence if needed
if target.shape[1] < self.config.max_seq_len:
# Pad targets by repeating the first value (assuming it's the earliest)
padding_needed = self.config.max_seq_len - target.shape[1]
first_target = target[:, :1] # (B, 1)
padding = first_target.expand(target.shape[0], padding_needed)
target = torch.cat([padding, target], dim=1) # Prepend padding
# 1) Main progress regression loss for matched sequences
# Target should be normalized progress from 0 to 1 (t/T)
L_progress = F.mse_loss(values, target)
import logging
# 2) Mismatched video-language pairs should predict zero progress
L_mismatch = torch.zeros((), device=values.device)
if self.training and self.config.use_mismatch_loss and values.size(0) > 1:
# Randomly shuffle language instructions within the batch
shuffled_indices = torch.randperm(B, device=values.device)
lang_mismatch = lang_emb[shuffled_indices]
logging.debug(
f"Padded targets from {target.shape[1] - padding_needed} to {self.config.max_seq_len}"
)
# Now safely index with idx
target = target[:, idx]
# Composite loss
# 1) Progress regression on z-scored values to match normalized progress labels y in [0,1]
# Debug: Check if values have enough variance
values_std = values.std()
if values_std < 1e-4:
# Early in training, model outputs are nearly constant
# Use direct MSE loss without z-scoring to encourage variance
import logging
logging.info(f"Low variance in values (std={values_std:.6f}), using direct MSE")
# Apply sigmoid directly to raw values to get them in [0,1] range
prog_pred = torch.sigmoid(values * 10.0) # Scale up to encourage learning
L_prog = F.mse_loss(prog_pred, torch.clamp(target, 0.0, 1.0))
else:
# Normal case: use z-score normalization
zV = zscore(values, eps=self.config.zscore_eps)
# Check for NaN after zscore
if torch.isnan(zV).any():
import logging
logging.warning(f"NaN after zscore. Values: {values}, zV: {zV}")
# Fallback to direct sigmoid
prog_pred = torch.sigmoid(values * 10.0)
else:
prog_pred = torch.sigmoid(zV)
L_prog = F.mse_loss(prog_pred, torch.clamp(target, 0.0, 1.0))
# Mismatched pairs: randomly shuffle language within batch and require near-zero progress
if self.training and torch.rand(()) < self.config.mismatch_lang_prob and values.size(0) > 1:
shuffled = torch.randperm(B, device=values.device)
lang_mismatch = lang_emb[shuffled]
# Forward pass with mismatched language
mismatch_feat = self.temporal(visual_seq, lang_mismatch, return_features=True)
mismatch_V = self.head(mismatch_feat).squeeze(-1)
L_prog_mismatch = F.mse_loss(
torch.sigmoid(zscore(mismatch_V, eps=self.config.zscore_eps)), torch.zeros_like(target)
)
else:
L_prog_mismatch = torch.zeros((), device=values.device)
mismatch_values = self.head(mismatch_feat).squeeze(-1)
# 2) Spatial-Aware InfoNCE: Use language to attend to relevant spatial regions
# Take late timesteps' spatial features
k = min(self.config.last_k_for_nce, spatial_features.shape[1])
late_spatial = spatial_features[:, -k:].mean(dim=1) # (B, num_patches, D)
# Mismatched pairs should predict zero progress
L_mismatch = F.mse_loss(mismatch_values, torch.zeros_like(target))
# Language queries spatial patches via cross-attention
lang_query = lang_emb.unsqueeze(1) # (B, 1, D)
attended_spatial, spatial_attn_weights = self.spatial_cross_attn(
query=lang_query, key=late_spatial, value=late_spatial, need_weights=True
)
attended_spatial = self.spatial_norm(attended_spatial).squeeze(1) # (B, D)
# Contrastive loss with spatially-attended features
attended_spatial = F.normalize(attended_spatial, dim=-1)
lang_norm = F.normalize(lang_emb, dim=-1)
logits_spatial = (attended_spatial @ lang_norm.t()) / self.config.nce_temperature # (B, B)
targets_nce = torch.arange(B, device=values.device)
L_spatial_nce = F.cross_entropy(logits_spatial, targets_nce)
# 3) ReWiND Reversible Ranking: Learn from both forward and reversed trajectories
# This teaches the model what constitutes progress vs undoing progress
L_rank_forward, L_rank_reverse = reversible_ranking_loss(
values,
target,
margin=self.config.ranking_margin,
num_pairs=self.config.num_ranking_pairs,
min_gap=self.config.min_rank_gap,
)
L_rewind = L_rank_forward + L_rank_reverse
# Check for NaNs in individual loss components
if torch.isnan(L_prog):
import logging
logging.warning(f"NaN in L_prog. Values: {values}, Target: {target}")
# Return a small loss with gradients instead of zero
L_prog = values.mean() * 0.0 + 0.01
if torch.isnan(L_spatial_nce):
import logging
logging.warning("NaN in L_spatial_nce")
# Use a dummy loss that maintains gradients
L_spatial_nce = attended_spatial.mean() * 0.0 + 0.01
if torch.isnan(L_rewind):
import logging
logging.warning("NaN in L_rewind")
# Use values to maintain gradient flow
L_rewind = values.mean() * 0.0 + 0.01
loss = (
self.config.lambda_prog * (L_prog + L_prog_mismatch)
+ self.config.lambda_spatial_nce * L_spatial_nce
+ self.config.lambda_rewind * L_rewind
)
# Final NaN check
if torch.isnan(loss):
import logging
logging.warning("NaN loss detected, using fallback loss")
# Use a small loss that maintains gradients
loss = values.mean() * 0.0 + 0.01
# Total loss is just progress regression (rewinding is handled via data augmentation)
loss = L_progress + L_mismatch
# Log individual loss components
loss_dict.update(
{
"loss_prog": L_prog.item() if not torch.isnan(L_prog) else 0.0,
"loss_prog_mismatch": L_prog_mismatch.item() if not torch.isnan(L_prog_mismatch) else 0.0,
"loss_spatial_nce": L_spatial_nce.item() if not torch.isnan(L_spatial_nce) else 0.0,
"loss_rewind_forward": L_rank_forward.item() if not torch.isnan(L_rank_forward) else 0.0,
"loss_rewind_reverse": L_rank_reverse.item() if not torch.isnan(L_rank_reverse) else 0.0,
"loss_progress": L_progress.item(),
"loss_mismatch": L_mismatch.item(),
}
)
@@ -715,243 +718,75 @@ def encode_language(
return emb
def pairwise_ranking_loss(logits: Tensor, target: Tensor, margin: float = 0.1, num_pairs: int = 32) -> Tensor:
# logits, target: (B, T)
B, T = logits.shape
if T < 2:
return logits.mean() * 0.0
# Sample pairs i<j and enforce r_j > r_i when target_j > target_i
losses = []
for _ in range(num_pairs):
i = torch.randint(0, T - 1, (B,), device=logits.device)
j = i + torch.randint(1, T - i.max(), (1,), device=logits.device)
j = j.expand_as(i)
li = logits[torch.arange(B), i]
lj = logits[torch.arange(B), j]
yi = target[torch.arange(B), i]
yj = target[torch.arange(B), j]
sign = torch.sign(yj - yi)
# hinge: max(0, margin - sign*(lj-li))
loss = F.relu(margin - sign * (lj - li))
losses.append(loss.mean())
return torch.stack(losses).mean()
def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor, Tensor]:
"""Apply video rewinding augmentation as described in ReWiND paper.
def zscore(x: Tensor, eps: float = 1e-3) -> Tensor:
"""Z-score normalization with numerical stability.
Each video in the batch has an independent chance of being rewound.
Args:
x: Tensor of shape (B, T) where B is batch size, T is sequence length
eps: Small epsilon for numerical stability
frames: Tensor of shape (B, T, C, H, W)
rewind_prob: Probability of applying rewind augmentation to each video
Returns:
Z-scored tensor of same shape as input
Augmented frames and corresponding progress labels
"""
# Handle both (B,) and (B, T) shapes
if x.dim() == 1:
x = x.unsqueeze(1) # Make it (B, 1)
B, T, C, H, W = frames.shape
device = frames.device
B, T = x.shape
# Create default progress labels (linearly increasing from 0 to 1)
default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1)
if T == 1:
# Single timestep: use tanh to bound values instead of z-score
return torch.tanh(x * 0.1)
# Apply rewind augmentation to each sample in batch independently
augmented_frames = []
augmented_progress = []
# Multiple timesteps: compute z-score across time dimension for each batch
mean = x.mean(dim=1, keepdim=True) # (B, 1)
std = x.std(dim=1, keepdim=True, unbiased=False) # (B, 1)
for b in range(B):
# Each video has independent chance of being rewound
should_rewind = torch.rand(1).item() < rewind_prob
# Check if std is valid (not zero or NaN)
std_is_valid = (std > eps) & (~torch.isnan(std))
if not should_rewind or T < 3:
# Keep original sequence
augmented_frames.append(frames[b])
augmented_progress.append(default_progress[b])
continue
# Safe std for division
std_safe = torch.where(std_is_valid, std, torch.ones_like(std))
# Apply rewinding to this video
# Split point i: between frame 2 and T-1
i = torch.randint(2, T, (1,)).item()
# Compute z-score where valid
z = (x - mean) / std_safe
# Rewind length k: between 1 and i-1 frames
k = torch.randint(1, min(i, T - i + 1), (1,)).item()
# For invalid cases (constant values across time), use tanh of centered values
z_fallback = torch.tanh((x - mean) * 0.1)
z = torch.where(std_is_valid.expand_as(z), z, z_fallback)
# Create rewound sequence: o1...oi, oi-1, ..., oi-k
forward_frames = frames[b, :i] # Frames up to split point
reverse_frames = frames[b, max(0, i-k):i].flip(dims=[0]) # Reversed frames
# Final safety clamp
z = torch.clamp(z, min=-5.0, max=5.0)
# Concatenate forward and reverse parts
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
# Check for any remaining NaNs and replace with 0
z = torch.nan_to_num(z, nan=0.0)
# Pad with zeros if needed to maintain shape
if rewound_seq.shape[0] < T:
padding = torch.zeros(T - rewound_seq.shape[0], C, H, W, device=device)
rewound_seq = torch.cat([rewound_seq, padding], dim=0)
elif rewound_seq.shape[0] > T:
rewound_seq = rewound_seq[:T]
return z
# Create corresponding progress labels
# Forward part: increasing progress
forward_progress = torch.linspace(0, i/T, i, device=device)
# Reverse part: decreasing progress
reverse_progress = torch.linspace(i/T, max(0, (i-k)/T), k, device=device)
rewound_progress = torch.cat([forward_progress, reverse_progress])
def temporal_logistic_ranking(
values: Tensor, margin: float = 0.1, min_gap: int = 1, num_pairs: int = 64
) -> Tensor:
"""VLC-style temporal monotonicity: encourage V[j] > V[i] for j>i.
# Pad progress if needed
if rewound_progress.shape[0] < T:
padding = torch.zeros(T - rewound_progress.shape[0], device=device)
rewound_progress = torch.cat([rewound_progress, padding])
elif rewound_progress.shape[0] > T:
rewound_progress = rewound_progress[:T]
Samples pairs (i<j) with a minimum gap and applies softplus(m - (Vj - Vi)).
"""
B, T = values.shape
if T < 2:
return values.mean() * 0.0
losses = []
device = values.device
for _ in range(num_pairs):
i = torch.randint(0, max(1, T - min_gap), (B,), device=device)
j = i + torch.randint(min_gap, T - i.max(), (1,), device=device)
j = j.expand_as(i)
vi = values[torch.arange(B), i]
vj = values[torch.arange(B), j]
losses.append(F.softplus(margin - (vj - vi)).mean())
return torch.stack(losses).mean()
augmented_frames.append(rewound_seq)
augmented_progress.append(rewound_progress)
def reversible_ranking_loss(
values: Tensor, target: Tensor, margin: float = 0.1, num_pairs: int = 64, min_gap: int = 1
) -> tuple[Tensor, Tensor]:
"""ReWiND-style reversible ranking: learn from both forward and reversed trajectories.
Key insight: If a trajectory shows progress forward, its reverse shows undoing progress.
By training on both, the model learns what constitutes progress vs regression.
Args:
values: (B, T) predicted values
target: (B, T) progress labels (0 to 1 for forward progress)
margin: Margin for ranking loss
num_pairs: Number of (far, near) pairs to sample
min_gap: Minimum temporal gap between pairs
Returns:
forward_loss: Loss from forward trajectory pairs
reverse_loss: Loss from reversed trajectory pairs
"""
B, T = values.shape
if T < 2:
zero_loss = values.mean() * 0.0
return zero_loss, zero_loss
device = values.device
# Forward trajectory ranking: later frames should have higher values
forward_losses = []
for _ in range(num_pairs // 2):
# Sample far-near pairs (far is earlier, near is later)
far_idx = torch.randint(0, max(1, T - min_gap), (B,), device=device)
near_idx = far_idx + torch.randint(min_gap, T - far_idx.max(), (1,), device=device)
near_idx = near_idx.expand_as(far_idx)
v_far = values[torch.arange(B), far_idx]
v_near = values[torch.arange(B), near_idx]
# Near (later) should have higher value than far (earlier)
forward_losses.append(F.softplus(margin - (v_near - v_far)).mean())
# Reversed trajectory ranking: treat reversed sequence with inverted progress
# Reverse both values and targets
reversed_values = values.flip(dims=[1]) # Reverse time dimension
reversed_target = 1.0 - target.flip(dims=[1]) # Invert and reverse progress
reverse_losses = []
for _ in range(num_pairs // 2):
# In reversed trajectory, what was "later" is now "earlier"
far_idx = torch.randint(0, max(1, T - min_gap), (B,), device=device)
near_idx = far_idx + torch.randint(min_gap, T - far_idx.max(), (1,), device=device)
near_idx = near_idx.expand_as(far_idx)
v_far_rev = reversed_values[torch.arange(B), far_idx]
v_near_rev = reversed_values[torch.arange(B), near_idx]
# In reversed trajectory with inverted progress,
# near (which was originally earlier) should still have higher value
reverse_losses.append(F.softplus(margin - (v_near_rev - v_far_rev)).mean())
forward_loss = torch.stack(forward_losses).mean() if forward_losses else values.mean() * 0.0
reverse_loss = torch.stack(reverse_losses).mean() if reverse_losses else values.mean() * 0.0
return forward_loss, reverse_loss
def intra_trajectory_directional_ranking(
values: Tensor, progress: Tensor, margin: float = 0.2, num_pairs: int = 64, min_gap: int = 1
) -> Tensor:
"""Directional ranking within trajectory based on progress labels.
For pairs i<j within a trajectory:
- If progress increases (y_j > y_i), enforce V_j > V_i
- If progress decreases (y_j < y_i), enforce V_j < V_i
- Ignore pairs where progress is unchanged
Uses logistic loss: log(1 + exp(m - s_ij * (V_j - V_i)))
where s_ij = sign(y_j - y_i)
"""
B, T = values.shape
if T < 2:
return values.mean() * 0.0
losses = []
device = values.device
for _ in range(num_pairs):
# Sample time pairs i < j
i = torch.randint(0, max(1, T - min_gap), (B,), device=device)
max_j = min(T, i.max() + T - min_gap)
j = i + torch.randint(min_gap, max_j - i.min(), (1,), device=device)
j = j.expand_as(i).clamp(max=T - 1)
# Get values and progress at sampled times
vi = values[torch.arange(B), i]
vj = values[torch.arange(B), j]
yi = progress[torch.arange(B), i]
yj = progress[torch.arange(B), j]
# Compute direction sign
s_ij = torch.sign(yj - yi)
# Only compute loss for non-zero progress differences
mask = s_ij != 0
if mask.any():
diff = vj - vi
loss = torch.log1p(torch.exp(margin - s_ij * diff))
losses.append(loss[mask].mean())
return torch.stack(losses).mean() if losses else values.mean() * 0.0
def inter_instruction_contrastive_ranking(
values_correct: Tensor, values_incorrect: Tensor, margin: float = 0.2
) -> Tensor:
"""Ranking between correct and incorrect instructions for same frames.
Enforces V_t(z) > V_t(z') where z is correct instruction and z' is incorrect.
Uses logistic loss: log(1 + exp(m - (V_t(z) - V_t(z'))))
"""
diff = values_correct - values_incorrect
return torch.log1p(torch.exp(margin - diff)).mean()
def flatness_under_mismatch(values: Tensor, epsilon: float = 0.05, num_pairs: int = 32) -> Tensor:
"""Enforce flat values over time for mismatched instructions.
For trajectory with wrong instruction, V should not change much over time.
Uses Huber loss to allow small variations within epsilon band.
"""
B, T = values.shape
if T < 2:
return values.mean() * 0.0
losses = []
device = values.device
for _ in range(num_pairs):
i = torch.randint(0, T - 1, (B,), device=device)
j = torch.randint(i.min() + 1, T, (1,), device=device)
j = j.expand_as(i)
vi = values[torch.arange(B), i]
vj = values[torch.arange(B), j]
# Huber loss with small delta for near-zero target
diff = vj - vi
loss = F.huber_loss(diff, torch.zeros_like(diff), delta=epsilon)
losses.append(loss)
return torch.stack(losses).mean()
return torch.stack(augmented_frames), torch.stack(augmented_progress)
+6 -3
View File
@@ -123,11 +123,14 @@ Default weights: $\lambda_{\text{prog}}=1.0$, $\lambda_{\text{spatial-nce}}=0.5$
- Visualize to check [x]
- Implement eval score or metric that is robust and can deal with generalization/is a good metric to try different architectures. And use it in an eval jupyter notebook with visalization of the live reward next to the video for part of the dataset: VOC score and score with correct and incorrect language captions [x]
- Do first training [x]
- Try different losses []
- Only rewind loss then eval []
- Implement on-the-fly progress label generation (no need for pre-annotated rewards) [x]
- Try different losses
- Only rewind loss [x]
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot
- Test only rewind loss (evaluate) []
- Check rewind implementatyion by hand []
- Only vlc loss then eval []
- Vlc + rewind loss then eval []
- Convert 1% of bc-z []
- Cleanup code []
- Try DINO v3 as encoder Base 86 M: https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m with HuggingFaceTB/SmolLM2-135M-Instruct ? []
- Add more artificial text to dataset generated by vlm (google gemini) []
+3
View File
@@ -136,9 +136,12 @@ def train(cfg: TrainPipelineConfig):
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Creating policy")
# Pass episode_data_index for RLearN to calculate proper progress
episode_data_index = dataset.episode_data_index if hasattr(dataset, 'episode_data_index') else None
policy = make_policy(
cfg=cfg.policy,
ds_meta=dataset.meta,
episode_data_index=episode_data_index,
)
preprocessor, postprocessor = make_processor(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
@@ -0,0 +1,59 @@
#!/usr/bin/env python
import torch
import numpy as np
from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
from lerobot.policies.rlearn.modeling_rlearn import RLearNPolicy
from lerobot.policies.rlearn.evaluation import RLearnEvaluator
def test_temporal_evaluation():
"""Test that evaluation creates proper temporal sequences with past frames."""
# Create a simple config
config = RLearNConfig(
max_seq_len=4, # Small for testing
dim_model=64, # Small for testing
n_heads=2,
n_layers=2,
)
# Create model (will be randomly initialized)
model = RLearNPolicy(config)
model.eval()
# Create evaluator
evaluator = RLearnEvaluator(model, device="cpu")
# Create test episode: 8 frames of 3x64x64 images
T, C, H, W = 8, 3, 64, 64
frames = torch.randn(T, C, H, W)
language = "test instruction"
print(f"Input episode shape: {frames.shape}")
print(f"Model expects sequences of length: {config.max_seq_len}")
# Test the evaluation
rewards = evaluator.predict_episode_rewards(frames, language, batch_size=4)
print(f"Output rewards shape: {rewards.shape}")
print(f"Rewards: {rewards}")
# Verify we get one reward per frame
assert len(rewards) == T, f"Expected {T} rewards, got {len(rewards)}"
print("✅ Test passed! Evaluation correctly processes temporal sequences.")
# Test with very short episode (shorter than max_seq_len)
short_frames = torch.randn(2, C, H, W) # Only 2 frames
short_rewards = evaluator.predict_episode_rewards(short_frames, language)
print(f"\nShort episode shape: {short_frames.shape}")
print(f"Short rewards shape: {short_rewards.shape}")
assert len(short_rewards) == 2, f"Expected 2 rewards, got {len(short_rewards)}"
print("✅ Short episode test passed!")
if __name__ == "__main__":
test_temporal_evaluation()