mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
hl-gauss
This commit is contained in:
@@ -56,6 +56,14 @@ class RLearNConfig(PreTrainedConfig):
|
||||
dropout: float = 0.10
|
||||
num_register_tokens: int = 4
|
||||
|
||||
# --- reward head options ---
|
||||
use_categorical_rewards: bool = False # classification over bins
|
||||
num_reward_bins: int = 25
|
||||
reward_min_value: float = 0.0 # for HL-Gauss range
|
||||
reward_max_value: float = 1.0
|
||||
use_hl_gauss_loss: bool = True # if False -> plain regression
|
||||
hl_gauss_num_bins: int = 25 # histogram resolution
|
||||
|
||||
# Inference-time subsampling and regularization
|
||||
inference_stride: int = 1
|
||||
frame_dropout_p: float = 0.10
|
||||
|
||||
@@ -21,6 +21,10 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
try:
|
||||
from hl_gauss_pytorch import HLGaussLayer
|
||||
except Exception:
|
||||
HLGaussLayer = None # Optional dependency; guarded at use sites
|
||||
|
||||
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
@@ -100,19 +104,39 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
)
|
||||
self.aggregator = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
|
||||
|
||||
# Per-frame predictor head
|
||||
# Per-frame predictor pre-head
|
||||
self.frame_mlp = nn.Sequential(
|
||||
nn.LayerNorm(config.dim_model),
|
||||
nn.Linear(config.dim_model, config.dim_model),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
)
|
||||
self.reward_head = nn.Sequential(
|
||||
nn.Linear(config.dim_model, config.dim_model),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.dim_model, 1),
|
||||
)
|
||||
|
||||
# Reward heads (mode-aware)
|
||||
self.use_categorical = bool(config.use_categorical_rewards)
|
||||
if self.use_categorical:
|
||||
self.reward_head = nn.Linear(config.dim_model, int(config.num_reward_bins))
|
||||
self.hl_gauss_layer = None
|
||||
else:
|
||||
# produce embeddings for HL-Gauss (or regression)
|
||||
self.reward_head = nn.Sequential(
|
||||
nn.Linear(config.dim_model, config.dim_model),
|
||||
nn.GELU(),
|
||||
nn.Dropout(config.dropout),
|
||||
nn.Linear(config.dim_model, config.dim_model),
|
||||
)
|
||||
if HLGaussLayer is not None:
|
||||
self.hl_gauss_layer = HLGaussLayer(
|
||||
dim=config.dim_model,
|
||||
use_regression=not bool(config.use_hl_gauss_loss),
|
||||
hl_gauss_loss=dict(
|
||||
min_value=float(config.reward_min_value),
|
||||
max_value=float(config.reward_max_value),
|
||||
num_bins=int(config.hl_gauss_num_bins),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.hl_gauss_layer = None
|
||||
|
||||
# Sampling and regularization knobs
|
||||
self.stride = max(1, int(config.inference_stride))
|
||||
@@ -442,18 +466,40 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Per-frame prediction
|
||||
frame_tokens = self.frame_mlp(attended_video) # (B, T_eff, D)
|
||||
raw_logits = self.reward_head(frame_tokens).squeeze(-1) # (B, T_eff)
|
||||
predicted_rewards = torch.sigmoid(raw_logits)
|
||||
|
||||
# Regularizers to avoid flat outputs and encourage local forward progress
|
||||
# Encourage non-flat predictions per sample
|
||||
# Optional masking from batch for variable-length sequences
|
||||
video_lens = batch.get("video_lens", None)
|
||||
video_mask = None
|
||||
if video_lens is not None:
|
||||
if torch.is_tensor(video_lens):
|
||||
video_lens = video_lens.to(frame_tokens.device).long()
|
||||
else:
|
||||
video_lens = torch.as_tensor(video_lens, device=frame_tokens.device, dtype=torch.long)
|
||||
video_mask = self._lens_to_mask(video_lens, frame_tokens.shape[1])
|
||||
|
||||
if self.use_categorical:
|
||||
# classification over bins
|
||||
video_frame_logits = self.reward_head(frame_tokens) # (B,T,L)
|
||||
raw_like_logits = video_frame_logits.max(dim=-1).values
|
||||
predicted_rewards = torch.softmax(video_frame_logits, dim=-1)
|
||||
else:
|
||||
# embeddings for HL-Gauss (or regression)
|
||||
video_frame_embeds = self.reward_head(frame_tokens) # (B,T,D)
|
||||
# derive a scalar proxy for regularizers
|
||||
raw_like_logits = torch.tanh(video_frame_embeds).mean(dim=-1)
|
||||
# predicted_rewards will be set after loss branch below
|
||||
|
||||
# Regularizers use raw_like_logits for generality
|
||||
var_min = 1e-3
|
||||
pred = predicted_rewards
|
||||
L_flat = F.relu(var_min - pred.var(dim=1, unbiased=False)).mean() if pred.shape[1] > 1 else torch.zeros((), device=device)
|
||||
# Enforce local forward progress on logits without overconstraining
|
||||
if self.use_categorical:
|
||||
# use the max-logit trajectory as a proxy
|
||||
pred_proxy = torch.softmax(video_frame_logits, dim=-1).max(dim=-1).values
|
||||
else:
|
||||
pred_proxy = torch.sigmoid(raw_like_logits)
|
||||
L_flat = F.relu(var_min - pred_proxy.var(dim=1, unbiased=False)).mean() if pred_proxy.shape[1] > 1 else torch.zeros((), device=device)
|
||||
rank_margin = 0.02
|
||||
if raw_logits.shape[1] > 1:
|
||||
L_rank = F.relu(rank_margin - (raw_logits[:, 1:] - raw_logits[:, :-1])).mean()
|
||||
if raw_like_logits.shape[1] > 1:
|
||||
L_rank = F.relu(rank_margin - (raw_like_logits[:, 1:] - raw_like_logits[:, :-1])).mean()
|
||||
else:
|
||||
L_rank = torch.zeros((), device=device)
|
||||
|
||||
@@ -481,15 +527,41 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# Compute main loss (or just return predictions in eval)
|
||||
loss_start = time.perf_counter()
|
||||
if target is None:
|
||||
total_loss = raw_logits.mean() * 0.0
|
||||
total_loss = torch.tensor(0.0, device=device)
|
||||
loss = total_loss
|
||||
predicted_rewards = pred_proxy if self.use_categorical else pred_proxy
|
||||
else:
|
||||
target_expanded = target # (B, T_eff)
|
||||
eps = self.config.logit_eps
|
||||
target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps))
|
||||
# Robust Huber (Smooth L1) on logits
|
||||
loss = F.smooth_l1_loss(raw_logits, target_logits, beta=0.25)
|
||||
total_loss = loss
|
||||
if self.use_categorical:
|
||||
# map targets in [0,1] to bins
|
||||
num_bins = int(self.config.num_reward_bins)
|
||||
bin_idx = (target.clamp(0, 1) * (num_bins - 1) + 1e-6).long()
|
||||
if video_mask is not None:
|
||||
bin_idx = torch.where(video_mask, bin_idx, torch.full_like(bin_idx, -1))
|
||||
loss_ce = F.cross_entropy(video_frame_logits.permute(0, 2, 1), bin_idx, ignore_index=-1)
|
||||
total_loss = loss_ce
|
||||
predicted_rewards = torch.softmax(video_frame_logits, dim=-1)
|
||||
else:
|
||||
# HL-Gauss or regression
|
||||
if (self.hl_gauss_layer is not None) and (not self.hl_gauss_layer.use_regression):
|
||||
loss = self.hl_gauss_layer(video_frame_embeds, target, mask=video_mask)
|
||||
total_loss = loss
|
||||
predicted_rewards = self.hl_gauss_layer(video_frame_embeds).detach()
|
||||
elif (self.hl_gauss_layer is not None) and self.hl_gauss_layer.use_regression:
|
||||
pred_values = self.hl_gauss_layer(video_frame_embeds) # (B,T)
|
||||
if video_mask is not None:
|
||||
loss = F.smooth_l1_loss(pred_values[video_mask], target[video_mask], beta=0.25)
|
||||
else:
|
||||
loss = F.smooth_l1_loss(pred_values, target, beta=0.25)
|
||||
total_loss = loss
|
||||
predicted_rewards = pred_values
|
||||
else:
|
||||
# fall back to existing logit regression path on a scalar proxy
|
||||
target_expanded = target
|
||||
eps = self.config.logit_eps
|
||||
target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps))
|
||||
loss = F.smooth_l1_loss(raw_like_logits, target_logits, beta=0.25)
|
||||
total_loss = loss
|
||||
predicted_rewards = torch.sigmoid(raw_like_logits)
|
||||
|
||||
|
||||
# Mismatched video-language pairs loss (only when languages actually differ)
|
||||
@@ -554,7 +626,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
print(f"Target: min={target_expanded.min():.3f}, max={target_expanded.max():.3f}, mean={target_expanded.mean():.3f}")
|
||||
has_high_targets = (target_expanded > 0.8).any().item()
|
||||
print(f"✓ Has targets >0.8: {has_high_targets} | T_eff: {T_eff}")
|
||||
print(f"Logits: min={raw_logits.min():.3f}, max={raw_logits.max():.3f}, mean={raw_logits.mean():.3f}")
|
||||
print(f"Logits(proxy): min={raw_like_logits.min():.3f}, max={raw_like_logits.max():.3f}, mean={raw_like_logits.mean():.3f}")
|
||||
print(f"Preds: min={predicted_rewards.min():.3f}, max={predicted_rewards.max():.3f}, mean={predicted_rewards.mean():.3f}")
|
||||
|
||||
# Show full arrays occasionally (25% chance within debug)
|
||||
@@ -608,8 +680,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
"pred_mean": float(predicted_rewards.mean().item()),
|
||||
"pred_std": float(predicted_rewards.std().item()),
|
||||
# Raw logits statistics (useful for monitoring head behavior)
|
||||
"raw_logits_mean": float(raw_logits.mean().item()),
|
||||
"raw_logits_std": float(raw_logits.std().item()),
|
||||
"raw_logits_mean": float(raw_like_logits.mean().item()),
|
||||
"raw_logits_std": float(raw_like_logits.std().item()),
|
||||
# Anchor sampling statistics
|
||||
"anchor_mean": float(anchor_stats.get('anchor_mean', 0.0)),
|
||||
"anchor_std": float(anchor_stats.get('anchor_std', 0.0)),
|
||||
@@ -716,6 +788,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
mask = torch.ones(cls_only.shape[:2], device=device, dtype=torch.bool)
|
||||
return cls_only, mask
|
||||
|
||||
def _lens_to_mask(self, lens: Tensor, T: int) -> Tensor:
|
||||
rng = torch.arange(T, device=lens.device)[None, :]
|
||||
return rng < lens[:, None]
|
||||
|
||||
def _extract_episode_and_frame_indices(self, batch: dict[str, Tensor]) -> tuple[Tensor | None, Tensor | None]:
|
||||
"""Try to extract (episode_index, frame_index) tensors from batch or complementary data.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user