This commit is contained in:
Pepijn
2025-09-01 13:11:53 +02:00
parent 4f51f7153c
commit da861139a3
@@ -135,8 +135,10 @@ class RLearNPolicy(PreTrainedPolicy):
num_bins=int(config.hl_gauss_num_bins),
),
)
self.hl_gauss_use_regression = not bool(config.use_hl_gauss_loss)
else:
self.hl_gauss_layer = None
self.hl_gauss_use_regression = False
# Sampling and regularization knobs
self.stride = max(1, int(config.inference_stride))
@@ -542,11 +544,11 @@ class RLearNPolicy(PreTrainedPolicy):
predicted_rewards = torch.softmax(video_frame_logits, dim=-1)
else:
# HL-Gauss or regression
if (self.hl_gauss_layer is not None) and (not self.hl_gauss_layer.use_regression):
if (self.hl_gauss_layer is not None) and (not self.hl_gauss_use_regression):
loss = self.hl_gauss_layer(video_frame_embeds, target, mask=video_mask)
total_loss = loss
predicted_rewards = self.hl_gauss_layer(video_frame_embeds).detach()
elif (self.hl_gauss_layer is not None) and self.hl_gauss_layer.use_regression:
elif (self.hl_gauss_layer is not None) and self.hl_gauss_use_regression:
pred_values = self.hl_gauss_layer(video_frame_embeds) # (B,T)
if video_mask is not None:
loss = F.smooth_l1_loss(pred_values[video_mask], target[video_mask], beta=0.25)