mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
fix
This commit is contained in:
@@ -92,7 +92,7 @@ try:
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ReWiND dependencies not installed. Please install: "
|
||||
"pip install x-transformers hl-gauss-pytorch einx einops x_mlps_pytorc"
|
||||
"pip install x-transformers hl-gauss-pytorch einx einops x-mlps-pytorch"
|
||||
) from e
|
||||
|
||||
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
|
||||
@@ -237,17 +237,16 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
frames = frames.to(device)
|
||||
|
||||
# Process video frames
|
||||
video_embeds = self._encode_video_frames(frames) # (B, T, D_vision)
|
||||
video_embeds = self._encode_video_frames(frames).to(device) # (B, T, D_vision)
|
||||
|
||||
# Language embeddings
|
||||
lang_embeds = self.text_encoder.encode(
|
||||
# Language embeddings (get lengths BEFORE padding)
|
||||
lang_embeds_list = self.text_encoder.encode(
|
||||
commands,
|
||||
output_value='token_embeddings',
|
||||
convert_to_tensor=True,
|
||||
device=device
|
||||
convert_to_tensor=False,
|
||||
)
|
||||
lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device)
|
||||
lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device)
|
||||
lens = torch.tensor([le.shape[0] for le in lang_embeds_list], device=device)
|
||||
lang_embeds = pad_sequence([torch.as_tensor(le, device=device) for le in lang_embeds_list], batch_first=True)
|
||||
mask = self._mask_from_lens(lens)
|
||||
|
||||
# Register tokens
|
||||
@@ -372,17 +371,16 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
commands = [str(commands)] * B
|
||||
|
||||
# Process video frames through DINOv2
|
||||
video_embeds = self._encode_video_frames(frames) # (B, T_eff, D_vision)
|
||||
video_embeds = self._encode_video_frames(frames).to(device) # (B, T_eff, D_vision)
|
||||
|
||||
# Language embeddings
|
||||
lang_embeds = self.text_encoder.encode(
|
||||
# Language embeddings (get lengths BEFORE padding)
|
||||
lang_embeds_list = self.text_encoder.encode(
|
||||
commands,
|
||||
output_value='token_embeddings',
|
||||
convert_to_tensor=True,
|
||||
device=device
|
||||
convert_to_tensor=False,
|
||||
)
|
||||
lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device)
|
||||
lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device)
|
||||
lens = torch.tensor([le.shape[0] for le in lang_embeds_list], device=device)
|
||||
lang_embeds = pad_sequence([torch.as_tensor(le, device=device) for le in lang_embeds_list], batch_first=True)
|
||||
mask = self._mask_from_lens(lens)
|
||||
|
||||
# Token preparation
|
||||
@@ -419,8 +417,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Check if video rewinding already set the target
|
||||
if self.training and self.config.use_video_rewind and "augmented_target" in locals():
|
||||
# Use the augmented target from video rewinding
|
||||
target = augmented_target
|
||||
# Use the augmented target from video rewinding and align with temporal subsampling
|
||||
target = augmented_target[:, idx]
|
||||
else:
|
||||
# Calculate true episode progress using episode_index and frame_index from batch
|
||||
if "episode_index" in batch and "frame_index" in batch and hasattr(self, "episode_data_index"):
|
||||
@@ -523,20 +521,22 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
shuffled_indices = torch.randperm(B, device=device)
|
||||
shuffled_commands = [commands[i] for i in shuffled_indices]
|
||||
|
||||
# Re-encode with mismatched language
|
||||
lang_embeds_mm = self.text_encoder.encode(
|
||||
# Re-encode with mismatched language (compute lengths before padding)
|
||||
lang_embeds_mm_list = self.text_encoder.encode(
|
||||
shuffled_commands,
|
||||
output_value='token_embeddings',
|
||||
convert_to_tensor=True,
|
||||
device=device
|
||||
convert_to_tensor=False,
|
||||
)
|
||||
lang_embeds_mm = pad_sequence(lang_embeds_mm, batch_first=True).to(device)
|
||||
lens_mm = torch.tensor([le.shape[0] for le in lang_embeds_mm_list], device=device)
|
||||
lang_embeds_mm = pad_sequence([torch.as_tensor(le, device=device) for le in lang_embeds_mm_list], batch_first=True)
|
||||
lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
|
||||
|
||||
# Pack and forward
|
||||
tokens_mm, _ = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d')
|
||||
attended_mm = self.decoder(tokens_mm, mask=mask)
|
||||
_, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape, 'b * d')
|
||||
tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d')
|
||||
mask_mm = self._mask_from_lens(lens_mm)
|
||||
mask_mm = F.pad(mask_mm, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
|
||||
attended_mm = self.decoder(tokens_mm, mask=mask_mm)
|
||||
_, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d')
|
||||
mismatch_embeds = self.mlp_predictor(attended_video_mm)
|
||||
|
||||
# Mismatched pairs should predict zero progress
|
||||
@@ -555,9 +555,11 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Log individual loss components
|
||||
loss_dict.update({
|
||||
"loss": total_loss.item(),
|
||||
"loss_main": loss.item(),
|
||||
"loss_mismatch": L_mismatch.item(),
|
||||
"loss": float(total_loss.detach().item()),
|
||||
"loss_main": float(loss.detach().item()),
|
||||
"loss_mismatch": float(L_mismatch.detach().item()),
|
||||
"t_eff": float(T_eff),
|
||||
"lang_len_mean": float(lens.float().mean().item()),
|
||||
})
|
||||
|
||||
return total_loss, loss_dict
|
||||
@@ -585,6 +587,7 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None
|
||||
OBS_IMAGES, # 'observation.images'
|
||||
OBS_IMAGE, # 'observation.image'
|
||||
"observation.images.image", # nested format from some datasets
|
||||
"observation.images.front",
|
||||
]
|
||||
|
||||
frames = None
|
||||
|
||||
@@ -79,8 +79,9 @@ _ Open X-Embodiment (OXE)
|
||||
- Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x]
|
||||
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
|
||||
- Test rewind (evaluate) [x]
|
||||
- Overfit on one episode []
|
||||
- Cleanup code? []
|
||||
- benchmark lucidrains vs this implementation forward pass, debug speed []
|
||||
- benchmark siglip 2 vs this implementation forward pass, debug speed []
|
||||
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent
|
||||
- Then on 10 percent
|
||||
- Ablation dino v2 vs dino v3 base 86 M
|
||||
|
||||
Reference in New Issue
Block a user