This commit is contained in:
Pepijn
2025-08-28 19:23:17 +02:00
parent bead25a58a
commit cc05067a76
6 changed files with 137 additions and 207 deletions
@@ -33,12 +33,14 @@ class RLearNConfig(PreTrainedConfig):
- Per-timestep reward logits or a single-step reward logit. - Per-timestep reward logits or a single-step reward logit.
Notes: Notes:
- This is the initial architecture. It uses frozen vision/text encoders - This follows the ReWiND paper architecture. It uses frozen vision/text encoders
(e.g. SigLIP2) and trains a lightweight temporal aggregator + head. (DINO v3 for vision, sentence-transformers for language) and trains a
lightweight temporal aggregator + head.
""" """
# Encoders # Encoders - Using DINOv2 (base) for vision and sentence-transformers for text (ReWiND paper)
model_name: str = "google/siglip2-base-patch16-256" vision_model_name: str = "facebook/dinov2-base"
text_model_name: str = "sentence-transformers/all-MiniLM-L12-v2"
freeze_backbones: bool = True freeze_backbones: bool = True
# Temporal aggregator # Temporal aggregator
+8 -8
View File
@@ -323,8 +323,8 @@ class RLearnEvaluator:
T, C, H, W = frames.shape T, C, H, W = frames.shape
# Expected input size for SigLIP2 is typically 256x256 # Expected input size for DINO v3 is 224x224
target_size = 256 target_size = 224
# Resize frames if needed # Resize frames if needed
if H != target_size or W != target_size: if H != target_size or W != target_size:
@@ -402,12 +402,12 @@ class RLearnEvaluator:
if len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]: if len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]:
img = img.permute(2, 0, 1) # HWC -> CHW img = img.permute(2, 0, 1) # HWC -> CHW
# Resize to expected input size (256x256 for SigLIP2) BEFORE stacking # Resize to expected input size (224x224 for DINO v3) BEFORE stacking
if img.shape[-2:] != (256, 256): if img.shape[-2:] != (224, 224):
import torch.nn.functional as F import torch.nn.functional as F
img = F.interpolate( img = F.interpolate(
img.unsqueeze(0), size=(256, 256), mode="bilinear", align_corners=False img.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False
).squeeze(0) ).squeeze(0)
# Normalize to [0, 1] if needed # Normalize to [0, 1] if needed
@@ -527,12 +527,12 @@ class RLearnEvaluator:
if len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]: if len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]:
img = img.permute(2, 0, 1) img = img.permute(2, 0, 1)
# Resize to expected input size (256x256 for SigLIP2) # Resize to expected input size (224x224 for DINO v3)
if img.shape[-2:] != (256, 256): if img.shape[-2:] != (224, 224):
import torch.nn.functional as F import torch.nn.functional as F
img = F.interpolate( img = F.interpolate(
img.unsqueeze(0), size=(256, 256), mode="bilinear", align_corners=False img.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False
).squeeze(0) ).squeeze(0)
# Normalize to [0, 1] if needed # Normalize to [0, 1] if needed
+87 -118
View File
@@ -17,7 +17,7 @@
""" """
RLearN: Video-Language Conditioned Reward Model (ReWiND Implementation) RLearN: Video-Language Conditioned Reward Model (ReWiND Implementation)
This implementation follows the ReWiND paper approach: This implementation follows the ReWiND paper approach (arXiv:2505.10911v1):
- Automatically generates linear progress labels (0 to 1) for each episode - Automatically generates linear progress labels (0 to 1) for each episode
- No need for pre-annotated rewards in the dataset - No need for pre-annotated rewards in the dataset
- Applies video rewinding augmentation to create synthetic failure trajectories - Applies video rewinding augmentation to create synthetic failure trajectories
@@ -33,7 +33,7 @@ High-level Architecture
| per-frame encode | per-frame encode
v v
+------------------------------+ +------------------------------+
| Vision Encoder (frozen) | e.g. SigLIP2 vision tower | Vision Encoder (frozen) | e.g. DINOv2 (base)
+------------------------------+ +------------------------------+
|s |s
| pooled per-frame embeddings (BT, H_v) | pooled per-frame embeddings (BT, H_v)
@@ -46,7 +46,7 @@ High-level Architecture
| | | |
| v | v
| +------------------------------+ | +------------------------------+
| | Text Encoder (frozen) | e.g. SigLIP2 text tower | | Text Encoder (frozen) | e.g. sentence-transformers
| +------------------------------+ | +------------------------------+
| | | |
| | pooled text embedding (B, H_t) | | pooled text embedding (B, H_t)
@@ -67,10 +67,8 @@ High-level Architecture
Output Output
- reward_logits: (B, T', 1) with T' ≤ T (affected by stride and frame dropout) - reward_logits: (B, T', 1) with T' ≤ T (affected by stride and frame dropout)
Training
- Loss: composite loss with progress regression, spatial-aware InfoNCE, and ReWiND reversible ranking
Notes Notes
- Uses DINOv2 (base, ~ViT-B) for vision and sentence-transformers (all-MiniLM-L12-v2) for text encoding.
- Backbones (vision/text) are frozen by default; only projections, temporal module, and head are trainable. - Backbones (vision/text) are frozen by default; only projections, temporal module, and head are trainable.
- Stride/frame dropout applied during training can subsample timesteps. - Stride/frame dropout applied during training can subsample timesteps.
""" """
@@ -91,8 +89,8 @@ from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
class RLearNPolicy(PreTrainedPolicy): class RLearNPolicy(PreTrainedPolicy):
"""Video-language conditioned reward model. """Video-language conditioned reward model.
- Visual encoder: frozen SigLIP2 (via transformers AutoModel), returns per-frame embeddings. - Visual encoder: frozen DINOv2 (base), returns per-frame embeddings.
- Text encoder: frozen SigLIP2 text tower, returns a language embedding. - Text encoder: frozen sentence-transformers (all-MiniLM-L12-v2), returns a language embedding.
- Temporal module: causal transformer over time that cross-attends to language embedding. - Temporal module: causal transformer over time that cross-attends to language embedding.
- Output: per-timestep reward logits; trainable small head. - Output: per-timestep reward logits; trainable small head.
""" """
@@ -105,24 +103,20 @@ class RLearNPolicy(PreTrainedPolicy):
self.config = config self.config = config
self.episode_data_index = episode_data_index # Store episode boundaries for progress calculation self.episode_data_index = episode_data_index # Store episode boundaries for progress calculation
# Encoders # Encoders - ReWiND paper setup: DINOv2 for vision, sentence-transformers for text
from transformers import AutoModel, AutoProcessor from transformers import AutoImageProcessor, AutoModel
from sentence_transformers import SentenceTransformer
self.vision_text_model = AutoModel.from_pretrained(config.model_name, trust_remote_code=True)
self.processor = AutoProcessor.from_pretrained(config.model_name, trust_remote_code=True) # Load DINOv2 (base) vision encoder with its processor
self.vision_processor = AutoImageProcessor.from_pretrained(config.vision_model_name)
# Detect towers self.vision_encoder = AutoModel.from_pretrained(config.vision_model_name)
if hasattr(self.vision_text_model, "vision_model") and hasattr(self.vision_text_model, "text_model"):
self.vision_encoder = self.vision_text_model.vision_model # Load sentence-transformers text encoder
self.text_encoder = self.vision_text_model.text_model self.text_encoder = SentenceTransformer(config.text_model_name)
self.vision_hidden = getattr(self.vision_text_model.config, "vision_config", None).hidden_size
self.text_hidden = getattr(self.vision_text_model.config, "text_config", None).hidden_size # DINOv2-base has 768 hidden size, all-MiniLM-L12-v2 has 384
else: self.vision_hidden = 768 # DINOv2-base
# Fallback if AutoModel exposes pooled outputs directly (rare for SigLIP2) self.text_hidden = 384 # all-MiniLM-L12-v2
self.vision_encoder = self.vision_text_model
self.text_encoder = self.vision_text_model
self.vision_hidden = getattr(self.vision_text_model.config, "hidden_size", 768)
self.text_hidden = getattr(self.vision_text_model.config, "hidden_size", 768)
if config.freeze_backbones: if config.freeze_backbones:
for p in self.vision_encoder.parameters(): for p in self.vision_encoder.parameters():
@@ -208,33 +202,46 @@ class RLearNPolicy(PreTrainedPolicy):
frames = frames[:, idx] frames = frames[:, idx]
B, T_eff, C, H, W = frames.shape # NEW: effective length after stride B, T_eff, C, H, W = frames.shape # NEW: effective length after stride
# Encode language # Encode language using sentence-transformers
lang_emb = encode_language( lang_emb = encode_language(
batch.get(OBS_LANGUAGE, None), self.text_encoder, self.processor, batch_size=B batch.get(OBS_LANGUAGE, None), self.text_encoder, batch_size=B
) )
# Ensure embeddings are normal tensors on the correct device (not inference tensors)
lang_emb = lang_emb.detach().clone().to(self.text_proj.weight.device)
lang_emb = self.text_proj(lang_emb) # (B, D) lang_emb = self.text_proj(lang_emb) # (B, D)
# Use the HF processor to standardize size & normalization # Process frames with DINOv2
# Flatten (B, T_eff, C, H, W) -> (BT, C, H, W) # Flatten (B, T_eff, C, H, W) -> (BT, C, H, W)
BT = B * T_eff BT = B * T_eff
flat = frames.reshape(BT, C, H, W).detach().cpu() flat = frames.reshape(BT, C, H, W)
# Convert to uint8 HWC numpy (processor prefers PIL/np) # Convert to list of PIL images or numpy arrays for the processor
# If already in [0,1], scale to [0,255] # DINOv2 processor expects images in HWC format
if flat.dtype != torch.uint8: images_list = []
if flat.numel() > 0 and float(flat.max()) <= 1.0: for i in range(BT):
flat = flat * 255.0 img = flat[i] # (C, H, W)
flat = flat.clamp(0, 255).round().to(torch.uint8) # Convert to HWC format
img = img.permute(1, 2, 0) # (H, W, C)
# Convert to numpy if needed
if img.dtype == torch.uint8:
img = img.cpu().numpy()
else:
# Convert to uint8 range
img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
images_list.append(img)
# Process with DINOv2 processor
processed = self.vision_processor(images=images_list, return_tensors="pt")
pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device)
images = [flat[k].permute(1, 2, 0).numpy() for k in range(flat.size(0))] # Encode frames through DINOv2
vision_outputs = self.vision_encoder(pixel_values)
proc_out = self.processor(images=images, return_tensors="pt") # Extract CLS tokens for temporal modeling
pixel_values = proc_out["pixel_values"].to(next(self.vision_encoder.parameters()).device) # DINOv2 outputs last_hidden_state of shape (batch_size, sequence_length, hidden_size)
# The CLS token is the first token
# Encode frames through visual tower per frame
vision_outputs = self.vision_encoder(pixel_values=pixel_values)
# Extract CLS tokens for temporal modeling
if hasattr(vision_outputs, "last_hidden_state"): if hasattr(vision_outputs, "last_hidden_state"):
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision) cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision)
else: else:
@@ -302,38 +309,45 @@ class RLearNPolicy(PreTrainedPolicy):
idx = torch.tensor([0], device=frames.device) idx = torch.tensor([0], device=frames.device)
frames = frames[:, idx] frames = frames[:, idx]
# Encode language # Encode language using sentence-transformers
lang_emb = encode_language( lang_emb = encode_language(
batch.get(OBS_LANGUAGE, None), self.text_encoder, self.processor, batch_size=B batch.get(OBS_LANGUAGE, None), self.text_encoder, batch_size=B
) )
# Ensure embeddings are normal tensors on the correct device (not inference tensors)
lang_emb = lang_emb.detach().clone().to(self.text_proj.weight.device)
lang_emb = self.text_proj(lang_emb) # (B, D) lang_emb = self.text_proj(lang_emb) # (B, D)
# Encode frames through visual tower per frame # Encode frames through DINOv2 visual encoder
# Flatten time for batched encode # Flatten time for batched encode
BT = B * frames.shape[1] BT = B * frames.shape[1]
flat = frames.reshape(BT, C, H, W) flat = frames.reshape(BT, C, H, W)
# Use HF processor to properly resize and normalize images # Convert to list of PIL images or numpy arrays for the processor
# Convert to CPU for processing, then move back to device # DINOv2 processor expects images in HWC format
flat_cpu = flat.detach().cpu() images_list = []
for i in range(BT):
img = flat[i] # (C, H, W)
# Convert to HWC format
img = img.permute(1, 2, 0) # (H, W, C)
# Convert to numpy if needed
if img.dtype == torch.uint8:
img = img.cpu().numpy()
else:
# Convert to uint8 range
img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
images_list.append(img)
# Process with DINOv2 processor
processed = self.vision_processor(images=images_list, return_tensors="pt")
pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device)
# Convert to uint8 HWC numpy format expected by processor # Encode through DINOv2 model
if flat_cpu.dtype != torch.uint8: vision_outputs = self.vision_encoder(pixel_values)
if flat_cpu.numel() > 0 and float(flat_cpu.max()) <= 1.0:
flat_cpu = flat_cpu * 255.0
flat_cpu = flat_cpu.clamp(0, 255).round().to(torch.uint8)
# Convert to list of numpy arrays
images = [flat_cpu[k].permute(1, 2, 0).numpy() for k in range(flat_cpu.size(0))]
# Process with HF processor (resizes to 256x256 and normalizes)
proc_out = self.processor(images=images, return_tensors="pt")
pixel_values = proc_out["pixel_values"].to(next(self.vision_encoder.parameters()).device)
# Encode through vision model
vision_outputs = self.vision_encoder(pixel_values=pixel_values)
# Extract CLS token for temporal modeling # Extract CLS token for temporal modeling
# DINOv2 outputs last_hidden_state of shape (batch_size, sequence_length, hidden_size)
if hasattr(vision_outputs, "last_hidden_state"): if hasattr(vision_outputs, "last_hidden_state"):
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D) - CLS token cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D) - CLS token
else: else:
@@ -430,48 +444,6 @@ class RLearNPolicy(PreTrainedPolicy):
# Apply stride/dropout indexing to match the processed frames # Apply stride/dropout indexing to match the processed frames
target = target[:, idx] target = target[:, idx]
elif "index" in batch and hasattr(self, "episode_data_index"):
# Fallback: Use global index if available
global_indices = batch["index"] # Shape: (B,)
# For each index, find which episode it belongs to and its position
progress_values = []
for global_idx in global_indices:
# Find which episode this index belongs to
episode_starts = self.episode_data_index["from"]
episode_ends = self.episode_data_index["to"]
# Find the episode by checking which range the index falls into
episode_idx = None
frame_in_episode = None
for ep_idx in range(len(episode_starts)):
if episode_starts[ep_idx] <= global_idx < episode_ends[ep_idx]:
episode_idx = ep_idx
frame_in_episode = global_idx.item() - episode_starts[ep_idx].item()
break
if episode_idx is not None:
# Calculate position within episode
ep_start = episode_starts[episode_idx].item()
ep_end = episode_ends[episode_idx].item()
ep_length = ep_end - ep_start
# Progress from 0 to 1 within the episode
progress = frame_in_episode / max(1, ep_length - 1)
else:
# Fallback if we can't find the episode (shouldn't happen)
progress = 0.5
progress_values.append(progress)
# For temporal window, use simplified linear progress
# (proper calculation would need all frame indices in the window)
T_effective = len(idx)
target = torch.tensor(progress_values, device=values.device, dtype=values.dtype)
target = target.unsqueeze(1).expand(B, T_effective) # Simple expansion
else: else:
raise ValueError( raise ValueError(
"No episode information found in batch. Please ensure 'episode_index' and 'frame_index' keys are present." "No episode information found in batch. Please ensure 'episode_index' and 'frame_index' keys are present."
@@ -698,8 +670,9 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None
def encode_language( def encode_language(
language_input: Tensor | list | str | None, text_encoder, processor, batch_size: int language_input: Tensor | list | str | None, text_encoder, batch_size: int
) -> Tensor: ) -> Tensor:
"""Encode language using sentence-transformers (ReWiND paper setup)."""
# language_input can be: list[str] length B, or None # language_input can be: list[str] length B, or None
if language_input is None: if language_input is None:
texts = [""] * batch_size texts = [""] * batch_size
@@ -709,16 +682,12 @@ def encode_language(
# Single string for the batch # Single string for the batch
texts = [str(language_input)] * batch_size texts = [str(language_input)] * batch_size
inputs = processor(text=texts, padding=True, return_tensors="pt") # For sentence-transformers, we can directly encode
inputs = {k: v.to(next(text_encoder.parameters()).device) for k, v in inputs.items()} # Returns tensor of shape (batch_size, embedding_dim)
outputs = text_encoder(**inputs) device = next(iter(text_encoder.parameters())).device if hasattr(text_encoder, 'parameters') else 'cpu'
if hasattr(outputs, "pooler_output"): embeddings = text_encoder.encode(texts, convert_to_tensor=True, device=device)
emb = outputs.pooler_output
elif hasattr(outputs, "last_hidden_state"): return embeddings
emb = outputs.last_hidden_state[:, 0]
else:
raise RuntimeError("Unsupported text encoder output structure")
return emb
def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor, Tensor]: def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor, Tensor]:
@@ -60,9 +60,9 @@ def make_rlearn_processor(
), ),
ToBatchProcessor(), ToBatchProcessor(),
RLearnLanguageFromTaskProcessor(), RLearnLanguageFromTaskProcessor(),
# Use the same model name for tokenizer to keep vocab aligned with text tower # Use the text model name for tokenizer to keep vocab aligned with text tower
TokenizerProcessor( TokenizerProcessor(
tokenizer_name=config.model_name, tokenizer_name=config.text_model_name,
max_length=128, max_length=128,
padding="max_length", padding="max_length",
truncation=True, truncation=True,
+9 -57
View File
@@ -31,7 +31,9 @@ Little less relevant but still similar papers:
Input should be the current image or whole video and the task goal specified in text/language. Output is current reward. Input should be the current image or whole video and the task goal specified in text/language. Output is current reward.
Archiutecture: Archiutecture:
_ inputs: video o1:T (or current o1:t), language z; _ inputs: video o1:T (or current o1:t), language z;
_ google/siglip2-large-patch16-256: https://huggingface.co/google/siglip2-large-patch16-256 \* Temporal module: small causal transformer (“cross-modal sequential aggregator”), with first-frame positional embedding (to avoid position cheating), frame-dropout, and stride sampling; outputs per-timestep logits. _ DINO v3 ViT-B/16 (86M params): https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m for vision encoding
_ sentence-transformers/all-MiniLM-L12-v2: https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 for text encoding
\* Temporal module: small causal transformer ("cross-modal sequential aggregator"), with first-frame positional embedding (to avoid position cheating), frame-dropout, and stride sampling; outputs per-timestep logits.
Loss: See this chatgpt thread: https://chatgpt.com/s/t_68999a50a0b081919abc365cdd205e01 Loss: See this chatgpt thread: https://chatgpt.com/s/t_68999a50a0b081919abc365cdd205e01
@@ -59,59 +61,6 @@ _ GTEA+ Gaze: https://cbs.ic.gatech.edu/fpv/
_ YouCook2 dataset _ YouCook2 dataset
_ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/ _ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
### Implemented Loss (Spatial-Aware Composite Loss)
Our implementation uses a **composite loss with spatial awareness** to address the limitations of standard contrastive learning (e.g., CLIP's inability to distinguish "move left" vs "move right"). The loss has three components:
##### 1) Progress Regression Loss (L_prog)
Predicts normalized progress values for each timestep:
$$
L_{\text{prog}} = \text{MSE}(\sigma(z(V_t)), y_t)
$$
where $z(·)$ is z-score normalization, $\sigma$ is sigmoid, and $y_t \in [0,1]$ is the progress label.
**Purpose:** Grounds the model in actual task progress, not just visual-language similarity.
##### 2) Spatial-Aware InfoNCE Loss (L_spatial_nce)
Instead of using pooled features, we:
- Extract spatial patch features from SigLIP2's last_hidden_state (e.g., 16×16 patches)
- Use cross-attention where language queries attend to relevant spatial regions
- Compute contrastive loss on the attended spatial features
$$
L_{\text{spatial-nce}} = -\log \frac{\exp(s_{ii}/\tau)}{\sum_j \exp(s_{ij}/\tau)}
$$
where $s_{ij}$ is the similarity between spatially-attended features from trajectory $i$ and language $j$.
**Purpose:** Preserves spatial information that pooling discards, enabling distinction of spatial relationships.
##### 3) ReWiND Reversible Ranking Loss (L_rewind)
Based on ReWiND's key insight: learn from both forward AND reversed trajectories.
The loss has two components:
- **Forward ranking**: Sample (far, near) pairs where near is later in time, enforce $V_{\text{near}} > V_{\text{far}}$
- **Reverse ranking**: Reverse the trajectory and invert progress labels, then apply same ranking
$$
L_{\text{rewind}} = L_{\text{forward}} + L_{\text{reverse}}
$$
where both use: $\text{softplus}(m - (V_{\text{near}} - V_{\text{far}}))$
**Purpose:** By training on reversed trajectories with inverted progress, the model learns to distinguish progress from undoing progress. This is ReWiND's core contribution - understanding that tasks can be reversible.
##### Total Loss:
$$
L = \lambda_{\text{prog}} L_{\text{prog}} + \lambda_{\text{spatial-nce}} L_{\text{spatial-nce}} + \lambda_{\text{rewind}} L_{\text{rewind}}
$$
Default weights: $\lambda_{\text{prog}}=1.0$, $\lambda_{\text{spatial-nce}}=0.5$, $\lambda_{\text{rewind}}=0.4$
### TODOs: ### TODOs:
@@ -126,17 +75,20 @@ Default weights: $\lambda_{\text{prog}}=1.0$, $\lambda_{\text{spatial-nce}}=0.5$
- Implement on-the-fly progress label generation (no need for pre-annotated rewards) [x] - Implement on-the-fly progress label generation (no need for pre-annotated rewards) [x]
- Try different losses - Try different losses
- Only rewind loss [x] - Only rewind loss [x]
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
- check code is same as rewind repo code (architecture and trainign details) []
- Test only rewind loss (evaluate) [] - Test only rewind loss (evaluate) []
- Check rewind implementatyion by hand [] - Check rewind implementation by hand/cleanup []
- Only vlc loss then eval [] - Only vlc loss then eval []
- Vlc + rewind loss then eval [] - Vlc + Rewind loss then eval []
- Cleanup code [] - Cleanup code []
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent - Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent
- Then on 10 percent - Then on 10 percent
- Try DINO v3 as encoder Base 86 M: https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m with HuggingFaceTB/SmolLM2-135M-Instruct ? [] - Ablation dino v2 vs dino v3 base 86 M
- Add more artificial text to dataset generated by vlm (google gemini) [] - Add more artificial text to dataset generated by vlm (google gemini) []
- See google gemini vlm caption [] https://gemini.google.com/app/7e332ffaf32580f2 - See google gemini vlm caption [] https://gemini.google.com/app/7e332ffaf32580f2
- Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446, https://arxiv.org/pdf/2412.04453 - Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446, https://arxiv.org/pdf/2412.04453
- How can we improve spatial aware learning? co generating captions for each frame with language decoder? - How can we improve spatial aware learning? co generating captions for each frame with language decoder?
- Extend evaluation [] - Extend evaluation []
- Add other datasets mentioned above [] - Add other datasets mentioned above []
- Ablation for size vision encoder, language encoder, temporal head
+25 -18
View File
@@ -25,10 +25,12 @@ from tests.utils import require_package
@require_package("transformers") @require_package("transformers")
@require_package("sentence_transformers")
def test_rlearn_instantiation_and_forward_tensor_batch(): def test_rlearn_instantiation_and_forward_tensor_batch():
"""Instantiate RLearN and run a forward pass with a (B, T, C, H, W) tensor input using a real model and real text.""" """Instantiate RLearN and run a forward pass with a (B, T, C, H, W) tensor input using a real model and real text."""
cfg = RLearNConfig( cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256", vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
push_to_hub=False, push_to_hub=False,
freeze_backbones=True, freeze_backbones=True,
) )
@@ -54,10 +56,12 @@ def test_rlearn_instantiation_and_forward_tensor_batch():
@require_package("transformers") @require_package("transformers")
@require_package("sentence_transformers")
def test_rlearn_instantiation_and_forward_list_batch_with_language(): def test_rlearn_instantiation_and_forward_list_batch_with_language():
"""Instantiate RLearN and run a forward pass with a list-of-frames input and real language using a real model.""" """Instantiate RLearN and run a forward pass with a list-of-frames input and real language using a real model."""
cfg = RLearNConfig( cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256", vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
push_to_hub=False, push_to_hub=False,
freeze_backbones=True, freeze_backbones=True,
) )
@@ -84,18 +88,17 @@ def test_rlearn_instantiation_and_forward_list_batch_with_language():
@require_package("transformers") @require_package("transformers")
@require_package("sentence_transformers")
def test_rlearn_composite_loss_shapes_and_terms(): def test_rlearn_composite_loss_shapes_and_terms():
"""Smoke test composite loss: checks presence of terms and valid gradients.""" """Smoke test composite loss: checks presence of terms and valid gradients."""
cfg = RLearNConfig( cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256", vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
push_to_hub=False, push_to_hub=False,
freeze_backbones=True, freeze_backbones=True,
loss_type="composite", use_video_rewind=True,
lambda_prog=1.0, rewind_prob=0.5,
lambda_spatial_nce=0.5, use_mismatch_loss=True,
lambda_rewind=0.4,
num_ranking_pairs=32, # Fewer pairs for testing
last_k_for_nce=2,
) )
cfg.input_features = { cfg.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
@@ -117,17 +120,17 @@ def test_rlearn_composite_loss_shapes_and_terms():
loss, logs = policy.forward(batch) loss, logs = policy.forward(batch)
assert isinstance(loss, torch.Tensor) and torch.isfinite(loss) assert isinstance(loss, torch.Tensor) and torch.isfinite(loss)
# Expect composite terms present with spatial awareness and ReWiND # Expect ReWiND loss terms (progress and mismatch)
assert "loss_prog" in logs assert "loss_progress" in logs
assert "loss_spatial_nce" in logs assert "loss_mismatch" in logs
assert "loss_rewind_forward" in logs
assert "loss_rewind_reverse" in logs
@require_package("transformers") @require_package("transformers")
@require_package("sentence_transformers")
def test_rlearn_preprocessor_tokenizes_and_copies_task(): def test_rlearn_preprocessor_tokenizes_and_copies_task():
cfg = RLearNConfig( cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256", vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
device="cpu", device="cpu",
push_to_hub=False, push_to_hub=False,
) )
@@ -161,9 +164,11 @@ def test_rlearn_preprocessor_tokenizes_and_copies_task():
@require_package("transformers") @require_package("transformers")
@require_package("sentence_transformers")
def test_rlearn_preprocessor_string_task_and_to_batch(): def test_rlearn_preprocessor_string_task_and_to_batch():
cfg = RLearNConfig( cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256", vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
device="cpu", device="cpu",
push_to_hub=False, push_to_hub=False,
) )
@@ -194,14 +199,16 @@ def test_rlearn_preprocessor_string_task_and_to_batch():
@require_package("transformers") @require_package("transformers")
@require_package("sentence_transformers")
def test_rlearn_pipeline_end_to_end_forward(): def test_rlearn_pipeline_end_to_end_forward():
"""End-to-end: preprocessor + model forward using RLearN pipeline on synthetic data.""" """End-to-end: preprocessor + model forward using RLearN pipeline on synthetic data."""
cfg = RLearNConfig( cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256", vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
device="cpu", device="cpu",
push_to_hub=False, push_to_hub=False,
freeze_backbones=True, freeze_backbones=True,
loss_type="composite", use_video_rewind=True,
) )
cfg.input_features = { cfg.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),