mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
fix
This commit is contained in:
@@ -92,7 +92,7 @@ try:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"ReWiND dependencies not installed. Please install: "
|
"ReWiND dependencies not installed. Please install: "
|
||||||
"pip install x-transformers hl-gauss-pytorch einx einops x_mlps_pytorc"
|
"pip install x-transformers hl-gauss-pytorch einx einops x-mlps-pytorch"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
|
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
|
||||||
@@ -237,17 +237,16 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
frames = frames.to(device)
|
frames = frames.to(device)
|
||||||
|
|
||||||
# Process video frames
|
# Process video frames
|
||||||
video_embeds = self._encode_video_frames(frames) # (B, T, D_vision)
|
video_embeds = self._encode_video_frames(frames).to(device) # (B, T, D_vision)
|
||||||
|
|
||||||
# Language embeddings
|
# Language embeddings (get lengths BEFORE padding)
|
||||||
lang_embeds = self.text_encoder.encode(
|
lang_embeds_list = self.text_encoder.encode(
|
||||||
commands,
|
commands,
|
||||||
output_value='token_embeddings',
|
output_value='token_embeddings',
|
||||||
convert_to_tensor=True,
|
convert_to_tensor=False,
|
||||||
device=device
|
|
||||||
)
|
)
|
||||||
lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device)
|
lens = torch.tensor([le.shape[0] for le in lang_embeds_list], device=device)
|
||||||
lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device)
|
lang_embeds = pad_sequence([torch.as_tensor(le, device=device) for le in lang_embeds_list], batch_first=True)
|
||||||
mask = self._mask_from_lens(lens)
|
mask = self._mask_from_lens(lens)
|
||||||
|
|
||||||
# Register tokens
|
# Register tokens
|
||||||
@@ -372,17 +371,16 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
commands = [str(commands)] * B
|
commands = [str(commands)] * B
|
||||||
|
|
||||||
# Process video frames through DINOv2
|
# Process video frames through DINOv2
|
||||||
video_embeds = self._encode_video_frames(frames) # (B, T_eff, D_vision)
|
video_embeds = self._encode_video_frames(frames).to(device) # (B, T_eff, D_vision)
|
||||||
|
|
||||||
# Language embeddings
|
# Language embeddings (get lengths BEFORE padding)
|
||||||
lang_embeds = self.text_encoder.encode(
|
lang_embeds_list = self.text_encoder.encode(
|
||||||
commands,
|
commands,
|
||||||
output_value='token_embeddings',
|
output_value='token_embeddings',
|
||||||
convert_to_tensor=True,
|
convert_to_tensor=False,
|
||||||
device=device
|
|
||||||
)
|
)
|
||||||
lang_embeds = pad_sequence(lang_embeds, batch_first=True).to(device)
|
lens = torch.tensor([le.shape[0] for le in lang_embeds_list], device=device)
|
||||||
lens = torch.tensor([le.shape[0] for le in lang_embeds], device=device)
|
lang_embeds = pad_sequence([torch.as_tensor(le, device=device) for le in lang_embeds_list], batch_first=True)
|
||||||
mask = self._mask_from_lens(lens)
|
mask = self._mask_from_lens(lens)
|
||||||
|
|
||||||
# Token preparation
|
# Token preparation
|
||||||
@@ -419,8 +417,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Check if video rewinding already set the target
|
# Check if video rewinding already set the target
|
||||||
if self.training and self.config.use_video_rewind and "augmented_target" in locals():
|
if self.training and self.config.use_video_rewind and "augmented_target" in locals():
|
||||||
# Use the augmented target from video rewinding
|
# Use the augmented target from video rewinding and align with temporal subsampling
|
||||||
target = augmented_target
|
target = augmented_target[:, idx]
|
||||||
else:
|
else:
|
||||||
# Calculate true episode progress using episode_index and frame_index from batch
|
# Calculate true episode progress using episode_index and frame_index from batch
|
||||||
if "episode_index" in batch and "frame_index" in batch and hasattr(self, "episode_data_index"):
|
if "episode_index" in batch and "frame_index" in batch and hasattr(self, "episode_data_index"):
|
||||||
@@ -523,20 +521,22 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
shuffled_indices = torch.randperm(B, device=device)
|
shuffled_indices = torch.randperm(B, device=device)
|
||||||
shuffled_commands = [commands[i] for i in shuffled_indices]
|
shuffled_commands = [commands[i] for i in shuffled_indices]
|
||||||
|
|
||||||
# Re-encode with mismatched language
|
# Re-encode with mismatched language (compute lengths before padding)
|
||||||
lang_embeds_mm = self.text_encoder.encode(
|
lang_embeds_mm_list = self.text_encoder.encode(
|
||||||
shuffled_commands,
|
shuffled_commands,
|
||||||
output_value='token_embeddings',
|
output_value='token_embeddings',
|
||||||
convert_to_tensor=True,
|
convert_to_tensor=False,
|
||||||
device=device
|
|
||||||
)
|
)
|
||||||
lang_embeds_mm = pad_sequence(lang_embeds_mm, batch_first=True).to(device)
|
lens_mm = torch.tensor([le.shape[0] for le in lang_embeds_mm_list], device=device)
|
||||||
|
lang_embeds_mm = pad_sequence([torch.as_tensor(le, device=device) for le in lang_embeds_mm_list], batch_first=True)
|
||||||
lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
|
lang_tokens_mm = self.to_lang_tokens(lang_embeds_mm)
|
||||||
|
|
||||||
# Pack and forward
|
# Pack and forward
|
||||||
tokens_mm, _ = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d')
|
tokens_mm, lang_video_packed_shape_mm = pack((lang_tokens_mm, register_tokens, video_tokens), 'b * d')
|
||||||
attended_mm = self.decoder(tokens_mm, mask=mask)
|
mask_mm = self._mask_from_lens(lens_mm)
|
||||||
_, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape, 'b * d')
|
mask_mm = F.pad(mask_mm, (0, register_tokens.shape[1] + video_tokens.shape[1]), value=True)
|
||||||
|
attended_mm = self.decoder(tokens_mm, mask=mask_mm)
|
||||||
|
_, _, attended_video_mm = unpack(attended_mm, lang_video_packed_shape_mm, 'b * d')
|
||||||
mismatch_embeds = self.mlp_predictor(attended_video_mm)
|
mismatch_embeds = self.mlp_predictor(attended_video_mm)
|
||||||
|
|
||||||
# Mismatched pairs should predict zero progress
|
# Mismatched pairs should predict zero progress
|
||||||
@@ -555,9 +555,11 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Log individual loss components
|
# Log individual loss components
|
||||||
loss_dict.update({
|
loss_dict.update({
|
||||||
"loss": total_loss.item(),
|
"loss": float(total_loss.detach().item()),
|
||||||
"loss_main": loss.item(),
|
"loss_main": float(loss.detach().item()),
|
||||||
"loss_mismatch": L_mismatch.item(),
|
"loss_mismatch": float(L_mismatch.detach().item()),
|
||||||
|
"t_eff": float(T_eff),
|
||||||
|
"lang_len_mean": float(lens.float().mean().item()),
|
||||||
})
|
})
|
||||||
|
|
||||||
return total_loss, loss_dict
|
return total_loss, loss_dict
|
||||||
@@ -585,6 +587,7 @@ def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None
|
|||||||
OBS_IMAGES, # 'observation.images'
|
OBS_IMAGES, # 'observation.images'
|
||||||
OBS_IMAGE, # 'observation.image'
|
OBS_IMAGE, # 'observation.image'
|
||||||
"observation.images.image", # nested format from some datasets
|
"observation.images.image", # nested format from some datasets
|
||||||
|
"observation.images.front",
|
||||||
]
|
]
|
||||||
|
|
||||||
frames = None
|
frames = None
|
||||||
|
|||||||
@@ -79,8 +79,9 @@ _ Open X-Embodiment (OXE)
|
|||||||
- Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x]
|
- Exactly similar to: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11 [x]
|
||||||
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
|
- Try DINO v2 as encoder Base 86 M: with https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2 [x]
|
||||||
- Test rewind (evaluate) [x]
|
- Test rewind (evaluate) [x]
|
||||||
|
- Overfit on one episode []
|
||||||
- Cleanup code? []
|
- Cleanup code? []
|
||||||
- benchmark lucidrains vs this implementation forward pass, debug speed []
|
- benchmark siglip 2 vs this implementation forward pass, debug speed []
|
||||||
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent
|
- Convert python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id=IPEC-COMMUNITY/bc_z_lerobot and train on 1 percent
|
||||||
- Then on 10 percent
|
- Then on 10 percent
|
||||||
- Ablation dino v2 vs dino v3 base 86 M
|
- Ablation dino v2 vs dino v3 base 86 M
|
||||||
|
|||||||
Reference in New Issue
Block a user