diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 1644637cd..5ab83f4f3 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -92,7 +92,7 @@ try: except ImportError as e: raise ImportError( "ReWiND dependencies not installed. Please install: " - "pip install x-transformers hl-gauss-pytorch einx einops x_mlps_pytorc" + "pip install x-transformers hl-gauss-pytorch einx einops x-mlps-pytorch" ) from e from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD @@ -237,17 +237,16 @@ class RLearNPolicy(PreTrainedPolicy): frames = frames.to(device) # Process video frames - video_embeds = self._encode_video_frames(frames) # (B, T, D_vision) + video_embeds = self._encode_video_frames(frames).to(device) # (B, T, D_vision) - # Language embeddings - lang_embeds = self.text_encoder.encode( + # Language embeddings (get lengths BEFORE padding) + lang_embeds_list = self.text_encoder.encode( commands, output_value='token_embeddings', - convert_to_tensor=True, - device=device + convert_to_tensor=False, ) - lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device) - lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device) + lens = torch.tensor([le.shape[0] for le in lang_embeds_list], device=device) + lang_embeds = pad_sequence([torch.as_tensor(le, device=device) for le in lang_embeds_list], batch_first=True) mask = self._mask_from_lens(lens) # Register tokens @@ -372,17 +371,16 @@ class RLearNPolicy(PreTrainedPolicy): commands = [str(commands)] * B # Process video frames through DINOv2 - video_embeds = self._encode_video_frames(frames) # (B, T_eff, D_vision) + video_embeds = self._encode_video_frames(frames).to(device) # (B, T_eff, D_vision) - # Language embeddings - lang_embeds = self.text_encoder.encode( + # Language embeddings (get lengths BEFORE padding) + lang_embeds_list = self.text_encoder.encode( commands, output_value='token_embeddings', - convert_to_tensor=True, - device=device + convert_to_tensor=False, ) - lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device) - lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device) + lens = torch.tensor([le.shape[0] for le in lang_embeds_list], device=device) + lang_embeds = pad_sequence([torch.as_tensor(le, device=device) for le in lang_embeds_list], batch_first=True) mask = self._mask_from_lens(lens) # Token preparation @@ -419,8 +417,8 @@ class RLearNPolicy(PreTrainedPolicy): # Check if video rewinding already set the target if self.training and self.config.use_video_rewind and "augmented_target" in locals(): - # Use the augmented target from video rewinding - target = augmented_target + # Use the augmented target from video rewinding and align with temporal subsampling + target = augmented_target[:, idx] else: # Calculate true episode progress using episode_index and frame_index from batch if "episode_index" in batch and "frame_index" in batch and hasattr(self, "episode_data_index"): @@ -523,20 +521,22 @@ class RLearNPolicy(PreTrainedPolicy): shuffled_indices = torch.randperm(B, device=device) shuffled_commands = [commands[i] for i in shuffled_indices] - # Re-encode with mismatched language - lang_embeds_mm = self.text_encoder.encode( + # Re-encode with mismatched language (compute lengths before padding) + lang_embeds_mm_list = self.text_encoder.encode( shuffled_commands, output_value='token_embeddings', - convert_to_tensor=True, - device=device + convert_to_tensor=False, ) - lang_embeds_mm = pad_sequence(lang_embeds_mm, batch_first=True).to(device) + lens_mm = torch.tensor([le.shape[0] for le in lang_embeds_mm_list], device=device) + lang_embeds_mm = pad_sequence([torch.as_tensor(le, device=device) for le in lang_embeds_mm_list], batch_first=True) lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm) # Pack and forward - tokens_mm, _ = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d') - attended_mm = self.decoder(tokens_mm, mask=mask) - _, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape, 'b * d') + tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d') + mask_mm = self._mask_from_lens(lens_mm) + mask_mm = F.pad(mask_mm, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True) + attended_mm = self.decoder(tokens_mm, mask=mask_mm) + _, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d') mismatch_embeds = self.mlp_predictor(attended_video_mm) # Mismatched pairs should predict zero progress @@ -555,9 +555,11 @@ class RLearNPolicy(PreTrainedPolicy): # Log individual loss components loss_dict.update({ - "loss": total_loss.item(), - "loss_main": loss.item(), - "loss_mismatch": L_mismatch.item(), + "loss": float(total_loss.detach().item()), + "loss_main": float(loss.detach().item()), + "loss_mismatch": float(L_mismatch.detach().item()), + "t_eff": float(T_eff), + "lang_len_mean": float(lens.float().mean().item()), }) return total_loss, loss_dict @@ -585,6 +587,7 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None OBS_IMAGES, # 'observation.images' OBS_IMAGE, # 'observation.image' "observation.images.image", # nested format from some datasets + "observation.images.front", ] frames = None diff --git a/src/lerobot/policies/rlearn/rlearn_plan.md b/src/lerobot/policies/rlearn/rlearn_plan.md index b9757c8ab..e35d3b7b8 100644 --- a/src/lerobot/policies/rlearn/rlearn_plan.md +++ b/src/lerobot/policies/rlearn/rlearn_plan.md @@ -79,8 +79,9 @@ _ Open X-Embodiment (OXE) - Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x] - Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x] - Test rewind (evaluate) [x] +- Overfit on one episode [] - Cleanup code? [] -- benchmark lucidrains vs this implementation forward pass, debug speed [] +- benchmark siglip 2 vs this implementation forward pass, debug speed [] - Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent - Then on 10 percent - Ablation dino v2 vs dino v3 base 86 M