From c9243c29b02304566b78a095cac8ba9d4832e5ff Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sun, 31 Aug 2025 16:34:46 +0200 Subject: [PATCH] cleanup --- .../policies/rlearn/modeling_rlearn.py | 73 +++++++++++++++---- 1 file changed, 60 insertions(+), 13 deletions(-) diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index ae66096fd..ca3053dae 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -184,13 +184,22 @@ class RLearNPolicy(PreTrainedPolicy): # Layer normalization before reward head to stabilize MLP outputs self.pre_reward_norm = nn.LayerNorm(config.dim_model) - # Regression head (logit mode only) - self.reward_head = nn.Linear(config.dim_model, 1) + # Temporal-aware regression head (logit mode only) + # Concatenates frame embedding with normalized temporal position + self.reward_head = nn.Sequential( + nn.Linear(config.dim_model + 1, config.dim_model), # +1 for temporal position + nn.ReLU(), + nn.Linear(config.dim_model, 1) + ) - # Initialize head for logit regression + # Initialize temporal-aware head for logit regression with torch.no_grad(): - self.reward_head.weight.normal_(0.0, config.head_weight_init_std) - self.reward_head.bias.fill_(0.0) + # First layer: embedding + position -> embedding + nn.init.normal_(self.reward_head[0].weight, 0.0, config.head_weight_init_std) + nn.init.zeros_(self.reward_head[0].bias) + # Output layer: embedding -> logit + nn.init.normal_(self.reward_head[2].weight, 0.0, config.head_weight_init_std) + nn.init.zeros_(self.reward_head[2].bias) # Simple frame dropout probability self.frame_dropout_p = config.frame_dropout_p @@ -309,9 +318,19 @@ class RLearNPolicy(PreTrainedPolicy): # MLP predictor video_frame_embeds = self.mlp_predictor(frame_tokens) - # Get rewards via logit regression head + # Get rewards via temporal-aware logit regression head normalized_embeds = self.pre_reward_norm(video_frame_embeds) - raw_logits = self.reward_head(normalized_embeds).squeeze(-1) # (B, T) + + # Add temporal position information + B, T_pred = normalized_embeds.shape[:2] + temporal_pos = torch.linspace(0, 1, T_pred, device=normalized_embeds.device) + temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_pred, 1) # (B, T, 1) + + # Concatenate embeddings with temporal position + temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1) # (B, T, D+1) + + # Forward through temporal-aware head + raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T) return torch.sigmoid(raw_logits) # Apply sigmoid for final predictions def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: @@ -480,17 +499,35 @@ class RLearNPolicy(PreTrainedPolicy): # During inference, we might not want to compute loss if not self.training and target is None: - # Return predictions without loss + # Return predictions without loss using temporal-aware head normalized_embeds = self.pre_reward_norm(video_frame_embeds) - rewards = self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1) + + # Add temporal position information + B_inf, T_inf = normalized_embeds.shape[:2] + temporal_pos = torch.linspace(0, 1, T_inf, device=normalized_embeds.device) + temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B_inf, T_inf, 1) + + # Concatenate and forward through temporal-aware head + temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1) + raw_logits = self.reward_head(temporal_input).squeeze(-1) + rewards = torch.sigmoid(raw_logits) return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()} # Calculate loss using logit regression loss_start = time.perf_counter() - # Get model outputs + # Get model outputs with temporal-aware head normalized_embeds = self.pre_reward_norm(video_frame_embeds) - raw_logits = self.reward_head(normalized_embeds).squeeze(-1) # (B, T_eff) + + # Add temporal position information + temporal_pos = torch.linspace(0, 1, T_eff, device=normalized_embeds.device) + temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_eff, 1) # (B, T_eff, 1) + + # Concatenate embeddings with temporal position + temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1) # (B, T_eff, D+1) + + # Forward through temporal-aware head + raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T_eff) # Logit regression: transform targets to logit space and compute MSE on logits eps = self.config.logit_eps @@ -537,9 +574,19 @@ class RLearNPolicy(PreTrainedPolicy): mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D) mismatch_embeds = self.mlp_predictor(mismatch_tokens) - # Predict near-zero progress for mismatched pairs + # Predict near-zero progress for mismatched pairs with temporal awareness normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds) - mismatch_raw_logits = self.reward_head(normalized_mismatch_embeds).squeeze(-1) + + # Add temporal position information for mismatch computation + T_mismatch = normalized_mismatch_embeds.shape[1] + temporal_pos_mm = torch.linspace(0, 1, T_mismatch, device=normalized_mismatch_embeds.device) + temporal_pos_mm = temporal_pos_mm.unsqueeze(0).unsqueeze(-1).expand(B, T_mismatch, 1) + + # Concatenate mismatch embeddings with temporal position + temporal_input_mm = torch.cat([normalized_mismatch_embeds, temporal_pos_mm], dim=-1) + + # Forward through temporal-aware head + mismatch_raw_logits = self.reward_head(temporal_input_mm).squeeze(-1) # Create mask tensor for loss calculation mismatch_tensor = torch.tensor(mismatch_mask, device=device, dtype=torch.bool)