This commit is contained in:
Pepijn
2025-08-30 11:37:16 +02:00
parent 3f616f0ebe
commit f5c39d6292
2 changed files with 33 additions and 29 deletions
+31 -28
View File
@@ -92,7 +92,7 @@ try:
except ImportError as e:
raise ImportError(
"ReWiND dependencies not installed. Please install: "
"pip install x-transformers hl-gauss-pytorch einx einops x_mlps_pytorc"
"pip install x-transformers hl-gauss-pytorch einx einops x-mlps-pytorch"
) from e
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
@@ -237,17 +237,16 @@ class RLearNPolicy(PreTrainedPolicy):
frames = frames.to(device)
# Process video frames
video_embeds = self._encode_video_frames(frames) # (B, T, D_vision)
video_embeds = self._encode_video_frames(frames).to(device) # (B, T, D_vision)
# Language embeddings
lang_embeds = self.text_encoder.encode(
# Language embeddings (get lengths BEFORE padding)
lang_embeds_list = self.text_encoder.encode(
commands,
output_value='token_embeddings',
convert_to_tensor=True,
device=device
convert_to_tensor=False,
)
lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device)
lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device)
lens = torch.tensor([le.shape[0] for le in lang_embeds_list], device=device)
lang_embeds = pad_sequence([torch.as_tensor(le, device=device) for le in lang_embeds_list], batch_first=True)
mask = self._mask_from_lens(lens)
# Register tokens
@@ -372,17 +371,16 @@ class RLearNPolicy(PreTrainedPolicy):
commands = [str(commands)] * B
# Process video frames through DINOv2
video_embeds = self._encode_video_frames(frames) # (B, T_eff, D_vision)
video_embeds = self._encode_video_frames(frames).to(device) # (B, T_eff, D_vision)
# Language embeddings
lang_embeds = self.text_encoder.encode(
# Language embeddings (get lengths BEFORE padding)
lang_embeds_list = self.text_encoder.encode(
commands,
output_value='token_embeddings',
convert_to_tensor=True,
device=device
convert_to_tensor=False,
)
lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device)
lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device)
lens = torch.tensor([le.shape[0] for le in lang_embeds_list], device=device)
lang_embeds = pad_sequence([torch.as_tensor(le, device=device) for le in lang_embeds_list], batch_first=True)
mask = self._mask_from_lens(lens)
# Token preparation
@@ -419,8 +417,8 @@ class RLearNPolicy(PreTrainedPolicy):
# Check if video rewinding already set the target
if self.training and self.config.use_video_rewind and "augmented_target" in locals():
# Use the augmented target from video rewinding
target = augmented_target
# Use the augmented target from video rewinding and align with temporal subsampling
target = augmented_target[:, idx]
else:
# Calculate true episode progress using episode_index and frame_index from batch
if "episode_index" in batch and "frame_index" in batch and hasattr(self, "episode_data_index"):
@@ -523,20 +521,22 @@ class RLearNPolicy(PreTrainedPolicy):
shuffled_indices = torch.randperm(B, device=device)
shuffled_commands = [commands[i] for i in shuffled_indices]
# Re-encode with mismatched language
lang_embeds_mm = self.text_encoder.encode(
# Re-encode with mismatched language (compute lengths before padding)
lang_embeds_mm_list = self.text_encoder.encode(
shuffled_commands,
output_value='token_embeddings',
convert_to_tensor=True,
device=device
convert_to_tensor=False,
)
lang_embeds_mm = pad_sequence(lang_embeds_mm, batch_first=True).to(device)
lens_mm = torch.tensor([le.shape[0] for le in lang_embeds_mm_list], device=device)
lang_embeds_mm = pad_sequence([torch.as_tensor(le, device=device) for le in lang_embeds_mm_list], batch_first=True)
lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
# Pack and forward
tokens_mm, _ = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d')
attended_mm = self.decoder(tokens_mm, mask=mask)
_, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape, 'b * d')
tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d')
mask_mm = self._mask_from_lens(lens_mm)
mask_mm = F.pad(mask_mm, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
attended_mm = self.decoder(tokens_mm, mask=mask_mm)
_, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d')
mismatch_embeds = self.mlp_predictor(attended_video_mm)
# Mismatched pairs should predict zero progress
@@ -555,9 +555,11 @@ class RLearNPolicy(PreTrainedPolicy):
# Log individual loss components
loss_dict.update({
"loss": total_loss.item(),
"loss_main": loss.item(),
"loss_mismatch": L_mismatch.item(),
"loss": float(total_loss.detach().item()),
"loss_main": float(loss.detach().item()),
"loss_mismatch": float(L_mismatch.detach().item()),
"t_eff": float(T_eff),
"lang_len_mean": float(lens.float().mean().item()),
})
return total_loss, loss_dict
@@ -585,6 +587,7 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None
OBS_IMAGES, # 'observation.images'
OBS_IMAGE, # 'observation.image'
"observation.images.image", # nested format from some datasets
"observation.images.front",
]
frames = None
+2 -1
View File
@@ -79,8 +79,9 @@ _ Open X-Embodiment (OXE)
- Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x]
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
- Test rewind (evaluate) [x]
- Overfit on one episode []
- Cleanup code? []
- benchmark lucidrains vs this implementation forward pass, debug speed []
- benchmark siglip 2 vs this implementation forward pass, debug speed []
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent
- Then on 10 percent
- Ablation dino v2 vs dino v3 base 86 M