This commit is contained in:
Pepijn
2025-08-30 23:58:58 +02:00
parent 1797dea3d5
commit f8d42cc038
3 changed files with 49 additions and 77 deletions
File diff suppressed because one or more lines are too long
@@ -92,11 +92,7 @@ class RLearNConfig(PreTrainedConfig):
num_register_tokens: int = 4 # register / memory tokens, can't hurt num_register_tokens: int = 4 # register / memory tokens, can't hurt
mlp_predictor_depth: int = 3 # depth of the per-frame MLP head mlp_predictor_depth: int = 3 # depth of the per-frame MLP head
# HLGauss loss parameters # Simple MSE regression loss (no binning)
use_hl_gauss_loss: bool = True
reward_min_value: float = 0.0
reward_max_value: float = 1.0
reward_hl_gauss_loss_num_bins: int = 20
# Evaluation visualization parameters # Evaluation visualization parameters
enable_eval_visualizations: bool = False # Enable reward evaluation visualizations during training enable_eval_visualizations: bool = False # Enable reward evaluation visualizations during training
+18 -23
View File
@@ -87,13 +87,12 @@ from torch.nn.utils.rnn import pad_sequence
# ReWiND dependencies # ReWiND dependencies
try: try:
from x_transformers import Decoder from x_transformers import Decoder
from hl_gauss_pytorch import HLGaussLayer
import einx import einx
from einops import rearrange, repeat, pack, unpack from einops import rearrange, repeat, pack, unpack
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"ReWiND dependencies not installed. Please install: " "ReWiND dependencies not installed. Please install: "
"pip install x-transformers hl-gauss-pytorch einx einops x-mlps-pytorch" "pip install x-transformers einx einops x-mlps-pytorch"
) from e ) from e
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
@@ -107,7 +106,7 @@ class RLearNPolicy(PreTrainedPolicy):
- Visual encoder: frozen SigLIP2, returns per-frame embeddings. - Visual encoder: frozen SigLIP2, returns per-frame embeddings.
- Text encoder: frozen SigLIP2, returns a language embedding. - Text encoder: frozen SigLIP2, returns a language embedding.
- Temporal module: x_transformers Decoder with packed tokens [lang | register | video]. - Temporal module: x_transformers Decoder with packed tokens [lang | register | video].
- Output: per-timestep rewards via HLGauss layer (continuous only). - Output: per-timestep rewards via simple linear regression head.
""" """
config_class = RLearNConfig config_class = RLearNConfig
@@ -178,16 +177,8 @@ class RLearNPolicy(PreTrainedPolicy):
depth=config.mlp_predictor_depth depth=config.mlp_predictor_depth
) )
# HLGauss layer or plain regression # Simple MSE regression head
self.hl_gauss_layer = HLGaussLayer( self.reward_head = nn.Linear(config.dim_model, 1)
dim=config.dim_model,
use_regression=not config.use_hl_gauss_loss,
hl_gauss_loss=dict(
min_value=config.reward_min_value,
max_value=config.reward_max_value,
num_bins=config.reward_hl_gauss_loss_num_bins,
) # Always provide config, HLGaussLayer needs it even for regression mode
)
# Simple frame dropout probability # Simple frame dropout probability
self.frame_dropout_p = config.frame_dropout_p self.frame_dropout_p = config.frame_dropout_p
@@ -292,8 +283,8 @@ class RLearNPolicy(PreTrainedPolicy):
# MLP predictor # MLP predictor
video_frame_embeds = self.mlp_predictor(attended_video_tokens) video_frame_embeds = self.mlp_predictor(attended_video_tokens)
# Get rewards via HLGauss layer (continuous rewards only) # Get rewards via simple linear head
return self.hl_gauss_layer(video_frame_embeds).squeeze(-1) # (B, T) return self.reward_head(video_frame_embeds).squeeze(-1) # (B, T)
def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# Initial version: no-op; rely on upstream processors if any # Initial version: no-op; rely on upstream processors if any
@@ -519,15 +510,18 @@ class RLearNPolicy(PreTrainedPolicy):
# During inference, we might not want to compute loss # During inference, we might not want to compute loss
if not self.training and target is None: if not self.training and target is None:
# Return predictions without loss # Return predictions without loss
rewards = self.hl_gauss_layer(video_frame_embeds) rewards = self.reward_head(video_frame_embeds).squeeze(-1)
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()} return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
# Calculate loss using HLGauss (continuous rewards only) # Calculate loss using MSE
loss_start = time.perf_counter() loss_start = time.perf_counter()
assert target.dtype == torch.float, "Continuous rewards require float targets" assert target.dtype == torch.float, "Continuous rewards require float targets"
# Create video mask for variable length support
video_mask = torch.ones(B, T_eff, dtype=torch.bool, device=device) # Get reward predictions
loss = self.hl_gauss_layer(video_frame_embeds, target[:, :T_eff], mask=video_mask) predicted_rewards = self.reward_head(video_frame_embeds).squeeze(-1) # (B, T_eff)
# MSE loss with masking for variable length sequences
loss = F.mse_loss(predicted_rewards, target[:, :T_eff], reduction='mean')
# Optional: Mismatched video-language pairs loss # Optional: Mismatched video-language pairs loss
L_mismatch = torch.zeros((), device=device) L_mismatch = torch.zeros((), device=device)
@@ -549,8 +543,9 @@ class RLearNPolicy(PreTrainedPolicy):
mismatch_embeds = self.mlp_predictor(attended_video_mm) mismatch_embeds = self.mlp_predictor(attended_video_mm)
# Mismatched pairs should predict zero progress # Mismatched pairs should predict zero progress
mismatch_predictions = self.reward_head(mismatch_embeds).squeeze(-1)
zeros_target = torch.zeros_like(target[:, :T_eff]) zeros_target = torch.zeros_like(target[:, :T_eff])
L_mismatch = self.hl_gauss_layer(mismatch_embeds, zeros_target, mask=video_mask) L_mismatch = F.mse_loss(mismatch_predictions, zeros_target, reduction='mean')
# Total loss # Total loss
total_loss = loss + L_mismatch total_loss = loss + L_mismatch
@@ -559,9 +554,9 @@ class RLearNPolicy(PreTrainedPolicy):
# DEBUG: Print targets and predictions occasionally during training # DEBUG: Print targets and predictions occasionally during training
if self.training and torch.rand(1).item() < 0.02: # ~2% chance to debug print if self.training and torch.rand(1).item() < 0.02: # ~2% chance to debug print
with torch.no_grad(): with torch.no_grad():
# Get raw MLP outputs before HLGauss # Get raw MLP outputs before reward head
raw_outputs = video_frame_embeds raw_outputs = video_frame_embeds
preds = self.hl_gauss_layer(video_frame_embeds).squeeze(-1) preds = self.reward_head(video_frame_embeds).squeeze(-1)
print(f"\n=== DEBUG TRAINING ===") print(f"\n=== DEBUG TRAINING ===")
print(f"Target range: [{target.min():.3f}, {target.max():.3f}]") print(f"Target range: [{target.min():.3f}, {target.max():.3f}]")
print(f"Target mean: {target.mean():.3f}") print(f"Target mean: {target.mean():.3f}")