diff --git a/src/lerobot/policies/rlearn/configuration_rlearn.py b/src/lerobot/policies/rlearn/configuration_rlearn.py index 8058683ea..0744d1481 100644 --- a/src/lerobot/policies/rlearn/configuration_rlearn.py +++ b/src/lerobot/policies/rlearn/configuration_rlearn.py @@ -72,6 +72,7 @@ class RLearNConfig(PreTrainedConfig): weight_decay: float = 0.01 head_lr_multiplier: float = 5.0 logit_eps: float = 1e-4 + regularizer_warmup_steps: int = 500 # Performance optimizations use_amp: bool = False diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index bd7a06b9a..c27488b24 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -111,14 +111,19 @@ class RLearNPolicy(PreTrainedPolicy): # Reward heads (mode-aware) self.use_categorical = bool(config.use_categorical_rewards) if self.use_categorical: + # Classification head over bins for categorical mode self.reward_head = nn.Linear(config.dim_model, int(config.num_reward_bins)) self.hl_gauss_layer = None + self.scalar_head = None else: - # HL-Gauss expects per-bin logits; head outputs histogram-bin logits - self.reward_head = nn.Linear(config.dim_model, int(config.hl_gauss_num_bins)) + # Feature head produces feature vectors of size dim_model + self.reward_head = nn.Linear(config.dim_model, int(config.dim_model)) + # Optional scalar regression head for fallback when HL-Gauss is unavailable + self.scalar_head = nn.Linear(int(config.dim_model), 1) if HLGaussLayer is not None: + # HL-Gauss consumes D-dimensional features; histogram resolution is configured via num_bins self.hl_gauss_layer = HLGaussLayer( - dim=int(config.hl_gauss_num_bins), + dim=int(config.dim_model), use_regression=not bool(config.use_hl_gauss_loss), hl_gauss_loss=dict( min_value=float(config.reward_min_value), @@ -131,6 +136,13 @@ class RLearNPolicy(PreTrainedPolicy): self.hl_gauss_layer = None self.hl_gauss_use_regression = False + # Public alias used in debug prints + self.hl = self.hl_gauss_layer + + # Training step counter and regularizer warmup + self._step: int = 0 + self.regularizer_warmup_steps: int = int(getattr(config, "regularizer_warmup_steps", 500)) + # Sampling and regularization knobs self.stride = max(1, int(config.inference_stride)) self.frame_dropout_p = float(config.frame_dropout_p) @@ -161,7 +173,7 @@ class RLearNPolicy(PreTrainedPolicy): for name, param in self.named_parameters(): if param.requires_grad: - if "reward_head" in name: + if ("reward_head" in name) or ("scalar_head" in name): head_params.append(param) else: base_params.append(param) @@ -234,16 +246,13 @@ class RLearNPolicy(PreTrainedPolicy): values = (probs * bin_centers).sum(dim=-1) return values # (B, T) else: - # HL-Gauss continuous or regression fallback - head_out = self.reward_head(frame_tokens) # (B, T, Bins) for HL-Gauss or (B,T,*) for regression head - if (self.hl_gauss_layer is not None) and (not getattr(self, "hl_gauss_use_regression", False)): - return self.hl_gauss_layer(head_out) # (B, T) - elif (self.hl_gauss_layer is not None) and getattr(self, "hl_gauss_use_regression", False): - return self.hl_gauss_layer(head_out) # (B, T) + # HL-Gauss with feature head or scalar regression fallback + features = self.reward_head(frame_tokens) # (B, T, D) + if self.hl_gauss_layer is not None: + return self.hl_gauss_layer(features) # (B, T) else: - # Scalar proxy via mean over features, then sigmoid to [0,1] - raw_like_logits = torch.tanh(head_out).mean(dim=-1) # (B, T) - return torch.sigmoid(raw_like_logits) + raw = self.scalar_head(features).squeeze(-1) # (B, T) + return torch.sigmoid(raw) def _encode_video_frames(self, frames: Tensor) -> Tensor: """Encode video frames through SigLIP2 vision tower and return per-frame CLS embeddings. @@ -373,6 +382,11 @@ class RLearNPolicy(PreTrainedPolicy): elif not isinstance(commands, list): commands = [str(commands)] * B + # Verify you’re on the intended path + if self.training and torch.rand(1).item() < 0.01: + D = int(self.config.dim_model) + print(f"[RLearN] mode={'HL-GAUSS' if self.hl is not None else 'REGRESSION'}; D={D}, hist_bins={getattr(self.config,'hl_gauss_num_bins',None)}") + # Process video frames through vision encoder (returns patch tokens) vision_start = time.perf_counter() video_patch_embeds = self._encode_video_frames(frames).to(device) # (B, T_eff, P, D_vision) @@ -431,24 +445,27 @@ class RLearNPolicy(PreTrainedPolicy): raw_like_logits = video_frame_logits.max(dim=-1).values predicted_rewards = torch.softmax(video_frame_logits, dim=-1) else: - # embeddings for HL-Gauss (or regression) - video_frame_embeds = self.reward_head(frame_tokens) # (B,T,Bins) - # derive a scalar proxy for regularizers - raw_like_logits = torch.tanh(video_frame_embeds).mean(dim=-1) + # feature embeddings for HL-Gauss (or scalar regression) + video_frame_features = self.reward_head(frame_tokens) # (B,T,D) + # Use scalar predictions as proxy for regularizers + if self.hl_gauss_layer is not None: + raw_like_logits = self.hl_gauss_layer(video_frame_features) # (B,T) + else: + raw_like_logits = self.scalar_head(video_frame_features).squeeze(-1) # predicted_rewards will be set after loss branch below - # Regularizers use raw_like_logits for generality + # Regularizers with warmup; prevent early collapse and scale hunting var_min = 1e-3 if self.use_categorical: - # use the max-logit trajectory as a proxy pred_proxy = torch.softmax(video_frame_logits, dim=-1).max(dim=-1).values else: - pred_proxy = torch.sigmoid(raw_like_logits) - L_flat = F.relu(var_min - pred_proxy.var(dim=1, unbiased=False)).mean() if pred_proxy.shape[1] > 1 else torch.zeros((), device=device) - rank_margin = 0.02 - if raw_like_logits.shape[1] > 1: + pred_proxy = raw_like_logits.clamp(0.0, 1.0) if self.hl_gauss_layer is not None else torch.sigmoid(raw_like_logits) + if (self.training and self._step >= self.regularizer_warmup_steps) and pred_proxy.shape[1] > 1: + L_flat = F.relu(var_min - pred_proxy.var(dim=1, unbiased=False)).mean() + rank_margin = 0.02 L_rank = F.relu(rank_margin - (raw_like_logits[:, 1:] - raw_like_logits[:, :-1])).mean() else: + L_flat = torch.zeros((), device=device) L_rank = torch.zeros((), device=device) # Generate progress labels on-the-fly (ReWiND approach) @@ -493,15 +510,14 @@ class RLearNPolicy(PreTrainedPolicy): else: # HL-Gauss or regression if (self.hl_gauss_layer is not None) and (not self.hl_gauss_use_regression): - # Ensure targets within configured range t_min = float(self.config.reward_min_value) t_max = float(self.config.reward_max_value) target_clamped = target.clamp(t_min, t_max) - loss = self.hl_gauss_layer(video_frame_embeds, target_clamped, mask=video_mask) + loss = self.hl_gauss_layer(video_frame_features, target_clamped, mask=video_mask) total_loss = loss - predicted_rewards = self.hl_gauss_layer(video_frame_embeds) + predicted_rewards = self.hl_gauss_layer(video_frame_features) elif (self.hl_gauss_layer is not None) and self.hl_gauss_use_regression: - pred_values = self.hl_gauss_layer(video_frame_embeds) # (B,T) + pred_values = self.hl_gauss_layer(video_frame_features) # (B,T) if video_mask is not None: loss = F.smooth_l1_loss(pred_values[video_mask], target[video_mask], beta=0.25) else: @@ -509,13 +525,15 @@ class RLearNPolicy(PreTrainedPolicy): total_loss = loss predicted_rewards = pred_values else: - # fall back to existing logit regression path on a scalar proxy - target_expanded = target - eps = self.config.logit_eps - target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps)) - loss = F.smooth_l1_loss(raw_like_logits, target_logits, beta=0.25) + # Proper scalar regression fallback + pred_values_raw = self.scalar_head(video_frame_features).squeeze(-1) # (B,T) + pred_values = torch.sigmoid(pred_values_raw) + if video_mask is not None: + loss = F.smooth_l1_loss(pred_values[video_mask], target[video_mask], beta=0.25) + else: + loss = F.smooth_l1_loss(pred_values, target, beta=0.25) total_loss = loss - predicted_rewards = torch.sigmoid(raw_like_logits) + predicted_rewards = pred_values # Mismatched video-language pairs loss (only when languages actually differ) @@ -554,7 +572,14 @@ class RLearNPolicy(PreTrainedPolicy): # Process mismatch frames with single MLP mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D) - mismatch_raw_logits = self.reward_head(mismatch_tokens).squeeze(-1) + if self.use_categorical: + mismatch_raw_logits = self.reward_head(mismatch_tokens).max(dim=-1).values + else: + mismatch_features = self.reward_head(mismatch_tokens) + if self.hl_gauss_layer is not None: + mismatch_raw_logits = self.hl_gauss_layer(mismatch_features) + else: + mismatch_raw_logits = self.scalar_head(mismatch_features).squeeze(-1) mismatch_tensor = torch.tensor(mismatch_mask, device=device, dtype=torch.bool) if mismatch_tensor.any(): @@ -566,7 +591,7 @@ class RLearNPolicy(PreTrainedPolicy): L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean() # Total loss - total_loss = total_loss + L_mismatch + total_loss = total_loss + L_mismatch + L_flat + L_rank loss_time = time.perf_counter() - loss_start # DEBUG: Clean logit regression monitoring with full array printing @@ -707,6 +732,10 @@ class RLearNPolicy(PreTrainedPolicy): stats[key] = [] stats['last_print_time'] = current_time + # Step counter for warmup scheduling + if self.training: + self._step += 1 + return total_loss, loss_dict def _encode_language_tokens(self, commands: list[str], device: torch.device) -> tuple[Tensor, Tensor]: