mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
dino v2
This commit is contained in:
@@ -33,12 +33,14 @@ class RLearNConfig(PreTrainedConfig):
|
||||
- Per-timestep reward logits or a single-step reward logit.
|
||||
|
||||
Notes:
|
||||
- This is the initial architecture. It uses frozen vision/text encoders
|
||||
(e.g. SigLIP2) and trains a lightweight temporal aggregator + head.
|
||||
- This follows the ReWiND paper architecture. It uses frozen vision/text encoders
|
||||
(DINO v3 for vision, sentence-transformers for language) and trains a
|
||||
lightweight temporal aggregator + head.
|
||||
"""
|
||||
|
||||
# Encoders
|
||||
model_name: str = "google/siglip2-base-patch16-256"
|
||||
# Encoders - Using DINOv2 (base) for vision and sentence-transformers for text (ReWiND paper)
|
||||
vision_model_name: str = "facebook/dinov2-base"
|
||||
text_model_name: str = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
freeze_backbones: bool = True
|
||||
|
||||
# Temporal aggregator
|
||||
|
||||
@@ -323,8 +323,8 @@ class RLearnEvaluator:
|
||||
|
||||
T, C, H, W = frames.shape
|
||||
|
||||
# Expected input size for SigLIP2 is typically 256x256
|
||||
target_size = 256
|
||||
# Expected input size for DINO v3 is 224x224
|
||||
target_size = 224
|
||||
|
||||
# Resize frames if needed
|
||||
if H != target_size or W != target_size:
|
||||
@@ -402,12 +402,12 @@ class RLearnEvaluator:
|
||||
if len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]:
|
||||
img = img.permute(2, 0, 1) # HWC -> CHW
|
||||
|
||||
# Resize to expected input size (256x256 for SigLIP2) BEFORE stacking
|
||||
if img.shape[-2:] != (256, 256):
|
||||
# Resize to expected input size (224x224 for DINO v3) BEFORE stacking
|
||||
if img.shape[-2:] != (224, 224):
|
||||
import torch.nn.functional as F
|
||||
|
||||
img = F.interpolate(
|
||||
img.unsqueeze(0), size=(256, 256), mode="bilinear", align_corners=False
|
||||
img.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False
|
||||
).squeeze(0)
|
||||
|
||||
# Normalize to [0, 1] if needed
|
||||
@@ -527,12 +527,12 @@ class RLearnEvaluator:
|
||||
if len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]:
|
||||
img = img.permute(2, 0, 1)
|
||||
|
||||
# Resize to expected input size (256x256 for SigLIP2)
|
||||
if img.shape[-2:] != (256, 256):
|
||||
# Resize to expected input size (224x224 for DINO v3)
|
||||
if img.shape[-2:] != (224, 224):
|
||||
import torch.nn.functional as F
|
||||
|
||||
img = F.interpolate(
|
||||
img.unsqueeze(0), size=(256, 256), mode="bilinear", align_corners=False
|
||||
img.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False
|
||||
).squeeze(0)
|
||||
|
||||
# Normalize to [0, 1] if needed
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"""
|
||||
RLearN: Video-Language Conditioned Reward Model (ReWiND Implementation)
|
||||
|
||||
This implementation follows the ReWiND paper approach:
|
||||
This implementation follows the ReWiND paper approach (arXiv:2505.10911v1):
|
||||
- Automatically generates linear progress labels (0 to 1) for each episode
|
||||
- No need for pre-annotated rewards in the dataset
|
||||
- Applies video rewinding augmentation to create synthetic failure trajectories
|
||||
@@ -33,7 +33,7 @@ High-level Architecture
|
||||
| per-frame encode
|
||||
v
|
||||
+------------------------------+
|
||||
| Vision Encoder (frozen) | e.g. SigLIP2 vision tower
|
||||
| Vision Encoder (frozen) | e.g. DINOv2 (base)
|
||||
+------------------------------+
|
||||
|s
|
||||
| pooled per-frame embeddings (BT, H_v)
|
||||
@@ -46,7 +46,7 @@ High-level Architecture
|
||||
| |
|
||||
| v
|
||||
| +------------------------------+
|
||||
| | Text Encoder (frozen) | e.g. SigLIP2 text tower
|
||||
| | Text Encoder (frozen) | e.g. sentence-transformers
|
||||
| +------------------------------+
|
||||
| |
|
||||
| | pooled text embedding (B, H_t)
|
||||
@@ -67,10 +67,8 @@ High-level Architecture
|
||||
Output
|
||||
- reward_logits: (B, T', 1) with T' ≤ T (affected by stride and frame dropout)
|
||||
|
||||
Training
|
||||
- Loss: composite loss with progress regression, spatial-aware InfoNCE, and ReWiND reversible ranking
|
||||
|
||||
Notes
|
||||
- Uses DINOv2 (base, ~ViT-B) for vision and sentence-transformers (all-MiniLM-L12-v2) for text encoding.
|
||||
- Backbones (vision/text) are frozen by default; only projections, temporal module, and head are trainable.
|
||||
- Stride/frame dropout applied during training can subsample timesteps.
|
||||
"""
|
||||
@@ -91,8 +89,8 @@ from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
|
||||
class RLearNPolicy(PreTrainedPolicy):
|
||||
"""Video-language conditioned reward model.
|
||||
|
||||
- Visual encoder: frozen SigLIP2 (via transformers AutoModel), returns per-frame embeddings.
|
||||
- Text encoder: frozen SigLIP2 text tower, returns a language embedding.
|
||||
- Visual encoder: frozen DINOv2 (base), returns per-frame embeddings.
|
||||
- Text encoder: frozen sentence-transformers (all-MiniLM-L12-v2), returns a language embedding.
|
||||
- Temporal module: causal transformer over time that cross-attends to language embedding.
|
||||
- Output: per-timestep reward logits; trainable small head.
|
||||
"""
|
||||
@@ -105,24 +103,20 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
self.config = config
|
||||
self.episode_data_index = episode_data_index # Store episode boundaries for progress calculation
|
||||
|
||||
# Encoders
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
# Encoders - ReWiND paper setup: DINOv2 for vision, sentence-transformers for text
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
self.vision_text_model = AutoModel.from_pretrained(config.model_name, trust_remote_code=True)
|
||||
self.processor = AutoProcessor.from_pretrained(config.model_name, trust_remote_code=True)
|
||||
# Load DINOv2 (base) vision encoder with its processor
|
||||
self.vision_processor = AutoImageProcessor.from_pretrained(config.vision_model_name)
|
||||
self.vision_encoder = AutoModel.from_pretrained(config.vision_model_name)
|
||||
|
||||
# Detect towers
|
||||
if hasattr(self.vision_text_model, "vision_model") and hasattr(self.vision_text_model, "text_model"):
|
||||
self.vision_encoder = self.vision_text_model.vision_model
|
||||
self.text_encoder = self.vision_text_model.text_model
|
||||
self.vision_hidden = getattr(self.vision_text_model.config, "vision_config", None).hidden_size
|
||||
self.text_hidden = getattr(self.vision_text_model.config, "text_config", None).hidden_size
|
||||
else:
|
||||
# Fallback if AutoModel exposes pooled outputs directly (rare for SigLIP2)
|
||||
self.vision_encoder = self.vision_text_model
|
||||
self.text_encoder = self.vision_text_model
|
||||
self.vision_hidden = getattr(self.vision_text_model.config, "hidden_size", 768)
|
||||
self.text_hidden = getattr(self.vision_text_model.config, "hidden_size", 768)
|
||||
# Load sentence-transformers text encoder
|
||||
self.text_encoder = SentenceTransformer(config.text_model_name)
|
||||
|
||||
# DINOv2-base has 768 hidden size, all-MiniLM-L12-v2 has 384
|
||||
self.vision_hidden = 768 # DINOv2-base
|
||||
self.text_hidden = 384 # all-MiniLM-L12-v2
|
||||
|
||||
if config.freeze_backbones:
|
||||
for p in self.vision_encoder.parameters():
|
||||
@@ -208,33 +202,46 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
frames = frames[:, idx]
|
||||
B, T_eff, C, H, W = frames.shape # NEW: effective length after stride
|
||||
|
||||
# Encode language
|
||||
# Encode language using sentence-transformers
|
||||
lang_emb = encode_language(
|
||||
batch.get(OBS_LANGUAGE, None), self.text_encoder, self.processor, batch_size=B
|
||||
batch.get(OBS_LANGUAGE, None), self.text_encoder, batch_size=B
|
||||
)
|
||||
# Ensure embeddings are normal tensors on the correct device (not inference tensors)
|
||||
lang_emb = lang_emb.detach().clone().to(self.text_proj.weight.device)
|
||||
lang_emb = self.text_proj(lang_emb) # (B, D)
|
||||
|
||||
# Use the HF processor to standardize size & normalization
|
||||
# Process frames with DINOv2
|
||||
# Flatten (B, T_eff, C, H, W) -> (BT, C, H, W)
|
||||
BT = B * T_eff
|
||||
flat = frames.reshape(BT, C, H, W).detach().cpu()
|
||||
flat = frames.reshape(BT, C, H, W)
|
||||
|
||||
# Convert to uint8 HWC numpy (processor prefers PIL/np)
|
||||
# If already in [0,1], scale to [0,255]
|
||||
if flat.dtype != torch.uint8:
|
||||
if flat.numel() > 0 and float(flat.max()) <= 1.0:
|
||||
flat = flat * 255.0
|
||||
flat = flat.clamp(0, 255).round().to(torch.uint8)
|
||||
# Convert to list of PIL images or numpy arrays for the processor
|
||||
# DINOv2 processor expects images in HWC format
|
||||
images_list = []
|
||||
for i in range(BT):
|
||||
img = flat[i] # (C, H, W)
|
||||
# Convert to HWC format
|
||||
img = img.permute(1, 2, 0) # (H, W, C)
|
||||
|
||||
images = [flat[k].permute(1, 2, 0).numpy() for k in range(flat.size(0))]
|
||||
# Convert to numpy if needed
|
||||
if img.dtype == torch.uint8:
|
||||
img = img.cpu().numpy()
|
||||
else:
|
||||
# Convert to uint8 range
|
||||
img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
|
||||
|
||||
proc_out = self.processor(images=images, return_tensors="pt")
|
||||
pixel_values = proc_out["pixel_values"].to(next(self.vision_encoder.parameters()).device)
|
||||
images_list.append(img)
|
||||
|
||||
# Encode frames through visual tower per frame
|
||||
vision_outputs = self.vision_encoder(pixel_values=pixel_values)
|
||||
# Process with DINOv2 processor
|
||||
processed = self.vision_processor(images=images_list, return_tensors="pt")
|
||||
pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device)
|
||||
|
||||
# Encode frames through DINOv2
|
||||
vision_outputs = self.vision_encoder(pixel_values)
|
||||
|
||||
# Extract CLS tokens for temporal modeling
|
||||
# DINOv2 outputs last_hidden_state of shape (batch_size, sequence_length, hidden_size)
|
||||
# The CLS token is the first token
|
||||
if hasattr(vision_outputs, "last_hidden_state"):
|
||||
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision)
|
||||
else:
|
||||
@@ -302,38 +309,45 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
idx = torch.tensor([0], device=frames.device)
|
||||
frames = frames[:, idx]
|
||||
|
||||
# Encode language
|
||||
# Encode language using sentence-transformers
|
||||
lang_emb = encode_language(
|
||||
batch.get(OBS_LANGUAGE, None), self.text_encoder, self.processor, batch_size=B
|
||||
batch.get(OBS_LANGUAGE, None), self.text_encoder, batch_size=B
|
||||
)
|
||||
# Ensure embeddings are normal tensors on the correct device (not inference tensors)
|
||||
lang_emb = lang_emb.detach().clone().to(self.text_proj.weight.device)
|
||||
lang_emb = self.text_proj(lang_emb) # (B, D)
|
||||
|
||||
# Encode frames through visual tower per frame
|
||||
# Encode frames through DINOv2 visual encoder
|
||||
# Flatten time for batched encode
|
||||
BT = B * frames.shape[1]
|
||||
flat = frames.reshape(BT, C, H, W)
|
||||
|
||||
# Use HF processor to properly resize and normalize images
|
||||
# Convert to CPU for processing, then move back to device
|
||||
flat_cpu = flat.detach().cpu()
|
||||
# Convert to list of PIL images or numpy arrays for the processor
|
||||
# DINOv2 processor expects images in HWC format
|
||||
images_list = []
|
||||
for i in range(BT):
|
||||
img = flat[i] # (C, H, W)
|
||||
# Convert to HWC format
|
||||
img = img.permute(1, 2, 0) # (H, W, C)
|
||||
|
||||
# Convert to uint8 HWC numpy format expected by processor
|
||||
if flat_cpu.dtype != torch.uint8:
|
||||
if flat_cpu.numel() > 0 and float(flat_cpu.max()) <= 1.0:
|
||||
flat_cpu = flat_cpu * 255.0
|
||||
flat_cpu = flat_cpu.clamp(0, 255).round().to(torch.uint8)
|
||||
# Convert to numpy if needed
|
||||
if img.dtype == torch.uint8:
|
||||
img = img.cpu().numpy()
|
||||
else:
|
||||
# Convert to uint8 range
|
||||
img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
|
||||
|
||||
# Convert to list of numpy arrays
|
||||
images = [flat_cpu[k].permute(1, 2, 0).numpy() for k in range(flat_cpu.size(0))]
|
||||
images_list.append(img)
|
||||
|
||||
# Process with HF processor (resizes to 256x256 and normalizes)
|
||||
proc_out = self.processor(images=images, return_tensors="pt")
|
||||
pixel_values = proc_out["pixel_values"].to(next(self.vision_encoder.parameters()).device)
|
||||
# Process with DINOv2 processor
|
||||
processed = self.vision_processor(images=images_list, return_tensors="pt")
|
||||
pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device)
|
||||
|
||||
# Encode through vision model
|
||||
vision_outputs = self.vision_encoder(pixel_values=pixel_values)
|
||||
# Encode through DINOv2 model
|
||||
vision_outputs = self.vision_encoder(pixel_values)
|
||||
|
||||
# Extract CLS token for temporal modeling
|
||||
# DINOv2 outputs last_hidden_state of shape (batch_size, sequence_length, hidden_size)
|
||||
if hasattr(vision_outputs, "last_hidden_state"):
|
||||
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D) - CLS token
|
||||
else:
|
||||
@@ -430,48 +444,6 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Apply stride/dropout indexing to match the processed frames
|
||||
target = target[:, idx]
|
||||
|
||||
elif "index" in batch and hasattr(self, "episode_data_index"):
|
||||
# Fallback: Use global index if available
|
||||
global_indices = batch["index"] # Shape: (B,)
|
||||
|
||||
# For each index, find which episode it belongs to and its position
|
||||
progress_values = []
|
||||
|
||||
for global_idx in global_indices:
|
||||
# Find which episode this index belongs to
|
||||
episode_starts = self.episode_data_index["from"]
|
||||
episode_ends = self.episode_data_index["to"]
|
||||
|
||||
# Find the episode by checking which range the index falls into
|
||||
episode_idx = None
|
||||
frame_in_episode = None
|
||||
for ep_idx in range(len(episode_starts)):
|
||||
if episode_starts[ep_idx] <= global_idx < episode_ends[ep_idx]:
|
||||
episode_idx = ep_idx
|
||||
frame_in_episode = global_idx.item() - episode_starts[ep_idx].item()
|
||||
break
|
||||
|
||||
if episode_idx is not None:
|
||||
# Calculate position within episode
|
||||
ep_start = episode_starts[episode_idx].item()
|
||||
ep_end = episode_ends[episode_idx].item()
|
||||
ep_length = ep_end - ep_start
|
||||
|
||||
# Progress from 0 to 1 within the episode
|
||||
progress = frame_in_episode / max(1, ep_length - 1)
|
||||
else:
|
||||
# Fallback if we can't find the episode (shouldn't happen)
|
||||
progress = 0.5
|
||||
|
||||
progress_values.append(progress)
|
||||
|
||||
# For temporal window, use simplified linear progress
|
||||
# (proper calculation would need all frame indices in the window)
|
||||
T_effective = len(idx)
|
||||
target = torch.tensor(progress_values, device=values.device, dtype=values.dtype)
|
||||
target = target.unsqueeze(1).expand(B, T_effective) # Simple expansion
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"No episode information found in batch. Please ensure 'episode_index' and 'frame_index' keys are present."
|
||||
@@ -698,8 +670,9 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None
|
||||
|
||||
|
||||
def encode_language(
|
||||
language_input: Tensor | list | str | None, text_encoder, processor, batch_size: int
|
||||
language_input: Tensor | list | str | None, text_encoder, batch_size: int
|
||||
) -> Tensor:
|
||||
"""Encode language using sentence-transformers (ReWiND paper setup)."""
|
||||
# language_input can be: list[str] length B, or None
|
||||
if language_input is None:
|
||||
texts = [""] * batch_size
|
||||
@@ -709,16 +682,12 @@ def encode_language(
|
||||
# Single string for the batch
|
||||
texts = [str(language_input)] * batch_size
|
||||
|
||||
inputs = processor(text=texts, padding=True, return_tensors="pt")
|
||||
inputs = {k: v.to(next(text_encoder.parameters()).device) for k, v in inputs.items()}
|
||||
outputs = text_encoder(**inputs)
|
||||
if hasattr(outputs, "pooler_output"):
|
||||
emb = outputs.pooler_output
|
||||
elif hasattr(outputs, "last_hidden_state"):
|
||||
emb = outputs.last_hidden_state[:, 0]
|
||||
else:
|
||||
raise RuntimeError("Unsupported text encoder output structure")
|
||||
return emb
|
||||
# For sentence-transformers, we can directly encode
|
||||
# Returns tensor of shape (batch_size, embedding_dim)
|
||||
device = next(iter(text_encoder.parameters())).device if hasattr(text_encoder, 'parameters') else 'cpu'
|
||||
embeddings = text_encoder.encode(texts, convert_to_tensor=True, device=device)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor, Tensor]:
|
||||
|
||||
@@ -60,9 +60,9 @@ def make_rlearn_processor(
|
||||
),
|
||||
ToBatchProcessor(),
|
||||
RLearnLanguageFromTaskProcessor(),
|
||||
# Use the same model name for tokenizer to keep vocab aligned with text tower
|
||||
# Use the text model name for tokenizer to keep vocab aligned with text tower
|
||||
TokenizerProcessor(
|
||||
tokenizer_name=config.model_name,
|
||||
tokenizer_name=config.text_model_name,
|
||||
max_length=128,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
|
||||
@@ -31,7 +31,9 @@ Little less relevant but still similar papers:
|
||||
Input should be the current image or whole video and the task goal specified in text/language. Output is current reward.
|
||||
Archiutecture:
|
||||
_ inputs: video o1:T (or current o1:t), language z;
|
||||
_ google/siglip2-large-patch16-256: https://huggingface.co/google/siglip2-large-patch16-256 \* Temporal module: small causal transformer (“cross-modal sequential aggregator”), with first-frame positional embedding (to avoid position cheating), frame-dropout, and stride sampling; outputs per-timestep logits.
|
||||
_ DINO v3 ViT-B/16 (86M params): https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m for vision encoding
|
||||
_ sentence-transformers/all-MiniLM-L12-v2: https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 for text encoding
|
||||
\* Temporal module: small causal transformer ("cross-modal sequential aggregator"), with first-frame positional embedding (to avoid position cheating), frame-dropout, and stride sampling; outputs per-timestep logits.
|
||||
|
||||
Loss: See this chatgpt thread: https://chatgpt.com/s/t_68999a50a0b081919abc365cdd205e01
|
||||
|
||||
@@ -59,59 +61,6 @@ _ GTEA+ Gaze: https://cbs.ic.gatech.edu/fpv/
|
||||
_ YouCook2 dataset
|
||||
_ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
|
||||
|
||||
### Implemented Loss (Spatial-Aware Composite Loss)
|
||||
|
||||
Our implementation uses a **composite loss with spatial awareness** to address the limitations of standard contrastive learning (e.g., CLIP's inability to distinguish "move left" vs "move right"). The loss has three components:
|
||||
|
||||
##### 1) Progress Regression Loss (L_prog)
|
||||
|
||||
Predicts normalized progress values for each timestep:
|
||||
|
||||
$$
|
||||
L_{\text{prog}} = \text{MSE}(\sigma(z(V_t)), y_t)
|
||||
$$
|
||||
|
||||
where $z(·)$ is z-score normalization, $\sigma$ is sigmoid, and $y_t \in [0,1]$ is the progress label.
|
||||
**Purpose:** Grounds the model in actual task progress, not just visual-language similarity.
|
||||
|
||||
##### 2) Spatial-Aware InfoNCE Loss (L_spatial_nce)
|
||||
|
||||
Instead of using pooled features, we:
|
||||
|
||||
- Extract spatial patch features from SigLIP2's last_hidden_state (e.g., 16×16 patches)
|
||||
- Use cross-attention where language queries attend to relevant spatial regions
|
||||
- Compute contrastive loss on the attended spatial features
|
||||
|
||||
$$
|
||||
L_{\text{spatial-nce}} = -\log \frac{\exp(s_{ii}/\tau)}{\sum_j \exp(s_{ij}/\tau)}
|
||||
$$
|
||||
|
||||
where $s_{ij}$ is the similarity between spatially-attended features from trajectory $i$ and language $j$.
|
||||
**Purpose:** Preserves spatial information that pooling discards, enabling distinction of spatial relationships.
|
||||
|
||||
##### 3) ReWiND Reversible Ranking Loss (L_rewind)
|
||||
|
||||
Based on ReWiND's key insight: learn from both forward AND reversed trajectories.
|
||||
The loss has two components:
|
||||
|
||||
- **Forward ranking**: Sample (far, near) pairs where near is later in time, enforce $V_{\text{near}} > V_{\text{far}}$
|
||||
- **Reverse ranking**: Reverse the trajectory and invert progress labels, then apply same ranking
|
||||
|
||||
$$
|
||||
L_{\text{rewind}} = L_{\text{forward}} + L_{\text{reverse}}
|
||||
$$
|
||||
|
||||
where both use: $\text{softplus}(m - (V_{\text{near}} - V_{\text{far}}))$
|
||||
|
||||
**Purpose:** By training on reversed trajectories with inverted progress, the model learns to distinguish progress from undoing progress. This is ReWiND's core contribution - understanding that tasks can be reversible.
|
||||
|
||||
##### Total Loss:
|
||||
|
||||
$$
|
||||
L = \lambda_{\text{prog}} L_{\text{prog}} + \lambda_{\text{spatial-nce}} L_{\text{spatial-nce}} + \lambda_{\text{rewind}} L_{\text{rewind}}
|
||||
$$
|
||||
|
||||
Default weights: $\lambda_{\text{prog}}=1.0$, $\lambda_{\text{spatial-nce}}=0.5$, $\lambda_{\text{rewind}}=0.4$
|
||||
|
||||
### TODOs:
|
||||
|
||||
@@ -126,17 +75,20 @@ Default weights: $\lambda_{\text{prog}}=1.0$, $\lambda_{\text{spatial-nce}}=0.5$
|
||||
- Implement on-the-fly progress label generation (no need for pre-annotated rewards) [x]
|
||||
- Try different losses
|
||||
- Only rewind loss [x]
|
||||
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
|
||||
- check code is same as rewind repo code (architecture and trainign details) []
|
||||
- Test only rewind loss (evaluate) []
|
||||
- Check rewind implementatyion by hand []
|
||||
- Check rewind implementation by hand/cleanup []
|
||||
- Only vlc loss then eval []
|
||||
- Vlc + rewind loss then eval []
|
||||
- Vlc + Rewind loss then eval []
|
||||
- Cleanup code []
|
||||
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent
|
||||
- Then on 10 percent
|
||||
- Try DINO v3 as encoder Base 86 M: https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m with HuggingFaceTB/SmolLM2-135M-Instruct ? []
|
||||
- Ablation dino v2 vs dino v3 base 86 M
|
||||
- Add more artificial text to dataset generated by vlm (google gemini) []
|
||||
- See google gemini vlm caption [] https://gemini.google.com/app/7e332ffaf32580f2
|
||||
- Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446, https://arxiv.org/pdf/2412.04453
|
||||
- How can we improve spatial aware learning? co generating captions for each frame with language decoder?
|
||||
- Extend evaluation []
|
||||
- Add other datasets mentioned above []
|
||||
- Ablation for size vision encoder, language encoder, temporal head
|
||||
|
||||
@@ -25,10 +25,12 @@ from tests.utils import require_package
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_package("sentence_transformers")
|
||||
def test_rlearn_instantiation_and_forward_tensor_batch():
|
||||
"""Instantiate RLearN and run a forward pass with a (B, T, C, H, W) tensor input using a real model and real text."""
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||
push_to_hub=False,
|
||||
freeze_backbones=True,
|
||||
)
|
||||
@@ -54,10 +56,12 @@ def test_rlearn_instantiation_and_forward_tensor_batch():
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_package("sentence_transformers")
|
||||
def test_rlearn_instantiation_and_forward_list_batch_with_language():
|
||||
"""Instantiate RLearN and run a forward pass with a list-of-frames input and real language using a real model."""
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||
push_to_hub=False,
|
||||
freeze_backbones=True,
|
||||
)
|
||||
@@ -84,18 +88,17 @@ def test_rlearn_instantiation_and_forward_list_batch_with_language():
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_package("sentence_transformers")
|
||||
def test_rlearn_composite_loss_shapes_and_terms():
|
||||
"""Smoke test composite loss: checks presence of terms and valid gradients."""
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||
push_to_hub=False,
|
||||
freeze_backbones=True,
|
||||
loss_type="composite",
|
||||
lambda_prog=1.0,
|
||||
lambda_spatial_nce=0.5,
|
||||
lambda_rewind=0.4,
|
||||
num_ranking_pairs=32, # Fewer pairs for testing
|
||||
last_k_for_nce=2,
|
||||
use_video_rewind=True,
|
||||
rewind_prob=0.5,
|
||||
use_mismatch_loss=True,
|
||||
)
|
||||
cfg.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
@@ -117,17 +120,17 @@ def test_rlearn_composite_loss_shapes_and_terms():
|
||||
|
||||
loss, logs = policy.forward(batch)
|
||||
assert isinstance(loss, torch.Tensor) and torch.isfinite(loss)
|
||||
# Expect composite terms present with spatial awareness and ReWiND
|
||||
assert "loss_prog" in logs
|
||||
assert "loss_spatial_nce" in logs
|
||||
assert "loss_rewind_forward" in logs
|
||||
assert "loss_rewind_reverse" in logs
|
||||
# Expect ReWiND loss terms (progress and mismatch)
|
||||
assert "loss_progress" in logs
|
||||
assert "loss_mismatch" in logs
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_package("sentence_transformers")
|
||||
def test_rlearn_preprocessor_tokenizes_and_copies_task():
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||
device="cpu",
|
||||
push_to_hub=False,
|
||||
)
|
||||
@@ -161,9 +164,11 @@ def test_rlearn_preprocessor_tokenizes_and_copies_task():
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_package("sentence_transformers")
|
||||
def test_rlearn_preprocessor_string_task_and_to_batch():
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||
device="cpu",
|
||||
push_to_hub=False,
|
||||
)
|
||||
@@ -194,14 +199,16 @@ def test_rlearn_preprocessor_string_task_and_to_batch():
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_package("sentence_transformers")
|
||||
def test_rlearn_pipeline_end_to_end_forward():
|
||||
"""End-to-end: preprocessor + model forward using RLearN pipeline on synthetic data."""
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||
device="cpu",
|
||||
push_to_hub=False,
|
||||
freeze_backbones=True,
|
||||
loss_type="composite",
|
||||
use_video_rewind=True,
|
||||
)
|
||||
cfg.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
|
||||
Reference in New Issue
Block a user