mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
use only rewind loss
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -33,6 +33,8 @@ class DatasetConfig:
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
# Percentage of dataset to use (0-100). If set, overrides episodes parameter.
|
||||
percentage: float | None = None
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
|
||||
@@ -87,10 +87,21 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
|
||||
# Handle percentage parameter
|
||||
episodes = cfg.dataset.episodes
|
||||
if cfg.dataset.percentage is not None:
|
||||
# Calculate episodes based on percentage
|
||||
total_episodes = ds_meta.total_episodes
|
||||
num_episodes_to_use = max(1, int(total_episodes * cfg.dataset.percentage / 100))
|
||||
episodes = list(range(num_episodes_to_use))
|
||||
import logging
|
||||
logging.info(f"Using {cfg.dataset.percentage}% of dataset: {num_episodes_to_use}/{total_episodes} episodes")
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
episodes=episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
|
||||
@@ -20,6 +20,7 @@ from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
from .rlearn.configuration_rlearn import RLearNConfig as RLearNConfig
|
||||
|
||||
__all__ = [
|
||||
"ACTConfig",
|
||||
@@ -28,4 +29,5 @@ __all__ = [
|
||||
"SmolVLAConfig",
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
"RLearNConfig",
|
||||
]
|
||||
|
||||
@@ -244,6 +244,7 @@ def make_policy(
|
||||
cfg: PreTrainedConfig,
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
env_cfg: EnvConfig | None = None,
|
||||
episode_data_index: dict | None = None,
|
||||
) -> PreTrainedPolicy:
|
||||
"""Make an instance of a policy class.
|
||||
|
||||
@@ -301,6 +302,10 @@ def make_policy(
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
# Pass episode_data_index for RLearN policy to calculate proper progress
|
||||
if cfg.type == "rlearn" and episode_data_index is not None:
|
||||
kwargs["episode_data_index"] = episode_data_index
|
||||
|
||||
if cfg.pretrained_path:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
|
||||
@@ -61,31 +61,14 @@ class RLearNConfig(PreTrainedConfig):
|
||||
# Training
|
||||
learning_rate: float = 1e-4
|
||||
weight_decay: float = 0.01
|
||||
loss_type: str = "composite" # Always use composite loss with spatial awareness
|
||||
ranking_margin: float = 0.1
|
||||
|
||||
# Composite loss weights (with spatial awareness and ReWiND reversibility)
|
||||
lambda_prog: float = 1.0 # Progress regression weight
|
||||
lambda_spatial_nce: float = 0.5 # Spatial-aware InfoNCE weight
|
||||
lambda_rewind: float = 0.4 # ReWiND reversible ranking weight
|
||||
# ReWiND-specific parameters
|
||||
use_video_rewind: bool = True # Enable video rewinding augmentation
|
||||
rewind_prob: float = 0.5 # Probability of applying rewind to each batch
|
||||
use_mismatch_loss: bool = True # Enable mismatched language-video loss
|
||||
|
||||
# Loss hyperparameters
|
||||
nce_temperature: float = 0.07 # Temperature for InfoNCE
|
||||
zscore_eps: float = 1e-5 # Epsilon for z-score normalization
|
||||
min_rank_gap: int = 1 # Minimum gap for temporal ranking pairs
|
||||
num_ranking_pairs: int = 64 # Number of (far, near) pairs to sample for ReWiND
|
||||
last_k_for_nce: int = 3 # Use last k frames for InfoNCE
|
||||
mismatch_lang_prob: float = 0.2 # Probability of language mismatch augmentation
|
||||
|
||||
# Value-based pairwise loss hyperparameters (for value_pairwise mode)
|
||||
lambda_dir: float = 1.0 # Intra-trajectory directional ranking
|
||||
lambda_text: float = 0.5 # Inter-instruction contrastive ranking
|
||||
lambda_flat: float = 0.25 # Flatness under mismatch
|
||||
dir_margin: float = 0.2 # Margin for directional ranking
|
||||
text_margin: float = 0.2 # Margin for text contrastive ranking
|
||||
flat_epsilon: float = 0.05 # Epsilon band for flatness loss
|
||||
num_pairs_per_loss: int = 64 # Number of pairs to sample per loss term
|
||||
use_hard_negatives: bool = True # Whether to generate hard negative instructions
|
||||
# Loss hyperparameters (simplified for ReWiND)
|
||||
# The main loss is just MSE between predicted and target progress
|
||||
|
||||
# Normalization presets
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
@@ -116,7 +99,7 @@ class RLearNConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> list | None:
|
||||
# By default we supervise every provided timestep equally.
|
||||
# ReWiND generates progress labels on-the-fly, doesn't need reward data
|
||||
return None
|
||||
|
||||
def get_optimizer_preset(self): # type: ignore[override]
|
||||
|
||||
@@ -241,56 +241,70 @@ class RLearnEvaluator:
|
||||
@torch.no_grad()
|
||||
def predict_episode_rewards(self, frames: Tensor, language: str, batch_size: int = 16) -> np.ndarray:
|
||||
"""
|
||||
Predict rewards for a single episode.
|
||||
Predict rewards for a single episode using proper temporal sequences.
|
||||
|
||||
Args:
|
||||
frames: Video frames tensor of shape (T, C, H, W)
|
||||
language: Language instruction string
|
||||
batch_size: Maximum sequence length to process at once
|
||||
batch_size: Maximum number of temporal sequences to process at once
|
||||
|
||||
Returns:
|
||||
Predicted rewards array of shape (T,)
|
||||
"""
|
||||
T = frames.shape[0]
|
||||
max_seq_len = self.model.config.max_seq_len
|
||||
|
||||
# Preprocess frames to match model expectations
|
||||
processed_frames = self._preprocess_frames(frames)
|
||||
|
||||
# Process in chunks if episode is very long
|
||||
if T <= batch_size:
|
||||
# Single batch
|
||||
# Create temporal sequences for each frame
|
||||
# For frame i, we want frames [i-max_seq_len+1, ..., i-1, i]
|
||||
temporal_sequences = []
|
||||
|
||||
for i in range(T):
|
||||
# Create sequence ending at frame i
|
||||
seq_frames = []
|
||||
for j in range(max(0, i - max_seq_len + 1), i + 1):
|
||||
# Use frame j if available, otherwise repeat the first available frame
|
||||
frame_idx = max(0, min(j, T - 1))
|
||||
seq_frames.append(processed_frames[frame_idx])
|
||||
|
||||
# Pad sequence to max_seq_len by repeating the first frame if needed
|
||||
while len(seq_frames) < max_seq_len:
|
||||
seq_frames.insert(0, seq_frames[0]) # Prepend first frame
|
||||
|
||||
# Take only the last max_seq_len frames if we have too many
|
||||
seq_frames = seq_frames[-max_seq_len:]
|
||||
|
||||
temporal_sequences.append(torch.stack(seq_frames)) # (max_seq_len, C, H, W)
|
||||
|
||||
# Stack all temporal sequences: (T, max_seq_len, C, H, W)
|
||||
all_sequences = torch.stack(temporal_sequences)
|
||||
|
||||
# Process in batches
|
||||
rewards = []
|
||||
for i in range(0, T, batch_size):
|
||||
end_idx = min(i + batch_size, T)
|
||||
batch_sequences = all_sequences[i:end_idx].to(self.device) # (B, max_seq_len, C, H, W)
|
||||
|
||||
# Create batch for model
|
||||
batch = {
|
||||
OBS_IMAGES: processed_frames.unsqueeze(0).to(self.device), # (1, T, C, H, W)
|
||||
OBS_LANGUAGE: [language],
|
||||
OBS_IMAGES: batch_sequences, # (B, T, C, H, W) format expected by model
|
||||
OBS_LANGUAGE: [language] * batch_sequences.shape[0],
|
||||
}
|
||||
|
||||
# Use the new predict_rewards method
|
||||
values = self.model.predict_rewards(batch) # (1, T')
|
||||
rewards = values.squeeze(0).cpu().numpy() # (T',)
|
||||
# Predict rewards - model returns (B, T') but we want the last timestep for each sequence
|
||||
values = self.model.predict_rewards(batch) # (B, T')
|
||||
|
||||
else:
|
||||
# Process in overlapping chunks to handle very long episodes
|
||||
rewards = []
|
||||
stride = batch_size // 2 # 50% overlap
|
||||
# Take the last timestep prediction for each sequence (represents current frame reward)
|
||||
if values.dim() == 2:
|
||||
batch_rewards = values[:, -1].cpu().numpy() # (B,) - last timestep
|
||||
else:
|
||||
batch_rewards = values.cpu().numpy() # (B,) - already single timestep
|
||||
|
||||
for i in range(0, T, stride):
|
||||
end_idx = min(i + batch_size, T)
|
||||
chunk_frames = processed_frames[i:end_idx]
|
||||
rewards.extend(batch_rewards)
|
||||
|
||||
batch = {OBS_IMAGES: chunk_frames.unsqueeze(0).to(self.device), OBS_LANGUAGE: [language]}
|
||||
|
||||
chunk_values = self.model.predict_rewards(batch)
|
||||
chunk_rewards = chunk_values.squeeze(0).cpu().numpy()
|
||||
|
||||
# For overlapping chunks, only take the first half (except for the last chunk)
|
||||
if i + batch_size < T:
|
||||
rewards.extend(chunk_rewards[:stride])
|
||||
else:
|
||||
rewards.extend(chunk_rewards)
|
||||
|
||||
rewards = np.array(rewards[:T]) # Ensure exact length
|
||||
|
||||
return rewards
|
||||
return np.array(rewards[:T]) # Ensure exact length
|
||||
|
||||
def _preprocess_frames(self, frames: Tensor) -> Tensor:
|
||||
"""
|
||||
|
||||
@@ -15,7 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
RLearN: Video-Language Conditioned Reward Model
|
||||
RLearN: Video-Language Conditioned Reward Model (ReWiND Implementation)
|
||||
|
||||
This implementation follows the ReWiND paper approach:
|
||||
- Automatically generates linear progress labels (0 to 1) for each episode
|
||||
- No need for pre-annotated rewards in the dataset
|
||||
- Applies video rewinding augmentation to create synthetic failure trajectories
|
||||
|
||||
Inputs
|
||||
- images: (B, T, C, H, W) sequence of frames (or single frame with T=1)
|
||||
@@ -95,9 +100,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
config_class = RLearNConfig
|
||||
name = "rlearn"
|
||||
|
||||
def __init__(self, config: RLearNConfig):
|
||||
def __init__(self, config: RLearNConfig, episode_data_index: dict = None):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.episode_data_index = episode_data_index # Store episode boundaries for progress calculation
|
||||
|
||||
# Encoders
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
@@ -161,15 +167,6 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
if config.use_tanh_head:
|
||||
head_layers.append(nn.Tanh())
|
||||
self.head = nn.Sequential(*head_layers)
|
||||
# Projection from scalar value summary to text embedding dim for InfoNCE
|
||||
self.value_to_text_proj = nn.Linear(1, config.dim_model)
|
||||
|
||||
# Spatial attention for InfoNCE
|
||||
self.spatial_cross_attn = nn.MultiheadAttention(
|
||||
embed_dim=config.dim_model, num_heads=config.n_heads, batch_first=True
|
||||
)
|
||||
self.spatial_norm = nn.LayerNorm(config.dim_model)
|
||||
|
||||
# Simple frame dropout probability
|
||||
self.frame_dropout_p = config.frame_dropout_p
|
||||
self.stride = max(1, config.stride)
|
||||
@@ -274,12 +271,14 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
return batch
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Compute training loss and logs.
|
||||
"""Compute ReWiND training loss with on-the-fly progress label generation.
|
||||
|
||||
Expected batch keys:
|
||||
- OBS_IMAGES: list[Tensor] of shape [(B, C, H, W), ...] per time step or stacked (B, T, C, H, W)
|
||||
- OBS_LANGUAGE: optional string tokens already tokenized externally or raw strings
|
||||
- REWARD: (B, T) or (B,) target rewards
|
||||
|
||||
Note: Progress labels (0 to 1) are generated automatically for each episode.
|
||||
No REWARD key is needed in the batch.
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
@@ -288,6 +287,13 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
|
||||
B, T, C, H, W = frames.shape
|
||||
|
||||
# Apply video rewinding augmentation during training
|
||||
if self.training and self.config.use_video_rewind:
|
||||
frames, augmented_target = apply_video_rewind(frames, rewind_prob=self.config.rewind_prob)
|
||||
# Use augmented progress labels if rewinding was applied
|
||||
if REWARD in batch:
|
||||
target = augmented_target
|
||||
|
||||
# Apply stride and frame dropout during training
|
||||
idx = torch.arange(0, T, self.stride, device=frames.device)
|
||||
if self.training and self.frame_dropout_p > 0.0 and T > 1:
|
||||
@@ -328,24 +334,15 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# Encode through vision model
|
||||
vision_outputs = self.vision_encoder(pixel_values=pixel_values)
|
||||
|
||||
# Extract BOTH CLS token and spatial patches
|
||||
# Extract CLS token for temporal modeling
|
||||
if hasattr(vision_outputs, "last_hidden_state"):
|
||||
all_tokens = vision_outputs.last_hidden_state # (BT, num_tokens, D)
|
||||
cls_tokens = all_tokens[:, 0] # (BT, D) - CLS token for temporal modeling
|
||||
spatial_tokens = all_tokens[:, 1:] # (BT, num_patches, D) - spatial patches
|
||||
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D) - CLS token
|
||||
else:
|
||||
raise RuntimeError("Vision encoder must output last_hidden_state with spatial features")
|
||||
raise RuntimeError("Vision encoder must output last_hidden_state")
|
||||
|
||||
# Project CLS tokens for temporal sequence
|
||||
visual_seq = self.visual_proj(cls_tokens).reshape(B, -1, self.config.dim_model) # (B, T', D)
|
||||
|
||||
# Keep spatial features for spatial-aware losses (project them too)
|
||||
# Assuming 16x16 patches for 256x256 image with patch_size=16
|
||||
num_patches = spatial_tokens.shape[1]
|
||||
spatial_features = self.visual_proj(spatial_tokens).reshape(
|
||||
B, -1, num_patches, self.config.dim_model
|
||||
) # (B, T', num_patches, D)
|
||||
|
||||
# Add temporal positional encodings and optional first-frame bias
|
||||
pe = (
|
||||
self.positional_encoding[: visual_seq.shape[1]]
|
||||
@@ -362,150 +359,156 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
temporal_features = self.temporal(visual_seq, lang_emb, return_features=True) # (B, T', D)
|
||||
values = self.head(temporal_features).squeeze(-1) # (B, T')
|
||||
|
||||
# Targets
|
||||
target = batch.get(REWARD, None)
|
||||
# Generate progress labels on-the-fly (ReWiND approach)
|
||||
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
|
||||
loss_dict: dict[str, float] = {}
|
||||
if target is None:
|
||||
# If no labels, return zeros loss and logits for inference
|
||||
|
||||
# Check if video rewinding already set the target
|
||||
if self.training and self.config.use_video_rewind and 'augmented_target' in locals():
|
||||
# Use the augmented target from video rewinding
|
||||
target = augmented_target
|
||||
else:
|
||||
# Calculate true episode progress using episode_index and frame_index from batch
|
||||
if "episode_index" in batch and "frame_index" in batch and hasattr(self, 'episode_data_index'):
|
||||
# Get episode indices and frame indices from batch
|
||||
episode_indices = batch["episode_index"] # Shape: (B,)
|
||||
frame_indices = batch["frame_index"] # Shape: (B,)
|
||||
|
||||
# Calculate progress for the current frame in each sample
|
||||
progress_values = []
|
||||
|
||||
for b_idx in range(B):
|
||||
ep_idx = episode_indices[b_idx].item()
|
||||
frame_idx = frame_indices[b_idx].item()
|
||||
|
||||
# Get episode boundaries
|
||||
ep_start = self.episode_data_index["from"][ep_idx].item()
|
||||
ep_end = self.episode_data_index["to"][ep_idx].item()
|
||||
ep_length = ep_end - ep_start
|
||||
|
||||
# Progress from 0 to 1 within the episode
|
||||
# frame_index is relative to the episode (0-based within episode)
|
||||
progress = frame_idx / max(1, ep_length - 1)
|
||||
progress_values.append(progress)
|
||||
|
||||
# Create progress tensor for the current frame (last in temporal sequence)
|
||||
current_progress = torch.tensor(progress_values, device=values.device, dtype=values.dtype)
|
||||
|
||||
# Now calculate progress for ALL frames in the temporal window
|
||||
# The observation_delta_indices tell us which frames we're looking at
|
||||
delta_indices = self.config.observation_delta_indices # e.g., [-15, -14, ..., 0]
|
||||
|
||||
# Calculate progress for each frame in the temporal window
|
||||
all_progress = []
|
||||
for delta in delta_indices:
|
||||
# For each sample, calculate the progress of the frame at delta offset
|
||||
frame_progress = []
|
||||
for b_idx in range(B):
|
||||
ep_idx = episode_indices[b_idx].item()
|
||||
frame_idx = frame_indices[b_idx].item()
|
||||
|
||||
# Calculate the actual frame index with delta
|
||||
target_frame_idx = frame_idx + delta
|
||||
|
||||
# Get episode boundaries
|
||||
ep_start = self.episode_data_index["from"][ep_idx].item()
|
||||
ep_end = self.episode_data_index["to"][ep_idx].item()
|
||||
ep_length = ep_end - ep_start
|
||||
|
||||
# Clamp to episode boundaries (frame_index is relative to episode)
|
||||
target_frame_idx = max(0, min(ep_length - 1, target_frame_idx))
|
||||
|
||||
# Calculate progress for this frame
|
||||
prog = target_frame_idx / max(1, ep_length - 1)
|
||||
frame_progress.append(prog)
|
||||
|
||||
all_progress.append(torch.tensor(frame_progress, device=values.device, dtype=values.dtype))
|
||||
|
||||
# Stack to get (B, T) tensor where T is the temporal sequence length
|
||||
target = torch.stack(all_progress, dim=1) # (B, max_seq_len)
|
||||
|
||||
# Apply stride/dropout indexing to match the processed frames
|
||||
target = target[:, idx]
|
||||
|
||||
elif "index" in batch and hasattr(self, 'episode_data_index'):
|
||||
# Fallback: Use global index if available
|
||||
global_indices = batch["index"] # Shape: (B,)
|
||||
|
||||
# For each index, find which episode it belongs to and its position
|
||||
progress_values = []
|
||||
|
||||
for global_idx in global_indices:
|
||||
# Find which episode this index belongs to
|
||||
episode_starts = self.episode_data_index["from"]
|
||||
episode_ends = self.episode_data_index["to"]
|
||||
|
||||
# Find the episode by checking which range the index falls into
|
||||
episode_idx = None
|
||||
frame_in_episode = None
|
||||
for ep_idx in range(len(episode_starts)):
|
||||
if episode_starts[ep_idx] <= global_idx < episode_ends[ep_idx]:
|
||||
episode_idx = ep_idx
|
||||
frame_in_episode = global_idx.item() - episode_starts[ep_idx].item()
|
||||
break
|
||||
|
||||
if episode_idx is not None:
|
||||
# Calculate position within episode
|
||||
ep_start = episode_starts[episode_idx].item()
|
||||
ep_end = episode_ends[episode_idx].item()
|
||||
ep_length = ep_end - ep_start
|
||||
|
||||
# Progress from 0 to 1 within the episode
|
||||
progress = frame_in_episode / max(1, ep_length - 1)
|
||||
else:
|
||||
# Fallback if we can't find the episode (shouldn't happen)
|
||||
progress = 0.5
|
||||
|
||||
progress_values.append(progress)
|
||||
|
||||
# For temporal window, use simplified linear progress
|
||||
# (proper calculation would need all frame indices in the window)
|
||||
T_effective = len(idx)
|
||||
target = torch.tensor(progress_values, device=values.device, dtype=values.dtype)
|
||||
target = target.unsqueeze(1).expand(B, T_effective) # Simple expansion
|
||||
|
||||
else:
|
||||
raise ValueError("No episode information found in batch. Please ensure 'episode_index' and 'frame_index' keys are present.")
|
||||
|
||||
# During inference, we might not want to compute loss
|
||||
if not self.training and target is None:
|
||||
loss = values.mean() * 0.0
|
||||
loss_dict["has_labels"] = 0.0
|
||||
return loss, {**loss_dict, "values_mean": values.mean().item()}
|
||||
|
||||
# Align target with sampled timesteps
|
||||
if target.dim() == 1:
|
||||
target = target.unsqueeze(1) # (B, 1)
|
||||
# ReWiND Loss (following the paper exactly)
|
||||
# The core loss is progress regression with video rewinding augmentation
|
||||
|
||||
# Handle target padding to match frame sequence if needed
|
||||
if target.shape[1] < self.config.max_seq_len:
|
||||
# Pad targets by repeating the first value (assuming it's the earliest)
|
||||
padding_needed = self.config.max_seq_len - target.shape[1]
|
||||
first_target = target[:, :1] # (B, 1)
|
||||
padding = first_target.expand(target.shape[0], padding_needed)
|
||||
target = torch.cat([padding, target], dim=1) # Prepend padding
|
||||
# 1) Main progress regression loss for matched sequences
|
||||
# Target should be normalized progress from 0 to 1 (t/T)
|
||||
L_progress = F.mse_loss(values, target)
|
||||
|
||||
import logging
|
||||
# 2) Mismatched video-language pairs should predict zero progress
|
||||
L_mismatch = torch.zeros((), device=values.device)
|
||||
if self.training and self.config.use_mismatch_loss and values.size(0) > 1:
|
||||
# Randomly shuffle language instructions within the batch
|
||||
shuffled_indices = torch.randperm(B, device=values.device)
|
||||
lang_mismatch = lang_emb[shuffled_indices]
|
||||
|
||||
logging.debug(
|
||||
f"Padded targets from {target.shape[1] - padding_needed} to {self.config.max_seq_len}"
|
||||
)
|
||||
|
||||
# Now safely index with idx
|
||||
target = target[:, idx]
|
||||
|
||||
# Composite loss
|
||||
# 1) Progress regression on z-scored values to match normalized progress labels y in [0,1]
|
||||
|
||||
# Debug: Check if values have enough variance
|
||||
values_std = values.std()
|
||||
if values_std < 1e-4:
|
||||
# Early in training, model outputs are nearly constant
|
||||
# Use direct MSE loss without z-scoring to encourage variance
|
||||
import logging
|
||||
|
||||
logging.info(f"Low variance in values (std={values_std:.6f}), using direct MSE")
|
||||
# Apply sigmoid directly to raw values to get them in [0,1] range
|
||||
prog_pred = torch.sigmoid(values * 10.0) # Scale up to encourage learning
|
||||
L_prog = F.mse_loss(prog_pred, torch.clamp(target, 0.0, 1.0))
|
||||
else:
|
||||
# Normal case: use z-score normalization
|
||||
zV = zscore(values, eps=self.config.zscore_eps)
|
||||
# Check for NaN after zscore
|
||||
if torch.isnan(zV).any():
|
||||
import logging
|
||||
|
||||
logging.warning(f"NaN after zscore. Values: {values}, zV: {zV}")
|
||||
# Fallback to direct sigmoid
|
||||
prog_pred = torch.sigmoid(values * 10.0)
|
||||
else:
|
||||
prog_pred = torch.sigmoid(zV)
|
||||
|
||||
L_prog = F.mse_loss(prog_pred, torch.clamp(target, 0.0, 1.0))
|
||||
|
||||
# Mismatched pairs: randomly shuffle language within batch and require near-zero progress
|
||||
if self.training and torch.rand(()) < self.config.mismatch_lang_prob and values.size(0) > 1:
|
||||
shuffled = torch.randperm(B, device=values.device)
|
||||
lang_mismatch = lang_emb[shuffled]
|
||||
# Forward pass with mismatched language
|
||||
mismatch_feat = self.temporal(visual_seq, lang_mismatch, return_features=True)
|
||||
mismatch_V = self.head(mismatch_feat).squeeze(-1)
|
||||
L_prog_mismatch = F.mse_loss(
|
||||
torch.sigmoid(zscore(mismatch_V, eps=self.config.zscore_eps)), torch.zeros_like(target)
|
||||
)
|
||||
else:
|
||||
L_prog_mismatch = torch.zeros((), device=values.device)
|
||||
mismatch_values = self.head(mismatch_feat).squeeze(-1)
|
||||
|
||||
# 2) Spatial-Aware InfoNCE: Use language to attend to relevant spatial regions
|
||||
# Take late timesteps' spatial features
|
||||
k = min(self.config.last_k_for_nce, spatial_features.shape[1])
|
||||
late_spatial = spatial_features[:, -k:].mean(dim=1) # (B, num_patches, D)
|
||||
# Mismatched pairs should predict zero progress
|
||||
L_mismatch = F.mse_loss(mismatch_values, torch.zeros_like(target))
|
||||
|
||||
# Language queries spatial patches via cross-attention
|
||||
lang_query = lang_emb.unsqueeze(1) # (B, 1, D)
|
||||
attended_spatial, spatial_attn_weights = self.spatial_cross_attn(
|
||||
query=lang_query, key=late_spatial, value=late_spatial, need_weights=True
|
||||
)
|
||||
attended_spatial = self.spatial_norm(attended_spatial).squeeze(1) # (B, D)
|
||||
|
||||
# Contrastive loss with spatially-attended features
|
||||
attended_spatial = F.normalize(attended_spatial, dim=-1)
|
||||
lang_norm = F.normalize(lang_emb, dim=-1)
|
||||
logits_spatial = (attended_spatial @ lang_norm.t()) / self.config.nce_temperature # (B, B)
|
||||
targets_nce = torch.arange(B, device=values.device)
|
||||
L_spatial_nce = F.cross_entropy(logits_spatial, targets_nce)
|
||||
|
||||
# 3) ReWiND Reversible Ranking: Learn from both forward and reversed trajectories
|
||||
# This teaches the model what constitutes progress vs undoing progress
|
||||
L_rank_forward, L_rank_reverse = reversible_ranking_loss(
|
||||
values,
|
||||
target,
|
||||
margin=self.config.ranking_margin,
|
||||
num_pairs=self.config.num_ranking_pairs,
|
||||
min_gap=self.config.min_rank_gap,
|
||||
)
|
||||
L_rewind = L_rank_forward + L_rank_reverse
|
||||
|
||||
# Check for NaNs in individual loss components
|
||||
if torch.isnan(L_prog):
|
||||
import logging
|
||||
|
||||
logging.warning(f"NaN in L_prog. Values: {values}, Target: {target}")
|
||||
# Return a small loss with gradients instead of zero
|
||||
L_prog = values.mean() * 0.0 + 0.01
|
||||
|
||||
if torch.isnan(L_spatial_nce):
|
||||
import logging
|
||||
|
||||
logging.warning("NaN in L_spatial_nce")
|
||||
# Use a dummy loss that maintains gradients
|
||||
L_spatial_nce = attended_spatial.mean() * 0.0 + 0.01
|
||||
|
||||
if torch.isnan(L_rewind):
|
||||
import logging
|
||||
|
||||
logging.warning("NaN in L_rewind")
|
||||
# Use values to maintain gradient flow
|
||||
L_rewind = values.mean() * 0.0 + 0.01
|
||||
|
||||
loss = (
|
||||
self.config.lambda_prog * (L_prog + L_prog_mismatch)
|
||||
+ self.config.lambda_spatial_nce * L_spatial_nce
|
||||
+ self.config.lambda_rewind * L_rewind
|
||||
)
|
||||
|
||||
# Final NaN check
|
||||
if torch.isnan(loss):
|
||||
import logging
|
||||
|
||||
logging.warning("NaN loss detected, using fallback loss")
|
||||
# Use a small loss that maintains gradients
|
||||
loss = values.mean() * 0.0 + 0.01
|
||||
# Total loss is just progress regression (rewinding is handled via data augmentation)
|
||||
loss = L_progress + L_mismatch
|
||||
|
||||
# Log individual loss components
|
||||
loss_dict.update(
|
||||
{
|
||||
"loss_prog": L_prog.item() if not torch.isnan(L_prog) else 0.0,
|
||||
"loss_prog_mismatch": L_prog_mismatch.item() if not torch.isnan(L_prog_mismatch) else 0.0,
|
||||
"loss_spatial_nce": L_spatial_nce.item() if not torch.isnan(L_spatial_nce) else 0.0,
|
||||
"loss_rewind_forward": L_rank_forward.item() if not torch.isnan(L_rank_forward) else 0.0,
|
||||
"loss_rewind_reverse": L_rank_reverse.item() if not torch.isnan(L_rank_reverse) else 0.0,
|
||||
"loss_progress": L_progress.item(),
|
||||
"loss_mismatch": L_mismatch.item(),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -715,243 +718,75 @@ def encode_language(
|
||||
return emb
|
||||
|
||||
|
||||
def pairwise_ranking_loss(logits: Tensor, target: Tensor, margin: float = 0.1, num_pairs: int = 32) -> Tensor:
|
||||
# logits, target: (B, T)
|
||||
B, T = logits.shape
|
||||
if T < 2:
|
||||
return logits.mean() * 0.0
|
||||
# Sample pairs i<j and enforce r_j > r_i when target_j > target_i
|
||||
losses = []
|
||||
for _ in range(num_pairs):
|
||||
i = torch.randint(0, T - 1, (B,), device=logits.device)
|
||||
j = i + torch.randint(1, T - i.max(), (1,), device=logits.device)
|
||||
j = j.expand_as(i)
|
||||
li = logits[torch.arange(B), i]
|
||||
lj = logits[torch.arange(B), j]
|
||||
yi = target[torch.arange(B), i]
|
||||
yj = target[torch.arange(B), j]
|
||||
sign = torch.sign(yj - yi)
|
||||
# hinge: max(0, margin - sign*(lj-li))
|
||||
loss = F.relu(margin - sign * (lj - li))
|
||||
losses.append(loss.mean())
|
||||
return torch.stack(losses).mean()
|
||||
def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor, Tensor]:
|
||||
"""Apply video rewinding augmentation as described in ReWiND paper.
|
||||
|
||||
|
||||
def zscore(x: Tensor, eps: float = 1e-3) -> Tensor:
|
||||
"""Z-score normalization with numerical stability.
|
||||
Each video in the batch has an independent chance of being rewound.
|
||||
|
||||
Args:
|
||||
x: Tensor of shape (B, T) where B is batch size, T is sequence length
|
||||
eps: Small epsilon for numerical stability
|
||||
frames: Tensor of shape (B, T, C, H, W)
|
||||
rewind_prob: Probability of applying rewind augmentation to each video
|
||||
|
||||
Returns:
|
||||
Z-scored tensor of same shape as input
|
||||
Augmented frames and corresponding progress labels
|
||||
"""
|
||||
# Handle both (B,) and (B, T) shapes
|
||||
if x.dim() == 1:
|
||||
x = x.unsqueeze(1) # Make it (B, 1)
|
||||
B, T, C, H, W = frames.shape
|
||||
device = frames.device
|
||||
|
||||
B, T = x.shape
|
||||
# Create default progress labels (linearly increasing from 0 to 1)
|
||||
default_progress = torch.linspace(0, 1, T, device=device).unsqueeze(0).expand(B, -1)
|
||||
|
||||
if T == 1:
|
||||
# Single timestep: use tanh to bound values instead of z-score
|
||||
return torch.tanh(x * 0.1)
|
||||
# Apply rewind augmentation to each sample in batch independently
|
||||
augmented_frames = []
|
||||
augmented_progress = []
|
||||
|
||||
# Multiple timesteps: compute z-score across time dimension for each batch
|
||||
mean = x.mean(dim=1, keepdim=True) # (B, 1)
|
||||
std = x.std(dim=1, keepdim=True, unbiased=False) # (B, 1)
|
||||
for b in range(B):
|
||||
# Each video has independent chance of being rewound
|
||||
should_rewind = torch.rand(1).item() < rewind_prob
|
||||
|
||||
# Check if std is valid (not zero or NaN)
|
||||
std_is_valid = (std > eps) & (~torch.isnan(std))
|
||||
if not should_rewind or T < 3:
|
||||
# Keep original sequence
|
||||
augmented_frames.append(frames[b])
|
||||
augmented_progress.append(default_progress[b])
|
||||
continue
|
||||
|
||||
# Safe std for division
|
||||
std_safe = torch.where(std_is_valid, std, torch.ones_like(std))
|
||||
# Apply rewinding to this video
|
||||
# Split point i: between frame 2 and T-1
|
||||
i = torch.randint(2, T, (1,)).item()
|
||||
|
||||
# Compute z-score where valid
|
||||
z = (x - mean) / std_safe
|
||||
# Rewind length k: between 1 and i-1 frames
|
||||
k = torch.randint(1, min(i, T - i + 1), (1,)).item()
|
||||
|
||||
# For invalid cases (constant values across time), use tanh of centered values
|
||||
z_fallback = torch.tanh((x - mean) * 0.1)
|
||||
z = torch.where(std_is_valid.expand_as(z), z, z_fallback)
|
||||
# Create rewound sequence: o1...oi, oi-1, ..., oi-k
|
||||
forward_frames = frames[b, :i] # Frames up to split point
|
||||
reverse_frames = frames[b, max(0, i-k):i].flip(dims=[0]) # Reversed frames
|
||||
|
||||
# Final safety clamp
|
||||
z = torch.clamp(z, min=-5.0, max=5.0)
|
||||
# Concatenate forward and reverse parts
|
||||
rewound_seq = torch.cat([forward_frames, reverse_frames], dim=0)
|
||||
|
||||
# Check for any remaining NaNs and replace with 0
|
||||
z = torch.nan_to_num(z, nan=0.0)
|
||||
# Pad with zeros if needed to maintain shape
|
||||
if rewound_seq.shape[0] < T:
|
||||
padding = torch.zeros(T - rewound_seq.shape[0], C, H, W, device=device)
|
||||
rewound_seq = torch.cat([rewound_seq, padding], dim=0)
|
||||
elif rewound_seq.shape[0] > T:
|
||||
rewound_seq = rewound_seq[:T]
|
||||
|
||||
return z
|
||||
# Create corresponding progress labels
|
||||
# Forward part: increasing progress
|
||||
forward_progress = torch.linspace(0, i/T, i, device=device)
|
||||
# Reverse part: decreasing progress
|
||||
reverse_progress = torch.linspace(i/T, max(0, (i-k)/T), k, device=device)
|
||||
|
||||
rewound_progress = torch.cat([forward_progress, reverse_progress])
|
||||
|
||||
def temporal_logistic_ranking(
|
||||
values: Tensor, margin: float = 0.1, min_gap: int = 1, num_pairs: int = 64
|
||||
) -> Tensor:
|
||||
"""VLC-style temporal monotonicity: encourage V[j] > V[i] for j>i.
|
||||
# Pad progress if needed
|
||||
if rewound_progress.shape[0] < T:
|
||||
padding = torch.zeros(T - rewound_progress.shape[0], device=device)
|
||||
rewound_progress = torch.cat([rewound_progress, padding])
|
||||
elif rewound_progress.shape[0] > T:
|
||||
rewound_progress = rewound_progress[:T]
|
||||
|
||||
Samples pairs (i<j) with a minimum gap and applies softplus(m - (Vj - Vi)).
|
||||
"""
|
||||
B, T = values.shape
|
||||
if T < 2:
|
||||
return values.mean() * 0.0
|
||||
losses = []
|
||||
device = values.device
|
||||
for _ in range(num_pairs):
|
||||
i = torch.randint(0, max(1, T - min_gap), (B,), device=device)
|
||||
j = i + torch.randint(min_gap, T - i.max(), (1,), device=device)
|
||||
j = j.expand_as(i)
|
||||
vi = values[torch.arange(B), i]
|
||||
vj = values[torch.arange(B), j]
|
||||
losses.append(F.softplus(margin - (vj - vi)).mean())
|
||||
return torch.stack(losses).mean()
|
||||
augmented_frames.append(rewound_seq)
|
||||
augmented_progress.append(rewound_progress)
|
||||
|
||||
|
||||
def reversible_ranking_loss(
|
||||
values: Tensor, target: Tensor, margin: float = 0.1, num_pairs: int = 64, min_gap: int = 1
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""ReWiND-style reversible ranking: learn from both forward and reversed trajectories.
|
||||
|
||||
Key insight: If a trajectory shows progress forward, its reverse shows undoing progress.
|
||||
By training on both, the model learns what constitutes progress vs regression.
|
||||
|
||||
Args:
|
||||
values: (B, T) predicted values
|
||||
target: (B, T) progress labels (0 to 1 for forward progress)
|
||||
margin: Margin for ranking loss
|
||||
num_pairs: Number of (far, near) pairs to sample
|
||||
min_gap: Minimum temporal gap between pairs
|
||||
|
||||
Returns:
|
||||
forward_loss: Loss from forward trajectory pairs
|
||||
reverse_loss: Loss from reversed trajectory pairs
|
||||
"""
|
||||
B, T = values.shape
|
||||
if T < 2:
|
||||
zero_loss = values.mean() * 0.0
|
||||
return zero_loss, zero_loss
|
||||
|
||||
device = values.device
|
||||
|
||||
# Forward trajectory ranking: later frames should have higher values
|
||||
forward_losses = []
|
||||
for _ in range(num_pairs // 2):
|
||||
# Sample far-near pairs (far is earlier, near is later)
|
||||
far_idx = torch.randint(0, max(1, T - min_gap), (B,), device=device)
|
||||
near_idx = far_idx + torch.randint(min_gap, T - far_idx.max(), (1,), device=device)
|
||||
near_idx = near_idx.expand_as(far_idx)
|
||||
|
||||
v_far = values[torch.arange(B), far_idx]
|
||||
v_near = values[torch.arange(B), near_idx]
|
||||
|
||||
# Near (later) should have higher value than far (earlier)
|
||||
forward_losses.append(F.softplus(margin - (v_near - v_far)).mean())
|
||||
|
||||
# Reversed trajectory ranking: treat reversed sequence with inverted progress
|
||||
# Reverse both values and targets
|
||||
reversed_values = values.flip(dims=[1]) # Reverse time dimension
|
||||
reversed_target = 1.0 - target.flip(dims=[1]) # Invert and reverse progress
|
||||
|
||||
reverse_losses = []
|
||||
for _ in range(num_pairs // 2):
|
||||
# In reversed trajectory, what was "later" is now "earlier"
|
||||
far_idx = torch.randint(0, max(1, T - min_gap), (B,), device=device)
|
||||
near_idx = far_idx + torch.randint(min_gap, T - far_idx.max(), (1,), device=device)
|
||||
near_idx = near_idx.expand_as(far_idx)
|
||||
|
||||
v_far_rev = reversed_values[torch.arange(B), far_idx]
|
||||
v_near_rev = reversed_values[torch.arange(B), near_idx]
|
||||
|
||||
# In reversed trajectory with inverted progress,
|
||||
# near (which was originally earlier) should still have higher value
|
||||
reverse_losses.append(F.softplus(margin - (v_near_rev - v_far_rev)).mean())
|
||||
|
||||
forward_loss = torch.stack(forward_losses).mean() if forward_losses else values.mean() * 0.0
|
||||
reverse_loss = torch.stack(reverse_losses).mean() if reverse_losses else values.mean() * 0.0
|
||||
|
||||
return forward_loss, reverse_loss
|
||||
|
||||
|
||||
def intra_trajectory_directional_ranking(
|
||||
values: Tensor, progress: Tensor, margin: float = 0.2, num_pairs: int = 64, min_gap: int = 1
|
||||
) -> Tensor:
|
||||
"""Directional ranking within trajectory based on progress labels.
|
||||
|
||||
For pairs i<j within a trajectory:
|
||||
- If progress increases (y_j > y_i), enforce V_j > V_i
|
||||
- If progress decreases (y_j < y_i), enforce V_j < V_i
|
||||
- Ignore pairs where progress is unchanged
|
||||
|
||||
Uses logistic loss: log(1 + exp(m - s_ij * (V_j - V_i)))
|
||||
where s_ij = sign(y_j - y_i)
|
||||
"""
|
||||
B, T = values.shape
|
||||
if T < 2:
|
||||
return values.mean() * 0.0
|
||||
|
||||
losses = []
|
||||
device = values.device
|
||||
|
||||
for _ in range(num_pairs):
|
||||
# Sample time pairs i < j
|
||||
i = torch.randint(0, max(1, T - min_gap), (B,), device=device)
|
||||
max_j = min(T, i.max() + T - min_gap)
|
||||
j = i + torch.randint(min_gap, max_j - i.min(), (1,), device=device)
|
||||
j = j.expand_as(i).clamp(max=T - 1)
|
||||
|
||||
# Get values and progress at sampled times
|
||||
vi = values[torch.arange(B), i]
|
||||
vj = values[torch.arange(B), j]
|
||||
yi = progress[torch.arange(B), i]
|
||||
yj = progress[torch.arange(B), j]
|
||||
|
||||
# Compute direction sign
|
||||
s_ij = torch.sign(yj - yi)
|
||||
|
||||
# Only compute loss for non-zero progress differences
|
||||
mask = s_ij != 0
|
||||
if mask.any():
|
||||
diff = vj - vi
|
||||
loss = torch.log1p(torch.exp(margin - s_ij * diff))
|
||||
losses.append(loss[mask].mean())
|
||||
|
||||
return torch.stack(losses).mean() if losses else values.mean() * 0.0
|
||||
|
||||
|
||||
def inter_instruction_contrastive_ranking(
|
||||
values_correct: Tensor, values_incorrect: Tensor, margin: float = 0.2
|
||||
) -> Tensor:
|
||||
"""Ranking between correct and incorrect instructions for same frames.
|
||||
|
||||
Enforces V_t(z) > V_t(z') where z is correct instruction and z' is incorrect.
|
||||
Uses logistic loss: log(1 + exp(m - (V_t(z) - V_t(z'))))
|
||||
"""
|
||||
diff = values_correct - values_incorrect
|
||||
return torch.log1p(torch.exp(margin - diff)).mean()
|
||||
|
||||
|
||||
def flatness_under_mismatch(values: Tensor, epsilon: float = 0.05, num_pairs: int = 32) -> Tensor:
|
||||
"""Enforce flat values over time for mismatched instructions.
|
||||
|
||||
For trajectory with wrong instruction, V should not change much over time.
|
||||
Uses Huber loss to allow small variations within epsilon band.
|
||||
"""
|
||||
B, T = values.shape
|
||||
if T < 2:
|
||||
return values.mean() * 0.0
|
||||
|
||||
losses = []
|
||||
device = values.device
|
||||
|
||||
for _ in range(num_pairs):
|
||||
i = torch.randint(0, T - 1, (B,), device=device)
|
||||
j = torch.randint(i.min() + 1, T, (1,), device=device)
|
||||
j = j.expand_as(i)
|
||||
|
||||
vi = values[torch.arange(B), i]
|
||||
vj = values[torch.arange(B), j]
|
||||
|
||||
# Huber loss with small delta for near-zero target
|
||||
diff = vj - vi
|
||||
loss = F.huber_loss(diff, torch.zeros_like(diff), delta=epsilon)
|
||||
losses.append(loss)
|
||||
|
||||
return torch.stack(losses).mean()
|
||||
return torch.stack(augmented_frames), torch.stack(augmented_progress)
|
||||
|
||||
@@ -123,11 +123,14 @@ Default weights: $\lambda_{\text{prog}}=1.0$, $\lambda_{\text{spatial-nce}}=0.5$
|
||||
- Visualize to check [x]
|
||||
- Implement eval score or metric that is robust and can deal with generalization/is a good metric to try different architectures. And use it in an eval jupyter notebook with visalization of the live reward next to the video for part of the dataset: VOC score and score with correct and incorrect language captions [x]
|
||||
- Do first training [x]
|
||||
- Try different losses []
|
||||
- Only rewind loss then eval []
|
||||
- Implement on-the-fly progress label generation (no need for pre-annotated rewards) [x]
|
||||
- Try different losses
|
||||
- Only rewind loss [x]
|
||||
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot
|
||||
- Test only rewind loss (evaluate) []
|
||||
- Check rewind implementatyion by hand []
|
||||
- Only vlc loss then eval []
|
||||
- Vlc + rewind loss then eval []
|
||||
- Convert 1% of bc-z []
|
||||
- Cleanup code []
|
||||
- Try DINO v3 as encoder Base 86 M: https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m with HuggingFaceTB/SmolLM2-135M-Instruct ? []
|
||||
- Add more artificial text to dataset generated by vlm (google gemini) []
|
||||
|
||||
@@ -136,9 +136,12 @@ def train(cfg: TrainPipelineConfig):
|
||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
logging.info("Creating policy")
|
||||
# Pass episode_data_index for RLearN to calculate proper progress
|
||||
episode_data_index = dataset.episode_data_index if hasattr(dataset, 'episode_data_index') else None
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
episode_data_index=episode_data_index,
|
||||
)
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
|
||||
from lerobot.policies.rlearn.modeling_rlearn import RLearNPolicy
|
||||
from lerobot.policies.rlearn.evaluation import RLearnEvaluator
|
||||
|
||||
|
||||
def test_temporal_evaluation():
|
||||
"""Test that evaluation creates proper temporal sequences with past frames."""
|
||||
|
||||
# Create a simple config
|
||||
config = RLearNConfig(
|
||||
max_seq_len=4, # Small for testing
|
||||
dim_model=64, # Small for testing
|
||||
n_heads=2,
|
||||
n_layers=2,
|
||||
)
|
||||
|
||||
# Create model (will be randomly initialized)
|
||||
model = RLearNPolicy(config)
|
||||
model.eval()
|
||||
|
||||
# Create evaluator
|
||||
evaluator = RLearnEvaluator(model, device="cpu")
|
||||
|
||||
# Create test episode: 8 frames of 3x64x64 images
|
||||
T, C, H, W = 8, 3, 64, 64
|
||||
frames = torch.randn(T, C, H, W)
|
||||
language = "test instruction"
|
||||
|
||||
print(f"Input episode shape: {frames.shape}")
|
||||
print(f"Model expects sequences of length: {config.max_seq_len}")
|
||||
|
||||
# Test the evaluation
|
||||
rewards = evaluator.predict_episode_rewards(frames, language, batch_size=4)
|
||||
|
||||
print(f"Output rewards shape: {rewards.shape}")
|
||||
print(f"Rewards: {rewards}")
|
||||
|
||||
# Verify we get one reward per frame
|
||||
assert len(rewards) == T, f"Expected {T} rewards, got {len(rewards)}"
|
||||
|
||||
print("✅ Test passed! Evaluation correctly processes temporal sequences.")
|
||||
|
||||
# Test with very short episode (shorter than max_seq_len)
|
||||
short_frames = torch.randn(2, C, H, W) # Only 2 frames
|
||||
short_rewards = evaluator.predict_episode_rewards(short_frames, language)
|
||||
|
||||
print(f"\nShort episode shape: {short_frames.shape}")
|
||||
print(f"Short rewards shape: {short_rewards.shape}")
|
||||
assert len(short_rewards) == 2, f"Expected 2 rewards, got {len(short_rewards)}"
|
||||
|
||||
print("✅ Short episode test passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_temporal_evaluation()
|
||||
Reference in New Issue
Block a user