From cf9796b2f74e70a478639a99aa61695186ca871a Mon Sep 17 00:00:00 2001 From: Pepijn Date: Mon, 1 Sep 2025 14:57:24 +0200 Subject: [PATCH] fix eval --- .../policies/rlearn/modeling_rlearn.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index bf567a625..bd7a06b9a 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -182,6 +182,69 @@ class RLearNPolicy(PreTrainedPolicy): def select_action(self, batch: dict[str, Tensor]) -> Tensor: # Required by base class raise NotImplementedError("RLearN is a reward model and does not select actions") + @torch.no_grad() + def predict_rewards(self, batch: dict[str, Tensor]) -> Tensor: + """Inference helper used by eval_script: returns (B, T) per-frame rewards. + + Expects batch with OBS_IMAGES shaped (B, T, C, H, W) or list of length T with (B, C, H, W), + and optional OBS_LANGUAGE as a list[str] or single string. + """ + device = next(self.parameters()).device + + # Prepare frames and language + frames = extract_visual_sequence(batch, target_seq_len=None).to(device) # (B, T, C, H, W) + B, T, C, H, W = frames.shape + + commands = batch.get(OBS_LANGUAGE, None) + if commands is None: + commands = [""] * B + elif not isinstance(commands, list): + commands = [str(commands)] * B + + # Encode + video_frame_embeds = self._encode_video_frames(frames) # (B, T, D_v) + lang_embeds, mask = self._encode_language_tokens(commands, device) # (B, L, D_l), (B, L) + + # Project and add first-frame bias + video_tokens = self.to_video_tokens(video_frame_embeds) # (B, T, D) + video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos + lang_tokens = self.to_lang_tokens(lang_embeds) # (B, L, D) + + # Build masks and run transformer + lang_valid = mask + video_valid = torch.ones(B, T, device=device, dtype=torch.bool) + key_padding_mask = ~(torch.cat([lang_valid, video_valid], dim=1)) + tokens_seq = torch.cat([lang_tokens, video_tokens], dim=1) # (B, S, D) + S = tokens_seq.shape[1] + causal_mask = torch.triu(torch.ones(S, S, device=device, dtype=torch.bool), diagonal=1) + + attended_all = self.aggregator(tokens_seq, src_key_padding_mask=key_padding_mask, mask=causal_mask) + L_len = lang_tokens.shape[1] + attended_video = attended_all[:, L_len:, :] # (B, T, D) + + frame_tokens = self.frame_mlp(attended_video) # (B, T, D) + + if self.use_categorical: + logits = self.reward_head(frame_tokens) # (B, T, Bins) + probs = torch.softmax(logits, dim=-1) + # Expected value over bin centers in [reward_min_value, reward_max_value] + bin_centers = torch.linspace( + float(self.config.reward_min_value), float(self.config.reward_max_value), int(self.config.num_reward_bins), device=device + ) + values = (probs * bin_centers).sum(dim=-1) + return values # (B, T) + else: + # HL-Gauss continuous or regression fallback + head_out = self.reward_head(frame_tokens) # (B, T, Bins) for HL-Gauss or (B,T,*) for regression head + if (self.hl_gauss_layer is not None) and (not getattr(self, "hl_gauss_use_regression", False)): + return self.hl_gauss_layer(head_out) # (B, T) + elif (self.hl_gauss_layer is not None) and getattr(self, "hl_gauss_use_regression", False): + return self.hl_gauss_layer(head_out) # (B, T) + else: + # Scalar proxy via mean over features, then sigmoid to [0,1] + raw_like_logits = torch.tanh(head_out).mean(dim=-1) # (B, T) + return torch.sigmoid(raw_like_logits) + def _encode_video_frames(self, frames: Tensor) -> Tensor: """Encode video frames through SigLIP2 vision tower and return per-frame CLS embeddings.