This commit is contained in:
Pepijn
2026-01-20 20:08:28 +01:00
parent 0cb8c92fe4
commit f4ccf911fa
7 changed files with 5 additions and 18 deletions
-1
View File
@@ -60,4 +60,3 @@ lerobot-train \
- [Real-Time Chunking (Inference-Time RTC)](./rtc) - [Real-Time Chunking (Inference-Time RTC)](./rtc)
- [Pi0](./pi0), [Pi0.5](./pi05), [SmolVLA](./smolvla), [WALL-OSS](./walloss) - [Pi0](./pi0), [Pi0.5](./pi05), [SmolVLA](./smolvla), [WALL-OSS](./walloss)
+1 -3
View File
@@ -724,9 +724,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
if time_emb.dim() == 2: if time_emb.dim() == 2:
time_emb = time_emb[:, None, :].expand_as(action_emb) time_emb = time_emb[:, None, :].expand_as(action_emb)
elif time_emb.shape[:2] != action_emb.shape[:2]: elif time_emb.shape[:2] != action_emb.shape[:2]:
raise ValueError( raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}"
)
action_time_emb = torch.cat([action_emb, time_emb], dim=2) action_time_emb = torch.cat([action_emb, time_emb], dim=2)
def mlp_func(action_time_emb): def mlp_func(action_time_emb):
+1 -3
View File
@@ -1270,9 +1270,7 @@ class PI05Policy(PreTrainedPolicy):
noise = self.model.sample_noise(actions.shape, actions.device) noise = self.model.sample_noise(actions.shape, actions.device)
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device) delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1]) time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
losses = self.model.forward( losses = self.model.forward(images, img_masks, tokens, masks, actions, noise=noise, time=time)
images, img_masks, tokens, masks, actions, noise=noise, time=time
)
else: else:
losses = self.model.forward(images, img_masks, tokens, masks, actions) losses = self.model.forward(images, img_masks, tokens, masks, actions)
@@ -69,8 +69,6 @@ class RTCTrainingConfig:
if self.min_delay < 0: if self.min_delay < 0:
raise ValueError(f"min_delay must be >= 0, got {self.min_delay}") raise ValueError(f"min_delay must be >= 0, got {self.min_delay}")
if self.max_delay < self.min_delay: if self.max_delay < self.min_delay:
raise ValueError( raise ValueError(f"max_delay ({self.max_delay}) must be >= min_delay ({self.min_delay})")
f"max_delay ({self.max_delay}) must be >= min_delay ({self.min_delay})"
)
if self.exp_decay <= 0: if self.exp_decay <= 0:
raise ValueError(f"exp_decay must be positive, got {self.exp_decay}") raise ValueError(f"exp_decay must be positive, got {self.exp_decay}")
+1 -4
View File
@@ -27,9 +27,7 @@ def sample_rtc_delay(cfg: RTCTrainingConfig, batch_size: int, device: torch.devi
return torch.full((batch_size,), cfg.min_delay, device=device, dtype=torch.long) return torch.full((batch_size,), cfg.min_delay, device=device, dtype=torch.long)
if cfg.delay_distribution == RTCTrainingDelayDistribution.UNIFORM: if cfg.delay_distribution == RTCTrainingDelayDistribution.UNIFORM:
return torch.randint( return torch.randint(cfg.min_delay, cfg.max_delay + 1, (batch_size,), device=device, dtype=torch.long)
cfg.min_delay, cfg.max_delay + 1, (batch_size,), device=device, dtype=torch.long
)
delay_values = torch.arange(cfg.min_delay, cfg.max_delay + 1, device=device, dtype=torch.long) delay_values = torch.arange(cfg.min_delay, cfg.max_delay + 1, device=device, dtype=torch.long)
weights = torch.exp(-cfg.exp_decay * delay_values.to(dtype=torch.float32)) weights = torch.exp(-cfg.exp_decay * delay_values.to(dtype=torch.float32))
@@ -62,4 +60,3 @@ def masked_mean(
masked = losses * mask masked = losses * mask
denom = mask.sum(dim=reduce_dims).clamp_min(eps) denom = mask.sum(dim=reduce_dims).clamp_min(eps)
return masked.sum(dim=reduce_dims) / denom return masked.sum(dim=reduce_dims) / denom
@@ -751,9 +751,7 @@ class VLAFlowMatching(nn.Module):
if time_emb.dim() == 2: if time_emb.dim() == 2:
time_emb = time_emb[:, None, :].expand_as(action_emb) time_emb = time_emb[:, None, :].expand_as(action_emb)
elif time_emb.shape[:2] != action_emb.shape[:2]: elif time_emb.shape[:2] != action_emb.shape[:2]:
raise ValueError( raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}"
)
action_time_emb = torch.cat([action_emb, time_emb], dim=2) action_time_emb = torch.cat([action_emb, time_emb], dim=2)
action_time_emb = self.action_time_mlp_in(action_time_emb) action_time_emb = self.action_time_mlp_in(action_time_emb)
@@ -48,4 +48,3 @@ def test_apply_rtc_training_time_prefix_mask():
# Delay=2 means the first two steps are prefix (time forced to 0.0) and only the last two are postfix. # Delay=2 means the first two steps are prefix (time forced to 0.0) and only the last two are postfix.
assert torch.allclose(time_tokens[0], torch.tensor([0.0, 0.0, 0.5, 0.5])) assert torch.allclose(time_tokens[0], torch.tensor([0.0, 0.0, 0.5, 0.5]))
assert torch.equal(postfix_mask[0], torch.tensor([False, False, True, True])) assert torch.equal(postfix_mask[0], torch.tensor([False, False, True, True]))