This commit is contained in:
Pepijn
2025-09-01 13:09:00 +02:00
parent 9027c7866f
commit 4f51f7153c
2 changed files with 110 additions and 26 deletions
@@ -56,6 +56,14 @@ class RLearNConfig(PreTrainedConfig):
dropout: float = 0.10
num_register_tokens: int = 4
# --- reward head options ---
use_categorical_rewards: bool = False # classification over bins
num_reward_bins: int = 25
reward_min_value: float = 0.0 # for HL-Gauss range
reward_max_value: float = 1.0
use_hl_gauss_loss: bool = True # if False -> plain regression
hl_gauss_num_bins: int = 25 # histogram resolution
# Inference-time subsampling and regularization
inference_stride: int = 1
frame_dropout_p: float = 0.10
+102 -26
View File
@@ -21,6 +21,10 @@ import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
try:
from hl_gauss_pytorch import HLGaussLayer
except Exception:
HLGaussLayer = None # Optional dependency; guarded at use sites
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
from lerobot.policies.pretrained import PreTrainedPolicy
@@ -100,19 +104,39 @@ class RLearNPolicy(PreTrainedPolicy):
)
self.aggregator = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
# Per-frame predictor head
# Per-frame predictor pre-head
self.frame_mlp = nn.Sequential(
nn.LayerNorm(config.dim_model),
nn.Linear(config.dim_model, config.dim_model),
nn.GELU(),
nn.Dropout(config.dropout),
)
self.reward_head = nn.Sequential(
nn.Linear(config.dim_model, config.dim_model),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.dim_model, 1),
)
# Reward heads (mode-aware)
self.use_categorical = bool(config.use_categorical_rewards)
if self.use_categorical:
self.reward_head = nn.Linear(config.dim_model, int(config.num_reward_bins))
self.hl_gauss_layer = None
else:
# produce embeddings for HL-Gauss (or regression)
self.reward_head = nn.Sequential(
nn.Linear(config.dim_model, config.dim_model),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.dim_model, config.dim_model),
)
if HLGaussLayer is not None:
self.hl_gauss_layer = HLGaussLayer(
dim=config.dim_model,
use_regression=not bool(config.use_hl_gauss_loss),
hl_gauss_loss=dict(
min_value=float(config.reward_min_value),
max_value=float(config.reward_max_value),
num_bins=int(config.hl_gauss_num_bins),
),
)
else:
self.hl_gauss_layer = None
# Sampling and regularization knobs
self.stride = max(1, int(config.inference_stride))
@@ -442,18 +466,40 @@ class RLearNPolicy(PreTrainedPolicy):
# Per-frame prediction
frame_tokens = self.frame_mlp(attended_video) # (B, T_eff, D)
raw_logits = self.reward_head(frame_tokens).squeeze(-1) # (B, T_eff)
predicted_rewards = torch.sigmoid(raw_logits)
# Regularizers to avoid flat outputs and encourage local forward progress
# Encourage non-flat predictions per sample
# Optional masking from batch for variable-length sequences
video_lens = batch.get("video_lens", None)
video_mask = None
if video_lens is not None:
if torch.is_tensor(video_lens):
video_lens = video_lens.to(frame_tokens.device).long()
else:
video_lens = torch.as_tensor(video_lens, device=frame_tokens.device, dtype=torch.long)
video_mask = self._lens_to_mask(video_lens, frame_tokens.shape[1])
if self.use_categorical:
# classification over bins
video_frame_logits = self.reward_head(frame_tokens) # (B,T,L)
raw_like_logits = video_frame_logits.max(dim=-1).values
predicted_rewards = torch.softmax(video_frame_logits, dim=-1)
else:
# embeddings for HL-Gauss (or regression)
video_frame_embeds = self.reward_head(frame_tokens) # (B,T,D)
# derive a scalar proxy for regularizers
raw_like_logits = torch.tanh(video_frame_embeds).mean(dim=-1)
# predicted_rewards will be set after loss branch below
# Regularizers use raw_like_logits for generality
var_min = 1e-3
pred = predicted_rewards
L_flat = F.relu(var_min - pred.var(dim=1, unbiased=False)).mean() if pred.shape[1] > 1 else torch.zeros((), device=device)
# Enforce local forward progress on logits without overconstraining
if self.use_categorical:
# use the max-logit trajectory as a proxy
pred_proxy = torch.softmax(video_frame_logits, dim=-1).max(dim=-1).values
else:
pred_proxy = torch.sigmoid(raw_like_logits)
L_flat = F.relu(var_min - pred_proxy.var(dim=1, unbiased=False)).mean() if pred_proxy.shape[1] > 1 else torch.zeros((), device=device)
rank_margin = 0.02
if raw_logits.shape[1] > 1:
L_rank = F.relu(rank_margin - (raw_logits[:, 1:] - raw_logits[:, :-1])).mean()
if raw_like_logits.shape[1] > 1:
L_rank = F.relu(rank_margin - (raw_like_logits[:, 1:] - raw_like_logits[:, :-1])).mean()
else:
L_rank = torch.zeros((), device=device)
@@ -481,15 +527,41 @@ class RLearNPolicy(PreTrainedPolicy):
# Compute main loss (or just return predictions in eval)
loss_start = time.perf_counter()
if target is None:
total_loss = raw_logits.mean() * 0.0
total_loss = torch.tensor(0.0, device=device)
loss = total_loss
predicted_rewards = pred_proxy if self.use_categorical else pred_proxy
else:
target_expanded = target # (B, T_eff)
eps = self.config.logit_eps
target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps))
# Robust Huber (Smooth L1) on logits
loss = F.smooth_l1_loss(raw_logits, target_logits, beta=0.25)
total_loss = loss
if self.use_categorical:
# map targets in [0,1] to bins
num_bins = int(self.config.num_reward_bins)
bin_idx = (target.clamp(0, 1) * (num_bins - 1) + 1e-6).long()
if video_mask is not None:
bin_idx = torch.where(video_mask, bin_idx, torch.full_like(bin_idx, -1))
loss_ce = F.cross_entropy(video_frame_logits.permute(0, 2, 1), bin_idx, ignore_index=-1)
total_loss = loss_ce
predicted_rewards = torch.softmax(video_frame_logits, dim=-1)
else:
# HL-Gauss or regression
if (self.hl_gauss_layer is not None) and (not self.hl_gauss_layer.use_regression):
loss = self.hl_gauss_layer(video_frame_embeds, target, mask=video_mask)
total_loss = loss
predicted_rewards = self.hl_gauss_layer(video_frame_embeds).detach()
elif (self.hl_gauss_layer is not None) and self.hl_gauss_layer.use_regression:
pred_values = self.hl_gauss_layer(video_frame_embeds) # (B,T)
if video_mask is not None:
loss = F.smooth_l1_loss(pred_values[video_mask], target[video_mask], beta=0.25)
else:
loss = F.smooth_l1_loss(pred_values, target, beta=0.25)
total_loss = loss
predicted_rewards = pred_values
else:
# fall back to existing logit regression path on a scalar proxy
target_expanded = target
eps = self.config.logit_eps
target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps))
loss = F.smooth_l1_loss(raw_like_logits, target_logits, beta=0.25)
total_loss = loss
predicted_rewards = torch.sigmoid(raw_like_logits)
# Mismatched video-language pairs loss (only when languages actually differ)
@@ -554,7 +626,7 @@ class RLearNPolicy(PreTrainedPolicy):
print(f"Target: min={target_expanded.min():.3f}, max={target_expanded.max():.3f}, mean={target_expanded.mean():.3f}")
has_high_targets = (target_expanded > 0.8).any().item()
print(f"✓ Has targets >0.8: {has_high_targets} | T_eff: {T_eff}")
print(f"Logits: min={raw_logits.min():.3f}, max={raw_logits.max():.3f}, mean={raw_logits.mean():.3f}")
print(f"Logits(proxy): min={raw_like_logits.min():.3f}, max={raw_like_logits.max():.3f}, mean={raw_like_logits.mean():.3f}")
print(f"Preds: min={predicted_rewards.min():.3f}, max={predicted_rewards.max():.3f}, mean={predicted_rewards.mean():.3f}")
# Show full arrays occasionally (25% chance within debug)
@@ -608,8 +680,8 @@ class RLearNPolicy(PreTrainedPolicy):
"pred_mean": float(predicted_rewards.mean().item()),
"pred_std": float(predicted_rewards.std().item()),
# Raw logits statistics (useful for monitoring head behavior)
"raw_logits_mean": float(raw_logits.mean().item()),
"raw_logits_std": float(raw_logits.std().item()),
"raw_logits_mean": float(raw_like_logits.mean().item()),
"raw_logits_std": float(raw_like_logits.std().item()),
# Anchor sampling statistics
"anchor_mean": float(anchor_stats.get('anchor_mean', 0.0)),
"anchor_std": float(anchor_stats.get('anchor_std', 0.0)),
@@ -716,6 +788,10 @@ class RLearNPolicy(PreTrainedPolicy):
mask = torch.ones(cls_only.shape[:2], device=device, dtype=torch.bool)
return cls_only, mask
def _lens_to_mask(self, lens: Tensor, T: int) -> Tensor:
rng = torch.arange(T, device=lens.device)[None, :]
return rng < lens[:, None]
def _extract_episode_and_frame_indices(self, batch: dict[str, Tensor]) -> tuple[Tensor | None, Tensor | None]:
"""Try to extract (episode_index, frame_index) tensors from batch or complementary data.