mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
add lower out of bound sampling
This commit is contained in:
@@ -14,77 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
RLearN: Video-Language Conditioned Reward Model (ReWiND Implementation)
|
||||
|
||||
This implementation follows the ReWiND paper approach (arXiv:2505.10911v1):
|
||||
- Automatically generates linear progress labels (0 to 1) for each episode
|
||||
- No need for pre-annotated rewards in the dataset
|
||||
- Applies video rewinding augmentation to create synthetic failure trajectories
|
||||
|
||||
Inputs
|
||||
- images: (B, T, C, H, W) sequence of frames (or single frame with T=1)
|
||||
- language: list[str] of length B (goal/instruction)
|
||||
|
||||
High-level Architecture
|
||||
|
||||
images (B,T,C,H,W)
|
||||
|
|
||||
| per-frame encode
|
||||
v
|
||||
+------------------------------+
|
||||
| Vision Encoder (frozen) | e.g. SigLIP2 (base)
|
||||
+------------------------------+
|
||||
|s
|
||||
| pooled per-frame embeddings (BT, H_v)
|
||||
v
|
||||
reshape -> (B, T, H_v) -- Linear proj --> (B, T, D)
|
||||
+ Positional Encoding [0..T)
|
||||
+ Optional first-frame bias
|
||||
|
|
||||
| language (B, str)
|
||||
| |
|
||||
| v
|
||||
| +------------------------------+
|
||||
| | Text Encoder (frozen) | e.g. SigLIP2
|
||||
| +------------------------------+
|
||||
| |
|
||||
| | pooled text embedding (B, H_t)
|
||||
| v
|
||||
| Linear proj -> (B, D)
|
||||
| |
|
||||
+-----------------v----------------------+
|
||||
|
|
||||
+--------------------------v---------------------------+
|
||||
| Temporal Causal Transformer (n_layers, n_heads) |
|
||||
| - self-attention over time with causal mask |
|
||||
| - cross-attention to a single language token |
|
||||
+--------------------------+---------------------------+
|
||||
|
|
||||
LayerNorm + Linear Head (D -> 1)
|
||||
|
|
||||
v
|
||||
Output
|
||||
- reward_logits: (B, T', 1) with T' ≤ T (affected by stride and frame dropout)
|
||||
|
||||
Notes
|
||||
- Uses SigLIP2 for both vision and text encoding.
|
||||
- Backbones (vision/text) are frozen by default; only projections, temporal module, and head are trainable.
|
||||
- Stride/frame dropout applied during training can subsample timesteps.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from contextlib import nullcontext
|
||||
from itertools import chain
|
||||
from operator import truediv
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# ReWiND dependencies
|
||||
try:
|
||||
@@ -103,9 +39,9 @@ from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
|
||||
|
||||
|
||||
class RLearNPolicy(PreTrainedPolicy):
|
||||
"""Video-language conditioned reward model following ReWiND architecture exactly: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11.
|
||||
"""Video-language conditioned reward model following ReWiND architecture: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11.
|
||||
|
||||
- Visual encoder: frozen SigLIP2, returns per-frame embeddings.
|
||||
- Visual encoder: frozen DinoV3 encoder, returns per-frame embeddings.
|
||||
- Text encoder: frozen SigLIP2, returns a language embedding.
|
||||
- Temporal module: x_transformers Decoder with packed tokens [lang | register | video].
|
||||
- Output: per-timestep rewards via simple linear regression head.
|
||||
@@ -347,7 +283,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
return batch
|
||||
|
||||
def _encode_video_frames(self, frames: Tensor) -> Tensor:
|
||||
"""Encode video frames through SigLIP2 to get per-frame embeddings.
|
||||
"""Encode video frames through DinoV3 to get per-frame embeddings.
|
||||
|
||||
Args:
|
||||
frames: (B, T, C, H, W)
|
||||
@@ -1031,16 +967,18 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
for i in range(T):
|
||||
delta = -(T - 1 - i) * effective_stride
|
||||
w_idx = anchor_in_window + delta
|
||||
# Lower-bound OOB: clamp to 0 (repeat first frame)
|
||||
if w_idx < 0:
|
||||
w_idx = -w_idx
|
||||
w_idx = 0
|
||||
had_oob = True
|
||||
# Upper-bound OOB (shouldn't happen when sampling past): clamp to last
|
||||
elif w_idx >= available_T:
|
||||
w_idx = 2 * (available_T - 1) - w_idx
|
||||
w_idx = available_T - 1
|
||||
print(f" ⚠️ OOB: {w_idx} >= {available_T}, this should not happen!")
|
||||
had_oob = True
|
||||
w_idx = max(0, min(w_idx, available_T - 1))
|
||||
window_indices.append(w_idx)
|
||||
|
||||
# Map window index back to episode-relative absolute frame index
|
||||
# Map window index back to episode-relative absolute frame index and clamp to 0..ep_length-1
|
||||
abs_idx = cur_frame_idx + (w_idx - (available_T - 1))
|
||||
abs_idx = int(max(0, min(abs_idx, ep_length - 1)))
|
||||
frame_indices_for_progress.append(abs_idx)
|
||||
|
||||
Reference in New Issue
Block a user