mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 21:50:03 +00:00
small impr
This commit is contained in:
@@ -76,7 +76,6 @@ Notes
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import time
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -301,12 +300,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
Returns:
|
Returns:
|
||||||
(B, T, D_vision)
|
(B, T, D_vision)
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
|
||||||
B, T, C, H, W = frames.shape
|
B, T, C, H, W = frames.shape
|
||||||
flat = rearrange(frames, 'b t c h w -> (b t) c h w')
|
flat = rearrange(frames, 'b t c h w -> (b t) c h w')
|
||||||
|
|
||||||
# Process with DINOv2
|
# Process with DINOv2
|
||||||
preprocess_start = time.time()
|
|
||||||
images_list = []
|
images_list = []
|
||||||
for i in range(B * T):
|
for i in range(B * T):
|
||||||
img = flat[i].permute(1, 2, 0) # CHW -> HWC
|
img = flat[i].permute(1, 2, 0) # CHW -> HWC
|
||||||
@@ -315,29 +312,14 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
else:
|
else:
|
||||||
img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
|
img = (img.clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
|
||||||
images_list.append(img)
|
images_list.append(img)
|
||||||
preprocess_time = time.time() - preprocess_start
|
|
||||||
|
|
||||||
processor_start = time.time()
|
|
||||||
processed = self.vision_processor(images=images_list, return_tensors="pt")
|
processed = self.vision_processor(images=images_list, return_tensors="pt")
|
||||||
pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device)
|
pixel_values = processed["pixel_values"].to(next(self.vision_encoder.parameters()).device)
|
||||||
processor_time = time.time() - processor_start
|
|
||||||
|
|
||||||
encoder_start = time.time()
|
|
||||||
vision_outputs = self.vision_encoder(pixel_values)
|
vision_outputs = self.vision_encoder(pixel_values)
|
||||||
encoder_time = time.time() - encoder_start
|
|
||||||
|
|
||||||
# Extract CLS tokens
|
# Extract CLS tokens
|
||||||
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision)
|
cls_tokens = vision_outputs.last_hidden_state[:, 0] # (BT, D_vision)
|
||||||
result = rearrange(cls_tokens, '(b t) d -> b t d', b=B, t=T)
|
return rearrange(cls_tokens, '(b t) d -> b t d', b=B, t=T)
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
|
||||||
print(f"🎬 Video encoding timing (B={B}, T={T}):")
|
|
||||||
print(f" - Preprocess: {preprocess_time:.3f}s")
|
|
||||||
print(f" - Processor: {processor_time:.3f}s")
|
|
||||||
print(f" - DINOv2: {encoder_time:.3f}s")
|
|
||||||
print(f" - Total: {total_time:.3f}s")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _mask_from_lens(self, lens: Tensor) -> Tensor:
|
def _mask_from_lens(self, lens: Tensor) -> Tensor:
|
||||||
"""Create mask from sequence lengths."""
|
"""Create mask from sequence lengths."""
|
||||||
@@ -354,13 +336,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
Note: Progress labels (0 to 1) are generated automatically for each episode.
|
Note: Progress labels (0 to 1) are generated automatically for each episode.
|
||||||
No REWARD key is needed in the batch.
|
No REWARD key is needed in the batch.
|
||||||
"""
|
"""
|
||||||
forward_start = time.time()
|
|
||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
|
|
||||||
# Extract frames and form (B, T, C, H, W)
|
# Extract frames and form (B, T, C, H, W)
|
||||||
data_prep_start = time.time()
|
|
||||||
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
|
frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
|
||||||
B, T, C, H, W = frames.shape
|
B, T, C, H, W = frames.shape
|
||||||
device = next(self.parameters()).device
|
device = next(self.parameters()).device
|
||||||
@@ -391,36 +370,22 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
commands = [""] * B
|
commands = [""] * B
|
||||||
elif not isinstance(commands, list):
|
elif not isinstance(commands, list):
|
||||||
commands = [str(commands)] * B
|
commands = [str(commands)] * B
|
||||||
data_prep_time = time.time() - data_prep_start
|
|
||||||
|
|
||||||
# Process video frames through DINOv2
|
# Process video frames through DINOv2
|
||||||
video_embeds = self._encode_video_frames(frames) # (B, T_eff, D_vision) - timing inside
|
video_embeds = self._encode_video_frames(frames) # (B, T_eff, D_vision)
|
||||||
|
|
||||||
# Language embeddings
|
# Language embeddings
|
||||||
lang_start = time.time()
|
|
||||||
print(f"🔍 Text encoder device: {next(self.text_encoder.parameters()).device if hasattr(self.text_encoder, 'parameters') else 'Unknown'}")
|
|
||||||
print(f"🔍 Target device: {device}")
|
|
||||||
print(f"🔍 Commands: {len(commands)} items, first: '{commands[0][:50]}...'")
|
|
||||||
|
|
||||||
lang_embeds = self.text_encoder.encode(
|
lang_embeds = self.text_encoder.encode(
|
||||||
commands,
|
commands,
|
||||||
output_value='token_embeddings',
|
output_value='token_embeddings',
|
||||||
convert_to_tensor=True,
|
convert_to_tensor=True,
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
encode_time = time.time() - lang_start
|
|
||||||
|
|
||||||
pad_start = time.time()
|
|
||||||
lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device)
|
lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device)
|
||||||
lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device)
|
lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device)
|
||||||
mask = self._mask_from_lens(lens)
|
mask = self._mask_from_lens(lens)
|
||||||
pad_time = time.time() - pad_start
|
|
||||||
|
|
||||||
lang_time = time.time() - lang_start
|
|
||||||
print(f"🗣️ Language breakdown: encode={encode_time:.3f}s, pad={pad_time:.3f}s, total={lang_time:.3f}s")
|
|
||||||
|
|
||||||
# Token preparation
|
# Token preparation
|
||||||
token_prep_start = time.time()
|
|
||||||
# Register tokens
|
# Register tokens
|
||||||
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b=B)
|
register_tokens = repeat(self.register_tokens, 'n d -> b n d', b=B)
|
||||||
|
|
||||||
@@ -438,20 +403,15 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Extend mask for register and video tokens
|
# Extend mask for register and video tokens
|
||||||
mask = F.pad(mask, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
|
mask = F.pad(mask, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
|
||||||
token_prep_time = time.time() - token_prep_start
|
|
||||||
|
|
||||||
# Forward through x_transformers Decoder
|
# Forward through x_transformers Decoder
|
||||||
transformer_start = time.time()
|
|
||||||
attended = self.decoder(tokens, mask=mask)
|
attended = self.decoder(tokens, mask=mask)
|
||||||
transformer_time = time.time() - transformer_start
|
|
||||||
|
|
||||||
# Unpack and get video token features
|
# Unpack and get video token features
|
||||||
unpack_start = time.time()
|
|
||||||
_, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d')
|
_, _, attended_video_tokens = unpack(attended, lang_video_packed_shape, 'b * d')
|
||||||
|
|
||||||
# MLP predictor
|
# MLP predictor
|
||||||
video_frame_embeds = self.mlp_predictor(attended_video_tokens)
|
video_frame_embeds = self.mlp_predictor(attended_video_tokens)
|
||||||
unpack_time = time.time() - unpack_start
|
|
||||||
|
|
||||||
# Generate progress labels on-the-fly (ReWiND approach)
|
# Generate progress labels on-the-fly (ReWiND approach)
|
||||||
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
|
# IMPORTANT: Progress should be 0-1 across the ENTIRE EPISODE, not just the temporal window
|
||||||
@@ -540,7 +500,6 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
|
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
|
||||||
|
|
||||||
# Calculate loss using HLGauss or categorical
|
# Calculate loss using HLGauss or categorical
|
||||||
loss_start = time.time()
|
|
||||||
if self.categorical_rewards:
|
if self.categorical_rewards:
|
||||||
# Categorical cross-entropy loss
|
# Categorical cross-entropy loss
|
||||||
assert target.dtype in (torch.long, torch.int), "Categorical rewards require integer targets"
|
assert target.dtype in (torch.long, torch.int), "Categorical rewards require integer targets"
|
||||||
@@ -555,7 +514,6 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
# Create video mask for variable length support
|
# Create video mask for variable length support
|
||||||
video_mask = torch.ones(B, T_eff, dtype=torch.bool, device=device)
|
video_mask = torch.ones(B, T_eff, dtype=torch.bool, device=device)
|
||||||
loss = self.hl_gauss_layer(video_frame_embeds, target[:, :T_eff], mask=video_mask)
|
loss = self.hl_gauss_layer(video_frame_embeds, target[:, :T_eff], mask=video_mask)
|
||||||
loss_time = time.time() - loss_start
|
|
||||||
|
|
||||||
# Optional: Mismatched video-language pairs loss
|
# Optional: Mismatched video-language pairs loss
|
||||||
L_mismatch = torch.zeros((), device=device)
|
L_mismatch = torch.zeros((), device=device)
|
||||||
@@ -594,29 +552,12 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Total loss
|
# Total loss
|
||||||
total_loss = loss + L_mismatch
|
total_loss = loss + L_mismatch
|
||||||
|
|
||||||
# Calculate and print timing summary
|
|
||||||
total_forward_time = time.time() - forward_start
|
|
||||||
|
|
||||||
print(f"\n⏱️ RLearN Forward Pass Timing (B={B}, T_eff={T_eff}):")
|
|
||||||
print(f" 📊 Data prep: {data_prep_time:.3f}s ({data_prep_time/total_forward_time*100:.1f}%)")
|
|
||||||
print(f" 🗣️ Language: {lang_time:.3f}s ({lang_time/total_forward_time*100:.1f}%)")
|
|
||||||
print(f" 🔧 Token prep: {token_prep_time:.3f}s ({token_prep_time/total_forward_time*100:.1f}%)")
|
|
||||||
print(f" 🤖 Transformer: {transformer_time:.3f}s ({transformer_time/total_forward_time*100:.1f}%)")
|
|
||||||
print(f" 📦 Unpack+MLP: {unpack_time:.3f}s ({unpack_time/total_forward_time*100:.1f}%)")
|
|
||||||
print(f" 🎯 Loss calc: {loss_time:.3f}s ({loss_time/total_forward_time*100:.1f}%)")
|
|
||||||
print(f" 🏁 Total: {total_forward_time:.3f}s")
|
|
||||||
|
|
||||||
# Log individual loss components
|
# Log individual loss components
|
||||||
loss_dict.update({
|
loss_dict.update({
|
||||||
"loss": total_loss.item(),
|
"loss": total_loss.item(),
|
||||||
"loss_main": loss.item(),
|
"loss_main": loss.item(),
|
||||||
"loss_mismatch": L_mismatch.item(),
|
"loss_mismatch": L_mismatch.item(),
|
||||||
# Add timing metrics to loss dict for logging
|
|
||||||
"timing/total_forward": total_forward_time,
|
|
||||||
"timing/data_prep": data_prep_time,
|
|
||||||
"timing/language": lang_time,
|
|
||||||
"timing/transformer": transformer_time,
|
|
||||||
})
|
})
|
||||||
|
|
||||||
return total_loss, loss_dict
|
return total_loss, loss_dict
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ _ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
|
|||||||
- Only rewind loss [x]
|
- Only rewind loss [x]
|
||||||
- Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x]
|
- Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x]
|
||||||
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
|
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
|
||||||
|
- benchmark lucidrains vs this implementation forward pass []
|
||||||
- Test rewind (evaluate) []
|
- Test rewind (evaluate) []
|
||||||
- Cleanup code? []
|
- Cleanup code? []
|
||||||
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent
|
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent
|
||||||
@@ -88,5 +89,5 @@ _ HOWTO100M: https://www.di.ens.fr/willow/research/howto100m/
|
|||||||
- Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446, https://arxiv.org/pdf/2412.04453
|
- Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446, https://arxiv.org/pdf/2412.04453
|
||||||
- How can we improve spatial aware learning? solve issue of Contrastive learning and position
|
- How can we improve spatial aware learning? solve issue of Contrastive learning and position
|
||||||
- Extend evaluation []
|
- Extend evaluation []
|
||||||
- Add other datasets mentioned above []
|
- Add other datasets from OXE metioned in rewind []
|
||||||
- Ablation for size vision encoder, language encoder, temporal head
|
- Ablation for size vision encoder, language encoder, temporal head
|
||||||
|
|||||||
Reference in New Issue
Block a user