mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
format
This commit is contained in:
@@ -60,4 +60,3 @@ lerobot-train \
|
|||||||
|
|
||||||
- [Real-Time Chunking (Inference-Time RTC)](./rtc)
|
- [Real-Time Chunking (Inference-Time RTC)](./rtc)
|
||||||
- [Pi0](./pi0), [Pi0.5](./pi05), [SmolVLA](./smolvla), [WALL-OSS](./walloss)
|
- [Pi0](./pi0), [Pi0.5](./pi05), [SmolVLA](./smolvla), [WALL-OSS](./walloss)
|
||||||
|
|
||||||
|
|||||||
@@ -724,9 +724,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
if time_emb.dim() == 2:
|
if time_emb.dim() == 2:
|
||||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||||
elif time_emb.shape[:2] != action_emb.shape[:2]:
|
elif time_emb.shape[:2] != action_emb.shape[:2]:
|
||||||
raise ValueError(
|
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
|
||||||
f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}"
|
|
||||||
)
|
|
||||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||||
|
|
||||||
def mlp_func(action_time_emb):
|
def mlp_func(action_time_emb):
|
||||||
|
|||||||
@@ -1270,9 +1270,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||||
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
|
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
|
||||||
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
|
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
|
||||||
losses = self.model.forward(
|
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise=noise, time=time)
|
||||||
images, img_masks, tokens, masks, actions, noise=noise, time=time
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||||
|
|
||||||
|
|||||||
@@ -69,8 +69,6 @@ class RTCTrainingConfig:
|
|||||||
if self.min_delay < 0:
|
if self.min_delay < 0:
|
||||||
raise ValueError(f"min_delay must be >= 0, got {self.min_delay}")
|
raise ValueError(f"min_delay must be >= 0, got {self.min_delay}")
|
||||||
if self.max_delay < self.min_delay:
|
if self.max_delay < self.min_delay:
|
||||||
raise ValueError(
|
raise ValueError(f"max_delay ({self.max_delay}) must be >= min_delay ({self.min_delay})")
|
||||||
f"max_delay ({self.max_delay}) must be >= min_delay ({self.min_delay})"
|
|
||||||
)
|
|
||||||
if self.exp_decay <= 0:
|
if self.exp_decay <= 0:
|
||||||
raise ValueError(f"exp_decay must be positive, got {self.exp_decay}")
|
raise ValueError(f"exp_decay must be positive, got {self.exp_decay}")
|
||||||
|
|||||||
@@ -27,9 +27,7 @@ def sample_rtc_delay(cfg: RTCTrainingConfig, batch_size: int, device: torch.devi
|
|||||||
return torch.full((batch_size,), cfg.min_delay, device=device, dtype=torch.long)
|
return torch.full((batch_size,), cfg.min_delay, device=device, dtype=torch.long)
|
||||||
|
|
||||||
if cfg.delay_distribution == RTCTrainingDelayDistribution.UNIFORM:
|
if cfg.delay_distribution == RTCTrainingDelayDistribution.UNIFORM:
|
||||||
return torch.randint(
|
return torch.randint(cfg.min_delay, cfg.max_delay + 1, (batch_size,), device=device, dtype=torch.long)
|
||||||
cfg.min_delay, cfg.max_delay + 1, (batch_size,), device=device, dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
delay_values = torch.arange(cfg.min_delay, cfg.max_delay + 1, device=device, dtype=torch.long)
|
delay_values = torch.arange(cfg.min_delay, cfg.max_delay + 1, device=device, dtype=torch.long)
|
||||||
weights = torch.exp(-cfg.exp_decay * delay_values.to(dtype=torch.float32))
|
weights = torch.exp(-cfg.exp_decay * delay_values.to(dtype=torch.float32))
|
||||||
@@ -62,4 +60,3 @@ def masked_mean(
|
|||||||
masked = losses * mask
|
masked = losses * mask
|
||||||
denom = mask.sum(dim=reduce_dims).clamp_min(eps)
|
denom = mask.sum(dim=reduce_dims).clamp_min(eps)
|
||||||
return masked.sum(dim=reduce_dims) / denom
|
return masked.sum(dim=reduce_dims) / denom
|
||||||
|
|
||||||
|
|||||||
@@ -751,9 +751,7 @@ class VLAFlowMatching(nn.Module):
|
|||||||
if time_emb.dim() == 2:
|
if time_emb.dim() == 2:
|
||||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||||
elif time_emb.shape[:2] != action_emb.shape[:2]:
|
elif time_emb.shape[:2] != action_emb.shape[:2]:
|
||||||
raise ValueError(
|
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
|
||||||
f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}"
|
|
||||||
)
|
|
||||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||||
|
|
||||||
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
||||||
|
|||||||
@@ -48,4 +48,3 @@ def test_apply_rtc_training_time_prefix_mask():
|
|||||||
# Delay=2 means the first two steps are prefix (time forced to 0.0) and only the last two are postfix.
|
# Delay=2 means the first two steps are prefix (time forced to 0.0) and only the last two are postfix.
|
||||||
assert torch.allclose(time_tokens[0], torch.tensor([0.0, 0.0, 0.5, 0.5]))
|
assert torch.allclose(time_tokens[0], torch.tensor([0.0, 0.0, 0.5, 0.5]))
|
||||||
assert torch.equal(postfix_mask[0], torch.tensor([False, False, True, True]))
|
assert torch.equal(postfix_mask[0], torch.tensor([False, False, True, True]))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user