This commit is contained in:
Pepijn
2025-08-27 17:16:31 +02:00
parent 2a901f8134
commit 34ca077d78
4 changed files with 81 additions and 36 deletions
File diff suppressed because one or more lines are too long
@@ -105,7 +105,8 @@ class RLearNConfig(PreTrainedConfig):
@property @property
def observation_delta_indices(self) -> list | None: def observation_delta_indices(self) -> list | None:
# Use temporal sequences: past frames from -(max_seq_len-1) to current (0) # Use temporal sequences: past frames from -(max_seq_len-1) to current (0)
# This gives us max_seq_len frames total, e.g. [-15, -14, ..., -1, 0] for max_seq_len=16 # This gives us max_seq_len frames total, e.g. [-3, -2, -1, 0] for max_seq_len=4
# The dataset will handle padding/repeating frames for episodes shorter than this
return list(range(1 - self.max_seq_len, 1)) return list(range(1 - self.max_seq_len, 1))
@property @property
+59 -16
View File
@@ -202,8 +202,8 @@ class RLearNPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
# Extract frames and form (B, T, C, H, W) # Extract frames and form (B, T, C, H, W), padding if needed
frames = extract_visual_sequence(batch) frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
B, T, C, H, W = frames.shape B, T, C, H, W = frames.shape
# Apply stride (no dropout during eval) # Apply stride (no dropout during eval)
@@ -284,8 +284,8 @@ class RLearNPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
# Extract frames and form (B, T, C, H, W) # Extract frames and form (B, T, C, H, W), padding if needed
frames = extract_visual_sequence(batch) frames = extract_visual_sequence(batch, target_seq_len=self.config.max_seq_len)
B, T, C, H, W = frames.shape B, T, C, H, W = frames.shape
# Apply stride and frame dropout during training # Apply stride and frame dropout during training
@@ -374,6 +374,19 @@ class RLearNPolicy(PreTrainedPolicy):
# Align target with sampled timesteps # Align target with sampled timesteps
if target.dim() == 1: if target.dim() == 1:
target = target.unsqueeze(1) # (B, 1) target = target.unsqueeze(1) # (B, 1)
# Handle target padding to match frame sequence if needed
if target.shape[1] < self.config.max_seq_len:
# Pad targets by repeating the first value (assuming it's the earliest)
padding_needed = self.config.max_seq_len - target.shape[1]
first_target = target[:, :1] # (B, 1)
padding = first_target.expand(target.shape[0], padding_needed)
target = torch.cat([padding, target], dim=1) # Prepend padding
import logging
logging.debug(f"Padded targets from {target.shape[1] - padding_needed} to {self.config.max_seq_len}")
# Now safely index with idx
target = target[:, idx] target = target[:, idx]
# Composite loss # Composite loss
@@ -602,7 +615,17 @@ def generate_causal_mask(T: int, device=None) -> Tensor:
return mask return mask
def extract_visual_sequence(batch: dict[str, Tensor]) -> Tensor: def extract_visual_sequence(batch: dict[str, Tensor], target_seq_len: int = None) -> Tensor:
"""Extract visual sequence from batch and ensure it has the expected temporal length.
Args:
batch: Input batch containing image data
target_seq_len: Expected sequence length. If provided and the actual sequence is shorter,
it will be padded by repeating the first frame.
Returns:
Tensor of shape (B, T, C, H, W)
"""
# Accept various image key formats from datasets # Accept various image key formats from datasets
# With delta_indices, the dataset provides temporal sequences automatically # With delta_indices, the dataset provides temporal sequences automatically
@@ -613,6 +636,7 @@ def extract_visual_sequence(batch: dict[str, Tensor]) -> Tensor:
"observation.images.image", # nested format from some datasets "observation.images.image", # nested format from some datasets
] ]
frames = None
for key in possible_keys: for key in possible_keys:
if key in batch: if key in batch:
image_val = batch[key] image_val = batch[key]
@@ -620,28 +644,47 @@ def extract_visual_sequence(batch: dict[str, Tensor]) -> Tensor:
if isinstance(image_val, list) and len(image_val) > 0: if isinstance(image_val, list) and len(image_val) > 0:
# List of (B, C, H, W) -> stack over time # List of (B, C, H, W) -> stack over time
# This happens when dataset provides temporal sequence as list # This happens when dataset provides temporal sequence as list
return torch.stack(image_val, dim=1) frames = torch.stack(image_val, dim=1)
break
elif torch.is_tensor(image_val): elif torch.is_tensor(image_val):
# Tensor of shape (B, T, C, H, W) or (B, C, H, W) # Tensor of shape (B, T, C, H, W) or (B, C, H, W)
if image_val.dim() == 5: if image_val.dim() == 5:
# Already has time dimension - this is what we expect with delta_indices # Already has time dimension - this is what we expect with delta_indices
return image_val frames = image_val
break
elif image_val.dim() == 4: elif image_val.dim() == 4:
# Add time dimension (single frame) - fallback for datasets without temporal sequences # Add time dimension (single frame) - fallback for datasets without temporal sequences
return image_val.unsqueeze(1) frames = image_val.unsqueeze(1)
break
else: else:
raise ValueError( raise ValueError(
f"'{key}' must be a Tensor of shape (B,T,C,H,W) or (B,C,H,W), got shape {image_val.shape}" f"'{key}' must be a Tensor of shape (B,T,C,H,W) or (B,C,H,W), got shape {image_val.shape}"
) )
# If no image key found, provide helpful error with available keys if frames is None:
available_keys = list(batch.keys()) # If no image key found, provide helpful error with available keys
image_like_keys = [k for k in available_keys if "image" in k.lower()] available_keys = list(batch.keys())
raise ValueError( image_like_keys = [k for k in available_keys if "image" in k.lower()]
f"Could not find image data in batch. Looked for keys: {possible_keys}. " raise ValueError(
f"Available keys with 'image': {image_like_keys}. " f"Could not find image data in batch. Looked for keys: {possible_keys}. "
f"All keys: {available_keys}" f"Available keys with 'image': {image_like_keys}. "
) f"All keys: {available_keys}"
)
# Pad sequence if needed
if target_seq_len is not None:
B, T, C, H, W = frames.shape
if T < target_seq_len:
# Pad by repeating the first frame (assumes first frame in sequence is the earliest)
padding_needed = target_seq_len - T
first_frame = frames[:, :1] # (B, 1, C, H, W)
padding = first_frame.expand(B, padding_needed, C, H, W)
frames = torch.cat([padding, frames], dim=1) # Prepend padding
import logging
logging.debug(f"Padded sequence from {T} to {target_seq_len} frames by repeating first frame")
return frames
def encode_language( def encode_language(
+5 -4
View File
@@ -124,13 +124,14 @@ Default weights: $\lambda_{\text{prog}}=1.0$, $\lambda_{\text{spatial-nce}}=0.5$
- Implement eval score or metric that is robust and can deal with generalization/is a good metric to try different architectures. And use it in an eval jupyter notebook with visalization of the live reward next to the video for part of the dataset: VOC score and score with correct and incorrect language captions [x] - Implement eval score or metric that is robust and can deal with generalization/is a good metric to try different architectures. And use it in an eval jupyter notebook with visalization of the live reward next to the video for part of the dataset: VOC score and score with correct and incorrect language captions [x]
- Do first training [x] - Do first training [x]
- Try different losses [] - Try different losses []
- Only vlc loss then eval []
- Only rewind loss then eval [] - Only rewind loss then eval []
- Only vlc loss then eval []
- Vlc + rewind loss then eval [] - Vlc + rewind loss then eval []
- Cleanup code - Convert 1% of bc-z []
- Switch to DINO v3 as encoder Base 86 M: https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m with HuggingFaceTB/SmolLM2-135M-Instruct ? - Cleanup code []
- Try DINO v3 as encoder Base 86 M: https://huggingface.co/facebook/dinov3-vitb16-pretrain-lvd1689m with HuggingFaceTB/SmolLM2-135M-Instruct ? []
- Add more artificial text to dataset generated by vlm (google gemini) [] - Add more artificial text to dataset generated by vlm (google gemini) []
- See google gemini vlm caption from Leandro [] https://gemini.google.com/app/7e332ffaf32580f2 - See google gemini vlm caption [] https://gemini.google.com/app/7e332ffaf32580f2
- Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446 - Multiple captions per video, creat method to generate as much data as possible etc [] https://arxiv.org/abs/2508.13446
- How can we improve spatial aware learning? co generating captions for each frame with language decoder? - How can we improve spatial aware learning? co generating captions for each frame with language decoder?
- Add droid [] - Add droid []