From da861139a301ab59420c1178913f13390d9faeff Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 1 Sep 2025 13:11:53 +0200 Subject: [PATCH] hl-gauss --- src/lerobot/policies/rlearn/modeling_rlearn.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 819be8832..f1a8dc4dd 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -135,8 +135,10 @@ class RLearNPolicy(PreTrainedPolicy): num_bins=int(config.hl_gauss_num_bins), ), ) + self.hl_gauss_use_regression = not bool(config.use_hl_gauss_loss) else: self.hl_gauss_layer = None + self.hl_gauss_use_regression = False # Sampling and regularization knobs self.stride = max(1, int(config.inference_stride)) @@ -542,11 +544,11 @@ class RLearNPolicy(PreTrainedPolicy): predicted_rewards = torch.softmax(video_frame_logits, dim=-1) else: # HL-Gauss or regression - if (self.hl_gauss_layer is not None) and (not self.hl_gauss_layer.use_regression): + if (self.hl_gauss_layer is not None) and (not self.hl_gauss_use_regression): loss = self.hl_gauss_layer(video_frame_embeds, target, mask=video_mask) total_loss = loss predicted_rewards = self.hl_gauss_layer(video_frame_embeds).detach() - elif (self.hl_gauss_layer is not None) and self.hl_gauss_layer.use_regression: + elif (self.hl_gauss_layer is not None) and self.hl_gauss_use_regression: pred_values = self.hl_gauss_layer(video_frame_embeds) # (B,T) if video_mask is not None: loss = F.smooth_l1_loss(pred_values[video_mask], target[video_mask], beta=0.25)