Add inference for training time rtc

This commit is contained in:
Pepijn
2026-01-29 11:05:42 +01:00
parent c3fa269b21
commit f147a4cd48
5 changed files with 187 additions and 35 deletions
+26 -2
View File
@@ -12,7 +12,7 @@ LeRobot supports this for `pi0`, `pi05` and `smolvla` without changing model par
## How It Works
At training time:
### At Training Time
- Sample a delay `d` per batch element.
- Keep the first `d` action steps as **ground truth** (no noise).
@@ -20,6 +20,13 @@ At training time:
- Set the flow-matching timestep to **1.0** for prefix tokens and normal timesteps for postfix tokens.
- Mask the loss to only train on the postfix.
### At Inference Time
When `rtc_training_config.enabled=true`, the model uses training-time RTC inference:
- Replace prefix positions in `x_t` with previous chunk's leftover actions.
- Set timestep to **1.0** for prefix positions.
---
## Quick Start (CLI)
@@ -36,11 +43,28 @@ lerobot-train \
---
## Inference with Training-Time RTC
After training with `rtc_training_config`, use the same config at inference. The model will automatically use training-time RTC inference:
```python
policy = PI0Policy.from_pretrained("path/to/trained/model")
# rtc_training_config is loaded from the saved config
actions = policy.predict_action_chunk(
batch,
inference_delay=5, # estimated delay in timesteps
prev_chunk_left_over=previous_actions, # from previous chunk
)
```
---
## Key Parameters
`RTCTrainingConfig` is available on the policy config (`pi0`, `pi05`, `smolvla`, `xvla`):
- **`enabled`**: Toggle training-time RTC.
- **`enabled`**: Toggle training-time RTC (both training and inference).
- **`min_delay` / `max_delay`**: Delay range (inclusive).
- **`delay_distribution`**:
- `UNIFORM`: uniform in `[min_delay, max_delay]`
+35 -7
View File
@@ -44,7 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.rtc.training_time import apply_rtc_training_time, masked_mean, sample_rtc_delay
from lerobot.policies.rtc.training_time import (
apply_rtc_training_time,
apply_training_time_rtc_inference,
masked_mean,
sample_rtc_delay,
)
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -612,6 +617,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _training_time_rtc_inference_enabled(self):
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -861,9 +869,27 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
dt = -1.0 / num_steps
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
use_training_time_rtc = self._training_time_rtc_inference_enabled()
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
if use_training_time_rtc:
x_t_cond, time_tensor = apply_training_time_rtc_inference(
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
)
v_t = self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=x_t_cond,
timestep=time_tensor,
)
elif self._rtc_enabled():
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
@@ -875,11 +901,6 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
@@ -889,7 +910,14 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
v_t = self.denoise_step(
state=state,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=x_t,
timestep=time_tensor,
)
x_t = x_t + dt * v_t
+33 -7
View File
@@ -44,7 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.rtc.training_time import apply_rtc_training_time, masked_mean, sample_rtc_delay
from lerobot.policies.rtc.training_time import (
apply_rtc_training_time,
apply_training_time_rtc_inference,
masked_mean,
sample_rtc_delay,
)
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -609,6 +614,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _training_time_rtc_inference_enabled(self):
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training:
@@ -832,9 +840,26 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
dt = -1.0 / num_steps
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
use_training_time_rtc = self._training_time_rtc_inference_enabled()
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
if use_training_time_rtc:
x_t_cond, time_tensor = apply_training_time_rtc_inference(
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
)
v_t = self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=x_t_cond,
timestep=time_tensor,
)
elif self._rtc_enabled():
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
@@ -845,11 +870,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
@@ -859,7 +879,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
x_t=x_t,
timestep=time_tensor,
)
x_t = x_t + dt * v_t
+48
View File
@@ -60,3 +60,51 @@ def masked_mean(
masked = losses * mask
denom = mask.sum(dim=reduce_dims).clamp_min(eps)
return masked.sum(dim=reduce_dims) / denom
def apply_training_time_rtc_inference(
x_t: torch.Tensor,
time: float,
inference_delay: int | None,
prev_chunk_left_over: torch.Tensor | None,
chunk_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply training-time RTC conditioning during inference.
Based on Algorithm 1 from "Training-Time Action Conditioning for Efficient Real-Time Chunking".
At each denoising step:
1. Replace prefix positions in x_t with ground truth from previous chunk
2. Create per-token timesteps with 1.0 for prefix positions
Args:
x_t: Current noisy actions (B, T, D)
time: Current flow matching timestep (scalar)
inference_delay: Number of prefix actions to condition on
prev_chunk_left_over: Previous chunk's leftover actions (B, T, D)
chunk_size: Total chunk size T
Returns:
x_t_conditioned: x_t with prefix replaced by previous actions
time_per_token: Per-token timesteps (B, T) with 1.0 for prefix
"""
batch_size = x_t.shape[0]
device = x_t.device
if inference_delay is None or inference_delay <= 0 or prev_chunk_left_over is None:
time_scalar = torch.full((batch_size,), time, device=device, dtype=torch.float32)
return x_t, time_scalar
delay = min(inference_delay, chunk_size)
prefix_mask = torch.arange(chunk_size, device=device)[None, :] < delay
x_t_conditioned = torch.where(
prefix_mask[:, :, None].expand_as(x_t),
prev_chunk_left_over[:, :chunk_size, :],
x_t,
)
time_per_token = torch.full((batch_size, chunk_size), time, device=device, dtype=torch.float32)
time_per_token = time_per_token.masked_fill(prefix_mask, 1.0)
return x_t_conditioned, time_per_token
@@ -63,7 +63,12 @@ from typing_extensions import Unpack
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.policies.rtc.training_time import apply_rtc_training_time, masked_mean, sample_rtc_delay
from lerobot.policies.rtc.training_time import (
apply_rtc_training_time,
apply_training_time_rtc_inference,
masked_mean,
sample_rtc_delay,
)
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.policies.utils import (
@@ -613,6 +618,9 @@ class VLAFlowMatching(nn.Module):
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _training_time_rtc_inference_enabled(self):
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
def set_requires_grad(self):
for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj
@@ -851,9 +859,26 @@ class VLAFlowMatching(nn.Module):
num_steps = self.config.num_steps
dt = -1.0 / num_steps
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
use_training_time_rtc = self._training_time_rtc_inference_enabled()
x_t = noise
for step in range(num_steps):
time = 1.0 + step * dt
if use_training_time_rtc:
x_t_cond, time_tensor = apply_training_time_rtc_inference(
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
)
v_t = self.denoise_step(
x_t=x_t_cond,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=time_tensor,
)
elif self._rtc_enabled():
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
@@ -864,11 +889,6 @@ class VLAFlowMatching(nn.Module):
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
@@ -878,7 +898,13 @@ class VLAFlowMatching(nn.Module):
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
v_t = self.denoise_step(
x_t=x_t,
prefix_pad_masks=prefix_pad_masks,
past_key_values=past_key_values,
timestep=time_tensor,
)
x_t = x_t + dt * v_t