diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index 3b7961db0..d1d5291b7 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -67,7 +67,7 @@ class RLearNConfig(PreTrainedConfig): logit_eps: float = 1e-4 # Performance optimizations - use_amp: bool = True + use_amp: bool = False compile_model: bool = True # ReWiND augmentation diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index f2a70aa8f..d5225d29d 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -162,13 +162,13 @@ class RLearNPolicy(PreTrainedPolicy): raise NotImplementedError("RLearN is a reward model and does not select actions") def _encode_video_frames(self, frames: Tensor) -> Tensor: - """Encode video frames through DinoV3 to get per-frame PATCH embeddings. + """Encode video frames through SigLIP2 vision tower and return per-frame CLS embeddings. Args: frames: (B, T, C, H, W) Returns: - (B, T, P, D_vision) where P is number of patch tokens per frame (excludes CLS) + (B, T, D_vision) CLS token per frame """ B, T, C, H, W = frames.shape flat = frames.reshape(B * T, C, H, W) @@ -195,40 +195,38 @@ class RLearNPolicy(PreTrainedPolicy): # Process in batch through SigLIP2 vision tower vision_outputs = self.vision_model(**inputs) - # Prefer patch tokens from last_hidden_state (exclude CLS at index 0) + # Prefer CLS token from last_hidden_state at index 0 if hasattr(vision_outputs, 'last_hidden_state') and vision_outputs.last_hidden_state is not None: tokens = vision_outputs.last_hidden_state # (BT, N_tokens, D) - if tokens.dim() == 3 and tokens.shape[1] > 1: - patch_tokens_flat = tokens[:, 1:, :] # (BT, P, D) + if tokens.dim() == 3 and tokens.shape[1] >= 1: + cls_tokens_flat = tokens[:, 0, :] # (BT, D) else: - # Only one token available → treat as single patch - patch_tokens_flat = tokens[:, :1, :] + # Fallback to pooler if structure unexpected + cls_tokens_flat = getattr(vision_outputs, 'pooler_output') elif hasattr(vision_outputs, 'pooler_output') and vision_outputs.pooler_output is not None: - # No per-patch tokens available, synthesize single patch from pooler - patch_tokens_flat = vision_outputs.pooler_output[:, None, :] # (BT, 1, D) + # Use pooled output + cls_tokens_flat = vision_outputs.pooler_output # (BT, D) else: raise RuntimeError("SigLIP2 vision outputs do not contain last_hidden_state or pooler_output") - # Robustly reshape to (B, T, P, D): detect correct flatten order by maximizing temporal variance (on patch-mean) + # Robustly reshape CLS to (B, T, D): detect correct flatten order by maximizing temporal variance try: - P = patch_tokens_flat.shape[1] - cand1 = patch_tokens_flat.reshape(B, T, P, -1) - cand2 = patch_tokens_flat.reshape(T, B, P, -1).permute(1, 0, 2, 3) - def mean_time_diff_4d(x): + D = cls_tokens_flat.shape[-1] + cand1 = cls_tokens_flat.reshape(B, T, D) + cand2 = cls_tokens_flat.reshape(T, B, D).permute(1, 0, 2) + def mean_time_diff_3d(x): if T <= 1: return torch.tensor(0.0, device=x.device) - x_mean = x.mean(dim=2) # (B, T, D) - diffs = (x_mean[:, 1:, :] - x_mean[:, :-1, :]).pow(2).sum(dim=-1).sqrt() + diffs = (x[:, 1:, :] - x[:, :-1, :]).pow(2).sum(dim=-1).sqrt() return diffs.mean() - diff1 = mean_time_diff_4d(cand1) - diff2 = mean_time_diff_4d(cand2) - patch_features = cand1 if diff1 >= diff2 else cand2 + diff1 = mean_time_diff_3d(cand1) + diff2 = mean_time_diff_3d(cand2) + frame_features = cand1 if diff1 >= diff2 else cand2 if self.training and torch.rand(1).item() < 0.05: print(f"SigLIP reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}") except Exception: # Fallback to default - P = patch_tokens_flat.shape[1] - patch_features = patch_tokens_flat.view(B, T, P, -1) + frame_features = cls_tokens_flat.view(B, T, -1) # DEBUG: Analyze vision feature variability (use per-frame pooled features for readability) if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging @@ -277,8 +275,8 @@ class RLearNPolicy(PreTrainedPolicy): else: print(f" ✓ Batch samples have different first frames. Diff: {batch_first_frame_diff:.6f}") - # Check feature statistics (pooled over patches) - vision_features = patch_features.mean(dim=2) # (B, T, D) + # Check feature statistics + vision_features = frame_features # (B, T, D) feature_mean = vision_features.mean().item() feature_std = vision_features.std().item() print(f"Feature stats: mean={feature_mean:.4f}, std={feature_std:.4f}") @@ -286,21 +284,13 @@ class RLearNPolicy(PreTrainedPolicy): # Extra DIAGNOSTIC: CLS vs patch mean/max deltas for one sample, two far-apart frames try: if 'last_hidden_state' in vision_outputs.__dict__ and T >= 2: - # Recover CLS tokens - cls_flat = tokens[:, 0, :] # (BT, D) - cls = cls_flat.view(B, T, -1) + # Recover CLS tokens (already computed as frame_features) + cls = frame_features b0 = 0 f0, f1 = 0, T - 1 # L2 between CLS at two frames cls_l2 = (cls[b0, f1] - cls[b0, f0]).pow(2).sum().sqrt().item() - # Patch mean L2 - pm_f0 = patch_features[b0, f0].mean(dim=0) - pm_f1 = patch_features[b0, f1].mean(dim=0) - pm_l2 = (pm_f1 - pm_f0).pow(2).sum().sqrt().item() - # Max over patches L2 - per_patch_l2 = (patch_features[b0, f1] - patch_features[b0, f0]).pow(2).sum(dim=1).sqrt() - max_p_l2 = per_patch_l2.max().item() - print(f"CLS ΔL2: {cls_l2:.6f} | mean(patches) ΔL2: {pm_l2:.6f} | max(patch) ΔL2: {max_p_l2:.6f}") + print(f"CLS ΔL2: {cls_l2:.6f}") except Exception as _: pass @@ -344,7 +334,7 @@ class RLearNPolicy(PreTrainedPolicy): print("=" * 50) - return patch_features + return frame_features def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Compute ReWiND training loss with on-the-fly progress label generation. @@ -406,8 +396,8 @@ class RLearNPolicy(PreTrainedPolicy): # Token preparation # Project embeddings lang_tokens = self.to_lang_tokens(lang_embeds) # (B, L, D) - # Collapse patches to per-frame tokens then project - video_frame_embeds = video_patch_embeds.mean(dim=2) # (B, T_eff, D_vision) + # SigLIP2 CLS per-frame already returned + video_frame_embeds = video_patch_embeds # (B, T_eff, D_vision) video_tokens = self.to_video_tokens(video_frame_embeds) # (B, T_eff, D) # First-frame positional embedding only video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos @@ -658,7 +648,7 @@ class RLearNPolicy(PreTrainedPolicy): """Return (embeddings, mask) for language tokens using SigLIP2. embeddings: (B, L, D); mask: (B, L) True for valid tokens. """ - # Optimized: Process all commands in batch (much faster than individual processing) + # Optimized: Process all commands in batch and take CLS token proc = self.processor( text=commands, return_tensors='pt', @@ -688,9 +678,10 @@ class RLearNPolicy(PreTrainedPolicy): # Batch encode through text model outputs = self.text_model.text_model(input_ids=input_ids, attention_mask=attention_mask) - last_hidden = outputs.last_hidden_state - mask = attention_mask.bool() - return last_hidden, mask + # Use CLS token (position 0) as single language token + cls_only = outputs.last_hidden_state[:, :1, :] + mask = torch.ones(cls_only.shape[:2], device=device, dtype=torch.bool) + return cls_only, mask def _extract_episode_and_frame_indices(self, batch: dict[str, Tensor]) -> tuple[Tensor | None, Tensor | None]: """Try to extract (episode_index, frame_index) tensors from batch or complementary data.