mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 21:19:53 +00:00
hl-gauss
This commit is contained in:
@@ -56,6 +56,14 @@ class RLearNConfig(PreTrainedConfig):
|
|||||||
dropout: float = 0.10
|
dropout: float = 0.10
|
||||||
num_register_tokens: int = 4
|
num_register_tokens: int = 4
|
||||||
|
|
||||||
|
# --- reward head options ---
|
||||||
|
use_categorical_rewards: bool = False # classification over bins
|
||||||
|
num_reward_bins: int = 25
|
||||||
|
reward_min_value: float = 0.0 # for HL-Gauss range
|
||||||
|
reward_max_value: float = 1.0
|
||||||
|
use_hl_gauss_loss: bool = True # if False -> plain regression
|
||||||
|
hl_gauss_num_bins: int = 25 # histogram resolution
|
||||||
|
|
||||||
# Inference-time subsampling and regularization
|
# Inference-time subsampling and regularization
|
||||||
inference_stride: int = 1
|
inference_stride: int = 1
|
||||||
frame_dropout_p: float = 0.10
|
frame_dropout_p: float = 0.10
|
||||||
|
|||||||
@@ -21,6 +21,10 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
try:
|
||||||
|
from hl_gauss_pytorch import HLGaussLayer
|
||||||
|
except Exception:
|
||||||
|
HLGaussLayer = None # Optional dependency; guarded at use sites
|
||||||
|
|
||||||
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
|
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
@@ -100,19 +104,39 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
)
|
)
|
||||||
self.aggregator = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
|
self.aggregator = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
|
||||||
|
|
||||||
# Per-frame predictor head
|
# Per-frame predictor pre-head
|
||||||
self.frame_mlp = nn.Sequential(
|
self.frame_mlp = nn.Sequential(
|
||||||
nn.LayerNorm(config.dim_model),
|
nn.LayerNorm(config.dim_model),
|
||||||
nn.Linear(config.dim_model, config.dim_model),
|
nn.Linear(config.dim_model, config.dim_model),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.Dropout(config.dropout),
|
nn.Dropout(config.dropout),
|
||||||
)
|
)
|
||||||
self.reward_head = nn.Sequential(
|
|
||||||
nn.Linear(config.dim_model, config.dim_model),
|
# Reward heads (mode-aware)
|
||||||
nn.GELU(),
|
self.use_categorical = bool(config.use_categorical_rewards)
|
||||||
nn.Dropout(config.dropout),
|
if self.use_categorical:
|
||||||
nn.Linear(config.dim_model, 1),
|
self.reward_head = nn.Linear(config.dim_model, int(config.num_reward_bins))
|
||||||
)
|
self.hl_gauss_layer = None
|
||||||
|
else:
|
||||||
|
# produce embeddings for HL-Gauss (or regression)
|
||||||
|
self.reward_head = nn.Sequential(
|
||||||
|
nn.Linear(config.dim_model, config.dim_model),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(config.dropout),
|
||||||
|
nn.Linear(config.dim_model, config.dim_model),
|
||||||
|
)
|
||||||
|
if HLGaussLayer is not None:
|
||||||
|
self.hl_gauss_layer = HLGaussLayer(
|
||||||
|
dim=config.dim_model,
|
||||||
|
use_regression=not bool(config.use_hl_gauss_loss),
|
||||||
|
hl_gauss_loss=dict(
|
||||||
|
min_value=float(config.reward_min_value),
|
||||||
|
max_value=float(config.reward_max_value),
|
||||||
|
num_bins=int(config.hl_gauss_num_bins),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.hl_gauss_layer = None
|
||||||
|
|
||||||
# Sampling and regularization knobs
|
# Sampling and regularization knobs
|
||||||
self.stride = max(1, int(config.inference_stride))
|
self.stride = max(1, int(config.inference_stride))
|
||||||
@@ -442,18 +466,40 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Per-frame prediction
|
# Per-frame prediction
|
||||||
frame_tokens = self.frame_mlp(attended_video) # (B, T_eff, D)
|
frame_tokens = self.frame_mlp(attended_video) # (B, T_eff, D)
|
||||||
raw_logits = self.reward_head(frame_tokens).squeeze(-1) # (B, T_eff)
|
|
||||||
predicted_rewards = torch.sigmoid(raw_logits)
|
|
||||||
|
|
||||||
# Regularizers to avoid flat outputs and encourage local forward progress
|
# Optional masking from batch for variable-length sequences
|
||||||
# Encourage non-flat predictions per sample
|
video_lens = batch.get("video_lens", None)
|
||||||
|
video_mask = None
|
||||||
|
if video_lens is not None:
|
||||||
|
if torch.is_tensor(video_lens):
|
||||||
|
video_lens = video_lens.to(frame_tokens.device).long()
|
||||||
|
else:
|
||||||
|
video_lens = torch.as_tensor(video_lens, device=frame_tokens.device, dtype=torch.long)
|
||||||
|
video_mask = self._lens_to_mask(video_lens, frame_tokens.shape[1])
|
||||||
|
|
||||||
|
if self.use_categorical:
|
||||||
|
# classification over bins
|
||||||
|
video_frame_logits = self.reward_head(frame_tokens) # (B,T,L)
|
||||||
|
raw_like_logits = video_frame_logits.max(dim=-1).values
|
||||||
|
predicted_rewards = torch.softmax(video_frame_logits, dim=-1)
|
||||||
|
else:
|
||||||
|
# embeddings for HL-Gauss (or regression)
|
||||||
|
video_frame_embeds = self.reward_head(frame_tokens) # (B,T,D)
|
||||||
|
# derive a scalar proxy for regularizers
|
||||||
|
raw_like_logits = torch.tanh(video_frame_embeds).mean(dim=-1)
|
||||||
|
# predicted_rewards will be set after loss branch below
|
||||||
|
|
||||||
|
# Regularizers use raw_like_logits for generality
|
||||||
var_min = 1e-3
|
var_min = 1e-3
|
||||||
pred = predicted_rewards
|
if self.use_categorical:
|
||||||
L_flat = F.relu(var_min - pred.var(dim=1, unbiased=False)).mean() if pred.shape[1] > 1 else torch.zeros((), device=device)
|
# use the max-logit trajectory as a proxy
|
||||||
# Enforce local forward progress on logits without overconstraining
|
pred_proxy = torch.softmax(video_frame_logits, dim=-1).max(dim=-1).values
|
||||||
|
else:
|
||||||
|
pred_proxy = torch.sigmoid(raw_like_logits)
|
||||||
|
L_flat = F.relu(var_min - pred_proxy.var(dim=1, unbiased=False)).mean() if pred_proxy.shape[1] > 1 else torch.zeros((), device=device)
|
||||||
rank_margin = 0.02
|
rank_margin = 0.02
|
||||||
if raw_logits.shape[1] > 1:
|
if raw_like_logits.shape[1] > 1:
|
||||||
L_rank = F.relu(rank_margin - (raw_logits[:, 1:] - raw_logits[:, :-1])).mean()
|
L_rank = F.relu(rank_margin - (raw_like_logits[:, 1:] - raw_like_logits[:, :-1])).mean()
|
||||||
else:
|
else:
|
||||||
L_rank = torch.zeros((), device=device)
|
L_rank = torch.zeros((), device=device)
|
||||||
|
|
||||||
@@ -481,15 +527,41 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
# Compute main loss (or just return predictions in eval)
|
# Compute main loss (or just return predictions in eval)
|
||||||
loss_start = time.perf_counter()
|
loss_start = time.perf_counter()
|
||||||
if target is None:
|
if target is None:
|
||||||
total_loss = raw_logits.mean() * 0.0
|
total_loss = torch.tensor(0.0, device=device)
|
||||||
loss = total_loss
|
loss = total_loss
|
||||||
|
predicted_rewards = pred_proxy if self.use_categorical else pred_proxy
|
||||||
else:
|
else:
|
||||||
target_expanded = target # (B, T_eff)
|
if self.use_categorical:
|
||||||
eps = self.config.logit_eps
|
# map targets in [0,1] to bins
|
||||||
target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps))
|
num_bins = int(self.config.num_reward_bins)
|
||||||
# Robust Huber (Smooth L1) on logits
|
bin_idx = (target.clamp(0, 1) * (num_bins - 1) + 1e-6).long()
|
||||||
loss = F.smooth_l1_loss(raw_logits, target_logits, beta=0.25)
|
if video_mask is not None:
|
||||||
total_loss = loss
|
bin_idx = torch.where(video_mask, bin_idx, torch.full_like(bin_idx, -1))
|
||||||
|
loss_ce = F.cross_entropy(video_frame_logits.permute(0, 2, 1), bin_idx, ignore_index=-1)
|
||||||
|
total_loss = loss_ce
|
||||||
|
predicted_rewards = torch.softmax(video_frame_logits, dim=-1)
|
||||||
|
else:
|
||||||
|
# HL-Gauss or regression
|
||||||
|
if (self.hl_gauss_layer is not None) and (not self.hl_gauss_layer.use_regression):
|
||||||
|
loss = self.hl_gauss_layer(video_frame_embeds, target, mask=video_mask)
|
||||||
|
total_loss = loss
|
||||||
|
predicted_rewards = self.hl_gauss_layer(video_frame_embeds).detach()
|
||||||
|
elif (self.hl_gauss_layer is not None) and self.hl_gauss_layer.use_regression:
|
||||||
|
pred_values = self.hl_gauss_layer(video_frame_embeds) # (B,T)
|
||||||
|
if video_mask is not None:
|
||||||
|
loss = F.smooth_l1_loss(pred_values[video_mask], target[video_mask], beta=0.25)
|
||||||
|
else:
|
||||||
|
loss = F.smooth_l1_loss(pred_values, target, beta=0.25)
|
||||||
|
total_loss = loss
|
||||||
|
predicted_rewards = pred_values
|
||||||
|
else:
|
||||||
|
# fall back to existing logit regression path on a scalar proxy
|
||||||
|
target_expanded = target
|
||||||
|
eps = self.config.logit_eps
|
||||||
|
target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps))
|
||||||
|
loss = F.smooth_l1_loss(raw_like_logits, target_logits, beta=0.25)
|
||||||
|
total_loss = loss
|
||||||
|
predicted_rewards = torch.sigmoid(raw_like_logits)
|
||||||
|
|
||||||
|
|
||||||
# Mismatched video-language pairs loss (only when languages actually differ)
|
# Mismatched video-language pairs loss (only when languages actually differ)
|
||||||
@@ -554,7 +626,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
print(f"Target: min={target_expanded.min():.3f}, max={target_expanded.max():.3f}, mean={target_expanded.mean():.3f}")
|
print(f"Target: min={target_expanded.min():.3f}, max={target_expanded.max():.3f}, mean={target_expanded.mean():.3f}")
|
||||||
has_high_targets = (target_expanded > 0.8).any().item()
|
has_high_targets = (target_expanded > 0.8).any().item()
|
||||||
print(f"✓ Has targets >0.8: {has_high_targets} | T_eff: {T_eff}")
|
print(f"✓ Has targets >0.8: {has_high_targets} | T_eff: {T_eff}")
|
||||||
print(f"Logits: min={raw_logits.min():.3f}, max={raw_logits.max():.3f}, mean={raw_logits.mean():.3f}")
|
print(f"Logits(proxy): min={raw_like_logits.min():.3f}, max={raw_like_logits.max():.3f}, mean={raw_like_logits.mean():.3f}")
|
||||||
print(f"Preds: min={predicted_rewards.min():.3f}, max={predicted_rewards.max():.3f}, mean={predicted_rewards.mean():.3f}")
|
print(f"Preds: min={predicted_rewards.min():.3f}, max={predicted_rewards.max():.3f}, mean={predicted_rewards.mean():.3f}")
|
||||||
|
|
||||||
# Show full arrays occasionally (25% chance within debug)
|
# Show full arrays occasionally (25% chance within debug)
|
||||||
@@ -608,8 +680,8 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
"pred_mean": float(predicted_rewards.mean().item()),
|
"pred_mean": float(predicted_rewards.mean().item()),
|
||||||
"pred_std": float(predicted_rewards.std().item()),
|
"pred_std": float(predicted_rewards.std().item()),
|
||||||
# Raw logits statistics (useful for monitoring head behavior)
|
# Raw logits statistics (useful for monitoring head behavior)
|
||||||
"raw_logits_mean": float(raw_logits.mean().item()),
|
"raw_logits_mean": float(raw_like_logits.mean().item()),
|
||||||
"raw_logits_std": float(raw_logits.std().item()),
|
"raw_logits_std": float(raw_like_logits.std().item()),
|
||||||
# Anchor sampling statistics
|
# Anchor sampling statistics
|
||||||
"anchor_mean": float(anchor_stats.get('anchor_mean', 0.0)),
|
"anchor_mean": float(anchor_stats.get('anchor_mean', 0.0)),
|
||||||
"anchor_std": float(anchor_stats.get('anchor_std', 0.0)),
|
"anchor_std": float(anchor_stats.get('anchor_std', 0.0)),
|
||||||
@@ -716,6 +788,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
mask = torch.ones(cls_only.shape[:2], device=device, dtype=torch.bool)
|
mask = torch.ones(cls_only.shape[:2], device=device, dtype=torch.bool)
|
||||||
return cls_only, mask
|
return cls_only, mask
|
||||||
|
|
||||||
|
def _lens_to_mask(self, lens: Tensor, T: int) -> Tensor:
|
||||||
|
rng = torch.arange(T, device=lens.device)[None, :]
|
||||||
|
return rng < lens[:, None]
|
||||||
|
|
||||||
def _extract_episode_and_frame_indices(self, batch: dict[str, Tensor]) -> tuple[Tensor | None, Tensor | None]:
|
def _extract_episode_and_frame_indices(self, batch: dict[str, Tensor]) -> tuple[Tensor | None, Tensor | None]:
|
||||||
"""Try to extract (episode_index, frame_index) tensors from batch or complementary data.
|
"""Try to extract (episode_index, frame_index) tensors from batch or complementary data.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user