From ce5b27d255708de02e074818c6677fc77ab5c554 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 1 Sep 2025 10:55:12 +0200 Subject: [PATCH] siglip again --- .../policies/rlearn/modeling_rlearn.py | 81 ++++++++++--------- 1 file changed, 42 insertions(+), 39 deletions(-) diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index c61658735..738bbf611 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -22,17 +22,6 @@ import torch import torch.nn.functional as F from torch import Tensor, nn -# ReWiND dependencies -try: - from x_transformers import Decoder - import einx - from einops import rearrange, repeat, pack, unpack -except ImportError as e: - raise ImportError( - "ReWiND dependencies not installed. Please install: " - "pip install x-transformers einx einops x-mlps-pytorch" - ) from e - from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig @@ -95,17 +84,17 @@ class RLearNPolicy(PreTrainedPolicy): self.first_frame_pos = nn.Parameter(torch.zeros(1, 1, config.dim_model)) # Cross-modal sequential aggregator – causal transformer over - # [language tokens | video frame tokens] - self.decoder = Decoder( - dim=config.dim_model, - depth=config.num_layers, - heads=config.num_heads, - ff_mult=config.ff_mult, - attn_dropout=config.dropout, - ff_dropout=config.dropout, - cross_attend=False, - causal=True, + # [language tokens | video frame tokens] using PyTorch TransformerEncoder + encoder_layer = nn.TransformerEncoderLayer( + d_model=config.dim_model, + nhead=config.num_heads, + dim_feedforward=config.dim_model * config.ff_mult, + dropout=config.dropout, + activation="gelu", + batch_first=True, + norm_first=True, ) + self.aggregator = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers) # Per-frame predictor head self.frame_mlp = nn.Sequential( @@ -138,7 +127,7 @@ class RLearNPolicy(PreTrainedPolicy): try: self.vision_model = torch.compile(self.vision_model, mode="reduce-overhead") self.text_model = torch.compile(self.text_model, mode="reduce-overhead") - self.decoder = torch.compile(self.decoder, mode="reduce-overhead") + self.aggregator = torch.compile(self.aggregator, mode="reduce-overhead") print("✅ Applied torch.compile to encoders and transformer") except Exception as e: print(f"⚠️ torch.compile failed: {e}") @@ -182,7 +171,7 @@ class RLearNPolicy(PreTrainedPolicy): (B, T, P, D_vision) where P is number of patch tokens per frame (excludes CLS) """ B, T, C, H, W = frames.shape - flat = rearrange(frames, 'b t c h w -> (b t) c h w') + flat = frames.reshape(B * T, C, H, W) # Optimized: Process tensor directly without numpy conversion device = next(self.vision_model.parameters()).device @@ -222,8 +211,9 @@ class RLearNPolicy(PreTrainedPolicy): # Robustly reshape to (B, T, P, D): detect correct flatten order by maximizing temporal variance (on patch-mean) try: - cand1 = rearrange(patch_tokens_flat, '(b t) p d -> b t p d', b=B, t=T) - cand2 = rearrange(patch_tokens_flat, '(t b) p d -> b t p d', t=T, b=B) + P = patch_tokens_flat.shape[1] + cand1 = patch_tokens_flat.reshape(B, T, P, -1) + cand2 = patch_tokens_flat.reshape(T, B, P, -1).permute(1, 0, 2, 3) def mean_time_diff_4d(x): if T <= 1: return torch.tensor(0.0, device=x.device) @@ -237,7 +227,8 @@ class RLearNPolicy(PreTrainedPolicy): print(f"DINO reshape choice: {'(b t)->b t' if diff1 >= diff2 else '(t b)->b t'} | diff1={diff1.item():.6f}, diff2={diff2.item():.6f}") except Exception: # Fallback to default - patch_features = rearrange(patch_tokens_flat, '(b t) p d -> b t p d', b=B, t=T) + P = patch_tokens_flat.shape[1] + patch_features = patch_tokens_flat.view(B, T, P, -1) # DEBUG: Analyze vision feature variability (use per-frame pooled features for readability) if self.training and torch.rand(1).item() < 0.1: # 10% of training steps for more frequent debugging @@ -297,7 +288,7 @@ class RLearNPolicy(PreTrainedPolicy): if 'last_hidden_state' in vision_outputs.__dict__ and T >= 2: # Recover CLS tokens cls_flat = tokens[:, 0, :] # (BT, D) - cls = rearrange(cls_flat, '(b t) d -> b t d', b=B, t=T) + cls = cls_flat.view(B, T, -1) b0 = 0 f0, f1 = 0, T - 1 # L2 between CLS at two frames @@ -421,16 +412,26 @@ class RLearNPolicy(PreTrainedPolicy): # First-frame positional embedding only video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos - # Build attention mask for decoder (True = keep) - # Language mask from tokenizer, rest are fully valid - full_mask = F.pad(mask, (0, video_tokens.shape[1]), value=True) - # Pack and run transformer + # Build masks for TransformerEncoder + lang_valid = mask # (B, L) True where valid + video_valid = torch.ones(B, video_tokens.shape[1], device=device, dtype=torch.bool) + valid_mask = torch.cat([lang_valid, video_valid], dim=1) # (B, S) + key_padding_mask = ~valid_mask # True -> masked + + tokens_seq = torch.cat([lang_tokens, video_tokens], dim=1) # (B, S, D) + + # Causal mask (S, S): True masks out future positions + S = tokens_seq.shape[1] + causal_mask = torch.triu(torch.ones(S, S, device=device, dtype=torch.bool), diagonal=1) + transformer_start = time.perf_counter() - tokens_packed, packed_shape = pack((lang_tokens, video_tokens), 'b * d') - attended = self.decoder(tokens_packed, mask=full_mask) - attended_lang, attended_video = unpack(attended, packed_shape, 'b * d') + attended_all = self.aggregator(tokens_seq, src_key_padding_mask=key_padding_mask, mask=causal_mask) transformer_time = time.perf_counter() - transformer_start + # Split back video part + L_len = lang_tokens.shape[1] + attended_video = attended_all[:, L_len:, :] + # Per-frame prediction frame_tokens = self.frame_mlp(attended_video) # (B, T_eff, D) raw_logits = self.reward_head(frame_tokens).squeeze(-1) # (B, T_eff) @@ -496,11 +497,13 @@ class RLearNPolicy(PreTrainedPolicy): lang_embeds_mm, mask_mm = self._encode_language_tokens(shuffled_commands, device) lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm) - # Pack and forward - tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, video_tokens), 'b * d') - mask_mm = F.pad(mask_mm, (0, video_tokens.shape[1]), value=True) - attended_mm = self.decoder(tokens_mm, mask=mask_mm) - _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d') + # Pack and forward with masks + lang_valid_mm = mask_mm + valid_mask_mm = torch.cat([lang_valid_mm, video_valid], dim=1) + key_padding_mask_mm = ~valid_mask_mm + tokens_seq_mm = torch.cat([lang_tokens_mm, video_tokens], dim=1) + attended_all_mm = self.aggregator(tokens_seq_mm, src_key_padding_mask=key_padding_mask_mm, mask=causal_mask) + attended_video_mm = attended_all_mm[:, L_len:, :] # Process mismatch frames with single MLP mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D)