add lower out of bound sampling

This commit is contained in:
Pepijn
2025-08-31 20:38:45 +02:00
parent a1a3fa435d
commit eff5b90542
+9 -71
View File
@@ -14,77 +14,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
RLearN: Video-Language Conditioned Reward Model (ReWiND Implementation)
This implementation follows the ReWiND paper approach (arXiv:2505.10911v1):
- Automatically generates linear progress labels (0 to 1) for each episode
- No need for pre-annotated rewards in the dataset
- Applies video rewinding augmentation to create synthetic failure trajectories
Inputs
- images: (B, T, C, H, W) sequence of frames (or single frame with T=1)
- language: list[str] of length B (goal/instruction)
High-level Architecture
images (B,T,C,H,W)
|
| per-frame encode
v
+------------------------------+
| Vision Encoder (frozen) | e.g. SigLIP2 (base)
+------------------------------+
|s
| pooled per-frame embeddings (BT, H_v)
v
reshape -> (B, T, H_v) -- Linear proj --> (B, T, D)
+ Positional Encoding [0..T)
+ Optional first-frame bias
|
| language (B, str)
| |
| v
| +------------------------------+
| | Text Encoder (frozen) | e.g. SigLIP2
| +------------------------------+
| |
| | pooled text embedding (B, H_t)
| v
| Linear proj -> (B, D)
| |
+-----------------v----------------------+
|
+--------------------------v---------------------------+
| Temporal Causal Transformer (n_layers, n_heads) |
| - self-attention over time with causal mask |
| - cross-attention to a single language token |
+--------------------------+---------------------------+
|
LayerNorm + Linear Head (D -> 1)
|
v
Output
- reward_logits: (B, T', 1) with T' ≤ T (affected by stride and frame dropout)
Notes
- Uses SigLIP2 for both vision and text encoding.
- Backbones (vision/text) are frozen by default; only projections, temporal module, and head are trainable.
- Stride/frame dropout applied during training can subsample timesteps.
"""
from __future__ import annotations from __future__ import annotations
import math
import numpy as np import numpy as np
from contextlib import nullcontext
from itertools import chain
from operator import truediv
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
# ReWiND dependencies # ReWiND dependencies
try: try:
@@ -103,9 +39,9 @@ from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
class RLearNPolicy(PreTrainedPolicy): class RLearNPolicy(PreTrainedPolicy):
"""Video-language conditioned reward model following ReWiND architecture exactly: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11. """Video-language conditioned reward model following ReWiND architecture: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11.
- Visual encoder: frozen SigLIP2, returns per-frame embeddings. - Visual encoder: frozen DinoV3 encoder, returns per-frame embeddings.
- Text encoder: frozen SigLIP2, returns a language embedding. - Text encoder: frozen SigLIP2, returns a language embedding.
- Temporal module: x_transformers Decoder with packed tokens [lang | register | video]. - Temporal module: x_transformers Decoder with packed tokens [lang | register | video].
- Output: per-timestep rewards via simple linear regression head. - Output: per-timestep rewards via simple linear regression head.
@@ -347,7 +283,7 @@ class RLearNPolicy(PreTrainedPolicy):
return batch return batch
def _encode_video_frames(self, frames: Tensor) -> Tensor: def _encode_video_frames(self, frames: Tensor) -> Tensor:
"""Encode video frames through SigLIP2 to get per-frame embeddings. """Encode video frames through DinoV3 to get per-frame embeddings.
Args: Args:
frames: (B, T, C, H, W) frames: (B, T, C, H, W)
@@ -1031,16 +967,18 @@ class RLearNPolicy(PreTrainedPolicy):
for i in range(T): for i in range(T):
delta = -(T - 1 - i) * effective_stride delta = -(T - 1 - i) * effective_stride
w_idx = anchor_in_window + delta w_idx = anchor_in_window + delta
# Lower-bound OOB: clamp to 0 (repeat first frame)
if w_idx < 0: if w_idx < 0:
w_idx = -w_idx w_idx = 0
had_oob = True had_oob = True
# Upper-bound OOB (shouldn't happen when sampling past): clamp to last
elif w_idx >= available_T: elif w_idx >= available_T:
w_idx = 2 * (available_T - 1) - w_idx w_idx = available_T - 1
print(f" ⚠️ OOB: {w_idx} >= {available_T}, this should not happen!")
had_oob = True had_oob = True
w_idx = max(0, min(w_idx, available_T - 1))
window_indices.append(w_idx) window_indices.append(w_idx)
# Map window index back to episode-relative absolute frame index # Map window index back to episode-relative absolute frame index and clamp to 0..ep_length-1
abs_idx = cur_frame_idx + (w_idx - (available_T - 1)) abs_idx = cur_frame_idx + (w_idx - (available_T - 1))
abs_idx = int(max(0, min(abs_idx, ep_length - 1))) abs_idx = int(max(0, min(abs_idx, ep_length - 1)))
frame_indices_for_progress.append(abs_idx) frame_indices_for_progress.append(abs_idx)