mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
fix eval
This commit is contained in:
@@ -182,6 +182,69 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor: # Required by base class
|
||||
raise NotImplementedError("RLearN is a reward model and does not select actions")
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_rewards(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Inference helper used by eval_script: returns (B, T) per-frame rewards.
|
||||
|
||||
Expects batch with OBS_IMAGES shaped (B, T, C, H, W) or list of length T with (B, C, H, W),
|
||||
and optional OBS_LANGUAGE as a list[str] or single string.
|
||||
"""
|
||||
device = next(self.parameters()).device
|
||||
|
||||
# Prepare frames and language
|
||||
frames = extract_visual_sequence(batch, target_seq_len=None).to(device) # (B, T, C, H, W)
|
||||
B, T, C, H, W = frames.shape
|
||||
|
||||
commands = batch.get(OBS_LANGUAGE, None)
|
||||
if commands is None:
|
||||
commands = [""] * B
|
||||
elif not isinstance(commands, list):
|
||||
commands = [str(commands)] * B
|
||||
|
||||
# Encode
|
||||
video_frame_embeds = self._encode_video_frames(frames) # (B, T, D_v)
|
||||
lang_embeds, mask = self._encode_language_tokens(commands, device) # (B, L, D_l), (B, L)
|
||||
|
||||
# Project and add first-frame bias
|
||||
video_tokens = self.to_video_tokens(video_frame_embeds) # (B, T, D)
|
||||
video_tokens[:, :1, :] = video_tokens[:, :1, :] + self.first_frame_pos
|
||||
lang_tokens = self.to_lang_tokens(lang_embeds) # (B, L, D)
|
||||
|
||||
# Build masks and run transformer
|
||||
lang_valid = mask
|
||||
video_valid = torch.ones(B, T, device=device, dtype=torch.bool)
|
||||
key_padding_mask = ~(torch.cat([lang_valid, video_valid], dim=1))
|
||||
tokens_seq = torch.cat([lang_tokens, video_tokens], dim=1) # (B, S, D)
|
||||
S = tokens_seq.shape[1]
|
||||
causal_mask = torch.triu(torch.ones(S, S, device=device, dtype=torch.bool), diagonal=1)
|
||||
|
||||
attended_all = self.aggregator(tokens_seq, src_key_padding_mask=key_padding_mask, mask=causal_mask)
|
||||
L_len = lang_tokens.shape[1]
|
||||
attended_video = attended_all[:, L_len:, :] # (B, T, D)
|
||||
|
||||
frame_tokens = self.frame_mlp(attended_video) # (B, T, D)
|
||||
|
||||
if self.use_categorical:
|
||||
logits = self.reward_head(frame_tokens) # (B, T, Bins)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
# Expected value over bin centers in [reward_min_value, reward_max_value]
|
||||
bin_centers = torch.linspace(
|
||||
float(self.config.reward_min_value), float(self.config.reward_max_value), int(self.config.num_reward_bins), device=device
|
||||
)
|
||||
values = (probs * bin_centers).sum(dim=-1)
|
||||
return values # (B, T)
|
||||
else:
|
||||
# HL-Gauss continuous or regression fallback
|
||||
head_out = self.reward_head(frame_tokens) # (B, T, Bins) for HL-Gauss or (B,T,*) for regression head
|
||||
if (self.hl_gauss_layer is not None) and (not getattr(self, "hl_gauss_use_regression", False)):
|
||||
return self.hl_gauss_layer(head_out) # (B, T)
|
||||
elif (self.hl_gauss_layer is not None) and getattr(self, "hl_gauss_use_regression", False):
|
||||
return self.hl_gauss_layer(head_out) # (B, T)
|
||||
else:
|
||||
# Scalar proxy via mean over features, then sigmoid to [0,1]
|
||||
raw_like_logits = torch.tanh(head_out).mean(dim=-1) # (B, T)
|
||||
return torch.sigmoid(raw_like_logits)
|
||||
|
||||
def _encode_video_frames(self, frames: Tensor) -> Tensor:
|
||||
"""Encode video frames through SigLIP2 vision tower and return per-frame CLS embeddings.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user