mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
small impr
This commit is contained in:
@@ -76,7 +76,6 @@ Notes
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import time
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
@@ -301,12 +300,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
Returns:
|
||||
(B, T, D_vision)
|
||||
"""
|
||||
start_time = time.time()
|
||||
B, T, C, H, W = frames.shape
|
||||
flat = rearrange(frames, 'b t c h w -> (b t) c h w')
|
||||
|
||||
# Process with DINOv2
|
||||
preprocess_start = time.time()
|
||||
images_list = []
|
||||
for i in range(B * T):
|
||||
img = flat[i].permute(1, 2, 0) # CHW -> HWC
|
||||
@@ -315,29 +312,14 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
else:
|
||||
img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
|
||||
images_list.append(img)
|
||||
preprocess_time = time.time() - preprocess_start
|
||||
|
||||
processor_start = time.time()
|
||||
processed = self.vision_processor(images=images_list, return_tensors="pt")
|
||||
pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device)
|
||||
processor_time = time.time() - processor_start
|
||||
|
||||
encoder_start = time.time()
|
||||
vision_outputs = self.vision_encoder(pixel_values)
|
||||
encoder_time = time.time() - encoder_start
|
||||
|
||||
# Extract CLS tokens
|
||||
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision)
|
||||
result = rearrange(cls_tokens, '(b t) d -> b t d', b=B, t=T)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
print(f"🎬 Video encoding timing (B={B}, T={T}):")
|
||||
print(f" - Preprocess: {preprocess_time:.3f}s")
|
||||
print(f" - Processor: {processor_time:.3f}s")
|
||||
print(f" - DINOv2: {encoder_time:.3f}s")
|
||||
print(f" - Total: {total_time:.3f}s")
|
||||
|
||||
return result
|
||||
return rearrange(cls_tokens, '(b t) d -> b t d', b=B, t=T)
|
||||
|
||||
def _mask_from_lens(self, lens: Tensor) -> Tensor:
|
||||
"""Create mask from sequence lengths."""
|
||||
@@ -354,13 +336,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
Note: Progress labels (0 to 1) are generated automatically for each episode.
|
||||
No REWARD key is needed in the batch.
|
||||
"""
|
||||
forward_start = time.time()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract frames and form (B, T, C, H, W)
|
||||
data_prep_start = time.time()
|
||||
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
|
||||
B, T, C, H, W = frames.shape
|
||||
device = next(self.parameters()).device
|
||||
@@ -391,36 +370,22 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
commands = [""] * B
|
||||
elif not isinstance(commands, list):
|
||||
commands = [str(commands)] * B
|
||||
data_prep_time = time.time() - data_prep_start
|
||||
|
||||
# Process video frames through DINOv2
|
||||
video_embeds = self._encode_video_frames(frames) # (B, T_eff, D_vision) - timing inside
|
||||
video_embeds = self._encode_video_frames(frames) # (B, T_eff, D_vision)
|
||||
|
||||
# Language embeddings
|
||||
lang_start = time.time()
|
||||
print(f"🔍 Text encoder device: {next(self.text_encoder.parameters()).device if hasattr(self.text_encoder, 'parameters') else 'Unknown'}")
|
||||
print(f"🔍 Target device: {device}")
|
||||
print(f"🔍 Commands: {len(commands)} items, first: '{commands[0][:50]}...'")
|
||||
|
||||
lang_embeds = self.text_encoder.encode(
|
||||
commands,
|
||||
output_value='token_embeddings',
|
||||
convert_to_tensor=True,
|
||||
device=device
|
||||
)
|
||||
encode_time = time.time() - lang_start
|
||||
|
||||
pad_start = time.time()
|
||||
lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device)
|
||||
lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device)
|
||||
mask = self._mask_from_lens(lens)
|
||||
pad_time = time.time() - pad_start
|
||||
|
||||
lang_time = time.time() - lang_start
|
||||
print(f"🗣️ Language breakdown: encode={encode_time:.3f}s, pad={pad_time:.3f}s, total={lang_time:.3f}s")
|
||||
|
||||
# Token preparation
|
||||
token_prep_start = time.time()
|
||||
# Register tokens
|
||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b=B)
|
||||
|
||||
@@ -438,20 +403,15 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Extend mask for register and video tokens
|
||||
mask = F.pad(mask, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
|
||||
token_prep_time = time.time() - token_prep_start
|
||||
|
||||
# Forward through x_transformers Decoder
|
||||
transformer_start = time.time()
|
||||
attended = self.decoder(tokens, mask=mask)
|
||||
transformer_time = time.time() - transformer_start
|
||||
|
||||
# Unpack and get video token features
|
||||
unpack_start = time.time()
|
||||
_, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d')
|
||||
|
||||
# MLP predictor
|
||||
video_frame_embeds = self.mlp_predictor(attended_video_tokens)
|
||||
unpack_time = time.time() - unpack_start
|
||||
|
||||
# Generate progress labels on-the-fly (ReWiND approach)
|
||||
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
|
||||
@@ -540,7 +500,6 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
|
||||
|
||||
# Calculate loss using HLGauss or categorical
|
||||
loss_start = time.time()
|
||||
if self.categorical_rewards:
|
||||
# Categorical cross-entropy loss
|
||||
assert target.dtype in (torch.long, torch.int), "Categorical rewards require integer targets"
|
||||
@@ -555,7 +514,6 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# Create video mask for variable length support
|
||||
video_mask = torch.ones(B, T_eff, dtype=torch.bool, device=device)
|
||||
loss = self.hl_gauss_layer(video_frame_embeds, target[:, :T_eff], mask=video_mask)
|
||||
loss_time = time.time() - loss_start
|
||||
|
||||
# Optional: Mismatched video-language pairs loss
|
||||
L_mismatch = torch.zeros((), device=device)
|
||||
@@ -594,29 +552,12 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Total loss
|
||||
total_loss = loss + L_mismatch
|
||||
|
||||
# Calculate and print timing summary
|
||||
total_forward_time = time.time() - forward_start
|
||||
|
||||
print(f"\n⏱️ RLearN Forward Pass Timing (B={B}, T_eff={T_eff}):")
|
||||
print(f" 📊 Data prep: {data_prep_time:.3f}s ({data_prep_time/total_forward_time*100:.1f}%)")
|
||||
print(f" 🗣️ Language: {lang_time:.3f}s ({lang_time/total_forward_time*100:.1f}%)")
|
||||
print(f" 🔧 Token prep: {token_prep_time:.3f}s ({token_prep_time/total_forward_time*100:.1f}%)")
|
||||
print(f" 🤖 Transformer: {transformer_time:.3f}s ({transformer_time/total_forward_time*100:.1f}%)")
|
||||
print(f" 📦 Unpack+MLP: {unpack_time:.3f}s ({unpack_time/total_forward_time*100:.1f}%)")
|
||||
print(f" 🎯 Loss calc: {loss_time:.3f}s ({loss_time/total_forward_time*100:.1f}%)")
|
||||
print(f" 🏁 Total: {total_forward_time:.3f}s")
|
||||
|
||||
# Log individual loss components
|
||||
loss_dict.update({
|
||||
"loss": total_loss.item(),
|
||||
"loss_main": loss.item(),
|
||||
"loss_mismatch": L_mismatch.item(),
|
||||
# Add timing metrics to loss dict for logging
|
||||
"timing/total_forward": total_forward_time,
|
||||
"timing/data_prep": data_prep_time,
|
||||
"timing/language": lang_time,
|
||||
"timing/transformer": transformer_time,
|
||||
})
|
||||
|
||||
return total_loss, loss_dict
|
||||
|
||||
@@ -77,6 +77,7 @@ _ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
|
||||
- Only rewind loss [x]
|
||||
- Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x]
|
||||
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
|
||||
- benchmark lucidrains vs this implementation forward pass []
|
||||
- Test rewind (evaluate) []
|
||||
- Cleanup code? []
|
||||
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent
|
||||
@@ -88,5 +89,5 @@ _ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
|
||||
- Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446, https://arxiv.org/pdf/2412.04453
|
||||
- How can we improve spatial aware learning? solve issue of Contrastive learning and position
|
||||
- Extend evaluation []
|
||||
- Add other datasets mentioned above []
|
||||
- Add other datasets from OXE metioned in rewind []
|
||||
- Ablation for size vision encoder, language encoder, temporal head
|
||||
|
||||
Reference in New Issue
Block a user