small impr

This commit is contained in:
Pepijn
2025-08-29 09:05:53 +02:00
parent 04d55e4670
commit 9698e74e88
2 changed files with 4 additions and 62 deletions
+2 -61
View File
@@ -76,7 +76,6 @@ Notes
from __future__ import annotations
import math
import time
from itertools import chain
import torch
@@ -301,12 +300,10 @@ class RLearNPolicy(PreTrainedPolicy):
Returns:
(B, T, D_vision)
"""
start_time = time.time()
B, T, C, H, W = frames.shape
flat = rearrange(frames, 'b t c h w -> (b t) c h w')
# Process with DINOv2
preprocess_start = time.time()
images_list = []
for i in range(B * T):
img = flat[i].permute(1, 2, 0) # CHW -> HWC
@@ -315,29 +312,14 @@ class RLearNPolicy(PreTrainedPolicy):
else:
img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
images_list.append(img)
preprocess_time = time.time() - preprocess_start
processor_start = time.time()
processed = self.vision_processor(images=images_list, return_tensors="pt")
pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device)
processor_time = time.time() - processor_start
encoder_start = time.time()
vision_outputs = self.vision_encoder(pixel_values)
encoder_time = time.time() - encoder_start
# Extract CLS tokens
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision)
result = rearrange(cls_tokens, '(b t) d -> b t d', b=B, t=T)
total_time = time.time() - start_time
print(f"🎬 Video encoding timing (B={B}, T={T}):")
print(f" - Preprocess: {preprocess_time:.3f}s")
print(f" - Processor: {processor_time:.3f}s")
print(f" - DINOv2: {encoder_time:.3f}s")
print(f" - Total: {total_time:.3f}s")
return result
return rearrange(cls_tokens, '(b t) d -> b t d', b=B, t=T)
def _mask_from_lens(self, lens: Tensor) -> Tensor:
"""Create mask from sequence lengths."""
@@ -354,13 +336,10 @@ class RLearNPolicy(PreTrainedPolicy):
Note: Progress labels (0 to 1) are generated automatically for each episode.
No REWARD key is needed in the batch.
"""
forward_start = time.time()
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# Extract frames and form (B, T, C, H, W)
data_prep_start = time.time()
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
B, T, C, H, W = frames.shape
device = next(self.parameters()).device
@@ -391,36 +370,22 @@ class RLearNPolicy(PreTrainedPolicy):
commands = [""] * B
elif not isinstance(commands, list):
commands = [str(commands)] * B
data_prep_time = time.time() - data_prep_start
# Process video frames through DINOv2
video_embeds = self._encode_video_frames(frames) # (B, T_eff, D_vision) - timing inside
video_embeds = self._encode_video_frames(frames) # (B, T_eff, D_vision)
# Language embeddings
lang_start = time.time()
print(f"🔍 Text encoder device: {next(self.text_encoder.parameters()).device if hasattr(self.text_encoder, 'parameters') else 'Unknown'}")
print(f"🔍 Target device: {device}")
print(f"🔍 Commands: {len(commands)} items, first: '{commands[0][:50]}...'")
lang_embeds = self.text_encoder.encode(
commands,
output_value='token_embeddings',
convert_to_tensor=True,
device=device
)
encode_time = time.time() - lang_start
pad_start = time.time()
lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device)
lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device)
mask = self._mask_from_lens(lens)
pad_time = time.time() - pad_start
lang_time = time.time() - lang_start
print(f"🗣️ Language breakdown: encode={encode_time:.3f}s, pad={pad_time:.3f}s, total={lang_time:.3f}s")
# Token preparation
token_prep_start = time.time()
# Register tokens
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b=B)
@@ -438,20 +403,15 @@ class RLearNPolicy(PreTrainedPolicy):
# Extend mask for register and video tokens
mask = F.pad(mask, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
token_prep_time = time.time() - token_prep_start
# Forward through x_transformers Decoder
transformer_start = time.time()
attended = self.decoder(tokens, mask=mask)
transformer_time = time.time() - transformer_start
# Unpack and get video token features
unpack_start = time.time()
_, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d')
# MLP predictor
video_frame_embeds = self.mlp_predictor(attended_video_tokens)
unpack_time = time.time() - unpack_start
# Generate progress labels on-the-fly (ReWiND approach)
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
@@ -540,7 +500,6 @@ class RLearNPolicy(PreTrainedPolicy):
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
# Calculate loss using HLGauss or categorical
loss_start = time.time()
if self.categorical_rewards:
# Categorical cross-entropy loss
assert target.dtype in (torch.long, torch.int), "Categorical rewards require integer targets"
@@ -555,7 +514,6 @@ class RLearNPolicy(PreTrainedPolicy):
# Create video mask for variable length support
video_mask = torch.ones(B, T_eff, dtype=torch.bool, device=device)
loss = self.hl_gauss_layer(video_frame_embeds, target[:, :T_eff], mask=video_mask)
loss_time = time.time() - loss_start
# Optional: Mismatched video-language pairs loss
L_mismatch = torch.zeros((), device=device)
@@ -594,29 +552,12 @@ class RLearNPolicy(PreTrainedPolicy):
# Total loss
total_loss = loss + L_mismatch
# Calculate and print timing summary
total_forward_time = time.time() - forward_start
print(f"\n⏱️ RLearN Forward Pass Timing (B={B}, T_eff={T_eff}):")
print(f" 📊 Data prep: {data_prep_time:.3f}s ({data_prep_time/total_forward_time*100:.1f}%)")
print(f" 🗣️ Language: {lang_time:.3f}s ({lang_time/total_forward_time*100:.1f}%)")
print(f" 🔧 Token prep: {token_prep_time:.3f}s ({token_prep_time/total_forward_time*100:.1f}%)")
print(f" 🤖 Transformer: {transformer_time:.3f}s ({transformer_time/total_forward_time*100:.1f}%)")
print(f" 📦 Unpack+MLP: {unpack_time:.3f}s ({unpack_time/total_forward_time*100:.1f}%)")
print(f" 🎯 Loss calc: {loss_time:.3f}s ({loss_time/total_forward_time*100:.1f}%)")
print(f" 🏁 Total: {total_forward_time:.3f}s")
# Log individual loss components
loss_dict.update({
"loss": total_loss.item(),
"loss_main": loss.item(),
"loss_mismatch": L_mismatch.item(),
# Add timing metrics to loss dict for logging
"timing/total_forward": total_forward_time,
"timing/data_prep": data_prep_time,
"timing/language": lang_time,
"timing/transformer": transformer_time,
})
return total_loss, loss_dict
+2 -1
View File
@@ -77,6 +77,7 @@ _ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
- Only rewind loss [x]
- Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x]
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
- benchmark lucidrains vs this implementation forward pass []
- Test rewind (evaluate) []
- Cleanup code? []
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent
@@ -88,5 +89,5 @@ _ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
- Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446, https://arxiv.org/pdf/2412.04453
- How can we improve spatial aware learning? solve issue of Contrastive learning and position
- Extend evaluation []
- Add other datasets mentioned above []
- Add other datasets from OXE metioned in rewind []
- Ablation for size vision encoder, language encoder, temporal head