mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f147a4cd48 | |||
| c3fa269b21 | |||
| 385ba8d1b7 | |||
| f4ccf911fa | |||
| 0cb8c92fe4 |
@@ -57,6 +57,8 @@
|
||||
title: Use Async Inference
|
||||
- local: rtc
|
||||
title: Real-Time Chunking (RTC)
|
||||
- local: training_time_rtc
|
||||
title: Training-Time RTC
|
||||
title: "Inference"
|
||||
- sections:
|
||||
- local: envhub
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
# Training-Time RTC
|
||||
|
||||
Training-Time RTC teaches the model to handle inference delay during training.
|
||||
It feeds the **ground-truth action prefix** to the model and trains only on the remaining postfix actions.
|
||||
This keeps chunk transitions smooth without doing any inference-time inpainting.
|
||||
|
||||
Based on: [Training-Time Action Conditioning for Efficient Real-Time Chunking](https://arxiv.org/abs/2512.05964).
|
||||
|
||||
LeRobot supports this for `pi0`, `pi05` and `smolvla` without changing model parameters.
|
||||
|
||||
---
|
||||
|
||||
## How It Works
|
||||
|
||||
### At Training Time
|
||||
|
||||
- Sample a delay `d` per batch element.
|
||||
- Keep the first `d` action steps as **ground truth** (no noise).
|
||||
- Add noise only to the postfix actions.
|
||||
- Set the flow-matching timestep to **1.0** for prefix tokens and normal timesteps for postfix tokens.
|
||||
- Mask the loss to only train on the postfix.
|
||||
|
||||
### At Inference Time
|
||||
|
||||
When `rtc_training_config.enabled=true`, the model uses training-time RTC inference:
|
||||
|
||||
- Replace prefix positions in `x_t` with previous chunk's leftover actions.
|
||||
- Set timestep to **1.0** for prefix positions.
|
||||
|
||||
---
|
||||
|
||||
## Quick Start (CLI)
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=pi0 \
|
||||
--dataset.repo_id=your/dataset \
|
||||
--policy.rtc_training_config.enabled=true \
|
||||
--policy.rtc_training_config.min_delay=0 \
|
||||
--policy.rtc_training_config.max_delay=6 \
|
||||
--policy.rtc_training_config.delay_distribution=UNIFORM
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Inference with Training-Time RTC
|
||||
|
||||
After training with `rtc_training_config`, use the same config at inference. The model will automatically use training-time RTC inference:
|
||||
|
||||
```python
|
||||
policy = PI0Policy.from_pretrained("path/to/trained/model")
|
||||
# rtc_training_config is loaded from the saved config
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
batch,
|
||||
inference_delay=5, # estimated delay in timesteps
|
||||
prev_chunk_left_over=previous_actions, # from previous chunk
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Parameters
|
||||
|
||||
`RTCTrainingConfig` is available on the policy config (`pi0`, `pi05`, `smolvla`, `xvla`):
|
||||
|
||||
- **`enabled`**: Toggle training-time RTC (both training and inference).
|
||||
- **`min_delay` / `max_delay`**: Delay range (inclusive).
|
||||
- **`delay_distribution`**:
|
||||
- `UNIFORM`: uniform in `[min_delay, max_delay]`
|
||||
- `EXP`: exponentially decayed distribution over delays
|
||||
- **`exp_decay`**: Exponential decay factor for `EXP` sampling.
|
||||
|
||||
---
|
||||
|
||||
## Notes and Recommendations
|
||||
|
||||
- Start with `min_delay=0` and `max_delay` around your expected worst-case inference delay.
|
||||
- Use `EXP` if you want more supervision on smaller delays.
|
||||
|
||||
---
|
||||
|
||||
## Related Docs
|
||||
|
||||
- [Real-Time Chunking (Inference-Time RTC)](./rtc)
|
||||
- [Pi0](./pi0), [Pi0.5](./pi05), [SmolVLA](./smolvla)
|
||||
@@ -50,3 +50,8 @@ class RTCAttentionSchedule(str, Enum):
|
||||
ONES = "ONES"
|
||||
LINEAR = "LINEAR"
|
||||
EXP = "EXP"
|
||||
|
||||
|
||||
class RTCTrainingDelayDistribution(str, Enum):
|
||||
UNIFORM = "UNIFORM"
|
||||
EXP = "EXP"
|
||||
|
||||
@@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
@@ -50,8 +50,9 @@ class PI0Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
# Real-Time Chunking (RTC) configurations
|
||||
rtc_config: RTCConfig | None = None
|
||||
rtc_training_config: RTCTrainingConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
|
||||
@@ -44,6 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.rtc.training_time import (
|
||||
apply_rtc_training_time,
|
||||
apply_training_time_rtc_inference,
|
||||
masked_mean,
|
||||
sample_rtc_delay,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -79,8 +85,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
if time.ndim not in (1, 2):
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
@@ -88,8 +94,14 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
if time.ndim == 1:
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
time_flat = time.reshape(-1)
|
||||
sin_input = scaling_factor[None, :] * time_flat[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb.reshape(*time.shape, dimension)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
@@ -605,6 +617,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _training_time_rtc_inference_enabled(self):
|
||||
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -714,7 +729,10 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
|
||||
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
if time_emb.dim() == 2:
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
elif time_emb.shape[:2] != action_emb.shape[:2]:
|
||||
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
def mlp_func(action_time_emb):
|
||||
@@ -750,7 +768,12 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
if time.ndim == 1:
|
||||
time_expanded = time[:, None, None]
|
||||
elif time.ndim == 2:
|
||||
time_expanded = time[:, :, None]
|
||||
else:
|
||||
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
@@ -846,24 +869,37 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
use_training_time_rtc = self._training_time_rtc_inference_enabled()
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
if use_training_time_rtc:
|
||||
x_t_cond, time_tensor = apply_training_time_rtc_inference(
|
||||
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
|
||||
)
|
||||
v_t = self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
x_t=x_t_cond,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
elif self._rtc_enabled():
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
@@ -874,7 +910,14 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=x_t,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
@@ -1277,7 +1320,19 @@ class PI0Policy(PreTrainedPolicy):
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Compute loss
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
postfix_mask = None
|
||||
rtc_cfg = self.config.rtc_training_config
|
||||
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
|
||||
batch_size = actions.shape[0]
|
||||
time = self.model.sample_time(batch_size, actions.device)
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
|
||||
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
|
||||
losses = self.model.forward(
|
||||
images, img_masks, lang_tokens, lang_masks, state, actions, noise=noise, time=time
|
||||
)
|
||||
else:
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
@@ -1289,12 +1344,12 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
@@ -52,6 +52,7 @@ class PI05Config(PreTrainedConfig):
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
rtc_training_config: RTCTrainingConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
|
||||
@@ -44,6 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.rtc.training_time import (
|
||||
apply_rtc_training_time,
|
||||
apply_training_time_rtc_inference,
|
||||
masked_mean,
|
||||
sample_rtc_delay,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -78,8 +84,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
if time.ndim not in (1, 2):
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
@@ -87,8 +93,14 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
if time.ndim == 1:
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
time_flat = time.reshape(-1)
|
||||
sin_input = scaling_factor[None, :] * time_flat[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb.reshape(*time.shape, dimension)
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
@@ -602,6 +614,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _training_time_rtc_inference_enabled(self):
|
||||
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -729,7 +744,12 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
if time.ndim == 1:
|
||||
time_expanded = time[:, None, None]
|
||||
elif time.ndim == 2:
|
||||
time_expanded = time[:, :, None]
|
||||
else:
|
||||
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
@@ -820,23 +840,35 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
use_training_time_rtc = self._training_time_rtc_inference_enabled()
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
if use_training_time_rtc:
|
||||
x_t_cond, time_tensor = apply_training_time_rtc_inference(
|
||||
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
|
||||
)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
x_t=x_t_cond,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
elif self._rtc_enabled():
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
@@ -847,7 +879,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=x_t,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
@@ -1250,7 +1288,17 @@ class PI05Policy(PreTrainedPolicy):
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
postfix_mask = None
|
||||
rtc_cfg = self.config.rtc_training_config
|
||||
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
|
||||
batch_size = actions.shape[0]
|
||||
time = self.model.sample_time(batch_size, actions.device)
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
|
||||
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise=noise, time=time)
|
||||
else:
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
|
||||
# Truncate losses to actual action dimensions
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
@@ -1262,12 +1310,12 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ Based on:
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.configs.types import RTCAttentionSchedule, RTCTrainingDelayDistribution
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -53,3 +53,22 @@ class RTCConfig:
|
||||
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
|
||||
if self.debug_maxlen <= 0:
|
||||
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCTrainingConfig:
|
||||
"""Configuration for training-time RTC action prefix conditioning."""
|
||||
|
||||
enabled: bool = False
|
||||
min_delay: int = 0
|
||||
max_delay: int = 0
|
||||
delay_distribution: RTCTrainingDelayDistribution = RTCTrainingDelayDistribution.UNIFORM
|
||||
exp_decay: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.min_delay < 0:
|
||||
raise ValueError(f"min_delay must be >= 0, got {self.min_delay}")
|
||||
if self.max_delay < self.min_delay:
|
||||
raise ValueError(f"max_delay ({self.max_delay}) must be >= min_delay ({self.min_delay})")
|
||||
if self.exp_decay <= 0:
|
||||
raise ValueError(f"exp_decay must be positive, got {self.exp_decay}")
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import RTCTrainingDelayDistribution
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig
|
||||
|
||||
|
||||
def sample_rtc_delay(cfg: RTCTrainingConfig, batch_size: int, device: torch.device) -> torch.Tensor:
|
||||
if cfg.max_delay == cfg.min_delay:
|
||||
return torch.full((batch_size,), cfg.min_delay, device=device, dtype=torch.long)
|
||||
|
||||
if cfg.delay_distribution == RTCTrainingDelayDistribution.UNIFORM:
|
||||
return torch.randint(cfg.min_delay, cfg.max_delay + 1, (batch_size,), device=device, dtype=torch.long)
|
||||
|
||||
delay_values = torch.arange(cfg.min_delay, cfg.max_delay + 1, device=device, dtype=torch.long)
|
||||
weights = torch.exp(-cfg.exp_decay * delay_values.to(dtype=torch.float32))
|
||||
probs = weights / weights.sum()
|
||||
samples = torch.multinomial(probs, batch_size, replacement=True)
|
||||
return delay_values[samples]
|
||||
|
||||
|
||||
def apply_rtc_training_time(
|
||||
time: torch.Tensor, delay: torch.Tensor, seq_len: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
device = time.device
|
||||
delay = torch.clamp(delay, max=seq_len)
|
||||
prefix_mask = torch.arange(seq_len, device=device)[None, :] < delay[:, None]
|
||||
time_tokens = time[:, None].expand(-1, seq_len)
|
||||
time_tokens = time_tokens.masked_fill(prefix_mask, 0.0)
|
||||
postfix_mask = ~prefix_mask
|
||||
return time_tokens, postfix_mask
|
||||
|
||||
|
||||
def masked_mean(
|
||||
losses: torch.Tensor, mask: torch.Tensor | None, reduce_dims: tuple[int, ...], eps: float = 1e-8
|
||||
) -> torch.Tensor:
|
||||
if mask is None:
|
||||
return losses.mean(dim=reduce_dims)
|
||||
|
||||
mask = mask.to(dtype=losses.dtype)
|
||||
while mask.dim() < losses.dim():
|
||||
mask = mask.unsqueeze(-1)
|
||||
masked = losses * mask
|
||||
denom = mask.sum(dim=reduce_dims).clamp_min(eps)
|
||||
return masked.sum(dim=reduce_dims) / denom
|
||||
|
||||
|
||||
def apply_training_time_rtc_inference(
|
||||
x_t: torch.Tensor,
|
||||
time: float,
|
||||
inference_delay: int | None,
|
||||
prev_chunk_left_over: torch.Tensor | None,
|
||||
chunk_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply training-time RTC conditioning during inference.
|
||||
|
||||
Based on Algorithm 1 from "Training-Time Action Conditioning for Efficient Real-Time Chunking".
|
||||
|
||||
At each denoising step:
|
||||
1. Replace prefix positions in x_t with ground truth from previous chunk
|
||||
2. Create per-token timesteps with 1.0 for prefix positions
|
||||
|
||||
Args:
|
||||
x_t: Current noisy actions (B, T, D)
|
||||
time: Current flow matching timestep (scalar)
|
||||
inference_delay: Number of prefix actions to condition on
|
||||
prev_chunk_left_over: Previous chunk's leftover actions (B, T, D)
|
||||
chunk_size: Total chunk size T
|
||||
|
||||
Returns:
|
||||
x_t_conditioned: x_t with prefix replaced by previous actions
|
||||
time_per_token: Per-token timesteps (B, T) with 1.0 for prefix
|
||||
"""
|
||||
batch_size = x_t.shape[0]
|
||||
device = x_t.device
|
||||
|
||||
if inference_delay is None or inference_delay <= 0 or prev_chunk_left_over is None:
|
||||
time_scalar = torch.full((batch_size,), time, device=device, dtype=torch.float32)
|
||||
return x_t, time_scalar
|
||||
|
||||
delay = min(inference_delay, chunk_size)
|
||||
prefix_mask = torch.arange(chunk_size, device=device)[None, :] < delay
|
||||
|
||||
x_t_conditioned = torch.where(
|
||||
prefix_mask[:, :, None].expand_as(x_t),
|
||||
prev_chunk_left_over[:, :chunk_size, :],
|
||||
x_t,
|
||||
)
|
||||
|
||||
time_per_token = torch.full((batch_size, chunk_size), time, device=device, dtype=torch.float32)
|
||||
time_per_token = time_per_token.masked_fill(prefix_mask, 1.0)
|
||||
|
||||
return x_t_conditioned, time_per_token
|
||||
@@ -20,7 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@@ -103,8 +103,9 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
# Real-Time Chunking (RTC) configurations
|
||||
rtc_config: RTCConfig | None = None
|
||||
rtc_training_config: RTCTrainingConfig | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
@@ -63,6 +63,12 @@ from typing_extensions import Unpack
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.rtc.training_time import (
|
||||
apply_rtc_training_time,
|
||||
apply_training_time_rtc_inference,
|
||||
masked_mean,
|
||||
sample_rtc_delay,
|
||||
)
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
from lerobot.policies.utils import (
|
||||
@@ -85,8 +91,8 @@ def create_sinusoidal_pos_embedding(
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
if time.ndim not in (1, 2):
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
@@ -94,9 +100,14 @@ def create_sinusoidal_pos_embedding(
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
if time.ndim == 1:
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
|
||||
time_flat = time.reshape(-1)
|
||||
sin_input = scaling_factor[None, :] * time_flat[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb
|
||||
return pos_emb.reshape(*time.shape, dimension)
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
@@ -375,6 +386,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
actions = self.prepare_action(batch)
|
||||
postfix_mask = None
|
||||
rtc_cfg = self.config.rtc_training_config
|
||||
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
|
||||
batch_size = actions.shape[0]
|
||||
if time is None:
|
||||
time = self.model.sample_time(batch_size, actions.device)
|
||||
if noise is None:
|
||||
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
|
||||
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
@@ -384,6 +405,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
in_episode_bound = ~actions_is_pad
|
||||
losses = losses * in_episode_bound.unsqueeze(-1)
|
||||
loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
||||
postfix_mask = in_episode_bound if postfix_mask is None else (postfix_mask & in_episode_bound)
|
||||
|
||||
# Remove padding
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
@@ -391,12 +413,12 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
|
||||
if reduction == "none":
|
||||
# Return per-sample losses (B,) by averaging over time and action dims
|
||||
per_sample_loss = losses.mean(dim=(1, 2))
|
||||
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
|
||||
loss_dict["loss"] = per_sample_loss.mean().item()
|
||||
return per_sample_loss, loss_dict
|
||||
else:
|
||||
# Default: return scalar mean loss
|
||||
loss = losses.mean()
|
||||
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
@@ -596,6 +618,9 @@ class VLAFlowMatching(nn.Module):
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _training_time_rtc_inference_enabled(self):
|
||||
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
|
||||
|
||||
def set_requires_grad(self):
|
||||
for params in self.state_proj.parameters():
|
||||
params.requires_grad = self.config.train_state_proj
|
||||
@@ -731,7 +756,10 @@ class VLAFlowMatching(nn.Module):
|
||||
)
|
||||
time_emb = time_emb.type(dtype=dtype)
|
||||
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
if time_emb.dim() == 2:
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
elif time_emb.shape[:2] != action_emb.shape[:2]:
|
||||
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
||||
@@ -763,7 +791,12 @@ class VLAFlowMatching(nn.Module):
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
if time.ndim == 1:
|
||||
time_expanded = time[:, None, None]
|
||||
elif time.ndim == 2:
|
||||
time_expanded = time[:, :, None]
|
||||
else:
|
||||
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
@@ -826,23 +859,35 @@ class VLAFlowMatching(nn.Module):
|
||||
num_steps = self.config.num_steps
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
use_training_time_rtc = self._training_time_rtc_inference_enabled()
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
x_t=input_x_t,
|
||||
if use_training_time_rtc:
|
||||
x_t_cond, time_tensor = apply_training_time_rtc_inference(
|
||||
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
|
||||
)
|
||||
v_t = self.denoise_step(
|
||||
x_t=x_t_cond,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
timestep=current_timestep,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
elif self._rtc_enabled():
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
x_t=input_x_t,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
@@ -853,7 +898,13 @@ class VLAFlowMatching(nn.Module):
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
x_t=x_t,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for training-time RTC helpers."""
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import RTCTrainingDelayDistribution
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig
|
||||
from lerobot.policies.rtc.training_time import apply_rtc_training_time, sample_rtc_delay
|
||||
|
||||
|
||||
def test_rtc_training_config_defaults():
|
||||
config = RTCTrainingConfig()
|
||||
assert config.enabled is False
|
||||
assert config.min_delay == 0
|
||||
assert config.max_delay == 0
|
||||
assert config.delay_distribution == RTCTrainingDelayDistribution.UNIFORM
|
||||
assert config.exp_decay == 1.0
|
||||
|
||||
|
||||
def test_sample_rtc_delay_uniform_range():
|
||||
cfg = RTCTrainingConfig(enabled=True, min_delay=1, max_delay=4)
|
||||
delays = sample_rtc_delay(cfg, batch_size=100, device=torch.device("cpu"))
|
||||
assert delays.min().item() >= 1
|
||||
assert delays.max().item() <= 4
|
||||
|
||||
|
||||
def test_apply_rtc_training_time_prefix_mask():
|
||||
time = torch.tensor([0.5])
|
||||
delays = torch.tensor([2])
|
||||
time_tokens, postfix_mask = apply_rtc_training_time(time, delays, seq_len=4)
|
||||
assert time_tokens.shape == (1, 4)
|
||||
assert postfix_mask.shape == (1, 4)
|
||||
# Delay=2 means the first two steps are prefix (time forced to 0.0) and only the last two are postfix.
|
||||
assert torch.allclose(time_tokens[0], torch.tensor([0.0, 0.0, 0.5, 0.5]))
|
||||
assert torch.equal(postfix_mask[0], torch.tensor([False, False, True, True]))
|
||||
Reference in New Issue
Block a user