Compare commits

...

5 Commits

Author SHA1 Message Date
Pepijn f147a4cd48 Add inference for training time rtc 2026-01-29 11:05:42 +01:00
Pepijn c3fa269b21 Merge branch 'main' into feat/training_time_rtc 2026-01-27 17:34:56 +01:00
Pepijn 385ba8d1b7 remove wall-oss from doc links 2026-01-20 20:11:56 +01:00
Pepijn f4ccf911fa format 2026-01-20 20:08:28 +01:00
Pepijn 0cb8c92fe4 Implement training time rtc for pi0, pi0.5 and smolvla 2026-01-20 20:02:10 +01:00
12 changed files with 490 additions and 61 deletions
+2
View File
@@ -57,6 +57,8 @@
title: Use Async Inference title: Use Async Inference
- local: rtc - local: rtc
title: Real-Time Chunking (RTC) title: Real-Time Chunking (RTC)
- local: training_time_rtc
title: Training-Time RTC
title: "Inference" title: "Inference"
- sections: - sections:
- local: envhub - local: envhub
+86
View File
@@ -0,0 +1,86 @@
# Training-Time RTC
Training-Time RTC teaches the model to handle inference delay during training.
It feeds the **ground-truth action prefix** to the model and trains only on the remaining postfix actions.
This keeps chunk transitions smooth without doing any inference-time inpainting.
Based on: [Training-Time Action Conditioning for Efficient Real-Time Chunking](https://arxiv.org/abs/2512.05964).
LeRobot supports this for `pi0`, `pi05` and `smolvla` without changing model parameters.
---
## How It Works
### At Training Time
- Sample a delay `d` per batch element.
- Keep the first `d` action steps as **ground truth** (no noise).
- Add noise only to the postfix actions.
- Set the flow-matching timestep to **1.0** for prefix tokens and normal timesteps for postfix tokens.
- Mask the loss to only train on the postfix.
### At Inference Time
When `rtc_training_config.enabled=true`, the model uses training-time RTC inference:
- Replace prefix positions in `x_t` with previous chunk's leftover actions.
- Set timestep to **1.0** for prefix positions.
---
## Quick Start (CLI)
```bash
lerobot-train \
--policy.type=pi0 \
--dataset.repo_id=your/dataset \
--policy.rtc_training_config.enabled=true \
--policy.rtc_training_config.min_delay=0 \
--policy.rtc_training_config.max_delay=6 \
--policy.rtc_training_config.delay_distribution=UNIFORM
```
---
## Inference with Training-Time RTC
After training with `rtc_training_config`, use the same config at inference. The model will automatically use training-time RTC inference:
```python
policy = PI0Policy.from_pretrained("path/to/trained/model")
# rtc_training_config is loaded from the saved config
actions = policy.predict_action_chunk(
batch,
inference_delay=5, # estimated delay in timesteps
prev_chunk_left_over=previous_actions, # from previous chunk
)
```
---
## Key Parameters
`RTCTrainingConfig` is available on the policy config (`pi0`, `pi05`, `smolvla`, `xvla`):
- **`enabled`**: Toggle training-time RTC (both training and inference).
- **`min_delay` / `max_delay`**: Delay range (inclusive).
- **`delay_distribution`**:
- `UNIFORM`: uniform in `[min_delay, max_delay]`
- `EXP`: exponentially decayed distribution over delays
- **`exp_decay`**: Exponential decay factor for `EXP` sampling.
---
## Notes and Recommendations
- Start with `min_delay=0` and `max_delay` around your expected worst-case inference delay.
- Use `EXP` if you want more supervision on smaller delays.
---
## Related Docs
- [Real-Time Chunking (Inference-Time RTC)](./rtc)
- [Pi0](./pi0), [Pi0.5](./pi05), [SmolVLA](./smolvla)
+5
View File
@@ -50,3 +50,8 @@ class RTCAttentionSchedule(str, Enum):
ONES = "ONES" ONES = "ONES"
LINEAR = "LINEAR" LINEAR = "LINEAR"
EXP = "EXP" EXP = "EXP"
class RTCTrainingDelayDistribution(str, Enum):
UNIFORM = "UNIFORM"
EXP = "EXP"
@@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224 DEFAULT_IMAGE_SIZE = 224
@@ -50,8 +50,9 @@ class PI0Config(PreTrainedConfig):
min_period: float = 4e-3 min_period: float = 4e-3
max_period: float = 4.0 max_period: float = 4.0
# Real-Time Chunking (RTC) configuration # Real-Time Chunking (RTC) configurations
rtc_config: RTCConfig | None = None rtc_config: RTCConfig | None = None
rtc_training_config: RTCTrainingConfig | None = None
image_resolution: tuple[int, int] = ( image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE,
+74 -19
View File
@@ -44,6 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.rtc.training_time import (
apply_rtc_training_time,
apply_training_time_rtc_inference,
masked_mean,
sample_rtc_delay,
)
from lerobot.utils.constants import ( from lerobot.utils.constants import (
ACTION, ACTION,
OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_ATTENTION_MASK,
@@ -79,8 +85,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
if dimension % 2 != 0: if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2") raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1: if time.ndim not in (1, 2):
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
dtype = get_safe_dtype(torch.float64, device.type) dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
@@ -88,8 +94,14 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
# Compute the outer product # Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None] if time.ndim == 1:
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
time_flat = time.reshape(-1)
sin_input = scaling_factor[None, :] * time_flat[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb.reshape(*time.shape, dimension)
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy) def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
@@ -605,6 +617,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
def _rtc_enabled(self): def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _training_time_rtc_inference_enabled(self):
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs): def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled.""" """Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training: if self.gradient_checkpointing_enabled and self.training:
@@ -714,7 +729,10 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
time_emb = time_emb[:, None, :].expand_as(action_emb) if time_emb.dim() == 2:
time_emb = time_emb[:, None, :].expand_as(action_emb)
elif time_emb.shape[:2] != action_emb.shape[:2]:
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
action_time_emb = torch.cat([action_emb, time_emb], dim=2) action_time_emb = torch.cat([action_emb, time_emb], dim=2)
def mlp_func(action_time_emb): def mlp_func(action_time_emb):
@@ -750,7 +768,12 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
if time is None: if time is None:
time = self.sample_time(actions.shape[0], actions.device) time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None] if time.ndim == 1:
time_expanded = time[:, None, None]
elif time.ndim == 2:
time_expanded = time[:, :, None]
else:
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
x_t = time_expanded * noise + (1 - time_expanded) * actions x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions u_t = noise - actions
@@ -846,24 +869,37 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
dt = -1.0 / num_steps dt = -1.0 / num_steps
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
use_training_time_rtc = self._training_time_rtc_inference_enabled()
x_t = noise x_t = noise
for step in range(num_steps): for step in range(num_steps):
time = 1.0 + step * dt time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): if use_training_time_rtc:
return self.denoise_step( x_t_cond, time_tensor = apply_training_time_rtc_inference(
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
)
v_t = self.denoise_step(
state=state, state=state,
prefix_pad_masks=prefix_pad_masks, prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values, past_key_values=past_key_values,
x_t=input_x_t, x_t=x_t_cond,
timestep=current_timestep, timestep=time_tensor,
) )
elif self._rtc_enabled():
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
if self._rtc_enabled(): def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
inference_delay = kwargs.get("inference_delay") return self.denoise_step(
prev_chunk_left_over = kwargs.get("prev_chunk_left_over") state=state,
execution_horizon = kwargs.get("execution_horizon") prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
v_t = self.rtc_processor.denoise_step( v_t = self.rtc_processor.denoise_step(
x_t=x_t, x_t=x_t,
@@ -874,7 +910,14 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
execution_horizon=execution_horizon, execution_horizon=execution_horizon,
) )
else: else:
v_t = denoise_step_partial_call(x_t) time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
v_t = self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=x_t,
timestep=time_tensor,
)
x_t = x_t + dt * v_t x_t = x_t + dt * v_t
@@ -1277,7 +1320,19 @@ class PI0Policy(PreTrainedPolicy):
actions = self.prepare_action(batch) actions = self.prepare_action(batch)
# Compute loss # Compute loss
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions) postfix_mask = None
rtc_cfg = self.config.rtc_training_config
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
batch_size = actions.shape[0]
time = self.model.sample_time(batch_size, actions.device)
noise = self.model.sample_noise(actions.shape, actions.device)
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
losses = self.model.forward(
images, img_masks, lang_tokens, lang_masks, state, actions, noise=noise, time=time
)
else:
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
# Truncate losses to actual action dimensions # Truncate losses to actual action dimensions
original_action_dim = self.config.output_features[ACTION].shape[0] original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -1289,12 +1344,12 @@ class PI0Policy(PreTrainedPolicy):
if reduction == "none": if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims # Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = losses.mean(dim=(1, 2)) per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item() loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict return per_sample_loss, loss_dict
else: else:
# Default: return scalar mean loss # Default: return scalar mean loss
loss = losses.mean() loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
loss_dict["loss"] = loss.item() loss_dict["loss"] = loss.item()
return loss, loss_dict return loss, loss_dict
@@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224 DEFAULT_IMAGE_SIZE = 224
@@ -52,6 +52,7 @@ class PI05Config(PreTrainedConfig):
# Real-Time Chunking (RTC) configuration # Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None rtc_config: RTCConfig | None = None
rtc_training_config: RTCTrainingConfig | None = None
image_resolution: tuple[int, int] = ( image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE,
+66 -18
View File
@@ -44,6 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.rtc.training_time import (
apply_rtc_training_time,
apply_training_time_rtc_inference,
masked_mean,
sample_rtc_delay,
)
from lerobot.utils.constants import ( from lerobot.utils.constants import (
ACTION, ACTION,
OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_ATTENTION_MASK,
@@ -78,8 +84,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
if dimension % 2 != 0: if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2") raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1: if time.ndim not in (1, 2):
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
dtype = get_safe_dtype(torch.float64, device.type) dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
@@ -87,8 +93,14 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
# Compute the outer product # Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None] if time.ndim == 1:
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
time_flat = time.reshape(-1)
sin_input = scaling_factor[None, :] * time_flat[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb.reshape(*time.shape, dimension)
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy) def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
@@ -602,6 +614,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
def _rtc_enabled(self): def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _training_time_rtc_inference_enabled(self):
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs): def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled.""" """Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training: if self.gradient_checkpointing_enabled and self.training:
@@ -729,7 +744,12 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
if time is None: if time is None:
time = self.sample_time(actions.shape[0], actions.device) time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None] if time.ndim == 1:
time_expanded = time[:, None, None]
elif time.ndim == 2:
time_expanded = time[:, :, None]
else:
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
x_t = time_expanded * noise + (1 - time_expanded) * actions x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions u_t = noise - actions
@@ -820,23 +840,35 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
dt = -1.0 / num_steps dt = -1.0 / num_steps
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
use_training_time_rtc = self._training_time_rtc_inference_enabled()
x_t = noise x_t = noise
for step in range(num_steps): for step in range(num_steps):
time = 1.0 + step * dt time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): if use_training_time_rtc:
return self.denoise_step( x_t_cond, time_tensor = apply_training_time_rtc_inference(
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
)
v_t = self.denoise_step(
prefix_pad_masks=prefix_pad_masks, prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values, past_key_values=past_key_values,
x_t=input_x_t, x_t=x_t_cond,
timestep=current_timestep, timestep=time_tensor,
) )
elif self._rtc_enabled():
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
if self._rtc_enabled(): def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
inference_delay = kwargs.get("inference_delay") return self.denoise_step(
prev_chunk_left_over = kwargs.get("prev_chunk_left_over") prefix_pad_masks=prefix_pad_masks,
execution_horizon = kwargs.get("execution_horizon") past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
v_t = self.rtc_processor.denoise_step( v_t = self.rtc_processor.denoise_step(
x_t=x_t, x_t=x_t,
@@ -847,7 +879,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
execution_horizon=execution_horizon, execution_horizon=execution_horizon,
) )
else: else:
v_t = denoise_step_partial_call(x_t) time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=x_t,
timestep=time_tensor,
)
x_t = x_t + dt * v_t x_t = x_t + dt * v_t
@@ -1250,7 +1288,17 @@ class PI05Policy(PreTrainedPolicy):
actions = self.prepare_action(batch) actions = self.prepare_action(batch)
# Compute loss (no separate state needed for PI05) # Compute loss (no separate state needed for PI05)
losses = self.model.forward(images, img_masks, tokens, masks, actions) postfix_mask = None
rtc_cfg = self.config.rtc_training_config
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
batch_size = actions.shape[0]
time = self.model.sample_time(batch_size, actions.device)
noise = self.model.sample_noise(actions.shape, actions.device)
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise=noise, time=time)
else:
losses = self.model.forward(images, img_masks, tokens, masks, actions)
# Truncate losses to actual action dimensions # Truncate losses to actual action dimensions
original_action_dim = self.config.output_features[ACTION].shape[0] original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -1262,12 +1310,12 @@ class PI05Policy(PreTrainedPolicy):
if reduction == "none": if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims # Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = losses.mean(dim=(1, 2)) per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item() loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict return per_sample_loss, loss_dict
else: else:
# Default: return scalar mean loss # Default: return scalar mean loss
loss = losses.mean() loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
loss_dict["loss"] = loss.item() loss_dict["loss"] = loss.item()
return loss, loss_dict return loss, loss_dict
+20 -1
View File
@@ -23,7 +23,7 @@ Based on:
from dataclasses import dataclass from dataclasses import dataclass
from lerobot.configs.types import RTCAttentionSchedule from lerobot.configs.types import RTCAttentionSchedule, RTCTrainingDelayDistribution
@dataclass @dataclass
@@ -53,3 +53,22 @@ class RTCConfig:
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}") raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
if self.debug_maxlen <= 0: if self.debug_maxlen <= 0:
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}") raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
@dataclass
class RTCTrainingConfig:
"""Configuration for training-time RTC action prefix conditioning."""
enabled: bool = False
min_delay: int = 0
max_delay: int = 0
delay_distribution: RTCTrainingDelayDistribution = RTCTrainingDelayDistribution.UNIFORM
exp_decay: float = 1.0
def __post_init__(self):
if self.min_delay < 0:
raise ValueError(f"min_delay must be >= 0, got {self.min_delay}")
if self.max_delay < self.min_delay:
raise ValueError(f"max_delay ({self.max_delay}) must be >= min_delay ({self.min_delay})")
if self.exp_decay <= 0:
raise ValueError(f"exp_decay must be positive, got {self.exp_decay}")
+110
View File
@@ -0,0 +1,110 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from lerobot.configs.types import RTCTrainingDelayDistribution
from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig
def sample_rtc_delay(cfg: RTCTrainingConfig, batch_size: int, device: torch.device) -> torch.Tensor:
if cfg.max_delay == cfg.min_delay:
return torch.full((batch_size,), cfg.min_delay, device=device, dtype=torch.long)
if cfg.delay_distribution == RTCTrainingDelayDistribution.UNIFORM:
return torch.randint(cfg.min_delay, cfg.max_delay + 1, (batch_size,), device=device, dtype=torch.long)
delay_values = torch.arange(cfg.min_delay, cfg.max_delay + 1, device=device, dtype=torch.long)
weights = torch.exp(-cfg.exp_decay * delay_values.to(dtype=torch.float32))
probs = weights / weights.sum()
samples = torch.multinomial(probs, batch_size, replacement=True)
return delay_values[samples]
def apply_rtc_training_time(
time: torch.Tensor, delay: torch.Tensor, seq_len: int
) -> tuple[torch.Tensor, torch.Tensor]:
device = time.device
delay = torch.clamp(delay, max=seq_len)
prefix_mask = torch.arange(seq_len, device=device)[None, :] < delay[:, None]
time_tokens = time[:, None].expand(-1, seq_len)
time_tokens = time_tokens.masked_fill(prefix_mask, 0.0)
postfix_mask = ~prefix_mask
return time_tokens, postfix_mask
def masked_mean(
losses: torch.Tensor, mask: torch.Tensor | None, reduce_dims: tuple[int, ...], eps: float = 1e-8
) -> torch.Tensor:
if mask is None:
return losses.mean(dim=reduce_dims)
mask = mask.to(dtype=losses.dtype)
while mask.dim() < losses.dim():
mask = mask.unsqueeze(-1)
masked = losses * mask
denom = mask.sum(dim=reduce_dims).clamp_min(eps)
return masked.sum(dim=reduce_dims) / denom
def apply_training_time_rtc_inference(
x_t: torch.Tensor,
time: float,
inference_delay: int | None,
prev_chunk_left_over: torch.Tensor | None,
chunk_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply training-time RTC conditioning during inference.
Based on Algorithm 1 from "Training-Time Action Conditioning for Efficient Real-Time Chunking".
At each denoising step:
1. Replace prefix positions in x_t with ground truth from previous chunk
2. Create per-token timesteps with 1.0 for prefix positions
Args:
x_t: Current noisy actions (B, T, D)
time: Current flow matching timestep (scalar)
inference_delay: Number of prefix actions to condition on
prev_chunk_left_over: Previous chunk's leftover actions (B, T, D)
chunk_size: Total chunk size T
Returns:
x_t_conditioned: x_t with prefix replaced by previous actions
time_per_token: Per-token timesteps (B, T) with 1.0 for prefix
"""
batch_size = x_t.shape[0]
device = x_t.device
if inference_delay is None or inference_delay <= 0 or prev_chunk_left_over is None:
time_scalar = torch.full((batch_size,), time, device=device, dtype=torch.float32)
return x_t, time_scalar
delay = min(inference_delay, chunk_size)
prefix_mask = torch.arange(chunk_size, device=device)[None, :] < delay
x_t_conditioned = torch.where(
prefix_mask[:, :, None].expand_as(x_t),
prev_chunk_left_over[:, :chunk_size, :],
x_t,
)
time_per_token = torch.full((batch_size, chunk_size), time, device=device, dtype=torch.float32)
time_per_token = time_per_token.masked_fill(prefix_mask, 1.0)
return x_t_conditioned, time_per_token
@@ -20,7 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import ( from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig, CosineDecayWithWarmupSchedulerConfig,
) )
from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
from lerobot.utils.constants import OBS_IMAGES from lerobot.utils.constants import OBS_IMAGES
@@ -103,8 +103,9 @@ class SmolVLAConfig(PreTrainedConfig):
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
max_period: float = 4.0 max_period: float = 4.0
# Real-Time Chunking (RTC) configuration # Real-Time Chunking (RTC) configurations
rtc_config: RTCConfig | None = None rtc_config: RTCConfig | None = None
rtc_training_config: RTCTrainingConfig | None = None
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
@@ -63,6 +63,12 @@ from typing_extensions import Unpack
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.rtc.training_time import (
apply_rtc_training_time,
apply_training_time_rtc_inference,
masked_mean,
sample_rtc_delay,
)
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.policies.utils import ( from lerobot.policies.utils import (
@@ -85,8 +91,8 @@ def create_sinusoidal_pos_embedding(
if dimension % 2 != 0: if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2") raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1: if time.ndim not in (1, 2):
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
dtype = get_safe_dtype(torch.float64, device.type) dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
@@ -94,9 +100,14 @@ def create_sinusoidal_pos_embedding(
# Compute the outer product # Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None] if time.ndim == 1:
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
time_flat = time.reshape(-1)
sin_input = scaling_factor[None, :] * time_flat[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb return pos_emb.reshape(*time.shape, dimension)
def make_att_2d_masks(pad_masks, att_masks): def make_att_2d_masks(pad_masks, att_masks):
@@ -375,6 +386,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
actions = self.prepare_action(batch) actions = self.prepare_action(batch)
postfix_mask = None
rtc_cfg = self.config.rtc_training_config
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
batch_size = actions.shape[0]
if time is None:
time = self.model.sample_time(batch_size, actions.device)
if noise is None:
noise = self.model.sample_noise(actions.shape, actions.device)
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
actions_is_pad = batch.get("actions_id_pad") actions_is_pad = batch.get("actions_id_pad")
loss_dict = {} loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
@@ -384,6 +405,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
in_episode_bound = ~actions_is_pad in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1) losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone() loss_dict["losses_after_in_ep_bound"] = losses.clone()
postfix_mask = in_episode_bound if postfix_mask is None else (postfix_mask & in_episode_bound)
# Remove padding # Remove padding
losses = losses[:, :, : self.config.max_action_dim] losses = losses[:, :, : self.config.max_action_dim]
@@ -391,12 +413,12 @@ class SmolVLAPolicy(PreTrainedPolicy):
if reduction == "none": if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims # Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = losses.mean(dim=(1, 2)) per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item() loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict return per_sample_loss, loss_dict
else: else:
# Default: return scalar mean loss # Default: return scalar mean loss
loss = losses.mean() loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
loss_dict["loss"] = loss.item() loss_dict["loss"] = loss.item()
return loss, loss_dict return loss, loss_dict
@@ -596,6 +618,9 @@ class VLAFlowMatching(nn.Module):
def _rtc_enabled(self): def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _training_time_rtc_inference_enabled(self):
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
def set_requires_grad(self): def set_requires_grad(self):
for params in self.state_proj.parameters(): for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj params.requires_grad = self.config.train_state_proj
@@ -731,7 +756,10 @@ class VLAFlowMatching(nn.Module):
) )
time_emb = time_emb.type(dtype=dtype) time_emb = time_emb.type(dtype=dtype)
time_emb = time_emb[:, None, :].expand_as(action_emb) if time_emb.dim() == 2:
time_emb = time_emb[:, None, :].expand_as(action_emb)
elif time_emb.shape[:2] != action_emb.shape[:2]:
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
action_time_emb = torch.cat([action_emb, time_emb], dim=2) action_time_emb = torch.cat([action_emb, time_emb], dim=2)
action_time_emb = self.action_time_mlp_in(action_time_emb) action_time_emb = self.action_time_mlp_in(action_time_emb)
@@ -763,7 +791,12 @@ class VLAFlowMatching(nn.Module):
if time is None: if time is None:
time = self.sample_time(actions.shape[0], actions.device) time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None] if time.ndim == 1:
time_expanded = time[:, None, None]
elif time.ndim == 2:
time_expanded = time[:, :, None]
else:
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
x_t = time_expanded * noise + (1 - time_expanded) * actions x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
@@ -826,23 +859,35 @@ class VLAFlowMatching(nn.Module):
num_steps = self.config.num_steps num_steps = self.config.num_steps
dt = -1.0 / num_steps dt = -1.0 / num_steps
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
use_training_time_rtc = self._training_time_rtc_inference_enabled()
x_t = noise x_t = noise
for step in range(num_steps): for step in range(num_steps):
time = 1.0 + step * dt time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): if use_training_time_rtc:
return self.denoise_step( x_t_cond, time_tensor = apply_training_time_rtc_inference(
x_t=input_x_t, x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
)
v_t = self.denoise_step(
x_t=x_t_cond,
prefix_pad_masks=prefix_pad_masks, prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values, past_key_values=past_key_values,
timestep=current_timestep, timestep=time_tensor,
) )
elif self._rtc_enabled():
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
if self._rtc_enabled(): def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
inference_delay = kwargs.get("inference_delay") return self.denoise_step(
prev_chunk_left_over = kwargs.get("prev_chunk_left_over") x_t=input_x_t,
execution_horizon = kwargs.get("execution_horizon") prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=current_timestep,
)
v_t = self.rtc_processor.denoise_step( v_t = self.rtc_processor.denoise_step(
x_t=x_t, x_t=x_t,
@@ -853,7 +898,13 @@ class VLAFlowMatching(nn.Module):
execution_horizon=execution_horizon, execution_horizon=execution_horizon,
) )
else: else:
v_t = denoise_step_partial_call(x_t) time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
v_t = self.denoise_step(
x_t=x_t,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=time_tensor,
)
x_t = x_t + dt * v_t x_t = x_t + dt * v_t
@@ -0,0 +1,50 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for training-time RTC helpers."""
import torch
from lerobot.configs.types import RTCTrainingDelayDistribution
from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig
from lerobot.policies.rtc.training_time import apply_rtc_training_time, sample_rtc_delay
def test_rtc_training_config_defaults():
config = RTCTrainingConfig()
assert config.enabled is False
assert config.min_delay == 0
assert config.max_delay == 0
assert config.delay_distribution == RTCTrainingDelayDistribution.UNIFORM
assert config.exp_decay == 1.0
def test_sample_rtc_delay_uniform_range():
cfg = RTCTrainingConfig(enabled=True, min_delay=1, max_delay=4)
delays = sample_rtc_delay(cfg, batch_size=100, device=torch.device("cpu"))
assert delays.min().item() >= 1
assert delays.max().item() <= 4
def test_apply_rtc_training_time_prefix_mask():
time = torch.tensor([0.5])
delays = torch.tensor([2])
time_tokens, postfix_mask = apply_rtc_training_time(time, delays, seq_len=4)
assert time_tokens.shape == (1, 4)
assert postfix_mask.shape == (1, 4)
# Delay=2 means the first two steps are prefix (time forced to 0.0) and only the last two are postfix.
assert torch.allclose(time_tokens[0], torch.tensor([0.0, 0.0, 0.5, 0.5]))
assert torch.equal(postfix_mask[0], torch.tensor([False, False, True, True]))