mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
fix
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -92,11 +92,7 @@ class RLearNConfig(PreTrainedConfig):
|
|||||||
num_register_tokens: int = 4 # register / memory tokens, can't hurt
|
num_register_tokens: int = 4 # register / memory tokens, can't hurt
|
||||||
mlp_predictor_depth: int = 3 # depth of the per-frame MLP head
|
mlp_predictor_depth: int = 3 # depth of the per-frame MLP head
|
||||||
|
|
||||||
# HLGauss loss parameters
|
# Simple MSE regression loss (no binning)
|
||||||
use_hl_gauss_loss: bool = True
|
|
||||||
reward_min_value: float = 0.0
|
|
||||||
reward_max_value: float = 1.0
|
|
||||||
reward_hl_gauss_loss_num_bins: int = 20
|
|
||||||
|
|
||||||
# Evaluation visualization parameters
|
# Evaluation visualization parameters
|
||||||
enable_eval_visualizations: bool = False # Enable reward evaluation visualizations during training
|
enable_eval_visualizations: bool = False # Enable reward evaluation visualizations during training
|
||||||
|
|||||||
@@ -87,13 +87,12 @@ from torch.nn.utils.rnn import pad_sequence
|
|||||||
# ReWiND dependencies
|
# ReWiND dependencies
|
||||||
try:
|
try:
|
||||||
from x_transformers import Decoder
|
from x_transformers import Decoder
|
||||||
from hl_gauss_pytorch import HLGaussLayer
|
|
||||||
import einx
|
import einx
|
||||||
from einops import rearrange, repeat, pack, unpack
|
from einops import rearrange, repeat, pack, unpack
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"ReWiND dependencies not installed. Please install: "
|
"ReWiND dependencies not installed. Please install: "
|
||||||
"pip install x-transformers hl-gauss-pytorch einx einops x-mlps-pytorch"
|
"pip install x-transformers einx einops x-mlps-pytorch"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
|
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
|
||||||
@@ -107,7 +106,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
- Visual encoder: frozen SigLIP2, returns per-frame embeddings.
|
- Visual encoder: frozen SigLIP2, returns per-frame embeddings.
|
||||||
- Text encoder: frozen SigLIP2, returns a language embedding.
|
- Text encoder: frozen SigLIP2, returns a language embedding.
|
||||||
- Temporal module: x_transformers Decoder with packed tokens [lang | register | video].
|
- Temporal module: x_transformers Decoder with packed tokens [lang | register | video].
|
||||||
- Output: per-timestep rewards via HLGauss layer (continuous only).
|
- Output: per-timestep rewards via simple linear regression head.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = RLearNConfig
|
config_class = RLearNConfig
|
||||||
@@ -178,16 +177,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
depth=config.mlp_predictor_depth
|
depth=config.mlp_predictor_depth
|
||||||
)
|
)
|
||||||
|
|
||||||
# HLGauss layer or plain regression
|
# Simple MSE regression head
|
||||||
self.hl_gauss_layer = HLGaussLayer(
|
self.reward_head = nn.Linear(config.dim_model, 1)
|
||||||
dim=config.dim_model,
|
|
||||||
use_regression=not config.use_hl_gauss_loss,
|
|
||||||
hl_gauss_loss=dict(
|
|
||||||
min_value=config.reward_min_value,
|
|
||||||
max_value=config.reward_max_value,
|
|
||||||
num_bins=config.reward_hl_gauss_loss_num_bins,
|
|
||||||
) # Always provide config, HLGaussLayer needs it even for regression mode
|
|
||||||
)
|
|
||||||
|
|
||||||
# Simple frame dropout probability
|
# Simple frame dropout probability
|
||||||
self.frame_dropout_p = config.frame_dropout_p
|
self.frame_dropout_p = config.frame_dropout_p
|
||||||
@@ -292,8 +283,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
# MLP predictor
|
# MLP predictor
|
||||||
video_frame_embeds = self.mlp_predictor(attended_video_tokens)
|
video_frame_embeds = self.mlp_predictor(attended_video_tokens)
|
||||||
|
|
||||||
# Get rewards via HLGauss layer (continuous rewards only)
|
# Get rewards via simple linear head
|
||||||
return self.hl_gauss_layer(video_frame_embeds).squeeze(-1) # (B, T)
|
return self.reward_head(video_frame_embeds).squeeze(-1) # (B, T)
|
||||||
|
|
||||||
def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
# Initial version: no-op; rely on upstream processors if any
|
# Initial version: no-op; rely on upstream processors if any
|
||||||
@@ -519,15 +510,18 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
# During inference, we might not want to compute loss
|
# During inference, we might not want to compute loss
|
||||||
if not self.training and target is None:
|
if not self.training and target is None:
|
||||||
# Return predictions without loss
|
# Return predictions without loss
|
||||||
rewards = self.hl_gauss_layer(video_frame_embeds)
|
rewards = self.reward_head(video_frame_embeds).squeeze(-1)
|
||||||
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
|
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
|
||||||
|
|
||||||
# Calculate loss using HLGauss (continuous rewards only)
|
# Calculate loss using MSE
|
||||||
loss_start = time.perf_counter()
|
loss_start = time.perf_counter()
|
||||||
assert target.dtype == torch.float, "Continuous rewards require float targets"
|
assert target.dtype == torch.float, "Continuous rewards require float targets"
|
||||||
# Create video mask for variable length support
|
|
||||||
video_mask = torch.ones(B, T_eff, dtype=torch.bool, device=device)
|
# Get reward predictions
|
||||||
loss = self.hl_gauss_layer(video_frame_embeds, target[:, :T_eff], mask=video_mask)
|
predicted_rewards = self.reward_head(video_frame_embeds).squeeze(-1) # (B, T_eff)
|
||||||
|
|
||||||
|
# MSE loss with masking for variable length sequences
|
||||||
|
loss = F.mse_loss(predicted_rewards, target[:, :T_eff], reduction='mean')
|
||||||
|
|
||||||
# Optional: Mismatched video-language pairs loss
|
# Optional: Mismatched video-language pairs loss
|
||||||
L_mismatch = torch.zeros((), device=device)
|
L_mismatch = torch.zeros((), device=device)
|
||||||
@@ -549,8 +543,9 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
mismatch_embeds = self.mlp_predictor(attended_video_mm)
|
mismatch_embeds = self.mlp_predictor(attended_video_mm)
|
||||||
|
|
||||||
# Mismatched pairs should predict zero progress
|
# Mismatched pairs should predict zero progress
|
||||||
|
mismatch_predictions = self.reward_head(mismatch_embeds).squeeze(-1)
|
||||||
zeros_target = torch.zeros_like(target[:, :T_eff])
|
zeros_target = torch.zeros_like(target[:, :T_eff])
|
||||||
L_mismatch = self.hl_gauss_layer(mismatch_embeds, zeros_target, mask=video_mask)
|
L_mismatch = F.mse_loss(mismatch_predictions, zeros_target, reduction='mean')
|
||||||
|
|
||||||
# Total loss
|
# Total loss
|
||||||
total_loss = loss + L_mismatch
|
total_loss = loss + L_mismatch
|
||||||
@@ -559,9 +554,9 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
# DEBUG: Print targets and predictions occasionally during training
|
# DEBUG: Print targets and predictions occasionally during training
|
||||||
if self.training and torch.rand(1).item() < 0.02: # ~2% chance to debug print
|
if self.training and torch.rand(1).item() < 0.02: # ~2% chance to debug print
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Get raw MLP outputs before HLGauss
|
# Get raw MLP outputs before reward head
|
||||||
raw_outputs = video_frame_embeds
|
raw_outputs = video_frame_embeds
|
||||||
preds = self.hl_gauss_layer(video_frame_embeds).squeeze(-1)
|
preds = self.reward_head(video_frame_embeds).squeeze(-1)
|
||||||
print(f"\n=== DEBUG TRAINING ===")
|
print(f"\n=== DEBUG TRAINING ===")
|
||||||
print(f"Target range: [{target.min():.3f}, {target.max():.3f}]")
|
print(f"Target range: [{target.min():.3f}, {target.max():.3f}]")
|
||||||
print(f"Target mean: {target.mean():.3f}")
|
print(f"Target mean: {target.mean():.3f}")
|
||||||
|
|||||||
Reference in New Issue
Block a user