diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 29bafbdcb..1644637cd 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -76,7 +76,6 @@ Notes from __future__ import annotations import math -import time from itertools import chain import torch @@ -301,12 +300,10 @@ class RLearNPolicy(PreTrainedPolicy): Returns: (B, T, D_vision) """ - start_time = time.time() B, T, C, H, W = frames.shape flat = rearrange(frames, 'b t c h w -> (b t) c h w') # Process with DINOv2 - preprocess_start = time.time() images_list = [] for i in range(B * T): img = flat[i].permute(1, 2, 0) # CHW -> HWC @@ -315,29 +312,14 @@ class RLearNPolicy(PreTrainedPolicy): else: img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy() images_list.append(img) - preprocess_time = time.time() - preprocess_start - processor_start = time.time() processed = self.vision_processor(images=images_list, return_tensors="pt") pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device) - processor_time = time.time() - processor_start - - encoder_start = time.time() vision_outputs = self.vision_encoder(pixel_values) - encoder_time = time.time() - encoder_start # Extract CLS tokens cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision) - result = rearrange(cls_tokens, '(b t) d -> b t d', b=B, t=T) - - total_time = time.time() - start_time - print(f"šŸŽ¬ Video encoding timing (B={B}, T={T}):") - print(f" - Preprocess: {preprocess_time:.3f}s") - print(f" - Processor: {processor_time:.3f}s") - print(f" - DINOv2: {encoder_time:.3f}s") - print(f" - Total: {total_time:.3f}s") - - return result + return rearrange(cls_tokens, '(b t) d -> b t d', b=B, t=T) def _mask_from_lens(self, lens: Tensor) -> Tensor: """Create mask from sequence lengths.""" @@ -354,13 +336,10 @@ class RLearNPolicy(PreTrainedPolicy): Note: Progress labels (0 to 1) are generated automatically for each episode. No REWARD key is needed in the batch. """ - forward_start = time.time() - batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) # Extract frames and form (B, T, C, H, W) - data_prep_start = time.time() frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len) B, T, C, H, W = frames.shape device = next(self.parameters()).device @@ -391,36 +370,22 @@ class RLearNPolicy(PreTrainedPolicy): commands = [""] * B elif not isinstance(commands, list): commands = [str(commands)] * B - data_prep_time = time.time() - data_prep_start # Process video frames through DINOv2 - video_embeds = self._encode_video_frames(frames) # (B, T_eff, D_vision) - timing inside + video_embeds = self._encode_video_frames(frames) # (B, T_eff, D_vision) # Language embeddings - lang_start = time.time() - print(f"šŸ” Text encoder device: {next(self.text_encoder.parameters()).device if hasattr(self.text_encoder, 'parameters') else 'Unknown'}") - print(f"šŸ” Target device: {device}") - print(f"šŸ” Commands: {len(commands)} items, first: '{commands[0][:50]}...'") - lang_embeds = self.text_encoder.encode( commands, output_value='token_embeddings', convert_to_tensor=True, device=device ) - encode_time = time.time() - lang_start - - pad_start = time.time() lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device) lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device) mask = self._mask_from_lens(lens) - pad_time = time.time() - pad_start - - lang_time = time.time() - lang_start - print(f"šŸ—£ļø Language breakdown: encode={encode_time:.3f}s, pad={pad_time:.3f}s, total={lang_time:.3f}s") # Token preparation - token_prep_start = time.time() # Register tokens register_tokens = repeat(self.register_tokens, 'n d -> b n d', b=B) @@ -438,20 +403,15 @@ class RLearNPolicy(PreTrainedPolicy): # Extend mask for register and video tokens mask = F.pad(mask, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True) - token_prep_time = time.time() - token_prep_start # Forward through x_transformers Decoder - transformer_start = time.time() attended = self.decoder(tokens, mask=mask) - transformer_time = time.time() - transformer_start # Unpack and get video token features - unpack_start = time.time() _, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d') # MLP predictor video_frame_embeds = self.mlp_predictor(attended_video_tokens) - unpack_time = time.time() - unpack_start # Generate progress labels on-the-fly (ReWiND approach) # IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window @@ -540,7 +500,6 @@ class RLearNPolicy(PreTrainedPolicy): return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()} # Calculate loss using HLGauss or categorical - loss_start = time.time() if self.categorical_rewards: # Categorical cross-entropy loss assert target.dtype in (torch.long, torch.int), "Categorical rewards require integer targets" @@ -555,7 +514,6 @@ class RLearNPolicy(PreTrainedPolicy): # Create video mask for variable length support video_mask = torch.ones(B, T_eff, dtype=torch.bool, device=device) loss = self.hl_gauss_layer(video_frame_embeds, target[:, :T_eff], mask=video_mask) - loss_time = time.time() - loss_start # Optional: Mismatched video-language pairs loss L_mismatch = torch.zeros((), device=device) @@ -594,29 +552,12 @@ class RLearNPolicy(PreTrainedPolicy): # Total loss total_loss = loss + L_mismatch - - # Calculate and print timing summary - total_forward_time = time.time() - forward_start - - print(f"\nā±ļø RLearN Forward Pass Timing (B={B}, T_eff={T_eff}):") - print(f" šŸ“Š Data prep: {data_prep_time:.3f}s ({data_prep_time/total_forward_time*100:.1f}%)") - print(f" šŸ—£ļø Language: {lang_time:.3f}s ({lang_time/total_forward_time*100:.1f}%)") - print(f" šŸ”§ Token prep: {token_prep_time:.3f}s ({token_prep_time/total_forward_time*100:.1f}%)") - print(f" šŸ¤– Transformer: {transformer_time:.3f}s ({transformer_time/total_forward_time*100:.1f}%)") - print(f" šŸ“¦ Unpack+MLP: {unpack_time:.3f}s ({unpack_time/total_forward_time*100:.1f}%)") - print(f" šŸŽÆ Loss calc: {loss_time:.3f}s ({loss_time/total_forward_time*100:.1f}%)") - print(f" šŸ Total: {total_forward_time:.3f}s") # Log individual loss components loss_dict.update({ "loss": total_loss.item(), "loss_main": loss.item(), "loss_mismatch": L_mismatch.item(), - # Add timing metrics to loss dict for logging - "timing/total_forward": total_forward_time, - "timing/data_prep": data_prep_time, - "timing/language": lang_time, - "timing/transformer": transformer_time, }) return total_loss, loss_dict diff --git a/src/lerobot/policies/rlearn/rlearn_plan.md b/src/lerobot/policies/rlearn/rlearn_plan.md index ab2e806c6..1e64b32b5 100644 --- a/src/lerobot/policies/rlearn/rlearn_plan.md +++ b/src/lerobot/policies/rlearn/rlearn_plan.md @@ -77,6 +77,7 @@ _ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/ - Only rewind loss [x] - Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x] - Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x] + - benchmark lucidrains vs this implementation forward pass [] - Test rewind (evaluate) [] - Cleanup code? [] - Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent @@ -88,5 +89,5 @@ _ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/ - Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446, https://arxiv.org/pdf/2412.04453 - How can we improve spatial aware learning? solve issue of Contrastive learning and position - Extend evaluation [] -- Add other datasets mentioned above [] +- Add other datasets from OXE metioned in rewind [] - Ablation for size vision encoder, language encoder, temporal head