mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-19 01:07:18 +00:00
cleanup
This commit is contained in:
@@ -184,13 +184,22 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# Layer normalization before reward head to stabilize MLP outputs
|
||||
self.pre_reward_norm = nn.LayerNorm(config.dim_model)
|
||||
|
||||
# Regression head (logit mode only)
|
||||
self.reward_head = nn.Linear(config.dim_model, 1)
|
||||
# Temporal-aware regression head (logit mode only)
|
||||
# Concatenates frame embedding with normalized temporal position
|
||||
self.reward_head = nn.Sequential(
|
||||
nn.Linear(config.dim_model + 1, config.dim_model), # +1 for temporal position
|
||||
nn.ReLU(),
|
||||
nn.Linear(config.dim_model, 1)
|
||||
)
|
||||
|
||||
# Initialize head for logit regression
|
||||
# Initialize temporal-aware head for logit regression
|
||||
with torch.no_grad():
|
||||
self.reward_head.weight.normal_(0.0, config.head_weight_init_std)
|
||||
self.reward_head.bias.fill_(0.0)
|
||||
# First layer: embedding + position -> embedding
|
||||
nn.init.normal_(self.reward_head[0].weight, 0.0, config.head_weight_init_std)
|
||||
nn.init.zeros_(self.reward_head[0].bias)
|
||||
# Output layer: embedding -> logit
|
||||
nn.init.normal_(self.reward_head[2].weight, 0.0, config.head_weight_init_std)
|
||||
nn.init.zeros_(self.reward_head[2].bias)
|
||||
|
||||
# Simple frame dropout probability
|
||||
self.frame_dropout_p = config.frame_dropout_p
|
||||
@@ -309,9 +318,19 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
# MLP predictor
|
||||
video_frame_embeds = self.mlp_predictor(frame_tokens)
|
||||
|
||||
# Get rewards via logit regression head
|
||||
# Get rewards via temporal-aware logit regression head
|
||||
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
|
||||
raw_logits = self.reward_head(normalized_embeds).squeeze(-1) # (B, T)
|
||||
|
||||
# Add temporal position information
|
||||
B, T_pred = normalized_embeds.shape[:2]
|
||||
temporal_pos = torch.linspace(0, 1, T_pred, device=normalized_embeds.device)
|
||||
temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_pred, 1) # (B, T, 1)
|
||||
|
||||
# Concatenate embeddings with temporal position
|
||||
temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1) # (B, T, D+1)
|
||||
|
||||
# Forward through temporal-aware head
|
||||
raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T)
|
||||
return torch.sigmoid(raw_logits) # Apply sigmoid for final predictions
|
||||
|
||||
def normalize_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
@@ -480,17 +499,35 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
|
||||
# During inference, we might not want to compute loss
|
||||
if not self.training and target is None:
|
||||
# Return predictions without loss
|
||||
# Return predictions without loss using temporal-aware head
|
||||
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
|
||||
rewards = self.sigmoid(self.reward_head(normalized_embeds)).squeeze(-1)
|
||||
|
||||
# Add temporal position information
|
||||
B_inf, T_inf = normalized_embeds.shape[:2]
|
||||
temporal_pos = torch.linspace(0, 1, T_inf, device=normalized_embeds.device)
|
||||
temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B_inf, T_inf, 1)
|
||||
|
||||
# Concatenate and forward through temporal-aware head
|
||||
temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1)
|
||||
raw_logits = self.reward_head(temporal_input).squeeze(-1)
|
||||
rewards = torch.sigmoid(raw_logits)
|
||||
return rewards.mean() * 0.0, {"rewards_mean": rewards.mean().item()}
|
||||
|
||||
# Calculate loss using logit regression
|
||||
loss_start = time.perf_counter()
|
||||
|
||||
# Get model outputs
|
||||
# Get model outputs with temporal-aware head
|
||||
normalized_embeds = self.pre_reward_norm(video_frame_embeds)
|
||||
raw_logits = self.reward_head(normalized_embeds).squeeze(-1) # (B, T_eff)
|
||||
|
||||
# Add temporal position information
|
||||
temporal_pos = torch.linspace(0, 1, T_eff, device=normalized_embeds.device)
|
||||
temporal_pos = temporal_pos.unsqueeze(0).unsqueeze(-1).expand(B, T_eff, 1) # (B, T_eff, 1)
|
||||
|
||||
# Concatenate embeddings with temporal position
|
||||
temporal_input = torch.cat([normalized_embeds, temporal_pos], dim=-1) # (B, T_eff, D+1)
|
||||
|
||||
# Forward through temporal-aware head
|
||||
raw_logits = self.reward_head(temporal_input).squeeze(-1) # (B, T_eff)
|
||||
|
||||
# Logit regression: transform targets to logit space and compute MSE on logits
|
||||
eps = self.config.logit_eps
|
||||
@@ -537,9 +574,19 @@ class RLearNPolicy(PreTrainedPolicy):
|
||||
mismatch_tokens = self.frame_mlp(attended_video_mm) # (B, T, D)
|
||||
mismatch_embeds = self.mlp_predictor(mismatch_tokens)
|
||||
|
||||
# Predict near-zero progress for mismatched pairs
|
||||
# Predict near-zero progress for mismatched pairs with temporal awareness
|
||||
normalized_mismatch_embeds = self.pre_reward_norm(mismatch_embeds)
|
||||
mismatch_raw_logits = self.reward_head(normalized_mismatch_embeds).squeeze(-1)
|
||||
|
||||
# Add temporal position information for mismatch computation
|
||||
T_mismatch = normalized_mismatch_embeds.shape[1]
|
||||
temporal_pos_mm = torch.linspace(0, 1, T_mismatch, device=normalized_mismatch_embeds.device)
|
||||
temporal_pos_mm = temporal_pos_mm.unsqueeze(0).unsqueeze(-1).expand(B, T_mismatch, 1)
|
||||
|
||||
# Concatenate mismatch embeddings with temporal position
|
||||
temporal_input_mm = torch.cat([normalized_mismatch_embeds, temporal_pos_mm], dim=-1)
|
||||
|
||||
# Forward through temporal-aware head
|
||||
mismatch_raw_logits = self.reward_head(temporal_input_mm).squeeze(-1)
|
||||
|
||||
# Create mask tensor for loss calculation
|
||||
mismatch_tensor = torch.tensor(mismatch_mask, device=device, dtype=torch.bool)
|
||||
|
||||
Reference in New Issue
Block a user