add lower out of bound sampling

This commit is contained in:
Pepijn
2025-08-31 20:38:45 +02:00
parent a1a3fa435d
commit eff5b90542
+9 -71
View File
@@ -14,77 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
RLearN: Video-Language Conditioned Reward Model (ReWiND Implementation)
This implementation follows the ReWiND paper approach (arXiv:2505.10911v1):
- Automatically generates linear progress labels (0 to 1) for each episode
- No need for pre-annotated rewards in the dataset
- Applies video rewinding augmentation to create synthetic failure trajectories
Inputs
- images: (B, T, C, H, W) sequence of frames (or single frame with T=1)
- language: list[str] of length B (goal/instruction)
High-level Architecture
images (B,T,C,H,W)
|
| per-frame encode
v
+------------------------------+
| Vision Encoder (frozen) | e.g. SigLIP2 (base)
+------------------------------+
|s
| pooled per-frame embeddings (BT, H_v)
v
reshape -> (B, T, H_v) -- Linear proj --> (B, T, D)
+ Positional Encoding [0..T)
+ Optional first-frame bias
|
| language (B, str)
| |
| v
| +------------------------------+
| | Text Encoder (frozen) | e.g. SigLIP2
| +------------------------------+
| |
| | pooled text embedding (B, H_t)
| v
| Linear proj -> (B, D)
| |
+-----------------v----------------------+
|
+--------------------------v---------------------------+
| Temporal Causal Transformer (n_layers, n_heads) |
| - self-attention over time with causal mask |
| - cross-attention to a single language token |
+--------------------------+---------------------------+
|
LayerNorm + Linear Head (D -> 1)
|
v
Output
- reward_logits: (B, T', 1) with T' ≤ T (affected by stride and frame dropout)
Notes
- Uses SigLIP2 for both vision and text encoding.
- Backbones (vision/text) are frozen by default; only projections, temporal module, and head are trainable.
- Stride/frame dropout applied during training can subsample timesteps.
"""
from __future__ import annotations
import math
import numpy as np
from contextlib import nullcontext
from itertools import chain
from operator import truediv
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
# ReWiND dependencies
try:
@@ -103,9 +39,9 @@ from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
class RLearNPolicy(PreTrainedPolicy):
"""Video-language conditioned reward model following ReWiND architecture exactly: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11.
"""Video-language conditioned reward model following ReWiND architecture: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11.
- Visual encoder: frozen SigLIP2, returns per-frame embeddings.
- Visual encoder: frozen DinoV3 encoder, returns per-frame embeddings.
- Text encoder: frozen SigLIP2, returns a language embedding.
- Temporal module: x_transformers Decoder with packed tokens [lang | register | video].
- Output: per-timestep rewards via simple linear regression head.
@@ -347,7 +283,7 @@ class RLearNPolicy(PreTrainedPolicy):
return batch
def _encode_video_frames(self, frames: Tensor) -> Tensor:
"""Encode video frames through SigLIP2 to get per-frame embeddings.
"""Encode video frames through DinoV3 to get per-frame embeddings.
Args:
frames: (B, T, C, H, W)
@@ -1031,16 +967,18 @@ class RLearNPolicy(PreTrainedPolicy):
for i in range(T):
delta = -(T - 1 - i) * effective_stride
w_idx = anchor_in_window + delta
# Lower-bound OOB: clamp to 0 (repeat first frame)
if w_idx < 0:
w_idx = -w_idx
w_idx = 0
had_oob = True
# Upper-bound OOB (shouldn't happen when sampling past): clamp to last
elif w_idx >= available_T:
w_idx = 2 * (available_T - 1) - w_idx
w_idx = available_T - 1
print(f" ⚠️ OOB: {w_idx} >= {available_T}, this should not happen!")
had_oob = True
w_idx = max(0, min(w_idx, available_T - 1))
window_indices.append(w_idx)
# Map window index back to episode-relative absolute frame index
# Map window index back to episode-relative absolute frame index and clamp to 0..ep_length-1
abs_idx = cur_frame_idx + (w_idx - (available_T - 1))
abs_idx = int(max(0, min(abs_idx, ep_length - 1)))
frame_indices_for_progress.append(abs_idx)