mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 08:09:45 +00:00
hl-gauss
This commit is contained in:
@@ -135,8 +135,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
num_bins=int(config.hl_gauss_num_bins),
|
||||
),
|
||||
)
|
||||
self.hl_gauss_use_regression = not bool(config.use_hl_gauss_loss)
|
||||
else:
|
||||
self.hl_gauss_layer = None
|
||||
self.hl_gauss_use_regression = False
|
||||
|
||||
# Sampling and regularization knobs
|
||||
self.stride = max(1, int(config.inference_stride))
|
||||
@@ -542,11 +544,11 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
predicted_rewards = torch.softmax(video_frame_logits, dim=-1)
|
||||
else:
|
||||
# HL-Gauss or regression
|
||||
if (self.hl_gauss_layer is not None) and (not self.hl_gauss_layer.use_regression):
|
||||
if (self.hl_gauss_layer is not None) and (not self.hl_gauss_use_regression):
|
||||
loss = self.hl_gauss_layer(video_frame_embeds, target, mask=video_mask)
|
||||
total_loss = loss
|
||||
predicted_rewards = self.hl_gauss_layer(video_frame_embeds).detach()
|
||||
elif (self.hl_gauss_layer is not None) and self.hl_gauss_layer.use_regression:
|
||||
elif (self.hl_gauss_layer is not None) and self.hl_gauss_use_regression:
|
||||
pred_values = self.hl_gauss_layer(video_frame_embeds) # (B,T)
|
||||
if video_mask is not None:
|
||||
loss = F.smooth_l1_loss(pred_values[video_mask], target[video_mask], beta=0.25)
|
||||
|
||||
Reference in New Issue
Block a user