This commit is contained in:
Pepijn
2025-09-01 14:57:24 +02:00
parent 88116b11e1
commit cf9796b2f7
@@ -182,6 +182,69 @@ class RLearNPolicy(PreTrainedPolicy):
def select_action(self, batch: dict[str, Tensor]) -> Tensor: # Required by base class
raise NotImplementedError("RLearN is a reward model and does not select actions")
@torch.no_grad()
def predict_rewards(self, batch: dict[str, Tensor]) -> Tensor:
"""Inference helper used by eval_script: returns (B, T) per-frame rewards.
Expects batch with OBS_IMAGES shaped (B, T, C, H, W) or list of length T with (B, C, H, W),
and optional OBS_LANGUAGE as a list[str] or single string.
"""
device = next(self.parameters()).device
# Prepare frames and language
frames = extract_visual_sequence(batch, target_seq_len=None).to(device) # (B, T, C, H, W)
B, T, C, H, W = frames.shape
commands = batch.get(OBS_LANGUAGE, None)
if commands is None:
commands = [""] * B
elif not isinstance(commands, list):
commands = [str(commands)] * B
# Encode
video_frame_embeds = self._encode_video_frames(frames) # (B, T, D_v)
lang_embeds, mask = self._encode_language_tokens(commands, device) # (B, L, D_l), (B, L)
# Project and add first-frame bias
video_tokens = self.to_video_tokens(video_frame_embeds) # (B, T, D)
video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos
lang_tokens = self.to_lang_tokens(lang_embeds) # (B, L, D)
# Build masks and run transformer
lang_valid = mask
video_valid = torch.ones(B, T, device=device, dtype=torch.bool)
key_padding_mask = ~(torch.cat([lang_valid, video_valid], dim=1))
tokens_seq = torch.cat([lang_tokens, video_tokens], dim=1) # (B, S, D)
S = tokens_seq.shape[1]
causal_mask = torch.triu(torch.ones(S, S, device=device, dtype=torch.bool), diagonal=1)
attended_all = self.aggregator(tokens_seq, src_key_padding_mask=key_padding_mask, mask=causal_mask)
L_len = lang_tokens.shape[1]
attended_video = attended_all[:, L_len:, :] # (B, T, D)
frame_tokens = self.frame_mlp(attended_video) # (B, T, D)
if self.use_categorical:
logits = self.reward_head(frame_tokens) # (B, T, Bins)
probs = torch.softmax(logits, dim=-1)
# Expected value over bin centers in [reward_min_value, reward_max_value]
bin_centers = torch.linspace(
float(self.config.reward_min_value), float(self.config.reward_max_value), int(self.config.num_reward_bins), device=device
)
values = (probs * bin_centers).sum(dim=-1)
return values # (B, T)
else:
# HL-Gauss continuous or regression fallback
head_out = self.reward_head(frame_tokens) # (B, T, Bins) for HL-Gauss or (B,T,*) for regression head
if (self.hl_gauss_layer is not None) and (not getattr(self, "hl_gauss_use_regression", False)):
return self.hl_gauss_layer(head_out) # (B, T)
elif (self.hl_gauss_layer is not None) and getattr(self, "hl_gauss_use_regression", False):
return self.hl_gauss_layer(head_out) # (B, T)
else:
# Scalar proxy via mean over features, then sigmoid to [0,1]
raw_like_logits = torch.tanh(head_out).mean(dim=-1) # (B, T)
return torch.sigmoid(raw_like_logits)
def _encode_video_frames(self, frames: Tensor) -> Tensor:
"""Encode video frames through SigLIP2 vision tower and return per-frame CLS embeddings.