diff --git a/docs/source/training_time_rtc.mdx b/docs/source/training_time_rtc.mdx index f8edff6ce..7e7e64fac 100644 --- a/docs/source/training_time_rtc.mdx +++ b/docs/source/training_time_rtc.mdx @@ -12,7 +12,7 @@ LeRobot supports this for `pi0`, `pi05` and `smolvla` without changing model par ## How It Works -At training time: +### At Training Time - Sample a delay `d` per batch element. - Keep the first `d` action steps as **ground truth** (no noise). @@ -20,6 +20,13 @@ At training time: - Set the flow-matching timestep to **1.0** for prefix tokens and normal timesteps for postfix tokens. - Mask the loss to only train on the postfix. +### At Inference Time + +When `rtc_training_config.enabled=true`, the model uses training-time RTC inference: + +- Replace prefix positions in `x_t` with previous chunk's leftover actions. +- Set timestep to **1.0** for prefix positions. + --- ## Quick Start (CLI) @@ -36,11 +43,28 @@ lerobot-train \ --- +## Inference with Training-Time RTC + +After training with `rtc_training_config`, use the same config at inference. The model will automatically use training-time RTC inference: + +```python +policy = PI0Policy.from_pretrained("path/to/trained/model") +# rtc_training_config is loaded from the saved config + +actions = policy.predict_action_chunk( + batch, + inference_delay=5, # estimated delay in timesteps + prev_chunk_left_over=previous_actions, # from previous chunk +) +``` + +--- + ## Key Parameters `RTCTrainingConfig` is available on the policy config (`pi0`, `pi05`, `smolvla`, `xvla`): -- **`enabled`**: Toggle training-time RTC. +- **`enabled`**: Toggle training-time RTC (both training and inference). - **`min_delay` / `max_delay`**: Delay range (inclusive). - **`delay_distribution`**: - `UNIFORM`: uniform in `[min_delay, max_delay]` diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index bd9648d7b..3bdda1ef9 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -44,7 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.rtc.modeling_rtc import RTCProcessor -from lerobot.policies.rtc.training_time import apply_rtc_training_time, masked_mean, sample_rtc_delay +from lerobot.policies.rtc.training_time import ( + apply_rtc_training_time, + apply_training_time_rtc_inference, + masked_mean, + sample_rtc_delay, +) from lerobot.utils.constants import ( ACTION, OBS_LANGUAGE_ATTENTION_MASK, @@ -612,6 +617,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` def _rtc_enabled(self): return self.config.rtc_config is not None and self.config.rtc_config.enabled + def _training_time_rtc_inference_enabled(self): + return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled + def _apply_checkpoint(self, func, *args, **kwargs): """Helper method to apply gradient checkpointing if enabled.""" if self.gradient_checkpointing_enabled and self.training: @@ -861,24 +869,37 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` dt = -1.0 / num_steps + inference_delay = kwargs.get("inference_delay") + prev_chunk_left_over = kwargs.get("prev_chunk_left_over") + execution_horizon = kwargs.get("execution_horizon") + use_training_time_rtc = self._training_time_rtc_inference_enabled() + x_t = noise for step in range(num_steps): time = 1.0 + step * dt - time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) - def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): - return self.denoise_step( + if use_training_time_rtc: + x_t_cond, time_tensor = apply_training_time_rtc_inference( + x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size + ) + v_t = self.denoise_step( state=state, prefix_pad_masks=prefix_pad_masks, past_key_values=past_key_values, - x_t=input_x_t, - timestep=current_timestep, + x_t=x_t_cond, + timestep=time_tensor, ) + elif self._rtc_enabled(): + time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) - if self._rtc_enabled(): - inference_delay = kwargs.get("inference_delay") - prev_chunk_left_over = kwargs.get("prev_chunk_left_over") - execution_horizon = kwargs.get("execution_horizon") + def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): + return self.denoise_step( + state=state, + prefix_pad_masks=prefix_pad_masks, + past_key_values=past_key_values, + x_t=input_x_t, + timestep=current_timestep, + ) v_t = self.rtc_processor.denoise_step( x_t=x_t, @@ -889,7 +910,14 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` execution_horizon=execution_horizon, ) else: - v_t = denoise_step_partial_call(x_t) + time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) + v_t = self.denoise_step( + state=state, + prefix_pad_masks=prefix_pad_masks, + past_key_values=past_key_values, + x_t=x_t, + timestep=time_tensor, + ) x_t = x_t + dt * v_t diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 46c6b5c51..cbca282c9 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -44,7 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.rtc.modeling_rtc import RTCProcessor -from lerobot.policies.rtc.training_time import apply_rtc_training_time, masked_mean, sample_rtc_delay +from lerobot.policies.rtc.training_time import ( + apply_rtc_training_time, + apply_training_time_rtc_inference, + masked_mean, + sample_rtc_delay, +) from lerobot.utils.constants import ( ACTION, OBS_LANGUAGE_ATTENTION_MASK, @@ -609,6 +614,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` def _rtc_enabled(self): return self.config.rtc_config is not None and self.config.rtc_config.enabled + def _training_time_rtc_inference_enabled(self): + return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled + def _apply_checkpoint(self, func, *args, **kwargs): """Helper method to apply gradient checkpointing if enabled.""" if self.gradient_checkpointing_enabled and self.training: @@ -832,23 +840,35 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` dt = -1.0 / num_steps + inference_delay = kwargs.get("inference_delay") + prev_chunk_left_over = kwargs.get("prev_chunk_left_over") + execution_horizon = kwargs.get("execution_horizon") + use_training_time_rtc = self._training_time_rtc_inference_enabled() + x_t = noise for step in range(num_steps): time = 1.0 + step * dt - time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) - def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): - return self.denoise_step( + if use_training_time_rtc: + x_t_cond, time_tensor = apply_training_time_rtc_inference( + x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size + ) + v_t = self.denoise_step( prefix_pad_masks=prefix_pad_masks, past_key_values=past_key_values, - x_t=input_x_t, - timestep=current_timestep, + x_t=x_t_cond, + timestep=time_tensor, ) + elif self._rtc_enabled(): + time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) - if self._rtc_enabled(): - inference_delay = kwargs.get("inference_delay") - prev_chunk_left_over = kwargs.get("prev_chunk_left_over") - execution_horizon = kwargs.get("execution_horizon") + def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): + return self.denoise_step( + prefix_pad_masks=prefix_pad_masks, + past_key_values=past_key_values, + x_t=input_x_t, + timestep=current_timestep, + ) v_t = self.rtc_processor.denoise_step( x_t=x_t, @@ -859,7 +879,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` execution_horizon=execution_horizon, ) else: - v_t = denoise_step_partial_call(x_t) + time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) + v_t = self.denoise_step( + prefix_pad_masks=prefix_pad_masks, + past_key_values=past_key_values, + x_t=x_t, + timestep=time_tensor, + ) x_t = x_t + dt * v_t diff --git a/src/lerobot/policies/rtc/training_time.py b/src/lerobot/policies/rtc/training_time.py index 44de4a1d9..a47bd6cec 100644 --- a/src/lerobot/policies/rtc/training_time.py +++ b/src/lerobot/policies/rtc/training_time.py @@ -60,3 +60,51 @@ def masked_mean( masked = losses * mask denom = mask.sum(dim=reduce_dims).clamp_min(eps) return masked.sum(dim=reduce_dims) / denom + + +def apply_training_time_rtc_inference( + x_t: torch.Tensor, + time: float, + inference_delay: int | None, + prev_chunk_left_over: torch.Tensor | None, + chunk_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Apply training-time RTC conditioning during inference. + + Based on Algorithm 1 from "Training-Time Action Conditioning for Efficient Real-Time Chunking". + + At each denoising step: + 1. Replace prefix positions in x_t with ground truth from previous chunk + 2. Create per-token timesteps with 1.0 for prefix positions + + Args: + x_t: Current noisy actions (B, T, D) + time: Current flow matching timestep (scalar) + inference_delay: Number of prefix actions to condition on + prev_chunk_left_over: Previous chunk's leftover actions (B, T, D) + chunk_size: Total chunk size T + + Returns: + x_t_conditioned: x_t with prefix replaced by previous actions + time_per_token: Per-token timesteps (B, T) with 1.0 for prefix + """ + batch_size = x_t.shape[0] + device = x_t.device + + if inference_delay is None or inference_delay <= 0 or prev_chunk_left_over is None: + time_scalar = torch.full((batch_size,), time, device=device, dtype=torch.float32) + return x_t, time_scalar + + delay = min(inference_delay, chunk_size) + prefix_mask = torch.arange(chunk_size, device=device)[None, :] < delay + + x_t_conditioned = torch.where( + prefix_mask[:, :, None].expand_as(x_t), + prev_chunk_left_over[:, :chunk_size, :], + x_t, + ) + + time_per_token = torch.full((batch_size, chunk_size), time, device=device, dtype=torch.float32) + time_per_token = time_per_token.masked_fill(prefix_mask, 1.0) + + return x_t_conditioned, time_per_token diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 855621465..cd53ef78d 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -63,7 +63,12 @@ from typing_extensions import Unpack from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.rtc.modeling_rtc import RTCProcessor -from lerobot.policies.rtc.training_time import apply_rtc_training_time, masked_mean, sample_rtc_delay +from lerobot.policies.rtc.training_time import ( + apply_rtc_training_time, + apply_training_time_rtc_inference, + masked_mean, + sample_rtc_delay, +) from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel from lerobot.policies.utils import ( @@ -613,6 +618,9 @@ class VLAFlowMatching(nn.Module): def _rtc_enabled(self): return self.config.rtc_config is not None and self.config.rtc_config.enabled + def _training_time_rtc_inference_enabled(self): + return self.config.rtc_training_config is not None and self.config.rtc_training_config.enabled + def set_requires_grad(self): for params in self.state_proj.parameters(): params.requires_grad = self.config.train_state_proj @@ -851,23 +859,35 @@ class VLAFlowMatching(nn.Module): num_steps = self.config.num_steps dt = -1.0 / num_steps + inference_delay = kwargs.get("inference_delay") + prev_chunk_left_over = kwargs.get("prev_chunk_left_over") + execution_horizon = kwargs.get("execution_horizon") + use_training_time_rtc = self._training_time_rtc_inference_enabled() + x_t = noise for step in range(num_steps): time = 1.0 + step * dt - time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) - def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): - return self.denoise_step( - x_t=input_x_t, + if use_training_time_rtc: + x_t_cond, time_tensor = apply_training_time_rtc_inference( + x_t, time, inference_delay, prev_chunk_left_over, self.config.chunk_size + ) + v_t = self.denoise_step( + x_t=x_t_cond, prefix_pad_masks=prefix_pad_masks, past_key_values=past_key_values, - timestep=current_timestep, + timestep=time_tensor, ) + elif self._rtc_enabled(): + time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) - if self._rtc_enabled(): - inference_delay = kwargs.get("inference_delay") - prev_chunk_left_over = kwargs.get("prev_chunk_left_over") - execution_horizon = kwargs.get("execution_horizon") + def denoise_step_partial_call(input_x_t, current_timestep=time_tensor): + return self.denoise_step( + x_t=input_x_t, + prefix_pad_masks=prefix_pad_masks, + past_key_values=past_key_values, + timestep=current_timestep, + ) v_t = self.rtc_processor.denoise_step( x_t=x_t, @@ -878,7 +898,13 @@ class VLAFlowMatching(nn.Module): execution_horizon=execution_horizon, ) else: - v_t = denoise_step_partial_call(x_t) + time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize) + v_t = self.denoise_step( + x_t=x_t, + prefix_pad_masks=prefix_pad_masks, + past_key_values=past_key_values, + timestep=time_tensor, + ) x_t = x_t + dt * v_t