mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-23 03:07:16 +00:00
fix
This commit is contained in:
@@ -72,6 +72,7 @@ class RLearNConfig(PreTrainedConfig):
|
||||
weight_decay: float = 0.01
|
||||
head_lr_multiplier: float = 5.0
|
||||
logit_eps: float = 1e-4
|
||||
regularizer_warmup_steps: int = 500
|
||||
|
||||
# Performance optimizations
|
||||
use_amp: bool = False
|
||||
|
||||
@@ -111,14 +111,19 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# Reward heads (mode-aware)
|
||||
self.use_categorical = bool(config.use_categorical_rewards)
|
||||
if self.use_categorical:
|
||||
# Classification head over bins for categorical mode
|
||||
self.reward_head = nn.Linear(config.dim_model, int(config.num_reward_bins))
|
||||
self.hl_gauss_layer = None
|
||||
self.scalar_head = None
|
||||
else:
|
||||
# HL-Gauss expects per-bin logits; head outputs histogram-bin logits
|
||||
self.reward_head = nn.Linear(config.dim_model, int(config.hl_gauss_num_bins))
|
||||
# Feature head produces feature vectors of size dim_model
|
||||
self.reward_head = nn.Linear(config.dim_model, int(config.dim_model))
|
||||
# Optional scalar regression head for fallback when HL-Gauss is unavailable
|
||||
self.scalar_head = nn.Linear(int(config.dim_model), 1)
|
||||
if HLGaussLayer is not None:
|
||||
# HL-Gauss consumes D-dimensional features; histogram resolution is configured via num_bins
|
||||
self.hl_gauss_layer = HLGaussLayer(
|
||||
dim=int(config.hl_gauss_num_bins),
|
||||
dim=int(config.dim_model),
|
||||
use_regression=not bool(config.use_hl_gauss_loss),
|
||||
hl_gauss_loss=dict(
|
||||
min_value=float(config.reward_min_value),
|
||||
@@ -131,6 +136,13 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
self.hl_gauss_layer = None
|
||||
self.hl_gauss_use_regression = False
|
||||
|
||||
# Public alias used in debug prints
|
||||
self.hl = self.hl_gauss_layer
|
||||
|
||||
# Training step counter and regularizer warmup
|
||||
self._step: int = 0
|
||||
self.regularizer_warmup_steps: int = int(getattr(config, "regularizer_warmup_steps", 500))
|
||||
|
||||
# Sampling and regularization knobs
|
||||
self.stride = max(1, int(config.inference_stride))
|
||||
self.frame_dropout_p = float(config.frame_dropout_p)
|
||||
@@ -161,7 +173,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if param.requires_grad:
|
||||
if "reward_head" in name:
|
||||
if ("reward_head" in name) or ("scalar_head" in name):
|
||||
head_params.append(param)
|
||||
else:
|
||||
base_params.append(param)
|
||||
@@ -234,16 +246,13 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
values = (probs * bin_centers).sum(dim=-1)
|
||||
return values # (B, T)
|
||||
else:
|
||||
# HL-Gauss continuous or regression fallback
|
||||
head_out = self.reward_head(frame_tokens) # (B, T, Bins) for HL-Gauss or (B,T,*) for regression head
|
||||
if (self.hl_gauss_layer is not None) and (not getattr(self, "hl_gauss_use_regression", False)):
|
||||
return self.hl_gauss_layer(head_out) # (B, T)
|
||||
elif (self.hl_gauss_layer is not None) and getattr(self, "hl_gauss_use_regression", False):
|
||||
return self.hl_gauss_layer(head_out) # (B, T)
|
||||
# HL-Gauss with feature head or scalar regression fallback
|
||||
features = self.reward_head(frame_tokens) # (B, T, D)
|
||||
if self.hl_gauss_layer is not None:
|
||||
return self.hl_gauss_layer(features) # (B, T)
|
||||
else:
|
||||
# Scalar proxy via mean over features, then sigmoid to [0,1]
|
||||
raw_like_logits = torch.tanh(head_out).mean(dim=-1) # (B, T)
|
||||
return torch.sigmoid(raw_like_logits)
|
||||
raw = self.scalar_head(features).squeeze(-1) # (B, T)
|
||||
return torch.sigmoid(raw)
|
||||
|
||||
def _encode_video_frames(self, frames: Tensor) -> Tensor:
|
||||
"""Encode video frames through SigLIP2 vision tower and return per-frame CLS embeddings.
|
||||
@@ -373,6 +382,11 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
elif not isinstance(commands, list):
|
||||
commands = [str(commands)] * B
|
||||
|
||||
# Verify you’re on the intended path
|
||||
if self.training and torch.rand(1).item() < 0.01:
|
||||
D = int(self.config.dim_model)
|
||||
print(f"[RLearN] mode={'HL-GAUSS' if self.hl is not None else 'REGRESSION'}; D={D}, hist_bins={getattr(self.config,'hl_gauss_num_bins',None)}")
|
||||
|
||||
# Process video frames through vision encoder (returns patch tokens)
|
||||
vision_start = time.perf_counter()
|
||||
video_patch_embeds = self._encode_video_frames(frames).to(device) # (B, T_eff, P, D_vision)
|
||||
@@ -431,24 +445,27 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
raw_like_logits = video_frame_logits.max(dim=-1).values
|
||||
predicted_rewards = torch.softmax(video_frame_logits, dim=-1)
|
||||
else:
|
||||
# embeddings for HL-Gauss (or regression)
|
||||
video_frame_embeds = self.reward_head(frame_tokens) # (B,T,Bins)
|
||||
# derive a scalar proxy for regularizers
|
||||
raw_like_logits = torch.tanh(video_frame_embeds).mean(dim=-1)
|
||||
# feature embeddings for HL-Gauss (or scalar regression)
|
||||
video_frame_features = self.reward_head(frame_tokens) # (B,T,D)
|
||||
# Use scalar predictions as proxy for regularizers
|
||||
if self.hl_gauss_layer is not None:
|
||||
raw_like_logits = self.hl_gauss_layer(video_frame_features) # (B,T)
|
||||
else:
|
||||
raw_like_logits = self.scalar_head(video_frame_features).squeeze(-1)
|
||||
# predicted_rewards will be set after loss branch below
|
||||
|
||||
# Regularizers use raw_like_logits for generality
|
||||
# Regularizers with warmup; prevent early collapse and scale hunting
|
||||
var_min = 1e-3
|
||||
if self.use_categorical:
|
||||
# use the max-logit trajectory as a proxy
|
||||
pred_proxy = torch.softmax(video_frame_logits, dim=-1).max(dim=-1).values
|
||||
else:
|
||||
pred_proxy = torch.sigmoid(raw_like_logits)
|
||||
L_flat = F.relu(var_min - pred_proxy.var(dim=1, unbiased=False)).mean() if pred_proxy.shape[1] > 1 else torch.zeros((), device=device)
|
||||
rank_margin = 0.02
|
||||
if raw_like_logits.shape[1] > 1:
|
||||
pred_proxy = raw_like_logits.clamp(0.0, 1.0) if self.hl_gauss_layer is not None else torch.sigmoid(raw_like_logits)
|
||||
if (self.training and self._step >= self.regularizer_warmup_steps) and pred_proxy.shape[1] > 1:
|
||||
L_flat = F.relu(var_min - pred_proxy.var(dim=1, unbiased=False)).mean()
|
||||
rank_margin = 0.02
|
||||
L_rank = F.relu(rank_margin - (raw_like_logits[:, 1:] - raw_like_logits[:, :-1])).mean()
|
||||
else:
|
||||
L_flat = torch.zeros((), device=device)
|
||||
L_rank = torch.zeros((), device=device)
|
||||
|
||||
# Generate progress labels on-the-fly (ReWiND approach)
|
||||
@@ -493,15 +510,14 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
else:
|
||||
# HL-Gauss or regression
|
||||
if (self.hl_gauss_layer is not None) and (not self.hl_gauss_use_regression):
|
||||
# Ensure targets within configured range
|
||||
t_min = float(self.config.reward_min_value)
|
||||
t_max = float(self.config.reward_max_value)
|
||||
target_clamped = target.clamp(t_min, t_max)
|
||||
loss = self.hl_gauss_layer(video_frame_embeds, target_clamped, mask=video_mask)
|
||||
loss = self.hl_gauss_layer(video_frame_features, target_clamped, mask=video_mask)
|
||||
total_loss = loss
|
||||
predicted_rewards = self.hl_gauss_layer(video_frame_embeds)
|
||||
predicted_rewards = self.hl_gauss_layer(video_frame_features)
|
||||
elif (self.hl_gauss_layer is not None) and self.hl_gauss_use_regression:
|
||||
pred_values = self.hl_gauss_layer(video_frame_embeds) # (B,T)
|
||||
pred_values = self.hl_gauss_layer(video_frame_features) # (B,T)
|
||||
if video_mask is not None:
|
||||
loss = F.smooth_l1_loss(pred_values[video_mask], target[video_mask], beta=0.25)
|
||||
else:
|
||||
@@ -509,13 +525,15 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
total_loss = loss
|
||||
predicted_rewards = pred_values
|
||||
else:
|
||||
# fall back to existing logit regression path on a scalar proxy
|
||||
target_expanded = target
|
||||
eps = self.config.logit_eps
|
||||
target_logits = torch.logit(target_expanded.clamp(eps, 1 - eps))
|
||||
loss = F.smooth_l1_loss(raw_like_logits, target_logits, beta=0.25)
|
||||
# Proper scalar regression fallback
|
||||
pred_values_raw = self.scalar_head(video_frame_features).squeeze(-1) # (B,T)
|
||||
pred_values = torch.sigmoid(pred_values_raw)
|
||||
if video_mask is not None:
|
||||
loss = F.smooth_l1_loss(pred_values[video_mask], target[video_mask], beta=0.25)
|
||||
else:
|
||||
loss = F.smooth_l1_loss(pred_values, target, beta=0.25)
|
||||
total_loss = loss
|
||||
predicted_rewards = torch.sigmoid(raw_like_logits)
|
||||
predicted_rewards = pred_values
|
||||
|
||||
|
||||
# Mismatched video-language pairs loss (only when languages actually differ)
|
||||
@@ -554,7 +572,14 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# Process mismatch frames with single MLP
|
||||
mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D)
|
||||
mismatch_raw_logits = self.reward_head(mismatch_tokens).squeeze(-1)
|
||||
if self.use_categorical:
|
||||
mismatch_raw_logits = self.reward_head(mismatch_tokens).max(dim=-1).values
|
||||
else:
|
||||
mismatch_features = self.reward_head(mismatch_tokens)
|
||||
if self.hl_gauss_layer is not None:
|
||||
mismatch_raw_logits = self.hl_gauss_layer(mismatch_features)
|
||||
else:
|
||||
mismatch_raw_logits = self.scalar_head(mismatch_features).squeeze(-1)
|
||||
|
||||
mismatch_tensor = torch.tensor(mismatch_mask, device=device, dtype=torch.bool)
|
||||
if mismatch_tensor.any():
|
||||
@@ -566,7 +591,7 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
L_mismatch = mismatch_loss_per_sample[mismatch_tensor].mean()
|
||||
|
||||
# Total loss
|
||||
total_loss = total_loss + L_mismatch
|
||||
total_loss = total_loss + L_mismatch + L_flat + L_rank
|
||||
loss_time = time.perf_counter() - loss_start
|
||||
|
||||
# DEBUG: Clean logit regression monitoring with full array printing
|
||||
@@ -707,6 +732,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
stats[key] = []
|
||||
stats['last_print_time'] = current_time
|
||||
|
||||
# Step counter for warmup scheduling
|
||||
if self.training:
|
||||
self._step += 1
|
||||
|
||||
return total_loss, loss_dict
|
||||
|
||||
def _encode_language_tokens(self, commands: list[str], device: torch.device) -> tuple[Tensor, Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user