diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index 2d04f0b4f..c8b77ee2e 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -33,12 +33,14 @@ class RLearNConfig(PreTrainedConfig): - Per-timestep reward logits or a single-step reward logit. Notes: - - This is the initial architecture. It uses frozen vision/text encoders - (e.g. SigLIP2) and trains a lightweight temporal aggregator + head. + - This follows the ReWiND paper architecture. It uses frozen vision/text encoders + (DINO v3 for vision, sentence-transformers for language) and trains a + lightweight temporal aggregator + head. """ - # Encoders - model_name: str = "google/siglip2-base-patch16-256" + # Encoders - Using DINOv2 (base) for vision and sentence-transformers for text (ReWiND paper) + vision_model_name: str = "facebook/dinov2-base" + text_model_name: str = "sentence-transformers/all-MiniLM-L12-v2" freeze_backbones: bool = True # Temporal aggregator diff --git a/src/lerobot/policies/rlearn/evaluation.py b/src/lerobot/policies/rlearn/evaluation.py index 13f44e7cf..f673c2ad1 100644 --- a/src/lerobot/policies/rlearn/evaluation.py +++ b/src/lerobot/policies/rlearn/evaluation.py @@ -323,8 +323,8 @@ class RLearnEvaluator: T, C, H, W = frames.shape - # Expected input size for SigLIP2 is typically 256x256 - target_size = 256 + # Expected input size for DINO v3 is 224x224 + target_size = 224 # Resize frames if needed if H != target_size or W != target_size: @@ -402,12 +402,12 @@ class RLearnEvaluator: if len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]: img = img.permute(2, 0, 1) # HWC -> CHW - # Resize to expected input size (256x256 for SigLIP2) BEFORE stacking - if img.shape[-2:] != (256, 256): + # Resize to expected input size (224x224 for DINO v3) BEFORE stacking + if img.shape[-2:] != (224, 224): import torch.nn.functional as F img = F.interpolate( - img.unsqueeze(0), size=(256, 256), mode="bilinear", align_corners=False + img.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False ).squeeze(0) # Normalize to [0, 1] if needed @@ -527,12 +527,12 @@ class RLearnEvaluator: if len(img.shape) == 3 and img.shape[-1] in [1, 3, 4]: img = img.permute(2, 0, 1) - # Resize to expected input size (256x256 for SigLIP2) - if img.shape[-2:] != (256, 256): + # Resize to expected input size (224x224 for DINO v3) + if img.shape[-2:] != (224, 224): import torch.nn.functional as F img = F.interpolate( - img.unsqueeze(0), size=(256, 256), mode="bilinear", align_corners=False + img.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False ).squeeze(0) # Normalize to [0, 1] if needed diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 0bc7068a5..46d2b4819 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -17,7 +17,7 @@ """ RLearN: Video-Language Conditioned Reward Model (ReWiND Implementation) -This implementation follows the ReWiND paper approach: +This implementation follows the ReWiND paper approach (arXiv:2505.10911v1): - Automatically generates linear progress labels (0 to 1) for each episode - No need for pre-annotated rewards in the dataset - Applies video rewinding augmentation to create synthetic failure trajectories @@ -33,7 +33,7 @@ High-level Architecture | per-frame encode v +------------------------------+ - | Vision Encoder (frozen) | e.g. SigLIP2 vision tower + | Vision Encoder (frozen) | e.g. DINOv2 (base) +------------------------------+ |s | pooled per-frame embeddings (BT, H_v) @@ -46,7 +46,7 @@ High-level Architecture | | | v | +------------------------------+ - | | Text Encoder (frozen) | e.g. SigLIP2 text tower + | | Text Encoder (frozen) | e.g. sentence-transformers | +------------------------------+ | | | | pooled text embedding (B, H_t) @@ -67,10 +67,8 @@ High-level Architecture Output - reward_logits: (B, T', 1) with T' ≤ T (affected by stride and frame dropout) -Training - - Loss: composite loss with progress regression, spatial-aware InfoNCE, and ReWiND reversible ranking - Notes + - Uses DINOv2 (base, ~ViT-B) for vision and sentence-transformers (all-MiniLM-L12-v2) for text encoding. - Backbones (vision/text) are frozen by default; only projections, temporal module, and head are trainable. - Stride/frame dropout applied during training can subsample timesteps. """ @@ -91,8 +89,8 @@ from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig class RLearNPolicy(PreTrainedPolicy): """Video-language conditioned reward model. - - Visual encoder: frozen SigLIP2 (via transformers AutoModel), returns per-frame embeddings. - - Text encoder: frozen SigLIP2 text tower, returns a language embedding. + - Visual encoder: frozen DINOv2 (base), returns per-frame embeddings. + - Text encoder: frozen sentence-transformers (all-MiniLM-L12-v2), returns a language embedding. - Temporal module: causal transformer over time that cross-attends to language embedding. - Output: per-timestep reward logits; trainable small head. """ @@ -105,24 +103,20 @@ class RLearNPolicy(PreTrainedPolicy): self.config = config self.episode_data_index = episode_data_index # Store episode boundaries for progress calculation - # Encoders - from transformers import AutoModel, AutoProcessor - - self.vision_text_model = AutoModel.from_pretrained(config.model_name, trust_remote_code=True) - self.processor = AutoProcessor.from_pretrained(config.model_name, trust_remote_code=True) - - # Detect towers - if hasattr(self.vision_text_model, "vision_model") and hasattr(self.vision_text_model, "text_model"): - self.vision_encoder = self.vision_text_model.vision_model - self.text_encoder = self.vision_text_model.text_model - self.vision_hidden = getattr(self.vision_text_model.config, "vision_config", None).hidden_size - self.text_hidden = getattr(self.vision_text_model.config, "text_config", None).hidden_size - else: - # Fallback if AutoModel exposes pooled outputs directly (rare for SigLIP2) - self.vision_encoder = self.vision_text_model - self.text_encoder = self.vision_text_model - self.vision_hidden = getattr(self.vision_text_model.config, "hidden_size", 768) - self.text_hidden = getattr(self.vision_text_model.config, "hidden_size", 768) + # Encoders - ReWiND paper setup: DINOv2 for vision, sentence-transformers for text + from transformers import AutoImageProcessor, AutoModel + from sentence_transformers import SentenceTransformer + + # Load DINOv2 (base) vision encoder with its processor + self.vision_processor = AutoImageProcessor.from_pretrained(config.vision_model_name) + self.vision_encoder = AutoModel.from_pretrained(config.vision_model_name) + + # Load sentence-transformers text encoder + self.text_encoder = SentenceTransformer(config.text_model_name) + + # DINOv2-base has 768 hidden size, all-MiniLM-L12-v2 has 384 + self.vision_hidden = 768 # DINOv2-base + self.text_hidden = 384 # all-MiniLM-L12-v2 if config.freeze_backbones: for p in self.vision_encoder.parameters(): @@ -208,33 +202,46 @@ class RLearNPolicy(PreTrainedPolicy): frames = frames[:, idx] B, T_eff, C, H, W = frames.shape # NEW: effective length after stride - # Encode language + # Encode language using sentence-transformers lang_emb = encode_language( - batch.get(OBS_LANGUAGE, None), self.text_encoder, self.processor, batch_size=B + batch.get(OBS_LANGUAGE, None), self.text_encoder, batch_size=B ) + # Ensure embeddings are normal tensors on the correct device (not inference tensors) + lang_emb = lang_emb.detach().clone().to(self.text_proj.weight.device) lang_emb = self.text_proj(lang_emb) # (B, D) - # Use the HF processor to standardize size & normalization + # Process frames with DINOv2 # Flatten (B, T_eff, C, H, W) -> (BT, C, H, W) BT = B * T_eff - flat = frames.reshape(BT, C, H, W).detach().cpu() + flat = frames.reshape(BT, C, H, W) - # Convert to uint8 HWC numpy (processor prefers PIL/np) - # If already in [0,1], scale to [0,255] - if flat.dtype != torch.uint8: - if flat.numel() > 0 and float(flat.max()) <= 1.0: - flat = flat * 255.0 - flat = flat.clamp(0, 255).round().to(torch.uint8) + # Convert to list of PIL images or numpy arrays for the processor + # DINOv2 processor expects images in HWC format + images_list = [] + for i in range(BT): + img = flat[i] # (C, H, W) + # Convert to HWC format + img = img.permute(1, 2, 0) # (H, W, C) + + # Convert to numpy if needed + if img.dtype == torch.uint8: + img = img.cpu().numpy() + else: + # Convert to uint8 range + img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy() + + images_list.append(img) + + # Process with DINOv2 processor + processed = self.vision_processor(images=images_list, return_tensors="pt") + pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device) - images = [flat[k].permute(1, 2, 0).numpy() for k in range(flat.size(0))] + # Encode frames through DINOv2 + vision_outputs = self.vision_encoder(pixel_values) - proc_out = self.processor(images=images, return_tensors="pt") - pixel_values = proc_out["pixel_values"].to(next(self.vision_encoder.parameters()).device) - - # Encode frames through visual tower per frame - vision_outputs = self.vision_encoder(pixel_values=pixel_values) - - # Extract CLS tokens for temporal modeling + # Extract CLS tokens for temporal modeling + # DINOv2 outputs last_hidden_state of shape (batch_size, sequence_length, hidden_size) + # The CLS token is the first token if hasattr(vision_outputs, "last_hidden_state"): cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision) else: @@ -302,38 +309,45 @@ class RLearNPolicy(PreTrainedPolicy): idx = torch.tensor([0], device=frames.device) frames = frames[:, idx] - # Encode language + # Encode language using sentence-transformers lang_emb = encode_language( - batch.get(OBS_LANGUAGE, None), self.text_encoder, self.processor, batch_size=B + batch.get(OBS_LANGUAGE, None), self.text_encoder, batch_size=B ) + # Ensure embeddings are normal tensors on the correct device (not inference tensors) + lang_emb = lang_emb.detach().clone().to(self.text_proj.weight.device) lang_emb = self.text_proj(lang_emb) # (B, D) - # Encode frames through visual tower per frame + # Encode frames through DINOv2 visual encoder # Flatten time for batched encode BT = B * frames.shape[1] flat = frames.reshape(BT, C, H, W) - # Use HF processor to properly resize and normalize images - # Convert to CPU for processing, then move back to device - flat_cpu = flat.detach().cpu() + # Convert to list of PIL images or numpy arrays for the processor + # DINOv2 processor expects images in HWC format + images_list = [] + for i in range(BT): + img = flat[i] # (C, H, W) + # Convert to HWC format + img = img.permute(1, 2, 0) # (H, W, C) + + # Convert to numpy if needed + if img.dtype == torch.uint8: + img = img.cpu().numpy() + else: + # Convert to uint8 range + img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy() + + images_list.append(img) + + # Process with DINOv2 processor + processed = self.vision_processor(images=images_list, return_tensors="pt") + pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device) - # Convert to uint8 HWC numpy format expected by processor - if flat_cpu.dtype != torch.uint8: - if flat_cpu.numel() > 0 and float(flat_cpu.max()) <= 1.0: - flat_cpu = flat_cpu * 255.0 - flat_cpu = flat_cpu.clamp(0, 255).round().to(torch.uint8) - - # Convert to list of numpy arrays - images = [flat_cpu[k].permute(1, 2, 0).numpy() for k in range(flat_cpu.size(0))] - - # Process with HF processor (resizes to 256x256 and normalizes) - proc_out = self.processor(images=images, return_tensors="pt") - pixel_values = proc_out["pixel_values"].to(next(self.vision_encoder.parameters()).device) - - # Encode through vision model - vision_outputs = self.vision_encoder(pixel_values=pixel_values) + # Encode through DINOv2 model + vision_outputs = self.vision_encoder(pixel_values) # Extract CLS token for temporal modeling + # DINOv2 outputs last_hidden_state of shape (batch_size, sequence_length, hidden_size) if hasattr(vision_outputs, "last_hidden_state"): cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D) - CLS token else: @@ -430,48 +444,6 @@ class RLearNPolicy(PreTrainedPolicy): # Apply stride/dropout indexing to match the processed frames target = target[:, idx] - - elif "index" in batch and hasattr(self, "episode_data_index"): - # Fallback: Use global index if available - global_indices = batch["index"] # Shape: (B,) - - # For each index, find which episode it belongs to and its position - progress_values = [] - - for global_idx in global_indices: - # Find which episode this index belongs to - episode_starts = self.episode_data_index["from"] - episode_ends = self.episode_data_index["to"] - - # Find the episode by checking which range the index falls into - episode_idx = None - frame_in_episode = None - for ep_idx in range(len(episode_starts)): - if episode_starts[ep_idx] <= global_idx < episode_ends[ep_idx]: - episode_idx = ep_idx - frame_in_episode = global_idx.item() - episode_starts[ep_idx].item() - break - - if episode_idx is not None: - # Calculate position within episode - ep_start = episode_starts[episode_idx].item() - ep_end = episode_ends[episode_idx].item() - ep_length = ep_end - ep_start - - # Progress from 0 to 1 within the episode - progress = frame_in_episode / max(1, ep_length - 1) - else: - # Fallback if we can't find the episode (shouldn't happen) - progress = 0.5 - - progress_values.append(progress) - - # For temporal window, use simplified linear progress - # (proper calculation would need all frame indices in the window) - T_effective = len(idx) - target = torch.tensor(progress_values, device=values.device, dtype=values.dtype) - target = target.unsqueeze(1).expand(B, T_effective) # Simple expansion - else: raise ValueError( "No episode information found in batch. Please ensure 'episode_index' and 'frame_index' keys are present." @@ -698,8 +670,9 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None def encode_language( - language_input: Tensor | list | str | None, text_encoder, processor, batch_size: int + language_input: Tensor | list | str | None, text_encoder, batch_size: int ) -> Tensor: + """Encode language using sentence-transformers (ReWiND paper setup).""" # language_input can be: list[str] length B, or None if language_input is None: texts = [""] * batch_size @@ -709,16 +682,12 @@ def encode_language( # Single string for the batch texts = [str(language_input)] * batch_size - inputs = processor(text=texts, padding=True, return_tensors="pt") - inputs = {k: v.to(next(text_encoder.parameters()).device) for k, v in inputs.items()} - outputs = text_encoder(**inputs) - if hasattr(outputs, "pooler_output"): - emb = outputs.pooler_output - elif hasattr(outputs, "last_hidden_state"): - emb = outputs.last_hidden_state[:, 0] - else: - raise RuntimeError("Unsupported text encoder output structure") - return emb + # For sentence-transformers, we can directly encode + # Returns tensor of shape (batch_size, embedding_dim) + device = next(iter(text_encoder.parameters())).device if hasattr(text_encoder, 'parameters') else 'cpu' + embeddings = text_encoder.encode(texts, convert_to_tensor=True, device=device) + + return embeddings def apply_video_rewind(frames: Tensor, rewind_prob: float = 0.5) -> tuple[Tensor, Tensor]: diff --git a/src/lerobot/policies/rlearn/processor_rlearn.py b/src/lerobot/policies/rlearn/processor_rlearn.py index 85dc0c6d0..7bf87b979 100644 --- a/src/lerobot/policies/rlearn/processor_rlearn.py +++ b/src/lerobot/policies/rlearn/processor_rlearn.py @@ -60,9 +60,9 @@ def make_rlearn_processor( ), ToBatchProcessor(), RLearnLanguageFromTaskProcessor(), - # Use the same model name for tokenizer to keep vocab aligned with text tower + # Use the text model name for tokenizer to keep vocab aligned with text tower TokenizerProcessor( - tokenizer_name=config.model_name, + tokenizer_name=config.text_model_name, max_length=128, padding="max_length", truncation=True, diff --git a/src/lerobot/policies/rlearn/rlearn_plan.md b/src/lerobot/policies/rlearn/rlearn_plan.md index 4c1648adb..58295dafc 100644 --- a/src/lerobot/policies/rlearn/rlearn_plan.md +++ b/src/lerobot/policies/rlearn/rlearn_plan.md @@ -31,7 +31,9 @@ Little less relevant but still similar papers: Input should be the current image or whole video and the task goal specified in text/language. Output is current reward. Archiutecture: _ inputs: video o1:T (or current o1:t), language z; -_ google/siglip2-large-patch16-256: https://huggingface.co/google/siglip2-large-patch16-256 \* Temporal module: small causal transformer (“cross-modal sequential aggregator”), with first-frame positional embedding (to avoid position cheating), frame-dropout, and stride sampling; outputs per-timestep logits. +_ DINO v3 ViT-B/16 (86M params): https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m for vision encoding +_ sentence-transformers/all-MiniLM-L12-v2: https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 for text encoding +\* Temporal module: small causal transformer ("cross-modal sequential aggregator"), with first-frame positional embedding (to avoid position cheating), frame-dropout, and stride sampling; outputs per-timestep logits. Loss: See this chatgpt thread: https://chatgpt.com/s/t_68999a50a0b081919abc365cdd205e01 @@ -59,59 +61,6 @@ _ GTEA+ Gaze: https://cbs.ic.gatech.edu/fpv/ _ YouCook2 dataset _ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/ -### Implemented Loss (Spatial-Aware Composite Loss) - -Our implementation uses a **composite loss with spatial awareness** to address the limitations of standard contrastive learning (e.g., CLIP's inability to distinguish "move left" vs "move right"). The loss has three components: - -##### 1) Progress Regression Loss (L_prog) - -Predicts normalized progress values for each timestep: - -$$ -L_{\text{prog}} = \text{MSE}(\sigma(z(V_t)), y_t) -$$ - -where $z(·)$ is z-score normalization, $\sigma$ is sigmoid, and $y_t \in [0,1]$ is the progress label. -**Purpose:** Grounds the model in actual task progress, not just visual-language similarity. - -##### 2) Spatial-Aware InfoNCE Loss (L_spatial_nce) - -Instead of using pooled features, we: - -- Extract spatial patch features from SigLIP2's last_hidden_state (e.g., 16×16 patches) -- Use cross-attention where language queries attend to relevant spatial regions -- Compute contrastive loss on the attended spatial features - -$$ -L_{\text{spatial-nce}} = -\log \frac{\exp(s_{ii}/\tau)}{\sum_j \exp(s_{ij}/\tau)} -$$ - -where $s_{ij}$ is the similarity between spatially-attended features from trajectory $i$ and language $j$. -**Purpose:** Preserves spatial information that pooling discards, enabling distinction of spatial relationships. - -##### 3) ReWiND Reversible Ranking Loss (L_rewind) - -Based on ReWiND's key insight: learn from both forward AND reversed trajectories. -The loss has two components: - -- **Forward ranking**: Sample (far, near) pairs where near is later in time, enforce $V_{\text{near}} > V_{\text{far}}$ -- **Reverse ranking**: Reverse the trajectory and invert progress labels, then apply same ranking - -$$ -L_{\text{rewind}} = L_{\text{forward}} + L_{\text{reverse}} -$$ - -where both use: $\text{softplus}(m - (V_{\text{near}} - V_{\text{far}}))$ - -**Purpose:** By training on reversed trajectories with inverted progress, the model learns to distinguish progress from undoing progress. This is ReWiND's core contribution - understanding that tasks can be reversible. - -##### Total Loss: - -$$ -L = \lambda_{\text{prog}} L_{\text{prog}} + \lambda_{\text{spatial-nce}} L_{\text{spatial-nce}} + \lambda_{\text{rewind}} L_{\text{rewind}} -$$ - -Default weights: $\lambda_{\text{prog}}=1.0$, $\lambda_{\text{spatial-nce}}=0.5$, $\lambda_{\text{rewind}}=0.4$ ### TODOs: @@ -126,17 +75,20 @@ Default weights: $\lambda_{\text{prog}}=1.0$, $\lambda_{\text{spatial-nce}}=0.5$ - Implement on-the-fly progress label generation (no need for pre-annotated rewards) [x] - Try different losses - Only rewind loss [x] + - Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x] + - check code is same as rewind repo code (architecture and trainign details) [] - Test only rewind loss (evaluate) [] - - Check rewind implementatyion by hand [] + - Check rewind implementation by hand/cleanup [] - Only vlc loss then eval [] - - Vlc + rewind loss then eval [] + - Vlc + Rewind loss then eval [] - Cleanup code [] - Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent - Then on 10 percent -- Try DINO v3 as encoder Base 86 M: https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m with HuggingFaceTB/SmolLM2-135M-Instruct ? [] +- Ablation dino v2 vs dino v3 base 86 M - Add more artificial text to dataset generated by vlm (google gemini) [] - See google gemini vlm caption [] https://gemini.google.com/app/7e332ffaf32580f2 - Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446, https://arxiv.org/pdf/2412.04453 - How can we improve spatial aware learning? co generating captions for each frame with language decoder? - Extend evaluation [] - Add other datasets mentioned above [] +- Ablation for size vision encoder, language encoder, temporal head diff --git a/tests/policies/rlearn/test_rlearn.py b/tests/policies/rlearn/test_rlearn.py index a3489e1a6..18aea87af 100644 --- a/tests/policies/rlearn/test_rlearn.py +++ b/tests/policies/rlearn/test_rlearn.py @@ -25,10 +25,12 @@ from tests.utils import require_package @require_package("transformers") +@require_package("sentence_transformers") def test_rlearn_instantiation_and_forward_tensor_batch(): """Instantiate RLearN and run a forward pass with a (B, T, C, H, W) tensor input using a real model and real text.""" cfg = RLearNConfig( - model_name="google/siglip2-large-patch16-256", + vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m", + text_model_name="sentence-transformers/all-MiniLM-L12-v2", push_to_hub=False, freeze_backbones=True, ) @@ -54,10 +56,12 @@ def test_rlearn_instantiation_and_forward_tensor_batch(): @require_package("transformers") +@require_package("sentence_transformers") def test_rlearn_instantiation_and_forward_list_batch_with_language(): """Instantiate RLearN and run a forward pass with a list-of-frames input and real language using a real model.""" cfg = RLearNConfig( - model_name="google/siglip2-large-patch16-256", + vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m", + text_model_name="sentence-transformers/all-MiniLM-L12-v2", push_to_hub=False, freeze_backbones=True, ) @@ -84,18 +88,17 @@ def test_rlearn_instantiation_and_forward_list_batch_with_language(): @require_package("transformers") +@require_package("sentence_transformers") def test_rlearn_composite_loss_shapes_and_terms(): """Smoke test composite loss: checks presence of terms and valid gradients.""" cfg = RLearNConfig( - model_name="google/siglip2-large-patch16-256", + vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m", + text_model_name="sentence-transformers/all-MiniLM-L12-v2", push_to_hub=False, freeze_backbones=True, - loss_type="composite", - lambda_prog=1.0, - lambda_spatial_nce=0.5, - lambda_rewind=0.4, - num_ranking_pairs=32, # Fewer pairs for testing - last_k_for_nce=2, + use_video_rewind=True, + rewind_prob=0.5, + use_mismatch_loss=True, ) cfg.input_features = { "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), @@ -117,17 +120,17 @@ def test_rlearn_composite_loss_shapes_and_terms(): loss, logs = policy.forward(batch) assert isinstance(loss, torch.Tensor) and torch.isfinite(loss) - # Expect composite terms present with spatial awareness and ReWiND - assert "loss_prog" in logs - assert "loss_spatial_nce" in logs - assert "loss_rewind_forward" in logs - assert "loss_rewind_reverse" in logs + # Expect ReWiND loss terms (progress and mismatch) + assert "loss_progress" in logs + assert "loss_mismatch" in logs @require_package("transformers") +@require_package("sentence_transformers") def test_rlearn_preprocessor_tokenizes_and_copies_task(): cfg = RLearNConfig( - model_name="google/siglip2-large-patch16-256", + vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m", + text_model_name="sentence-transformers/all-MiniLM-L12-v2", device="cpu", push_to_hub=False, ) @@ -161,9 +164,11 @@ def test_rlearn_preprocessor_tokenizes_and_copies_task(): @require_package("transformers") +@require_package("sentence_transformers") def test_rlearn_preprocessor_string_task_and_to_batch(): cfg = RLearNConfig( - model_name="google/siglip2-large-patch16-256", + vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m", + text_model_name="sentence-transformers/all-MiniLM-L12-v2", device="cpu", push_to_hub=False, ) @@ -194,14 +199,16 @@ def test_rlearn_preprocessor_string_task_and_to_batch(): @require_package("transformers") +@require_package("sentence_transformers") def test_rlearn_pipeline_end_to_end_forward(): """End-to-end: preprocessor + model forward using RLearN pipeline on synthetic data.""" cfg = RLearNConfig( - model_name="google/siglip2-large-patch16-256", + vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m", + text_model_name="sentence-transformers/all-MiniLM-L12-v2", device="cpu", push_to_hub=False, freeze_backbones=True, - loss_type="composite", + use_video_rewind=True, ) cfg.input_features = { "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),