mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
fix
This commit is contained in:
@@ -34,7 +34,7 @@ class RLearNConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- This follows the ReWiND paper architecture. It uses frozen vision/text encoders
|
- This follows the ReWiND paper architecture. It uses frozen vision/text encoders
|
||||||
(DINO v3 for vision, sentence-transformers for language) and trains a
|
(DINO v3 for vision, sentence-transformers for language) and trains a
|
||||||
lightweight temporal aggregator + head.
|
lightweight temporal aggregator + head.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -61,15 +61,17 @@ class RLearNConfig(PreTrainedConfig):
|
|||||||
use_tanh_head: bool = False # when True, bound outputs in [-1, 1]
|
use_tanh_head: bool = False # when True, bound outputs in [-1, 1]
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
learning_rate: float = 1e-4
|
learning_rate: float = 3e-5
|
||||||
weight_decay: float = 0.01
|
weight_decay: float = 0.01
|
||||||
|
|
||||||
# ReWiND-specific parameters
|
# ReWiND-specific parameters
|
||||||
use_video_rewind: bool = True # Enable video rewinding augmentation
|
use_video_rewind: bool = False # Enable video rewinding augmentation
|
||||||
rewind_prob: float = 0.8 # Probability of applying rewind to each sample (paper: ~80%)
|
rewind_prob: float = 0.8 # Probability of applying rewind to each sample (paper: ~80%)
|
||||||
rewind_last3_prob: float = 0.1 # Of the rewinds, 10% only rewind the last 3 frames
|
rewind_last3_prob: float = 0.1 # Of the rewinds, 10% only rewind the last 3 frames
|
||||||
use_mismatch_loss: bool = True # Enable mismatched language-video loss
|
use_mismatch_loss: bool = False # Enable mismatched language-video loss
|
||||||
mismatch_prob: float = 0.2 # Probability to include a mismatched video-language forward pass (paper: ~20%)
|
mismatch_prob: float = (
|
||||||
|
0.2 # Probability to include a mismatched video-language forward pass (paper: ~20%)
|
||||||
|
)
|
||||||
|
|
||||||
# Loss hyperparameters (simplified for ReWiND)
|
# Loss hyperparameters (simplified for ReWiND)
|
||||||
# The main loss is just MSE between predicted and target progress
|
# The main loss is just MSE between predicted and target progress
|
||||||
@@ -85,7 +87,7 @@ class RLearNConfig(PreTrainedConfig):
|
|||||||
# Architectural knobs to better mirror ReWiND
|
# Architectural knobs to better mirror ReWiND
|
||||||
num_register_tokens: int = 4
|
num_register_tokens: int = 4
|
||||||
mlp_predictor_depth: int = 3 # depth of the per-frame MLP head
|
mlp_predictor_depth: int = 3 # depth of the per-frame MLP head
|
||||||
|
|
||||||
# HLGauss loss parameters
|
# HLGauss loss parameters
|
||||||
use_hl_gauss_loss: bool = True
|
use_hl_gauss_loss: bool = True
|
||||||
reward_min_value: float = 0.0
|
reward_min_value: float = 0.0
|
||||||
|
|||||||
Reference in New Issue
Block a user