mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
hl-gauss
This commit is contained in:
@@ -135,8 +135,10 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
num_bins=int(config.hl_gauss_num_bins),
|
num_bins=int(config.hl_gauss_num_bins),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
self.hl_gauss_use_regression = not bool(config.use_hl_gauss_loss)
|
||||||
else:
|
else:
|
||||||
self.hl_gauss_layer = None
|
self.hl_gauss_layer = None
|
||||||
|
self.hl_gauss_use_regression = False
|
||||||
|
|
||||||
# Sampling and regularization knobs
|
# Sampling and regularization knobs
|
||||||
self.stride = max(1, int(config.inference_stride))
|
self.stride = max(1, int(config.inference_stride))
|
||||||
@@ -542,11 +544,11 @@ class RLearNPolicy(PreTrainedPolicy):
|
|||||||
predicted_rewards = torch.softmax(video_frame_logits, dim=-1)
|
predicted_rewards = torch.softmax(video_frame_logits, dim=-1)
|
||||||
else:
|
else:
|
||||||
# HL-Gauss or regression
|
# HL-Gauss or regression
|
||||||
if (self.hl_gauss_layer is not None) and (not self.hl_gauss_layer.use_regression):
|
if (self.hl_gauss_layer is not None) and (not self.hl_gauss_use_regression):
|
||||||
loss = self.hl_gauss_layer(video_frame_embeds, target, mask=video_mask)
|
loss = self.hl_gauss_layer(video_frame_embeds, target, mask=video_mask)
|
||||||
total_loss = loss
|
total_loss = loss
|
||||||
predicted_rewards = self.hl_gauss_layer(video_frame_embeds).detach()
|
predicted_rewards = self.hl_gauss_layer(video_frame_embeds).detach()
|
||||||
elif (self.hl_gauss_layer is not None) and self.hl_gauss_layer.use_regression:
|
elif (self.hl_gauss_layer is not None) and self.hl_gauss_use_regression:
|
||||||
pred_values = self.hl_gauss_layer(video_frame_embeds) # (B,T)
|
pred_values = self.hl_gauss_layer(video_frame_embeds) # (B,T)
|
||||||
if video_mask is not None:
|
if video_mask is not None:
|
||||||
loss = F.smooth_l1_loss(pred_values[video_mask], target[video_mask], beta=0.25)
|
loss = F.smooth_l1_loss(pred_values[video_mask], target[video_mask], beta=0.25)
|
||||||
|
|||||||
Reference in New Issue
Block a user