This commit is contained in:
Pepijn
2025-09-01 15:41:24 +02:00
parent cf9796b2f7
commit 4e671ef080
2 changed files with 65 additions and 35 deletions
@@ -72,6 +72,7 @@ class RLearNConfig(PreTrainedConfig):
weight_decay: float = 0.01
head_lr_multiplier: float = 5.0
logit_eps: float = 1e-4
regularizer_warmup_steps: int = 500
# Performance optimizations
use_amp: bool = False
+64 -35
View File
@@ -111,14 +111,19 @@ class RLearNPolicy(PreTrainedPolicy):
# Reward heads (mode-aware)
self.use_categorical = bool(config.use_categorical_rewards)
if self.use_categorical:
# Classification head over bins for categorical mode
self.reward_head = nn.Linear(config.dim_model, int(config.num_reward_bins))
self.hl_gauss_layer = None
self.scalar_head = None
else:
# HL-Gauss expects per-bin logits; head outputs histogram-bin logits
self.reward_head = nn.Linear(config.dim_model, int(config.hl_gauss_num_bins))
# Feature head produces feature vectors of size dim_model
self.reward_head = nn.Linear(config.dim_model, int(config.dim_model))
# Optional scalar regression head for fallback when HL-Gauss is unavailable
self.scalar_head = nn.Linear(int(config.dim_model), 1)
if HLGaussLayer is not None:
# HL-Gauss consumes D-dimensional features; histogram resolution is configured via num_bins
self.hl_gauss_layer = HLGaussLayer(
dim=int(config.hl_gauss_num_bins),
dim=int(config.dim_model),
use_regression=not bool(config.use_hl_gauss_loss),
hl_gauss_loss=dict(
min_value=float(config.reward_min_value),
@@ -131,6 +136,13 @@ class RLearNPolicy(PreTrainedPolicy):
self.hl_gauss_layer = None
self.hl_gauss_use_regression = False
# Public alias used in debug prints
self.hl = self.hl_gauss_layer
# Training step counter and regularizer warmup
self._step: int = 0
self.regularizer_warmup_steps: int = int(getattr(config, "regularizer_warmup_steps", 500))
# Sampling and regularization knobs
self.stride = max(1, int(config.inference_stride))
self.frame_dropout_p = float(config.frame_dropout_p)
@@ -161,7 +173,7 @@ class RLearNPolicy(PreTrainedPolicy):
for name, param in self.named_parameters():
if param.requires_grad:
if "reward_head" in name:
if ("reward_head" in name) or ("scalar_head" in name):
head_params.append(param)
else:
base_params.append(param)
@@ -234,16 +246,13 @@ class RLearNPolicy(PreTrainedPolicy):
values = (probs * bin_centers).sum(dim=-1)
return values # (B, T)
else:
# HL-Gauss continuous or regression fallback
head_out = self.reward_head(frame_tokens) # (B, T, Bins) for HL-Gauss or (B,T,*) for regression head
if (self.hl_gauss_layer is not None) and (not getattr(self, "hl_gauss_use_regression", False)):
return self.hl_gauss_layer(head_out) # (B, T)
elif (self.hl_gauss_layer is not None) and getattr(self, "hl_gauss_use_regression", False):
return self.hl_gauss_layer(head_out) # (B, T)
# HL-Gauss with feature head or scalar regression fallback
features = self.reward_head(frame_tokens) # (B, T, D)
if self.hl_gauss_layer is not None:
return self.hl_gauss_layer(features) # (B, T)
else:
# Scalar proxy via mean over features, then sigmoid to [0,1]
raw_like_logits = torch.tanh(head_out).mean(dim=-1) # (B, T)
return torch.sigmoid(raw_like_logits)
raw = self.scalar_head(features).squeeze(-1) # (B, T)
return torch.sigmoid(raw)
def _encode_video_frames(self, frames: Tensor) -> Tensor:
"""Encode video frames through SigLIP2 vision tower and return per-frame CLS embeddings.
@@ -373,6 +382,11 @@ class RLearNPolicy(PreTrainedPolicy):
elif not isinstance(commands, list):
commands = [str(commands)] * B
# Verify youre on the intended path
if self.training and torch.rand(1).item() < 0.01:
D = int(self.config.dim_model)
print(f"[RLearN] mode={'HL-GAUSS' if self.hl is not None else 'REGRESSION'}; D={D}, hist_bins={getattr(self.config,'hl_gauss_num_bins',None)}")
# Process video frames through vision encoder (returns patch tokens)
vision_start = time.perf_counter()
video_patch_embeds = self._encode_video_frames(frames).to(device) # (B, T_eff, P, D_vision)
@@ -431,24 +445,27 @@ class RLearNPolicy(PreTrainedPolicy):
raw_like_logits = video_frame_logits.max(dim=-1).values
predicted_rewards = torch.softmax(video_frame_logits, dim=-1)
else:
# embeddings for HL-Gauss (or regression)
video_frame_embeds = self.reward_head(frame_tokens) # (B,T,Bins)
# derive a scalar proxy for regularizers
raw_like_logits = torch.tanh(video_frame_embeds).mean(dim=-1)
# feature embeddings for HL-Gauss (or scalar regression)
video_frame_features = self.reward_head(frame_tokens) # (B,T,D)
# Use scalar predictions as proxy for regularizers
if self.hl_gauss_layer is not None:
raw_like_logits = self.hl_gauss_layer(video_frame_features) # (B,T)
else:
raw_like_logits = self.scalar_head(video_frame_features).squeeze(-1)
# predicted_rewards will be set after loss branch below
# Regularizers use raw_like_logits for generality
# Regularizers with warmup; prevent early collapse and scale hunting
var_min = 1e-3
if self.use_categorical:
# use the max-logit trajectory as a proxy
pred_proxy = torch.softmax(video_frame_logits, dim=-1).max(dim=-1).values
else:
pred_proxy = torch.sigmoid(raw_like_logits)
L_flat = F.relu(var_min - pred_proxy.var(dim=1, unbiased=False)).mean() if pred_proxy.shape[1] > 1 else torch.zeros((), device=device)
rank_margin = 0.02
if raw_like_logits.shape[1] > 1:
pred_proxy = raw_like_logits.clamp(0.0, 1.0) if self.hl_gauss_layer is not None else torch.sigmoid(raw_like_logits)
if (self.training and self._step >= self.regularizer_warmup_steps) and pred_proxy.shape[1] > 1:
L_flat = F.relu(var_min - pred_proxy.var(dim=1, unbiased=False)).mean()
rank_margin = 0.02
L_rank = F.relu(rank_margin - (raw_like_logits[:, 1:] - raw_like_logits[:, :-1])).mean()
else:
L_flat = torch.zeros((), device=device)
L_rank = torch.zeros((), device=device)
# Generate progress labels on-the-fly (ReWiND approach)
@@ -493,15 +510,14 @@ class RLearNPolicy(PreTrainedPolicy):
else:
# HL-Gauss or regression
if (self.hl_gauss_layer is not None) and (not self.hl_gauss_use_regression):
# Ensure targets within configured range
t_min = float(self.config.reward_min_value)
t_max = float(self.config.reward_max_value)
target_clamped = target.clamp(t_min, t_max)
loss = self.hl_gauss_layer(video_frame_embeds, target_clamped, mask=video_mask)
loss = self.hl_gauss_layer(video_frame_features, target_clamped, mask=video_mask)
total_loss = loss
predicted_rewards = self.hl_gauss_layer(video_frame_embeds)
predicted_rewards = self.hl_gauss_layer(video_frame_features)
elif (self.hl_gauss_layer is not None) and self.hl_gauss_use_regression:
pred_values = self.hl_gauss_layer(video_frame_embeds) # (B,T)
pred_values = self.hl_gauss_layer(video_frame_features) # (B,T)
if video_mask is not None:
loss = F.smooth_l1_loss(pred_values[video_mask], target[video_mask], beta=0.25)
else:
@@ -509,13 +525,15 @@ class RLearNPolicy(PreTrainedPolicy):
total_loss = loss
predicted_rewards = pred_values
else:
# fall back to existing logit regression path on a scalar proxy
target_expanded = target
eps = self.config.logit_eps
target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps))
loss = F.smooth_l1_loss(raw_like_logits, target_logits, beta=0.25)
# Proper scalar regression fallback
pred_values_raw = self.scalar_head(video_frame_features).squeeze(-1) # (B,T)
pred_values = torch.sigmoid(pred_values_raw)
if video_mask is not None:
loss = F.smooth_l1_loss(pred_values[video_mask], target[video_mask], beta=0.25)
else:
loss = F.smooth_l1_loss(pred_values, target, beta=0.25)
total_loss = loss
predicted_rewards = torch.sigmoid(raw_like_logits)
predicted_rewards = pred_values
# Mismatched video-language pairs loss (only when languages actually differ)
@@ -554,7 +572,14 @@ class RLearNPolicy(PreTrainedPolicy):
# Process mismatch frames with single MLP
mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D)
mismatch_raw_logits = self.reward_head(mismatch_tokens).squeeze(-1)
if self.use_categorical:
mismatch_raw_logits = self.reward_head(mismatch_tokens).max(dim=-1).values
else:
mismatch_features = self.reward_head(mismatch_tokens)
if self.hl_gauss_layer is not None:
mismatch_raw_logits = self.hl_gauss_layer(mismatch_features)
else:
mismatch_raw_logits = self.scalar_head(mismatch_features).squeeze(-1)
mismatch_tensor = torch.tensor(mismatch_mask, device=device, dtype=torch.bool)
if mismatch_tensor.any():
@@ -566,7 +591,7 @@ class RLearNPolicy(PreTrainedPolicy):
L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean()
# Total loss
total_loss = total_loss + L_mismatch
total_loss = total_loss + L_mismatch + L_flat + L_rank
loss_time = time.perf_counter() - loss_start
# DEBUG: Clean logit regression monitoring with full array printing
@@ -707,6 +732,10 @@ class RLearNPolicy(PreTrainedPolicy):
stats[key] = []
stats['last_print_time'] = current_time
# Step counter for warmup scheduling
if self.training:
self._step += 1
return total_loss, loss_dict
def _encode_language_tokens(self, commands: list[str], device: torch.device) -> tuple[Tensor, Tensor]: