mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
add lower out of bound sampling
This commit is contained in:
@@ -14,77 +14,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
|
||||||
RLearN: Video-Language Conditioned Reward Model (ReWiND Implementation)
|
|
||||||
|
|
||||||
This implementation follows the ReWiND paper approach (arXiv:2505.10911v1):
|
|
||||||
- Automatically generates linear progress labels (0 to 1) for each episode
|
|
||||||
- No need for pre-annotated rewards in the dataset
|
|
||||||
- Applies video rewinding augmentation to create synthetic failure trajectories
|
|
||||||
|
|
||||||
Inputs
|
|
||||||
- images: (B, T, C, H, W) sequence of frames (or single frame with T=1)
|
|
||||||
- language: list[str] of length B (goal/instruction)
|
|
||||||
|
|
||||||
High-level Architecture
|
|
||||||
|
|
||||||
images (B,T,C,H,W)
|
|
||||||
|
|
|
||||||
| per-frame encode
|
|
||||||
v
|
|
||||||
+------------------------------+
|
|
||||||
| Vision Encoder (frozen) | e.g. SigLIP2 (base)
|
|
||||||
+------------------------------+
|
|
||||||
|s
|
|
||||||
| pooled per-frame embeddings (BT, H_v)
|
|
||||||
v
|
|
||||||
reshape -> (B, T, H_v) -- Linear proj --> (B, T, D)
|
|
||||||
+ Positional Encoding [0..T)
|
|
||||||
+ Optional first-frame bias
|
|
||||||
|
|
|
||||||
| language (B, str)
|
|
||||||
| |
|
|
||||||
| v
|
|
||||||
| +------------------------------+
|
|
||||||
| | Text Encoder (frozen) | e.g. SigLIP2
|
|
||||||
| +------------------------------+
|
|
||||||
| |
|
|
||||||
| | pooled text embedding (B, H_t)
|
|
||||||
| v
|
|
||||||
| Linear proj -> (B, D)
|
|
||||||
| |
|
|
||||||
+-----------------v----------------------+
|
|
||||||
|
|
|
||||||
+--------------------------v---------------------------+
|
|
||||||
| Temporal Causal Transformer (n_layers, n_heads) |
|
|
||||||
| - self-attention over time with causal mask |
|
|
||||||
| - cross-attention to a single language token |
|
|
||||||
+--------------------------+---------------------------+
|
|
||||||
|
|
|
||||||
LayerNorm + Linear Head (D -> 1)
|
|
||||||
|
|
|
||||||
v
|
|
||||||
Output
|
|
||||||
- reward_logits: (B, T', 1) with T' ≤ T (affected by stride and frame dropout)
|
|
||||||
|
|
||||||
Notes
|
|
||||||
- Uses SigLIP2 for both vision and text encoding.
|
|
||||||
- Backbones (vision/text) are frozen by default; only projections, temporal module, and head are trainable.
|
|
||||||
- Stride/frame dropout applied during training can subsample timesteps.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from contextlib import nullcontext
|
|
||||||
from itertools import chain
|
|
||||||
from operator import truediv
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
|
|
||||||
# ReWiND dependencies
|
# ReWiND dependencies
|
||||||
try:
|
try:
|
||||||
@@ -103,9 +39,9 @@ from lerobot.policies.rlearn.configuration_rlearn import RLearNConfig
|
|||||||
|
|
||||||
|
|
||||||
class RLearNPolicy(PreTrainedPolicy):
|
class RLearNPolicy(PreTrainedPolicy):
|
||||||
"""Video-language conditioned reward model following ReWiND architecture exactly: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11.
|
"""Video-language conditioned reward model following ReWiND architecture: https://github.com/lucidrains/rewind-reward-pytorch/blob/main/rewind_reward_pytorch/rewind_reward.py#L11.
|
||||||
|
|
||||||
- Visual encoder: frozen SigLIP2, returns per-frame embeddings.
|
- Visual encoder: frozen DinoV3 encoder, returns per-frame embeddings.
|
||||||
- Text encoder: frozen SigLIP2, returns a language embedding.
|
- Text encoder: frozen SigLIP2, returns a language embedding.
|
||||||
- Temporal module: x_transformers Decoder with packed tokens [lang | register | video].
|
- Temporal module: x_transformers Decoder with packed tokens [lang | register | video].
|
||||||
- Output: per-timestep rewards via simple linear regression head.
|
- Output: per-timestep rewards via simple linear regression head.
|
||||||
@@ -347,7 +283,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
def _encode_video_frames(self, frames: Tensor) -> Tensor:
|
def _encode_video_frames(self, frames: Tensor) -> Tensor:
|
||||||
"""Encode video frames through SigLIP2 to get per-frame embeddings.
|
"""Encode video frames through DinoV3 to get per-frame embeddings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
frames: (B, T, C, H, W)
|
frames: (B, T, C, H, W)
|
||||||
@@ -1031,16 +967,18 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
for i in range(T):
|
for i in range(T):
|
||||||
delta = -(T - 1 - i) * effective_stride
|
delta = -(T - 1 - i) * effective_stride
|
||||||
w_idx = anchor_in_window + delta
|
w_idx = anchor_in_window + delta
|
||||||
|
# Lower-bound OOB: clamp to 0 (repeat first frame)
|
||||||
if w_idx < 0:
|
if w_idx < 0:
|
||||||
w_idx = -w_idx
|
w_idx = 0
|
||||||
had_oob = True
|
had_oob = True
|
||||||
|
# Upper-bound OOB (shouldn't happen when sampling past): clamp to last
|
||||||
elif w_idx >= available_T:
|
elif w_idx >= available_T:
|
||||||
w_idx = 2 * (available_T - 1) - w_idx
|
w_idx = available_T - 1
|
||||||
|
print(f" ⚠️ OOB: {w_idx} >= {available_T}, this should not happen!")
|
||||||
had_oob = True
|
had_oob = True
|
||||||
w_idx = max(0, min(w_idx, available_T - 1))
|
|
||||||
window_indices.append(w_idx)
|
window_indices.append(w_idx)
|
||||||
|
|
||||||
# Map window index back to episode-relative absolute frame index
|
# Map window index back to episode-relative absolute frame index and clamp to 0..ep_length-1
|
||||||
abs_idx = cur_frame_idx + (w_idx - (available_T - 1))
|
abs_idx = cur_frame_idx + (w_idx - (available_T - 1))
|
||||||
abs_idx = int(max(0, min(abs_idx, ep_length - 1)))
|
abs_idx = int(max(0, min(abs_idx, ep_length - 1)))
|
||||||
frame_indices_for_progress.append(abs_idx)
|
frame_indices_for_progress.append(abs_idx)
|
||||||
|
|||||||
Reference in New Issue
Block a user