From 4f51f7153c6337b14fabac62c156f83aead3b9a0 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 1 Sep 2025 13:09:00 +0200 Subject: [PATCH] hl-gauss --- .../policies/rlearn/configuration_rlearn.py | 8 ++ .../policies/rlearn/modeling_rlearn.py | 128 ++++++++++++++---- 2 files changed, 110 insertions(+), 26 deletions(-) diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index a3ba8b5cc..690e50899 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -56,6 +56,14 @@ class RLearNConfig(PreTrainedConfig): dropout: float = 0.10 num_register_tokens: int = 4 + # --- reward head options --- + use_categorical_rewards: bool = False # classification over bins + num_reward_bins: int = 25 + reward_min_value: float = 0.0 # for HL-Gauss range + reward_max_value: float = 1.0 + use_hl_gauss_loss: bool = True # if False -> plain regression + hl_gauss_num_bins: int = 25 # histogram resolution + # Inference-time subsampling and regularization inference_stride: int = 1 frame_dropout_p: float = 0.10 diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index a986ae818..819be8832 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -21,6 +21,10 @@ import numpy as np import torch import torch.nn.functional as F from torch import Tensor, nn +try: + from hl_gauss_pytorch import HLGaussLayer +except Exception: + HLGaussLayer = None # Optional dependency; guarded at use sites from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE, REWARD from lerobot.policies.pretrained import PreTrainedPolicy @@ -100,19 +104,39 @@ class RLearNPolicy(PreTrainedPolicy): ) self.aggregator = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers) - # Per-frame predictor head + # Per-frame predictor pre-head self.frame_mlp = nn.Sequential( nn.LayerNorm(config.dim_model), nn.Linear(config.dim_model, config.dim_model), nn.GELU(), nn.Dropout(config.dropout), ) - self.reward_head = nn.Sequential( - nn.Linear(config.dim_model, config.dim_model), - nn.GELU(), - nn.Dropout(config.dropout), - nn.Linear(config.dim_model, 1), - ) + + # Reward heads (mode-aware) + self.use_categorical = bool(config.use_categorical_rewards) + if self.use_categorical: + self.reward_head = nn.Linear(config.dim_model, int(config.num_reward_bins)) + self.hl_gauss_layer = None + else: + # produce embeddings for HL-Gauss (or regression) + self.reward_head = nn.Sequential( + nn.Linear(config.dim_model, config.dim_model), + nn.GELU(), + nn.Dropout(config.dropout), + nn.Linear(config.dim_model, config.dim_model), + ) + if HLGaussLayer is not None: + self.hl_gauss_layer = HLGaussLayer( + dim=config.dim_model, + use_regression=not bool(config.use_hl_gauss_loss), + hl_gauss_loss=dict( + min_value=float(config.reward_min_value), + max_value=float(config.reward_max_value), + num_bins=int(config.hl_gauss_num_bins), + ), + ) + else: + self.hl_gauss_layer = None # Sampling and regularization knobs self.stride = max(1, int(config.inference_stride)) @@ -442,18 +466,40 @@ class RLearNPolicy(PreTrainedPolicy): # Per-frame prediction frame_tokens = self.frame_mlp(attended_video) # (B, T_eff, D) - raw_logits = self.reward_head(frame_tokens).squeeze(-1) # (B, T_eff) - predicted_rewards = torch.sigmoid(raw_logits) - # Regularizers to avoid flat outputs and encourage local forward progress - # Encourage non-flat predictions per sample + # Optional masking from batch for variable-length sequences + video_lens = batch.get("video_lens", None) + video_mask = None + if video_lens is not None: + if torch.is_tensor(video_lens): + video_lens = video_lens.to(frame_tokens.device).long() + else: + video_lens = torch.as_tensor(video_lens, device=frame_tokens.device, dtype=torch.long) + video_mask = self._lens_to_mask(video_lens, frame_tokens.shape[1]) + + if self.use_categorical: + # classification over bins + video_frame_logits = self.reward_head(frame_tokens) # (B,T,L) + raw_like_logits = video_frame_logits.max(dim=-1).values + predicted_rewards = torch.softmax(video_frame_logits, dim=-1) + else: + # embeddings for HL-Gauss (or regression) + video_frame_embeds = self.reward_head(frame_tokens) # (B,T,D) + # derive a scalar proxy for regularizers + raw_like_logits = torch.tanh(video_frame_embeds).mean(dim=-1) + # predicted_rewards will be set after loss branch below + + # Regularizers use raw_like_logits for generality var_min = 1e-3 - pred = predicted_rewards - L_flat = F.relu(var_min - pred.var(dim=1, unbiased=False)).mean() if pred.shape[1] > 1 else torch.zeros((), device=device) - # Enforce local forward progress on logits without overconstraining + if self.use_categorical: + # use the max-logit trajectory as a proxy + pred_proxy = torch.softmax(video_frame_logits, dim=-1).max(dim=-1).values + else: + pred_proxy = torch.sigmoid(raw_like_logits) + L_flat = F.relu(var_min - pred_proxy.var(dim=1, unbiased=False)).mean() if pred_proxy.shape[1] > 1 else torch.zeros((), device=device) rank_margin = 0.02 - if raw_logits.shape[1] > 1: - L_rank = F.relu(rank_margin - (raw_logits[:, 1:] - raw_logits[:, :-1])).mean() + if raw_like_logits.shape[1] > 1: + L_rank = F.relu(rank_margin - (raw_like_logits[:, 1:] - raw_like_logits[:, :-1])).mean() else: L_rank = torch.zeros((), device=device) @@ -481,15 +527,41 @@ class RLearNPolicy(PreTrainedPolicy): # Compute main loss (or just return predictions in eval) loss_start = time.perf_counter() if target is None: - total_loss = raw_logits.mean() * 0.0 + total_loss = torch.tensor(0.0, device=device) loss = total_loss + predicted_rewards = pred_proxy if self.use_categorical else pred_proxy else: - target_expanded = target # (B, T_eff) - eps = self.config.logit_eps - target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps)) - # Robust Huber (Smooth L1) on logits - loss = F.smooth_l1_loss(raw_logits, target_logits, beta=0.25) - total_loss = loss + if self.use_categorical: + # map targets in [0,1] to bins + num_bins = int(self.config.num_reward_bins) + bin_idx = (target.clamp(0, 1) * (num_bins - 1) + 1e-6).long() + if video_mask is not None: + bin_idx = torch.where(video_mask, bin_idx, torch.full_like(bin_idx, -1)) + loss_ce = F.cross_entropy(video_frame_logits.permute(0, 2, 1), bin_idx, ignore_index=-1) + total_loss = loss_ce + predicted_rewards = torch.softmax(video_frame_logits, dim=-1) + else: + # HL-Gauss or regression + if (self.hl_gauss_layer is not None) and (not self.hl_gauss_layer.use_regression): + loss = self.hl_gauss_layer(video_frame_embeds, target, mask=video_mask) + total_loss = loss + predicted_rewards = self.hl_gauss_layer(video_frame_embeds).detach() + elif (self.hl_gauss_layer is not None) and self.hl_gauss_layer.use_regression: + pred_values = self.hl_gauss_layer(video_frame_embeds) # (B,T) + if video_mask is not None: + loss = F.smooth_l1_loss(pred_values[video_mask], target[video_mask], beta=0.25) + else: + loss = F.smooth_l1_loss(pred_values, target, beta=0.25) + total_loss = loss + predicted_rewards = pred_values + else: + # fall back to existing logit regression path on a scalar proxy + target_expanded = target + eps = self.config.logit_eps + target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps)) + loss = F.smooth_l1_loss(raw_like_logits, target_logits, beta=0.25) + total_loss = loss + predicted_rewards = torch.sigmoid(raw_like_logits) # Mismatched video-language pairs loss (only when languages actually differ) @@ -554,7 +626,7 @@ class RLearNPolicy(PreTrainedPolicy): print(f"Target: min={target_expanded.min():.3f}, max={target_expanded.max():.3f}, mean={target_expanded.mean():.3f}") has_high_targets = (target_expanded > 0.8).any().item() print(f"✓ Has targets >0.8: {has_high_targets} | T_eff: {T_eff}") - print(f"Logits: min={raw_logits.min():.3f}, max={raw_logits.max():.3f}, mean={raw_logits.mean():.3f}") + print(f"Logits(proxy): min={raw_like_logits.min():.3f}, max={raw_like_logits.max():.3f}, mean={raw_like_logits.mean():.3f}") print(f"Preds: min={predicted_rewards.min():.3f}, max={predicted_rewards.max():.3f}, mean={predicted_rewards.mean():.3f}") # Show full arrays occasionally (25% chance within debug) @@ -608,8 +680,8 @@ class RLearNPolicy(PreTrainedPolicy): "pred_mean": float(predicted_rewards.mean().item()), "pred_std": float(predicted_rewards.std().item()), # Raw logits statistics (useful for monitoring head behavior) - "raw_logits_mean": float(raw_logits.mean().item()), - "raw_logits_std": float(raw_logits.std().item()), + "raw_logits_mean": float(raw_like_logits.mean().item()), + "raw_logits_std": float(raw_like_logits.std().item()), # Anchor sampling statistics "anchor_mean": float(anchor_stats.get('anchor_mean', 0.0)), "anchor_std": float(anchor_stats.get('anchor_std', 0.0)), @@ -716,6 +788,10 @@ class RLearNPolicy(PreTrainedPolicy): mask = torch.ones(cls_only.shape[:2], device=device, dtype=torch.bool) return cls_only, mask + def _lens_to_mask(self, lens: Tensor, T: int) -> Tensor: + rng = torch.arange(T, device=lens.device)[None, :] + return rng < lens[:, None] + def _extract_episode_and_frame_indices(self, batch: dict[str, Tensor]) -> tuple[Tensor | None, Tensor | None]: """Try to extract (episode_index, frame_index) tensors from batch or complementary data.