mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 21:19:53 +00:00
Add inference for training time rtc
This commit is contained in:
@@ -12,7 +12,7 @@ LeRobot supports this for `pi0`, `pi05` and `smolvla` without changing model par
|
||||
|
||||
## How It Works
|
||||
|
||||
At training time:
|
||||
### At Training Time
|
||||
|
||||
- Sample a delay `d` per batch element.
|
||||
- Keep the first `d` action steps as **ground truth** (no noise).
|
||||
@@ -20,6 +20,13 @@ At training time:
|
||||
- Set the flow-matching timestep to **1.0** for prefix tokens and normal timesteps for postfix tokens.
|
||||
- Mask the loss to only train on the postfix.
|
||||
|
||||
### At Inference Time
|
||||
|
||||
When `rtc_training_config.enabled=true`, the model uses training-time RTC inference:
|
||||
|
||||
- Replace prefix positions in `x_t` with previous chunk's leftover actions.
|
||||
- Set timestep to **1.0** for prefix positions.
|
||||
|
||||
---
|
||||
|
||||
## Quick Start (CLI)
|
||||
@@ -36,11 +43,28 @@ lerobot-train \
|
||||
|
||||
---
|
||||
|
||||
## Inference with Training-Time RTC
|
||||
|
||||
After training with `rtc_training_config`, use the same config at inference. The model will automatically use training-time RTC inference:
|
||||
|
||||
```python
|
||||
policy = PI0Policy.from_pretrained("path/to/trained/model")
|
||||
# rtc_training_config is loaded from the saved config
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
batch,
|
||||
inference_delay=5, # estimated delay in timesteps
|
||||
prev_chunk_left_over=previous_actions, # from previous chunk
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Parameters
|
||||
|
||||
`RTCTrainingConfig` is available on the policy config (`pi0`, `pi05`, `smolvla`, `xvla`):
|
||||
|
||||
- **`enabled`**: Toggle training-time RTC.
|
||||
- **`enabled`**: Toggle training-time RTC (both training and inference).
|
||||
- **`min_delay` / `max_delay`**: Delay range (inclusive).
|
||||
- **`delay_distribution`**:
|
||||
- `UNIFORM`: uniform in `[min_delay, max_delay]`
|
||||
|
||||
@@ -44,7 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.rtc.training_time import apply_rtc_training_time, masked_mean, sample_rtc_delay
|
||||
from lerobot.policies.rtc.training_time import (
|
||||
apply_rtc_training_time,
|
||||
apply_training_time_rtc_inference,
|
||||
masked_mean,
|
||||
sample_rtc_delay,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -612,6 +617,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _training_time_rtc_inference_enabled(self):
|
||||
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -861,9 +869,27 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
use_training_time_rtc = self._training_time_rtc_inference_enabled()
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
|
||||
if use_training_time_rtc:
|
||||
x_t_cond, time_tensor = apply_training_time_rtc_inference(
|
||||
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
|
||||
)
|
||||
v_t = self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=x_t_cond,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
elif self._rtc_enabled():
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
@@ -875,11 +901,6 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
@@ -889,7 +910,14 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=x_t,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
|
||||
@@ -44,7 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.rtc.training_time import apply_rtc_training_time, masked_mean, sample_rtc_delay
|
||||
from lerobot.policies.rtc.training_time import (
|
||||
apply_rtc_training_time,
|
||||
apply_training_time_rtc_inference,
|
||||
masked_mean,
|
||||
sample_rtc_delay,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -609,6 +614,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _training_time_rtc_inference_enabled(self):
|
||||
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -832,9 +840,26 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
use_training_time_rtc = self._training_time_rtc_inference_enabled()
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
|
||||
if use_training_time_rtc:
|
||||
x_t_cond, time_tensor = apply_training_time_rtc_inference(
|
||||
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
|
||||
)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=x_t_cond,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
elif self._rtc_enabled():
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
@@ -845,11 +870,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
@@ -859,7 +879,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=x_t,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
|
||||
@@ -60,3 +60,51 @@ def masked_mean(
|
||||
masked = losses * mask
|
||||
denom = mask.sum(dim=reduce_dims).clamp_min(eps)
|
||||
return masked.sum(dim=reduce_dims) / denom
|
||||
|
||||
|
||||
def apply_training_time_rtc_inference(
|
||||
x_t: torch.Tensor,
|
||||
time: float,
|
||||
inference_delay: int | None,
|
||||
prev_chunk_left_over: torch.Tensor | None,
|
||||
chunk_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply training-time RTC conditioning during inference.
|
||||
|
||||
Based on Algorithm 1 from "Training-Time Action Conditioning for Efficient Real-Time Chunking".
|
||||
|
||||
At each denoising step:
|
||||
1. Replace prefix positions in x_t with ground truth from previous chunk
|
||||
2. Create per-token timesteps with 1.0 for prefix positions
|
||||
|
||||
Args:
|
||||
x_t: Current noisy actions (B, T, D)
|
||||
time: Current flow matching timestep (scalar)
|
||||
inference_delay: Number of prefix actions to condition on
|
||||
prev_chunk_left_over: Previous chunk's leftover actions (B, T, D)
|
||||
chunk_size: Total chunk size T
|
||||
|
||||
Returns:
|
||||
x_t_conditioned: x_t with prefix replaced by previous actions
|
||||
time_per_token: Per-token timesteps (B, T) with 1.0 for prefix
|
||||
"""
|
||||
batch_size = x_t.shape[0]
|
||||
device = x_t.device
|
||||
|
||||
if inference_delay is None or inference_delay <= 0 or prev_chunk_left_over is None:
|
||||
time_scalar = torch.full((batch_size,), time, device=device, dtype=torch.float32)
|
||||
return x_t, time_scalar
|
||||
|
||||
delay = min(inference_delay, chunk_size)
|
||||
prefix_mask = torch.arange(chunk_size, device=device)[None, :] < delay
|
||||
|
||||
x_t_conditioned = torch.where(
|
||||
prefix_mask[:, :, None].expand_as(x_t),
|
||||
prev_chunk_left_over[:, :chunk_size, :],
|
||||
x_t,
|
||||
)
|
||||
|
||||
time_per_token = torch.full((batch_size, chunk_size), time, device=device, dtype=torch.float32)
|
||||
time_per_token = time_per_token.masked_fill(prefix_mask, 1.0)
|
||||
|
||||
return x_t_conditioned, time_per_token
|
||||
|
||||
@@ -63,7 +63,12 @@ from typing_extensions import Unpack
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.rtc.training_time import apply_rtc_training_time, masked_mean, sample_rtc_delay
|
||||
from lerobot.policies.rtc.training_time import (
|
||||
apply_rtc_training_time,
|
||||
apply_training_time_rtc_inference,
|
||||
masked_mean,
|
||||
sample_rtc_delay,
|
||||
)
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
from lerobot.policies.utils import (
|
||||
@@ -613,6 +618,9 @@ class VLAFlowMatching(nn.Module):
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _training_time_rtc_inference_enabled(self):
|
||||
return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled
|
||||
|
||||
def set_requires_grad(self):
|
||||
for params in self.state_proj.parameters():
|
||||
params.requires_grad = self.config.train_state_proj
|
||||
@@ -851,9 +859,26 @@ class VLAFlowMatching(nn.Module):
|
||||
num_steps = self.config.num_steps
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
use_training_time_rtc = self._training_time_rtc_inference_enabled()
|
||||
|
||||
x_t = noise
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
|
||||
if use_training_time_rtc:
|
||||
x_t_cond, time_tensor = apply_training_time_rtc_inference(
|
||||
x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size
|
||||
)
|
||||
v_t = self.denoise_step(
|
||||
x_t=x_t_cond,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
elif self._rtc_enabled():
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
@@ -864,11 +889,6 @@ class VLAFlowMatching(nn.Module):
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
@@ -878,7 +898,13 @@ class VLAFlowMatching(nn.Module):
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
x_t=x_t,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
timestep=time_tensor,
|
||||
)
|
||||
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
|
||||
Reference in New Issue
Block a user