mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
dino v2
This commit is contained in:
@@ -33,12 +33,14 @@ class RLearNConfig(PreTrainedConfig):
|
|||||||
- Per-timestep reward logits or a single-step reward logit.
|
- Per-timestep reward logits or a single-step reward logit.
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- This is the initial architecture. It uses frozen vision/text encoders
|
- This follows the ReWiND paper architecture. It uses frozen vision/text encoders
|
||||||
(e.g. SigLIP2) and trains a lightweight temporal aggregator + head.
|
(DINO v3 for vision, sentence-transformers for language) and trains a
|
||||||
|
lightweight temporal aggregator + head.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Encoders
|
# Encoders - Using DINOv2 (base) for vision and sentence-transformers for text (ReWiND paper)
|
||||||
model_name: str = "google/siglip2-base-patch16-256"
|
vision_model_name: str = "facebook/dinov2-base"
|
||||||
|
text_model_name: str = "sentence-transformers/all-MiniLM-L12-v2"
|
||||||
freeze_backbones: bool = True
|
freeze_backbones: bool = True
|
||||||
|
|
||||||
# Temporal aggregator
|
# Temporal aggregator
|
||||||
|
|||||||
@@ -323,8 +323,8 @@ class RLearnEvaluator:
|
|||||||
|
|
||||||
T, C, H, W = frames.shape
|
T, C, H, W = frames.shape
|
||||||
|
|
||||||
# Expected input size for SigLIP2 is typically 256x256
|
# Expected input size for DINO v3 is 224x224
|
||||||
target_size = 256
|
target_size = 224
|
||||||
|
|
||||||
# Resize frames if needed
|
# Resize frames if needed
|
||||||
if H != target_size or W != target_size:
|
if H != target_size or W != target_size:
|
||||||
@@ -402,12 +402,12 @@ class RLearnEvaluator:
|
|||||||
if len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]:
|
if len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]:
|
||||||
img = img.permute(2, 0, 1) # HWC -> CHW
|
img = img.permute(2, 0, 1) # HWC -> CHW
|
||||||
|
|
||||||
# Resize to expected input size (256x256 for SigLIP2) BEFORE stacking
|
# Resize to expected input size (224x224 for DINO v3) BEFORE stacking
|
||||||
if img.shape[-2:] != (256, 256):
|
if img.shape[-2:] != (224, 224):
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
img = F.interpolate(
|
img = F.interpolate(
|
||||||
img.unsqueeze(0), size=(256, 256), mode="bilinear", align_corners=False
|
img.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
|
|
||||||
# Normalize to [0, 1] if needed
|
# Normalize to [0, 1] if needed
|
||||||
@@ -527,12 +527,12 @@ class RLearnEvaluator:
|
|||||||
if len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]:
|
if len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]:
|
||||||
img = img.permute(2, 0, 1)
|
img = img.permute(2, 0, 1)
|
||||||
|
|
||||||
# Resize to expected input size (256x256 for SigLIP2)
|
# Resize to expected input size (224x224 for DINO v3)
|
||||||
if img.shape[-2:] != (256, 256):
|
if img.shape[-2:] != (224, 224):
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
img = F.interpolate(
|
img = F.interpolate(
|
||||||
img.unsqueeze(0), size=(256, 256), mode="bilinear", align_corners=False
|
img.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
|
|
||||||
# Normalize to [0, 1] if needed
|
# Normalize to [0, 1] if needed
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
"""
|
"""
|
||||||
RLearN: Video-Language Conditioned Reward Model (ReWiND Implementation)
|
RLearN: Video-Language Conditioned Reward Model (ReWiND Implementation)
|
||||||
|
|
||||||
This implementation follows the ReWiND paper approach:
|
This implementation follows the ReWiND paper approach (arXiv:2505.10911v1):
|
||||||
- Automatically generates linear progress labels (0 to 1) for each episode
|
- Automatically generates linear progress labels (0 to 1) for each episode
|
||||||
- No need for pre-annotated rewards in the dataset
|
- No need for pre-annotated rewards in the dataset
|
||||||
- Applies video rewinding augmentation to create synthetic failure trajectories
|
- Applies video rewinding augmentation to create synthetic failure trajectories
|
||||||
@@ -33,7 +33,7 @@ High-level Architecture
|
|||||||
| per-frame encode
|
| per-frame encode
|
||||||
v
|
v
|
||||||
+------------------------------+
|
+------------------------------+
|
||||||
| Vision Encoder (frozen) | e.g. SigLIP2 vision tower
|
| Vision Encoder (frozen) | e.g. DINOv2 (base)
|
||||||
+------------------------------+
|
+------------------------------+
|
||||||
|s
|
|s
|
||||||
| pooled per-frame embeddings (BT, H_v)
|
| pooled per-frame embeddings (BT, H_v)
|
||||||
@@ -46,7 +46,7 @@ High-level Architecture
|
|||||||
| |
|
| |
|
||||||
| v
|
| v
|
||||||
| +------------------------------+
|
| +------------------------------+
|
||||||
| | Text Encoder (frozen) | e.g. SigLIP2 text tower
|
| | Text Encoder (frozen) | e.g. sentence-transformers
|
||||||
| +------------------------------+
|
| +------------------------------+
|
||||||
| |
|
| |
|
||||||
| | pooled text embedding (B, H_t)
|
| | pooled text embedding (B, H_t)
|
||||||
@@ -67,10 +67,8 @@ High-level Architecture
|
|||||||
Output
|
Output
|
||||||
- reward_logits: (B, T', 1) with T' ≤ T (affected by stride and frame dropout)
|
- reward_logits: (B, T', 1) with T' ≤ T (affected by stride and frame dropout)
|
||||||
|
|
||||||
Training
|
|
||||||
- Loss: composite loss with progress regression, spatial-aware InfoNCE, and ReWiND reversible ranking
|
|
||||||
|
|
||||||
Notes
|
Notes
|
||||||
|
- Uses DINOv2 (base, ~ViT-B) for vision and sentence-transformers (all-MiniLM-L12-v2) for text encoding.
|
||||||
- Backbones (vision/text) are frozen by default; only projections, temporal module, and head are trainable.
|
- Backbones (vision/text) are frozen by default; only projections, temporal module, and head are trainable.
|
||||||
- Stride/frame dropout applied during training can subsample timesteps.
|
- Stride/frame dropout applied during training can subsample timesteps.
|
||||||
"""
|
"""
|
||||||
@@ -91,8 +89,8 @@ from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
|
|||||||
class RLearNPolicy(PreTrainedPolicy):
|
class RLearNPolicy(PreTrainedPolicy):
|
||||||
"""Video-language conditioned reward model.
|
"""Video-language conditioned reward model.
|
||||||
|
|
||||||
- Visual encoder: frozen SigLIP2 (via transformers AutoModel), returns per-frame embeddings.
|
- Visual encoder: frozen DINOv2 (base), returns per-frame embeddings.
|
||||||
- Text encoder: frozen SigLIP2 text tower, returns a language embedding.
|
- Text encoder: frozen sentence-transformers (all-MiniLM-L12-v2), returns a language embedding.
|
||||||
- Temporal module: causal transformer over time that cross-attends to language embedding.
|
- Temporal module: causal transformer over time that cross-attends to language embedding.
|
||||||
- Output: per-timestep reward logits; trainable small head.
|
- Output: per-timestep reward logits; trainable small head.
|
||||||
"""
|
"""
|
||||||
@@ -105,24 +103,20 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.episode_data_index = episode_data_index # Store episode boundaries for progress calculation
|
self.episode_data_index = episode_data_index # Store episode boundaries for progress calculation
|
||||||
|
|
||||||
# Encoders
|
# Encoders - ReWiND paper setup: DINOv2 for vision, sentence-transformers for text
|
||||||
from transformers import AutoModel, AutoProcessor
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
self.vision_text_model = AutoModel.from_pretrained(config.model_name, trust_remote_code=True)
|
|
||||||
self.processor = AutoProcessor.from_pretrained(config.model_name, trust_remote_code=True)
|
# Load DINOv2 (base) vision encoder with its processor
|
||||||
|
self.vision_processor = AutoImageProcessor.from_pretrained(config.vision_model_name)
|
||||||
# Detect towers
|
self.vision_encoder = AutoModel.from_pretrained(config.vision_model_name)
|
||||||
if hasattr(self.vision_text_model, "vision_model") and hasattr(self.vision_text_model, "text_model"):
|
|
||||||
self.vision_encoder = self.vision_text_model.vision_model
|
# Load sentence-transformers text encoder
|
||||||
self.text_encoder = self.vision_text_model.text_model
|
self.text_encoder = SentenceTransformer(config.text_model_name)
|
||||||
self.vision_hidden = getattr(self.vision_text_model.config, "vision_config", None).hidden_size
|
|
||||||
self.text_hidden = getattr(self.vision_text_model.config, "text_config", None).hidden_size
|
# DINOv2-base has 768 hidden size, all-MiniLM-L12-v2 has 384
|
||||||
else:
|
self.vision_hidden = 768 # DINOv2-base
|
||||||
# Fallback if AutoModel exposes pooled outputs directly (rare for SigLIP2)
|
self.text_hidden = 384 # all-MiniLM-L12-v2
|
||||||
self.vision_encoder = self.vision_text_model
|
|
||||||
self.text_encoder = self.vision_text_model
|
|
||||||
self.vision_hidden = getattr(self.vision_text_model.config, "hidden_size", 768)
|
|
||||||
self.text_hidden = getattr(self.vision_text_model.config, "hidden_size", 768)
|
|
||||||
|
|
||||||
if config.freeze_backbones:
|
if config.freeze_backbones:
|
||||||
for p in self.vision_encoder.parameters():
|
for p in self.vision_encoder.parameters():
|
||||||
@@ -208,33 +202,46 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
frames = frames[:, idx]
|
frames = frames[:, idx]
|
||||||
B, T_eff, C, H, W = frames.shape # NEW: effective length after stride
|
B, T_eff, C, H, W = frames.shape # NEW: effective length after stride
|
||||||
|
|
||||||
# Encode language
|
# Encode language using sentence-transformers
|
||||||
lang_emb = encode_language(
|
lang_emb = encode_language(
|
||||||
batch.get(OBS_LANGUAGE, None), self.text_encoder, self.processor, batch_size=B
|
batch.get(OBS_LANGUAGE, None), self.text_encoder, batch_size=B
|
||||||
)
|
)
|
||||||
|
# Ensure embeddings are normal tensors on the correct device (not inference tensors)
|
||||||
|
lang_emb = lang_emb.detach().clone().to(self.text_proj.weight.device)
|
||||||
lang_emb = self.text_proj(lang_emb) # (B, D)
|
lang_emb = self.text_proj(lang_emb) # (B, D)
|
||||||
|
|
||||||
# Use the HF processor to standardize size & normalization
|
# Process frames with DINOv2
|
||||||
# Flatten (B, T_eff, C, H, W) -> (BT, C, H, W)
|
# Flatten (B, T_eff, C, H, W) -> (BT, C, H, W)
|
||||||
BT = B * T_eff
|
BT = B * T_eff
|
||||||
flat = frames.reshape(BT, C, H, W).detach().cpu()
|
flat = frames.reshape(BT, C, H, W)
|
||||||
|
|
||||||
# Convert to uint8 HWC numpy (processor prefers PIL/np)
|
# Convert to list of PIL images or numpy arrays for the processor
|
||||||
# If already in [0,1], scale to [0,255]
|
# DINOv2 processor expects images in HWC format
|
||||||
if flat.dtype != torch.uint8:
|
images_list = []
|
||||||
if flat.numel() > 0 and float(flat.max()) <= 1.0:
|
for i in range(BT):
|
||||||
flat = flat * 255.0
|
img = flat[i] # (C, H, W)
|
||||||
flat = flat.clamp(0, 255).round().to(torch.uint8)
|
# Convert to HWC format
|
||||||
|
img = img.permute(1, 2, 0) # (H, W, C)
|
||||||
|
|
||||||
|
# Convert to numpy if needed
|
||||||
|
if img.dtype == torch.uint8:
|
||||||
|
img = img.cpu().numpy()
|
||||||
|
else:
|
||||||
|
# Convert to uint8 range
|
||||||
|
img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
|
||||||
|
|
||||||
|
images_list.append(img)
|
||||||
|
|
||||||
|
# Process with DINOv2 processor
|
||||||
|
processed = self.vision_processor(images=images_list, return_tensors="pt")
|
||||||
|
pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device)
|
||||||
|
|
||||||
images = [flat[k].permute(1, 2, 0).numpy() for k in range(flat.size(0))]
|
# Encode frames through DINOv2
|
||||||
|
vision_outputs = self.vision_encoder(pixel_values)
|
||||||
|
|
||||||
proc_out = self.processor(images=images, return_tensors="pt")
|
# Extract CLS tokens for temporal modeling
|
||||||
pixel_values = proc_out["pixel_values"].to(next(self.vision_encoder.parameters()).device)
|
# DINOv2 outputs last_hidden_state of shape (batch_size, sequence_length, hidden_size)
|
||||||
|
# The CLS token is the first token
|
||||||
# Encode frames through visual tower per frame
|
|
||||||
vision_outputs = self.vision_encoder(pixel_values=pixel_values)
|
|
||||||
|
|
||||||
# Extract CLS tokens for temporal modeling
|
|
||||||
if hasattr(vision_outputs, "last_hidden_state"):
|
if hasattr(vision_outputs, "last_hidden_state"):
|
||||||
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision)
|
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision)
|
||||||
else:
|
else:
|
||||||
@@ -302,38 +309,45 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
idx = torch.tensor([0], device=frames.device)
|
idx = torch.tensor([0], device=frames.device)
|
||||||
frames = frames[:, idx]
|
frames = frames[:, idx]
|
||||||
|
|
||||||
# Encode language
|
# Encode language using sentence-transformers
|
||||||
lang_emb = encode_language(
|
lang_emb = encode_language(
|
||||||
batch.get(OBS_LANGUAGE, None), self.text_encoder, self.processor, batch_size=B
|
batch.get(OBS_LANGUAGE, None), self.text_encoder, batch_size=B
|
||||||
)
|
)
|
||||||
|
# Ensure embeddings are normal tensors on the correct device (not inference tensors)
|
||||||
|
lang_emb = lang_emb.detach().clone().to(self.text_proj.weight.device)
|
||||||
lang_emb = self.text_proj(lang_emb) # (B, D)
|
lang_emb = self.text_proj(lang_emb) # (B, D)
|
||||||
|
|
||||||
# Encode frames through visual tower per frame
|
# Encode frames through DINOv2 visual encoder
|
||||||
# Flatten time for batched encode
|
# Flatten time for batched encode
|
||||||
BT = B * frames.shape[1]
|
BT = B * frames.shape[1]
|
||||||
flat = frames.reshape(BT, C, H, W)
|
flat = frames.reshape(BT, C, H, W)
|
||||||
|
|
||||||
# Use HF processor to properly resize and normalize images
|
# Convert to list of PIL images or numpy arrays for the processor
|
||||||
# Convert to CPU for processing, then move back to device
|
# DINOv2 processor expects images in HWC format
|
||||||
flat_cpu = flat.detach().cpu()
|
images_list = []
|
||||||
|
for i in range(BT):
|
||||||
|
img = flat[i] # (C, H, W)
|
||||||
|
# Convert to HWC format
|
||||||
|
img = img.permute(1, 2, 0) # (H, W, C)
|
||||||
|
|
||||||
|
# Convert to numpy if needed
|
||||||
|
if img.dtype == torch.uint8:
|
||||||
|
img = img.cpu().numpy()
|
||||||
|
else:
|
||||||
|
# Convert to uint8 range
|
||||||
|
img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
|
||||||
|
|
||||||
|
images_list.append(img)
|
||||||
|
|
||||||
|
# Process with DINOv2 processor
|
||||||
|
processed = self.vision_processor(images=images_list, return_tensors="pt")
|
||||||
|
pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device)
|
||||||
|
|
||||||
# Convert to uint8 HWC numpy format expected by processor
|
# Encode through DINOv2 model
|
||||||
if flat_cpu.dtype != torch.uint8:
|
vision_outputs = self.vision_encoder(pixel_values)
|
||||||
if flat_cpu.numel() > 0 and float(flat_cpu.max()) <= 1.0:
|
|
||||||
flat_cpu = flat_cpu * 255.0
|
|
||||||
flat_cpu = flat_cpu.clamp(0, 255).round().to(torch.uint8)
|
|
||||||
|
|
||||||
# Convert to list of numpy arrays
|
|
||||||
images = [flat_cpu[k].permute(1, 2, 0).numpy() for k in range(flat_cpu.size(0))]
|
|
||||||
|
|
||||||
# Process with HF processor (resizes to 256x256 and normalizes)
|
|
||||||
proc_out = self.processor(images=images, return_tensors="pt")
|
|
||||||
pixel_values = proc_out["pixel_values"].to(next(self.vision_encoder.parameters()).device)
|
|
||||||
|
|
||||||
# Encode through vision model
|
|
||||||
vision_outputs = self.vision_encoder(pixel_values=pixel_values)
|
|
||||||
|
|
||||||
# Extract CLS token for temporal modeling
|
# Extract CLS token for temporal modeling
|
||||||
|
# DINOv2 outputs last_hidden_state of shape (batch_size, sequence_length, hidden_size)
|
||||||
if hasattr(vision_outputs, "last_hidden_state"):
|
if hasattr(vision_outputs, "last_hidden_state"):
|
||||||
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D) - CLS token
|
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D) - CLS token
|
||||||
else:
|
else:
|
||||||
@@ -430,48 +444,6 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Apply stride/dropout indexing to match the processed frames
|
# Apply stride/dropout indexing to match the processed frames
|
||||||
target = target[:, idx]
|
target = target[:, idx]
|
||||||
|
|
||||||
elif "index" in batch and hasattr(self, "episode_data_index"):
|
|
||||||
# Fallback: Use global index if available
|
|
||||||
global_indices = batch["index"] # Shape: (B,)
|
|
||||||
|
|
||||||
# For each index, find which episode it belongs to and its position
|
|
||||||
progress_values = []
|
|
||||||
|
|
||||||
for global_idx in global_indices:
|
|
||||||
# Find which episode this index belongs to
|
|
||||||
episode_starts = self.episode_data_index["from"]
|
|
||||||
episode_ends = self.episode_data_index["to"]
|
|
||||||
|
|
||||||
# Find the episode by checking which range the index falls into
|
|
||||||
episode_idx = None
|
|
||||||
frame_in_episode = None
|
|
||||||
for ep_idx in range(len(episode_starts)):
|
|
||||||
if episode_starts[ep_idx] <= global_idx < episode_ends[ep_idx]:
|
|
||||||
episode_idx = ep_idx
|
|
||||||
frame_in_episode = global_idx.item() - episode_starts[ep_idx].item()
|
|
||||||
break
|
|
||||||
|
|
||||||
if episode_idx is not None:
|
|
||||||
# Calculate position within episode
|
|
||||||
ep_start = episode_starts[episode_idx].item()
|
|
||||||
ep_end = episode_ends[episode_idx].item()
|
|
||||||
ep_length = ep_end - ep_start
|
|
||||||
|
|
||||||
# Progress from 0 to 1 within the episode
|
|
||||||
progress = frame_in_episode / max(1, ep_length - 1)
|
|
||||||
else:
|
|
||||||
# Fallback if we can't find the episode (shouldn't happen)
|
|
||||||
progress = 0.5
|
|
||||||
|
|
||||||
progress_values.append(progress)
|
|
||||||
|
|
||||||
# For temporal window, use simplified linear progress
|
|
||||||
# (proper calculation would need all frame indices in the window)
|
|
||||||
T_effective = len(idx)
|
|
||||||
target = torch.tensor(progress_values, device=values.device, dtype=values.dtype)
|
|
||||||
target = target.unsqueeze(1).expand(B, T_effective) # Simple expansion
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No episode information found in batch. Please ensure 'episode_index' and 'frame_index' keys are present."
|
"No episode information found in batch. Please ensure 'episode_index' and 'frame_index' keys are present."
|
||||||
@@ -698,8 +670,9 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None
|
|||||||
|
|
||||||
|
|
||||||
def encode_language(
|
def encode_language(
|
||||||
language_input: Tensor | list | str | None, text_encoder, processor, batch_size: int
|
language_input: Tensor | list | str | None, text_encoder, batch_size: int
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
"""Encode language using sentence-transformers (ReWiND paper setup)."""
|
||||||
# language_input can be: list[str] length B, or None
|
# language_input can be: list[str] length B, or None
|
||||||
if language_input is None:
|
if language_input is None:
|
||||||
texts = [""] * batch_size
|
texts = [""] * batch_size
|
||||||
@@ -709,16 +682,12 @@ def encode_language(
|
|||||||
# Single string for the batch
|
# Single string for the batch
|
||||||
texts = [str(language_input)] * batch_size
|
texts = [str(language_input)] * batch_size
|
||||||
|
|
||||||
inputs = processor(text=texts, padding=True, return_tensors="pt")
|
# For sentence-transformers, we can directly encode
|
||||||
inputs = {k: v.to(next(text_encoder.parameters()).device) for k, v in inputs.items()}
|
# Returns tensor of shape (batch_size, embedding_dim)
|
||||||
outputs = text_encoder(**inputs)
|
device = next(iter(text_encoder.parameters())).device if hasattr(text_encoder, 'parameters') else 'cpu'
|
||||||
if hasattr(outputs, "pooler_output"):
|
embeddings = text_encoder.encode(texts, convert_to_tensor=True, device=device)
|
||||||
emb = outputs.pooler_output
|
|
||||||
elif hasattr(outputs, "last_hidden_state"):
|
return embeddings
|
||||||
emb = outputs.last_hidden_state[:, 0]
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Unsupported text encoder output structure")
|
|
||||||
return emb
|
|
||||||
|
|
||||||
|
|
||||||
def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor, Tensor]:
|
def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor, Tensor]:
|
||||||
|
|||||||
@@ -60,9 +60,9 @@ def make_rlearn_processor(
|
|||||||
),
|
),
|
||||||
ToBatchProcessor(),
|
ToBatchProcessor(),
|
||||||
RLearnLanguageFromTaskProcessor(),
|
RLearnLanguageFromTaskProcessor(),
|
||||||
# Use the same model name for tokenizer to keep vocab aligned with text tower
|
# Use the text model name for tokenizer to keep vocab aligned with text tower
|
||||||
TokenizerProcessor(
|
TokenizerProcessor(
|
||||||
tokenizer_name=config.model_name,
|
tokenizer_name=config.text_model_name,
|
||||||
max_length=128,
|
max_length=128,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
truncation=True,
|
truncation=True,
|
||||||
|
|||||||
@@ -31,7 +31,9 @@ Little less relevant but still similar papers:
|
|||||||
Input should be the current image or whole video and the task goal specified in text/language. Output is current reward.
|
Input should be the current image or whole video and the task goal specified in text/language. Output is current reward.
|
||||||
Archiutecture:
|
Archiutecture:
|
||||||
_ inputs: video o1:T (or current o1:t), language z;
|
_ inputs: video o1:T (or current o1:t), language z;
|
||||||
_ google/siglip2-large-patch16-256: https://huggingface.co/google/siglip2-large-patch16-256 \* Temporal module: small causal transformer (“cross-modal sequential aggregator”), with first-frame positional embedding (to avoid position cheating), frame-dropout, and stride sampling; outputs per-timestep logits.
|
_ DINO v3 ViT-B/16 (86M params): https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m for vision encoding
|
||||||
|
_ sentence-transformers/all-MiniLM-L12-v2: https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 for text encoding
|
||||||
|
\* Temporal module: small causal transformer ("cross-modal sequential aggregator"), with first-frame positional embedding (to avoid position cheating), frame-dropout, and stride sampling; outputs per-timestep logits.
|
||||||
|
|
||||||
Loss: See this chatgpt thread: https://chatgpt.com/s/t_68999a50a0b081919abc365cdd205e01
|
Loss: See this chatgpt thread: https://chatgpt.com/s/t_68999a50a0b081919abc365cdd205e01
|
||||||
|
|
||||||
@@ -59,59 +61,6 @@ _ GTEA+ Gaze: https://cbs.ic.gatech.edu/fpv/
|
|||||||
_ YouCook2 dataset
|
_ YouCook2 dataset
|
||||||
_ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
|
_ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
|
||||||
|
|
||||||
### Implemented Loss (Spatial-Aware Composite Loss)
|
|
||||||
|
|
||||||
Our implementation uses a **composite loss with spatial awareness** to address the limitations of standard contrastive learning (e.g., CLIP's inability to distinguish "move left" vs "move right"). The loss has three components:
|
|
||||||
|
|
||||||
##### 1) Progress Regression Loss (L_prog)
|
|
||||||
|
|
||||||
Predicts normalized progress values for each timestep:
|
|
||||||
|
|
||||||
$$
|
|
||||||
L_{\text{prog}} = \text{MSE}(\sigma(z(V_t)), y_t)
|
|
||||||
$$
|
|
||||||
|
|
||||||
where $z(·)$ is z-score normalization, $\sigma$ is sigmoid, and $y_t \in [0,1]$ is the progress label.
|
|
||||||
**Purpose:** Grounds the model in actual task progress, not just visual-language similarity.
|
|
||||||
|
|
||||||
##### 2) Spatial-Aware InfoNCE Loss (L_spatial_nce)
|
|
||||||
|
|
||||||
Instead of using pooled features, we:
|
|
||||||
|
|
||||||
- Extract spatial patch features from SigLIP2's last_hidden_state (e.g., 16×16 patches)
|
|
||||||
- Use cross-attention where language queries attend to relevant spatial regions
|
|
||||||
- Compute contrastive loss on the attended spatial features
|
|
||||||
|
|
||||||
$$
|
|
||||||
L_{\text{spatial-nce}} = -\log \frac{\exp(s_{ii}/\tau)}{\sum_j \exp(s_{ij}/\tau)}
|
|
||||||
$$
|
|
||||||
|
|
||||||
where $s_{ij}$ is the similarity between spatially-attended features from trajectory $i$ and language $j$.
|
|
||||||
**Purpose:** Preserves spatial information that pooling discards, enabling distinction of spatial relationships.
|
|
||||||
|
|
||||||
##### 3) ReWiND Reversible Ranking Loss (L_rewind)
|
|
||||||
|
|
||||||
Based on ReWiND's key insight: learn from both forward AND reversed trajectories.
|
|
||||||
The loss has two components:
|
|
||||||
|
|
||||||
- **Forward ranking**: Sample (far, near) pairs where near is later in time, enforce $V_{\text{near}} > V_{\text{far}}$
|
|
||||||
- **Reverse ranking**: Reverse the trajectory and invert progress labels, then apply same ranking
|
|
||||||
|
|
||||||
$$
|
|
||||||
L_{\text{rewind}} = L_{\text{forward}} + L_{\text{reverse}}
|
|
||||||
$$
|
|
||||||
|
|
||||||
where both use: $\text{softplus}(m - (V_{\text{near}} - V_{\text{far}}))$
|
|
||||||
|
|
||||||
**Purpose:** By training on reversed trajectories with inverted progress, the model learns to distinguish progress from undoing progress. This is ReWiND's core contribution - understanding that tasks can be reversible.
|
|
||||||
|
|
||||||
##### Total Loss:
|
|
||||||
|
|
||||||
$$
|
|
||||||
L = \lambda_{\text{prog}} L_{\text{prog}} + \lambda_{\text{spatial-nce}} L_{\text{spatial-nce}} + \lambda_{\text{rewind}} L_{\text{rewind}}
|
|
||||||
$$
|
|
||||||
|
|
||||||
Default weights: $\lambda_{\text{prog}}=1.0$, $\lambda_{\text{spatial-nce}}=0.5$, $\lambda_{\text{rewind}}=0.4$
|
|
||||||
|
|
||||||
### TODOs:
|
### TODOs:
|
||||||
|
|
||||||
@@ -126,17 +75,20 @@ Default weights: $\lambda_{\text{prog}}=1.0$, $\lambda_{\text{spatial-nce}}=0.5$
|
|||||||
- Implement on-the-fly progress label generation (no need for pre-annotated rewards) [x]
|
- Implement on-the-fly progress label generation (no need for pre-annotated rewards) [x]
|
||||||
- Try different losses
|
- Try different losses
|
||||||
- Only rewind loss [x]
|
- Only rewind loss [x]
|
||||||
|
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
|
||||||
|
- check code is same as rewind repo code (architecture and trainign details) []
|
||||||
- Test only rewind loss (evaluate) []
|
- Test only rewind loss (evaluate) []
|
||||||
- Check rewind implementatyion by hand []
|
- Check rewind implementation by hand/cleanup []
|
||||||
- Only vlc loss then eval []
|
- Only vlc loss then eval []
|
||||||
- Vlc + rewind loss then eval []
|
- Vlc + Rewind loss then eval []
|
||||||
- Cleanup code []
|
- Cleanup code []
|
||||||
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent
|
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent
|
||||||
- Then on 10 percent
|
- Then on 10 percent
|
||||||
- Try DINO v3 as encoder Base 86 M: https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m with HuggingFaceTB/SmolLM2-135M-Instruct ? []
|
- Ablation dino v2 vs dino v3 base 86 M
|
||||||
- Add more artificial text to dataset generated by vlm (google gemini) []
|
- Add more artificial text to dataset generated by vlm (google gemini) []
|
||||||
- See google gemini vlm caption [] https://gemini.google.com/app/7e332ffaf32580f2
|
- See google gemini vlm caption [] https://gemini.google.com/app/7e332ffaf32580f2
|
||||||
- Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446, https://arxiv.org/pdf/2412.04453
|
- Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446, https://arxiv.org/pdf/2412.04453
|
||||||
- How can we improve spatial aware learning? co generating captions for each frame with language decoder?
|
- How can we improve spatial aware learning? co generating captions for each frame with language decoder?
|
||||||
- Extend evaluation []
|
- Extend evaluation []
|
||||||
- Add other datasets mentioned above []
|
- Add other datasets mentioned above []
|
||||||
|
- Ablation for size vision encoder, language encoder, temporal head
|
||||||
|
|||||||
@@ -25,10 +25,12 @@ from tests.utils import require_package
|
|||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
|
@require_package("sentence_transformers")
|
||||||
def test_rlearn_instantiation_and_forward_tensor_batch():
|
def test_rlearn_instantiation_and_forward_tensor_batch():
|
||||||
"""Instantiate RLearN and run a forward pass with a (B, T, C, H, W) tensor input using a real model and real text."""
|
"""Instantiate RLearN and run a forward pass with a (B, T, C, H, W) tensor input using a real model and real text."""
|
||||||
cfg = RLearNConfig(
|
cfg = RLearNConfig(
|
||||||
model_name="google/siglip2-large-patch16-256",
|
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||||
|
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
freeze_backbones=True,
|
freeze_backbones=True,
|
||||||
)
|
)
|
||||||
@@ -54,10 +56,12 @@ def test_rlearn_instantiation_and_forward_tensor_batch():
|
|||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
|
@require_package("sentence_transformers")
|
||||||
def test_rlearn_instantiation_and_forward_list_batch_with_language():
|
def test_rlearn_instantiation_and_forward_list_batch_with_language():
|
||||||
"""Instantiate RLearN and run a forward pass with a list-of-frames input and real language using a real model."""
|
"""Instantiate RLearN and run a forward pass with a list-of-frames input and real language using a real model."""
|
||||||
cfg = RLearNConfig(
|
cfg = RLearNConfig(
|
||||||
model_name="google/siglip2-large-patch16-256",
|
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||||
|
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
freeze_backbones=True,
|
freeze_backbones=True,
|
||||||
)
|
)
|
||||||
@@ -84,18 +88,17 @@ def test_rlearn_instantiation_and_forward_list_batch_with_language():
|
|||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
|
@require_package("sentence_transformers")
|
||||||
def test_rlearn_composite_loss_shapes_and_terms():
|
def test_rlearn_composite_loss_shapes_and_terms():
|
||||||
"""Smoke test composite loss: checks presence of terms and valid gradients."""
|
"""Smoke test composite loss: checks presence of terms and valid gradients."""
|
||||||
cfg = RLearNConfig(
|
cfg = RLearNConfig(
|
||||||
model_name="google/siglip2-large-patch16-256",
|
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||||
|
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
freeze_backbones=True,
|
freeze_backbones=True,
|
||||||
loss_type="composite",
|
use_video_rewind=True,
|
||||||
lambda_prog=1.0,
|
rewind_prob=0.5,
|
||||||
lambda_spatial_nce=0.5,
|
use_mismatch_loss=True,
|
||||||
lambda_rewind=0.4,
|
|
||||||
num_ranking_pairs=32, # Fewer pairs for testing
|
|
||||||
last_k_for_nce=2,
|
|
||||||
)
|
)
|
||||||
cfg.input_features = {
|
cfg.input_features = {
|
||||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||||
@@ -117,17 +120,17 @@ def test_rlearn_composite_loss_shapes_and_terms():
|
|||||||
|
|
||||||
loss, logs = policy.forward(batch)
|
loss, logs = policy.forward(batch)
|
||||||
assert isinstance(loss, torch.Tensor) and torch.isfinite(loss)
|
assert isinstance(loss, torch.Tensor) and torch.isfinite(loss)
|
||||||
# Expect composite terms present with spatial awareness and ReWiND
|
# Expect ReWiND loss terms (progress and mismatch)
|
||||||
assert "loss_prog" in logs
|
assert "loss_progress" in logs
|
||||||
assert "loss_spatial_nce" in logs
|
assert "loss_mismatch" in logs
|
||||||
assert "loss_rewind_forward" in logs
|
|
||||||
assert "loss_rewind_reverse" in logs
|
|
||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
|
@require_package("sentence_transformers")
|
||||||
def test_rlearn_preprocessor_tokenizes_and_copies_task():
|
def test_rlearn_preprocessor_tokenizes_and_copies_task():
|
||||||
cfg = RLearNConfig(
|
cfg = RLearNConfig(
|
||||||
model_name="google/siglip2-large-patch16-256",
|
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||||
|
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||||
device="cpu",
|
device="cpu",
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
)
|
)
|
||||||
@@ -161,9 +164,11 @@ def test_rlearn_preprocessor_tokenizes_and_copies_task():
|
|||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
|
@require_package("sentence_transformers")
|
||||||
def test_rlearn_preprocessor_string_task_and_to_batch():
|
def test_rlearn_preprocessor_string_task_and_to_batch():
|
||||||
cfg = RLearNConfig(
|
cfg = RLearNConfig(
|
||||||
model_name="google/siglip2-large-patch16-256",
|
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||||
|
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||||
device="cpu",
|
device="cpu",
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
)
|
)
|
||||||
@@ -194,14 +199,16 @@ def test_rlearn_preprocessor_string_task_and_to_batch():
|
|||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
|
@require_package("sentence_transformers")
|
||||||
def test_rlearn_pipeline_end_to_end_forward():
|
def test_rlearn_pipeline_end_to_end_forward():
|
||||||
"""End-to-end: preprocessor + model forward using RLearN pipeline on synthetic data."""
|
"""End-to-end: preprocessor + model forward using RLearN pipeline on synthetic data."""
|
||||||
cfg = RLearNConfig(
|
cfg = RLearNConfig(
|
||||||
model_name="google/siglip2-large-patch16-256",
|
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||||
|
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||||
device="cpu",
|
device="cpu",
|
||||||
push_to_hub=False,
|
push_to_hub=False,
|
||||||
freeze_backbones=True,
|
freeze_backbones=True,
|
||||||
loss_type="composite",
|
use_video_rewind=True,
|
||||||
)
|
)
|
||||||
cfg.input_features = {
|
cfg.input_features = {
|
||||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||||
|
|||||||
Reference in New Issue
Block a user