Compare commits

...

5 Commits

Author SHA1 Message Date
Pepijn f147a4cd48 Add inference for training time rtc 2026-01-29 11:05:42 +01:00
Pepijn c3fa269b21 Merge branch 'main' into feat/training_time_rtc 2026-01-27 17:34:56 +01:00
Pepijn 385ba8d1b7 remove wall-oss from doc links 2026-01-20 20:11:56 +01:00
Pepijn f4ccf911fa format 2026-01-20 20:08:28 +01:00
Pepijn 0cb8c92fe4 Implement training time rtc for pi0, pi0.5 and smolvla 2026-01-20 20:02:10 +01:00
12 changed files with 490 additions and 61 deletions
+2
View File
@@ -57,6 +57,8 @@
title: Use Async Inference
- local: rtc
title: Real-Time Chunking (RTC)
- local: training_time_rtc
title: Training-Time RTC
title: "Inference"
- sections:
- local: envhub
+86
View File
@@ -0,0 +1,86 @@
# Training-Time RTC
Training-Time RTC teaches the model to handle inference delay during training.
It feeds the **ground-truth action prefix** to the model and trains only on the remaining postfix actions.
This keeps chunk transitions smooth without doing any inference-time inpainting.
Based on: [Training-Time Action Conditioning for Efficient Real-Time Chunking](https://arxiv.org/abs/2512.05964).
LeRobot supports this for `pi0`, `pi05` and `smolvla` without changing model parameters.
---
## How It Works
### At Training Time
- Sample a delay `d` per batch element.
- Keep the first `d` action steps as **ground truth** (no noise).
- Add noise only to the postfix actions.
- Set the flow-matching timestep to **1.0** for prefix tokens and normal timesteps for postfix tokens.
- Mask the loss to only train on the postfix.
### At Inference Time
When `rtc_training_config.enabled=true`, the model uses training-time RTC inference:
- Replace prefix positions in `x_t` with previous chunk's leftover actions.
- Set timestep to **1.0** for prefix positions.
---
## Quick Start (CLI)
```bash
lerobot-train \
--policy.type=pi0 \
--dataset.repo_id=your/dataset \
--policy.rtc_training_config.enabled=true \
--policy.rtc_training_config.min_delay=0 \
--policy.rtc_training_config.max_delay=6 \
--policy.rtc_training_config.delay_distribution=UNIFORM
```
---
## Inference with Training-Time RTC
After training with `rtc_training_config`, use the same config at inference. The model will automatically use training-time RTC inference:
```python
policy = PI0Policy.from_pretrained("path/to/trained/model")
# rtc_training_config is loaded from the saved config
actions = policy.predict_action_chunk(
batch,
inference_delay=5, # estimated delay in timesteps
prev_chunk_left_over=previous_actions, # from previous chunk
)
```
---
## Key Parameters
`RTCTrainingConfig` is available on the policy config (`pi0`, `pi05`, `smolvla`, `xvla`):
- **`enabled`**: Toggle training-time RTC (both training and inference).
- **`min_delay` / `max_delay`**: Delay range (inclusive).
- **`delay_distribution`**:
- `UNIFORM`: uniform in `[min_delay, max_delay]`
- `EXP`: exponentially decayed distribution over delays
- **`exp_decay`**: Exponential decay factor for `EXP` sampling.
---
## Notes and Recommendations
- Start with `min_delay=0` and `max_delay` around your expected worst-case inference delay.
- Use `EXP` if you want more supervision on smaller delays.
---
## Related Docs
- [Real-Time Chunking (Inference-Time RTC)](./rtc)
- [Pi0](./pi0), [Pi0.5](./pi05), [SmolVLA](./smolvla)
+5
View File
@@ -50,3 +50,8 @@ class RTCAttentionSchedule(str, Enum):
ONES = "ONES"
LINEAR = "LINEAR"
EXP = "EXP"
class RTCTrainingDelayDistribution(str, Enum):
UNIFORM = "UNIFORM"
EXP = "EXP"
@@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224
@@ -50,8 +50,9 @@ class PI0Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
# Real-Time Chunking (RTC) configurations
rtc_config: RTCConfig | None = None
rtc_training_config: RTCTrainingConfig | None = None
image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE,
+74 -19
View File
@@ -44,6 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.rtc.training_time import (
apply_rtc_training_time,
apply_training_time_rtc_inference,
masked_mean,
sample_rtc_delay,
)
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -79,8 +85,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
if time.ndim not in (1, 2):
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
@@ -88,8 +94,14 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
if time.ndim == 1:
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
time_flat = time.reshape(-1)
sin_input = scaling_factor[None, :] * time_flat[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb.reshape(*time.shape, dimension)
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
@@ -605,6 +617,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _training_time_rtc_inference_enabled(self):
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -714,7 +729,10 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
action_emb = self._apply_checkpoint(action_proj_func, noisy_actions)
time_emb = time_emb[:, None, :].expand_as(action_emb)
if time_emb.dim() == 2:
time_emb = time_emb[:, None, :].expand_as(action_emb)
elif time_emb.shape[:2] != action_emb.shape[:2]:
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
def mlp_func(action_time_emb):
@@ -750,7 +768,12 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
if time.ndim == 1:
time_expanded = time[:, None, None]
elif time.ndim == 2:
time_expanded = time[:, :, None]
else:
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
@@ -846,24 +869,37 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
dt = -1.0 / num_steps
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
use_training_time_rtc = self._training_time_rtc_inference_enabled()
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
if use_training_time_rtc:
x_t_cond, time_tensor = apply_training_time_rtc_inference(
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
)
v_t = self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
x_t=x_t_cond,
timestep=time_tensor,
)
elif self._rtc_enabled():
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
@@ -874,7 +910,14 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
v_t = self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=x_t,
timestep=time_tensor,
)
x_t = x_t + dt * v_t
@@ -1277,7 +1320,19 @@ class PI0Policy(PreTrainedPolicy):
actions = self.prepare_action(batch)
# Compute loss
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
postfix_mask = None
rtc_cfg = self.config.rtc_training_config
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
batch_size = actions.shape[0]
time = self.model.sample_time(batch_size, actions.device)
noise = self.model.sample_noise(actions.shape, actions.device)
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
losses = self.model.forward(
images, img_masks, lang_tokens, lang_masks, state, actions, noise=noise, time=time
)
else:
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
# Truncate losses to actual action dimensions
original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -1289,12 +1344,12 @@ class PI0Policy(PreTrainedPolicy):
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = losses.mean(dim=(1, 2))
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict
else:
# Default: return scalar mean loss
loss = losses.mean()
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
loss_dict["loss"] = loss.item()
return loss, loss_dict
@@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
DEFAULT_IMAGE_SIZE = 224
@@ -52,6 +52,7 @@ class PI05Config(PreTrainedConfig):
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
rtc_training_config: RTCTrainingConfig | None = None
image_resolution: tuple[int, int] = (
DEFAULT_IMAGE_SIZE,
+66 -18
View File
@@ -44,6 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.rtc.training_time import (
apply_rtc_training_time,
apply_training_time_rtc_inference,
masked_mean,
sample_rtc_delay,
)
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -78,8 +84,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
if time.ndim not in (1, 2):
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
@@ -87,8 +93,14 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
if time.ndim == 1:
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
time_flat = time.reshape(-1)
sin_input = scaling_factor[None, :] * time_flat[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb.reshape(*time.shape, dimension)
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
@@ -602,6 +614,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _training_time_rtc_inference_enabled(self):
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -729,7 +744,12 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
if time.ndim == 1:
time_expanded = time[:, None, None]
elif time.ndim == 2:
time_expanded = time[:, :, None]
else:
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
@@ -820,23 +840,35 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
dt = -1.0 / num_steps
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
use_training_time_rtc = self._training_time_rtc_inference_enabled()
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
if use_training_time_rtc:
x_t_cond, time_tensor = apply_training_time_rtc_inference(
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
)
v_t = self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
x_t=x_t_cond,
timestep=time_tensor,
)
elif self._rtc_enabled():
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
@@ -847,7 +879,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=x_t,
timestep=time_tensor,
)
x_t = x_t + dt * v_t
@@ -1250,7 +1288,17 @@ class PI05Policy(PreTrainedPolicy):
actions = self.prepare_action(batch)
# Compute loss (no separate state needed for PI05)
losses = self.model.forward(images, img_masks, tokens, masks, actions)
postfix_mask = None
rtc_cfg = self.config.rtc_training_config
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
batch_size = actions.shape[0]
time = self.model.sample_time(batch_size, actions.device)
noise = self.model.sample_noise(actions.shape, actions.device)
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise=noise, time=time)
else:
losses = self.model.forward(images, img_masks, tokens, masks, actions)
# Truncate losses to actual action dimensions
original_action_dim = self.config.output_features[ACTION].shape[0]
@@ -1262,12 +1310,12 @@ class PI05Policy(PreTrainedPolicy):
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = losses.mean(dim=(1, 2))
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict
else:
# Default: return scalar mean loss
loss = losses.mean()
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
loss_dict["loss"] = loss.item()
return loss, loss_dict
+20 -1
View File
@@ -23,7 +23,7 @@ Based on:
from dataclasses import dataclass
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.configs.types import RTCAttentionSchedule, RTCTrainingDelayDistribution
@dataclass
@@ -53,3 +53,22 @@ class RTCConfig:
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
if self.debug_maxlen <= 0:
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
@dataclass
class RTCTrainingConfig:
"""Configuration for training-time RTC action prefix conditioning."""
enabled: bool = False
min_delay: int = 0
max_delay: int = 0
delay_distribution: RTCTrainingDelayDistribution = RTCTrainingDelayDistribution.UNIFORM
exp_decay: float = 1.0
def __post_init__(self):
if self.min_delay < 0:
raise ValueError(f"min_delay must be >= 0, got {self.min_delay}")
if self.max_delay < self.min_delay:
raise ValueError(f"max_delay ({self.max_delay}) must be >= min_delay ({self.min_delay})")
if self.exp_decay <= 0:
raise ValueError(f"exp_decay must be positive, got {self.exp_decay}")
+110
View File
@@ -0,0 +1,110 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from lerobot.configs.types import RTCTrainingDelayDistribution
from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig
def sample_rtc_delay(cfg: RTCTrainingConfig, batch_size: int, device: torch.device) -> torch.Tensor:
if cfg.max_delay == cfg.min_delay:
return torch.full((batch_size,), cfg.min_delay, device=device, dtype=torch.long)
if cfg.delay_distribution == RTCTrainingDelayDistribution.UNIFORM:
return torch.randint(cfg.min_delay, cfg.max_delay + 1, (batch_size,), device=device, dtype=torch.long)
delay_values = torch.arange(cfg.min_delay, cfg.max_delay + 1, device=device, dtype=torch.long)
weights = torch.exp(-cfg.exp_decay * delay_values.to(dtype=torch.float32))
probs = weights / weights.sum()
samples = torch.multinomial(probs, batch_size, replacement=True)
return delay_values[samples]
def apply_rtc_training_time(
time: torch.Tensor, delay: torch.Tensor, seq_len: int
) -> tuple[torch.Tensor, torch.Tensor]:
device = time.device
delay = torch.clamp(delay, max=seq_len)
prefix_mask = torch.arange(seq_len, device=device)[None, :] < delay[:, None]
time_tokens = time[:, None].expand(-1, seq_len)
time_tokens = time_tokens.masked_fill(prefix_mask, 0.0)
postfix_mask = ~prefix_mask
return time_tokens, postfix_mask
def masked_mean(
losses: torch.Tensor, mask: torch.Tensor | None, reduce_dims: tuple[int, ...], eps: float = 1e-8
) -> torch.Tensor:
if mask is None:
return losses.mean(dim=reduce_dims)
mask = mask.to(dtype=losses.dtype)
while mask.dim() < losses.dim():
mask = mask.unsqueeze(-1)
masked = losses * mask
denom = mask.sum(dim=reduce_dims).clamp_min(eps)
return masked.sum(dim=reduce_dims) / denom
def apply_training_time_rtc_inference(
x_t: torch.Tensor,
time: float,
inference_delay: int | None,
prev_chunk_left_over: torch.Tensor | None,
chunk_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply training-time RTC conditioning during inference.
Based on Algorithm 1 from "Training-Time Action Conditioning for Efficient Real-Time Chunking".
At each denoising step:
1. Replace prefix positions in x_t with ground truth from previous chunk
2. Create per-token timesteps with 1.0 for prefix positions
Args:
x_t: Current noisy actions (B, T, D)
time: Current flow matching timestep (scalar)
inference_delay: Number of prefix actions to condition on
prev_chunk_left_over: Previous chunk's leftover actions (B, T, D)
chunk_size: Total chunk size T
Returns:
x_t_conditioned: x_t with prefix replaced by previous actions
time_per_token: Per-token timesteps (B, T) with 1.0 for prefix
"""
batch_size = x_t.shape[0]
device = x_t.device
if inference_delay is None or inference_delay <= 0 or prev_chunk_left_over is None:
time_scalar = torch.full((batch_size,), time, device=device, dtype=torch.float32)
return x_t, time_scalar
delay = min(inference_delay, chunk_size)
prefix_mask = torch.arange(chunk_size, device=device)[None, :] < delay
x_t_conditioned = torch.where(
prefix_mask[:, :, None].expand_as(x_t),
prev_chunk_left_over[:, :chunk_size, :],
x_t,
)
time_per_token = torch.full((batch_size, chunk_size), time, device=device, dtype=torch.float32)
time_per_token = time_per_token.masked_fill(prefix_mask, 1.0)
return x_t_conditioned, time_per_token
@@ -20,7 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig
from lerobot.utils.constants import OBS_IMAGES
@@ -103,8 +103,9 @@ class SmolVLAConfig(PreTrainedConfig):
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
# Real-Time Chunking (RTC) configurations
rtc_config: RTCConfig | None = None
rtc_training_config: RTCTrainingConfig | None = None
def __post_init__(self):
super().__post_init__()
@@ -63,6 +63,12 @@ from typing_extensions import Unpack
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.rtc.training_time import (
apply_rtc_training_time,
apply_training_time_rtc_inference,
masked_mean,
sample_rtc_delay,
)
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.policies.utils import (
@@ -85,8 +91,8 @@ def create_sinusoidal_pos_embedding(
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
if time.ndim not in (1, 2):
raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
@@ -94,9 +100,14 @@ def create_sinusoidal_pos_embedding(
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
if time.ndim == 1:
sin_input = scaling_factor[None, :] * time[:, None]
return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
time_flat = time.reshape(-1)
sin_input = scaling_factor[None, :] * time_flat[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb
return pos_emb.reshape(*time.shape, dimension)
def make_att_2d_masks(pad_masks, att_masks):
@@ -375,6 +386,16 @@ class SmolVLAPolicy(PreTrainedPolicy):
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
actions = self.prepare_action(batch)
postfix_mask = None
rtc_cfg = self.config.rtc_training_config
if rtc_cfg is not None and rtc_cfg.enabled and self.training:
batch_size = actions.shape[0]
if time is None:
time = self.model.sample_time(batch_size, actions.device)
if noise is None:
noise = self.model.sample_noise(actions.shape, actions.device)
delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device)
time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1])
actions_is_pad = batch.get("actions_id_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
@@ -384,6 +405,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone()
postfix_mask = in_episode_bound if postfix_mask is None else (postfix_mask & in_episode_bound)
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
@@ -391,12 +413,12 @@ class SmolVLAPolicy(PreTrainedPolicy):
if reduction == "none":
# Return per-sample losses (B,) by averaging over time and action dims
per_sample_loss = losses.mean(dim=(1, 2))
per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2))
loss_dict["loss"] = per_sample_loss.mean().item()
return per_sample_loss, loss_dict
else:
# Default: return scalar mean loss
loss = losses.mean()
loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2))
loss_dict["loss"] = loss.item()
return loss, loss_dict
@@ -596,6 +618,9 @@ class VLAFlowMatching(nn.Module):
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _training_time_rtc_inference_enabled(self):
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
def set_requires_grad(self):
for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj
@@ -731,7 +756,10 @@ class VLAFlowMatching(nn.Module):
)
time_emb = time_emb.type(dtype=dtype)
time_emb = time_emb[:, None, :].expand_as(action_emb)
if time_emb.dim() == 2:
time_emb = time_emb[:, None, :].expand_as(action_emb)
elif time_emb.shape[:2] != action_emb.shape[:2]:
raise ValueError(f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}")
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
action_time_emb = self.action_time_mlp_in(action_time_emb)
@@ -763,7 +791,12 @@ class VLAFlowMatching(nn.Module):
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
if time.ndim == 1:
time_expanded = time[:, None, None]
elif time.ndim == 2:
time_expanded = time[:, :, None]
else:
raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}")
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
@@ -826,23 +859,35 @@ class VLAFlowMatching(nn.Module):
num_steps = self.config.num_steps
dt = -1.0 / num_steps
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
use_training_time_rtc = self._training_time_rtc_inference_enabled()
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
x_t=input_x_t,
if use_training_time_rtc:
x_t_cond, time_tensor = apply_training_time_rtc_inference(
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
)
v_t = self.denoise_step(
x_t=x_t_cond,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=current_timestep,
timestep=time_tensor,
)
elif self._rtc_enabled():
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
return self.denoise_step(
x_t=input_x_t,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=current_timestep,
)
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
@@ -853,7 +898,13 @@ class VLAFlowMatching(nn.Module):
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
v_t = self.denoise_step(
x_t=x_t,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=time_tensor,
)
x_t = x_t + dt * v_t
@@ -0,0 +1,50 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for training-time RTC helpers."""
import torch
from lerobot.configs.types import RTCTrainingDelayDistribution
from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig
from lerobot.policies.rtc.training_time import apply_rtc_training_time, sample_rtc_delay
def test_rtc_training_config_defaults():
config = RTCTrainingConfig()
assert config.enabled is False
assert config.min_delay == 0
assert config.max_delay == 0
assert config.delay_distribution == RTCTrainingDelayDistribution.UNIFORM
assert config.exp_decay == 1.0
def test_sample_rtc_delay_uniform_range():
cfg = RTCTrainingConfig(enabled=True, min_delay=1, max_delay=4)
delays = sample_rtc_delay(cfg, batch_size=100, device=torch.device("cpu"))
assert delays.min().item() >= 1
assert delays.max().item() <= 4
def test_apply_rtc_training_time_prefix_mask():
time = torch.tensor([0.5])
delays = torch.tensor([2])
time_tokens, postfix_mask = apply_rtc_training_time(time, delays, seq_len=4)
assert time_tokens.shape == (1, 4)
assert postfix_mask.shape == (1, 4)
# Delay=2 means the first two steps are prefix (time forced to 0.0) and only the last two are postfix.
assert torch.allclose(time_tokens[0], torch.tensor([0.0, 0.0, 0.5, 0.5]))
assert torch.equal(postfix_mask[0], torch.tensor([False, False, True, True]))