From 0cb8c92fe4223bb0d1e4f0b938dcf1a692c94803 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 20 Jan 2026 20:02:10 +0100 Subject: [PATCH] Implement training time rtc for pi0, pi0.5 and smolvla --- docs/source/_toctree.yml | 2 + docs/source/training_time_rtc.mdx | 63 ++++++++++++++++++ src/lerobot/configs/types.py | 5 ++ src/lerobot/policies/pi0/configuration_pi0.py | 5 +- src/lerobot/policies/pi0/modeling_pi0.py | 47 +++++++++++--- .../policies/pi05/configuration_pi05.py | 3 +- src/lerobot/policies/pi05/modeling_pi05.py | 40 +++++++++--- src/lerobot/policies/rtc/configuration_rtc.py | 23 ++++++- src/lerobot/policies/rtc/training_time.py | 65 +++++++++++++++++++ .../policies/smolvla/configuration_smolvla.py | 5 +- .../policies/smolvla/modeling_smolvla.py | 43 +++++++++--- tests/policies/rtc/test_training_time_rtc.py | 51 +++++++++++++++ 12 files changed, 321 insertions(+), 31 deletions(-) create mode 100644 docs/source/training_time_rtc.mdx create mode 100644 src/lerobot/policies/rtc/training_time.py create mode 100644 tests/policies/rtc/test_training_time_rtc.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 2b8086cd7..b9bd9a2e2 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -57,6 +57,8 @@ title: Use Async Inference - local: rtc title: Real-Time Chunking (RTC) + - local: training_time_rtc + title: Training-Time RTC title: "Inference" - sections: - local: envhub diff --git a/docs/source/training_time_rtc.mdx b/docs/source/training_time_rtc.mdx new file mode 100644 index 000000000..3394b7f11 --- /dev/null +++ b/docs/source/training_time_rtc.mdx @@ -0,0 +1,63 @@ +# Training-Time RTC + +Training-Time RTC teaches the model to handle inference delay during training. +It feeds the **ground-truth action prefix** to the model and trains only on the remaining postfix actions. +This keeps chunk transitions smooth without doing any inference-time inpainting. + +Based on: [Training-Time Action Conditioning for Efficient Real-Time Chunking](https://arxiv.org/abs/2512.05964). + +LeRobot supports this for `pi0`, `pi05` and `smolvla` without changing model parameters. + +--- + +## How It Works + +At training time: + +- Sample a delay `d` per batch element. +- Keep the first `d` action steps as **ground truth** (no noise). +- Add noise only to the postfix actions. +- Set the flow-matching timestep to **1.0** for prefix tokens and normal timesteps for postfix tokens. +- Mask the loss to only train on the postfix. + +--- + +## Quick Start (CLI) + +```bash +lerobot-train \ + --policy.type=pi0 \ + --dataset.repo_id=your/dataset \ + --policy.rtc_training_config.enabled=true \ + --policy.rtc_training_config.min_delay=0 \ + --policy.rtc_training_config.max_delay=6 \ + --policy.rtc_training_config.delay_distribution=UNIFORM +``` + +--- + +## Key Parameters + +`RTCTrainingConfig` is available on the policy config (`pi0`, `pi05`, `smolvla`, `xvla`): + +- **`enabled`**: Toggle training-time RTC. +- **`min_delay` / `max_delay`**: Delay range (inclusive). +- **`delay_distribution`**: + - `UNIFORM`: uniform in `[min_delay, max_delay]` + - `EXP`: exponentially decayed distribution over delays +- **`exp_decay`**: Exponential decay factor for `EXP` sampling. + +--- + +## Notes and Recommendations + +- Start with `min_delay=0` and `max_delay` around your expected worst-case inference delay. +- Use `EXP` if you want more supervision on smaller delays. + +--- + +## Related Docs + +- [Real-Time Chunking (Inference-Time RTC)](./rtc) +- [Pi0](./pi0), [Pi0.5](./pi05), [SmolVLA](./smolvla), [WALL-OSS](./walloss) + diff --git a/src/lerobot/configs/types.py b/src/lerobot/configs/types.py index 18359ef05..8426afe55 100644 --- a/src/lerobot/configs/types.py +++ b/src/lerobot/configs/types.py @@ -50,3 +50,8 @@ class RTCAttentionSchedule(str, Enum): ONES = "ONES" LINEAR = "LINEAR" EXP = "EXP" + + +class RTCTrainingDelayDistribution(str, Enum): + UNIFORM = "UNIFORM" + EXP = "EXP" diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py index be9b4530f..a54ec5db0 100644 --- a/src/lerobot/policies/pi0/configuration_pi0.py +++ b/src/lerobot/policies/pi0/configuration_pi0.py @@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig -from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE DEFAULT_IMAGE_SIZE = 224 @@ -50,8 +50,9 @@ class PI0Config(PreTrainedConfig): min_period: float = 4e-3 max_period: float = 4.0 - # Real-Time Chunking (RTC) configuration + # Real-Time Chunking (RTC) configurations rtc_config: RTCConfig | None = None + rtc_training_config: RTCTrainingConfig | None = None image_resolution: tuple[int, int] = ( DEFAULT_IMAGE_SIZE, diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 58b5dc07b..b4435eccc 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -44,6 +44,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.rtc.modeling_rtc import RTCProcessor +from lerobot.policies.rtc.training_time import apply_rtc_training_time, masked_mean, sample_rtc_delay from lerobot.utils.constants import ( ACTION, OBS_LANGUAGE_ATTENTION_MASK, @@ -79,8 +80,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd if dimension % 2 != 0: raise ValueError(f"dimension ({dimension}) must be divisible by 2") - if time.ndim != 1: - raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + if time.ndim not in (1, 2): + raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.") dtype = get_safe_dtype(torch.float64, device.type) fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) @@ -88,8 +89,14 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd # Compute the outer product scaling_factor = 1.0 / period * 2 * math.pi - sin_input = scaling_factor[None, :] * time[:, None] - return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + if time.ndim == 1: + sin_input = scaling_factor[None, :] * time[:, None] + return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + + time_flat = time.reshape(-1) + sin_input = scaling_factor[None, :] * time_flat[:, None] + pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + return pos_emb.reshape(*time.shape, dimension) def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy) @@ -714,7 +721,12 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) - time_emb = time_emb[:, None, :].expand_as(action_emb) + if time_emb.dim() == 2: + time_emb = time_emb[:, None, :].expand_as(action_emb) + elif time_emb.shape[:2] != action_emb.shape[:2]: + raise ValueError( + f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}" + ) action_time_emb = torch.cat([action_emb, time_emb], dim=2) def mlp_func(action_time_emb): @@ -750,7 +762,12 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` if time is None: time = self.sample_time(actions.shape[0], actions.device) - time_expanded = time[:, None, None] + if time.ndim == 1: + time_expanded = time[:, None, None] + elif time.ndim == 2: + time_expanded = time[:, :, None] + else: + raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}") x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions @@ -1277,7 +1294,19 @@ class PI0Policy(PreTrainedPolicy): actions = self.prepare_action(batch) # Compute loss - losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions) + postfix_mask = None + rtc_cfg = self.config.rtc_training_config + if rtc_cfg is not None and rtc_cfg.enabled and self.training: + batch_size = actions.shape[0] + time = self.model.sample_time(batch_size, actions.device) + noise = self.model.sample_noise(actions.shape, actions.device) + delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device) + time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1]) + losses = self.model.forward( + images, img_masks, lang_tokens, lang_masks, state, actions, noise=noise, time=time + ) + else: + losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions) # Truncate losses to actual action dimensions original_action_dim = self.config.output_features[ACTION].shape[0] @@ -1289,12 +1318,12 @@ class PI0Policy(PreTrainedPolicy): if reduction == "none": # Return per-sample losses (B,) by averaging over time and action dims - per_sample_loss = losses.mean(dim=(1, 2)) + per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2)) loss_dict["loss"] = per_sample_loss.mean().item() return per_sample_loss, loss_dict else: # Default: return scalar mean loss - loss = losses.mean() + loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2)) loss_dict["loss"] = loss.item() return loss, loss_dict diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py index b96e6d196..f8be7c8bb 100644 --- a/src/lerobot/policies/pi05/configuration_pi05.py +++ b/src/lerobot/policies/pi05/configuration_pi05.py @@ -20,7 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig -from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE DEFAULT_IMAGE_SIZE = 224 @@ -52,6 +52,7 @@ class PI05Config(PreTrainedConfig): # Real-Time Chunking (RTC) configuration rtc_config: RTCConfig | None = None + rtc_training_config: RTCTrainingConfig | None = None image_resolution: tuple[int, int] = ( DEFAULT_IMAGE_SIZE, diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 104ec63bf..bdf9f272e 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -44,6 +44,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.rtc.modeling_rtc import RTCProcessor +from lerobot.policies.rtc.training_time import apply_rtc_training_time, masked_mean, sample_rtc_delay from lerobot.utils.constants import ( ACTION, OBS_LANGUAGE_ATTENTION_MASK, @@ -78,8 +79,8 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd if dimension % 2 != 0: raise ValueError(f"dimension ({dimension}) must be divisible by 2") - if time.ndim != 1: - raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + if time.ndim not in (1, 2): + raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.") dtype = get_safe_dtype(torch.float64, device.type) fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) @@ -87,8 +88,14 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd # Compute the outer product scaling_factor = 1.0 / period * 2 * math.pi - sin_input = scaling_factor[None, :] * time[:, None] - return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + if time.ndim == 1: + sin_input = scaling_factor[None, :] * time[:, None] + return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + + time_flat = time.reshape(-1) + sin_input = scaling_factor[None, :] * time_flat[:, None] + pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + return pos_emb.reshape(*time.shape, dimension) def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy) @@ -729,7 +736,12 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` if time is None: time = self.sample_time(actions.shape[0], actions.device) - time_expanded = time[:, None, None] + if time.ndim == 1: + time_expanded = time[:, None, None] + elif time.ndim == 2: + time_expanded = time[:, :, None] + else: + raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}") x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions @@ -1250,7 +1262,19 @@ class PI05Policy(PreTrainedPolicy): actions = self.prepare_action(batch) # Compute loss (no separate state needed for PI05) - losses = self.model.forward(images, img_masks, tokens, masks, actions) + postfix_mask = None + rtc_cfg = self.config.rtc_training_config + if rtc_cfg is not None and rtc_cfg.enabled and self.training: + batch_size = actions.shape[0] + time = self.model.sample_time(batch_size, actions.device) + noise = self.model.sample_noise(actions.shape, actions.device) + delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device) + time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1]) + losses = self.model.forward( + images, img_masks, tokens, masks, actions, noise=noise, time=time + ) + else: + losses = self.model.forward(images, img_masks, tokens, masks, actions) # Truncate losses to actual action dimensions original_action_dim = self.config.output_features[ACTION].shape[0] @@ -1262,12 +1286,12 @@ class PI05Policy(PreTrainedPolicy): if reduction == "none": # Return per-sample losses (B,) by averaging over time and action dims - per_sample_loss = losses.mean(dim=(1, 2)) + per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2)) loss_dict["loss"] = per_sample_loss.mean().item() return per_sample_loss, loss_dict else: # Default: return scalar mean loss - loss = losses.mean() + loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2)) loss_dict["loss"] = loss.item() return loss, loss_dict diff --git a/src/lerobot/policies/rtc/configuration_rtc.py b/src/lerobot/policies/rtc/configuration_rtc.py index 70a8dfb09..b7e55723c 100644 --- a/src/lerobot/policies/rtc/configuration_rtc.py +++ b/src/lerobot/policies/rtc/configuration_rtc.py @@ -23,7 +23,7 @@ Based on: from dataclasses import dataclass -from lerobot.configs.types import RTCAttentionSchedule +from lerobot.configs.types import RTCAttentionSchedule, RTCTrainingDelayDistribution @dataclass @@ -53,3 +53,24 @@ class RTCConfig: raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}") if self.debug_maxlen <= 0: raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}") + + +@dataclass +class RTCTrainingConfig: + """Configuration for training-time RTC action prefix conditioning.""" + + enabled: bool = False + min_delay: int = 0 + max_delay: int = 0 + delay_distribution: RTCTrainingDelayDistribution = RTCTrainingDelayDistribution.UNIFORM + exp_decay: float = 1.0 + + def __post_init__(self): + if self.min_delay < 0: + raise ValueError(f"min_delay must be >= 0, got {self.min_delay}") + if self.max_delay < self.min_delay: + raise ValueError( + f"max_delay ({self.max_delay}) must be >= min_delay ({self.min_delay})" + ) + if self.exp_decay <= 0: + raise ValueError(f"exp_decay must be positive, got {self.exp_decay}") diff --git a/src/lerobot/policies/rtc/training_time.py b/src/lerobot/policies/rtc/training_time.py new file mode 100644 index 000000000..b6fb05ec6 --- /dev/null +++ b/src/lerobot/policies/rtc/training_time.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch + +from lerobot.configs.types import RTCTrainingDelayDistribution +from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig + + +def sample_rtc_delay(cfg: RTCTrainingConfig, batch_size: int, device: torch.device) -> torch.Tensor: + if cfg.max_delay == cfg.min_delay: + return torch.full((batch_size,), cfg.min_delay, device=device, dtype=torch.long) + + if cfg.delay_distribution == RTCTrainingDelayDistribution.UNIFORM: + return torch.randint( + cfg.min_delay, cfg.max_delay + 1, (batch_size,), device=device, dtype=torch.long + ) + + delay_values = torch.arange(cfg.min_delay, cfg.max_delay + 1, device=device, dtype=torch.long) + weights = torch.exp(-cfg.exp_decay * delay_values.to(dtype=torch.float32)) + probs = weights / weights.sum() + samples = torch.multinomial(probs, batch_size, replacement=True) + return delay_values[samples] + + +def apply_rtc_training_time( + time: torch.Tensor, delay: torch.Tensor, seq_len: int +) -> tuple[torch.Tensor, torch.Tensor]: + device = time.device + delay = torch.clamp(delay, max=seq_len) + prefix_mask = torch.arange(seq_len, device=device)[None, :] < delay[:, None] + time_tokens = time[:, None].expand(-1, seq_len) + time_tokens = time_tokens.masked_fill(prefix_mask, 0.0) + postfix_mask = ~prefix_mask + return time_tokens, postfix_mask + + +def masked_mean( + losses: torch.Tensor, mask: torch.Tensor | None, reduce_dims: tuple[int, ...], eps: float = 1e-8 +) -> torch.Tensor: + if mask is None: + return losses.mean(dim=reduce_dims) + + mask = mask.to(dtype=losses.dtype) + while mask.dim() < losses.dim(): + mask = mask.unsqueeze(-1) + masked = losses * mask + denom = mask.sum(dim=reduce_dims).clamp_min(eps) + return masked.sum(dim=reduce_dims) / denom + diff --git a/src/lerobot/policies/smolvla/configuration_smolvla.py b/src/lerobot/policies/smolvla/configuration_smolvla.py index c32c8a60e..fa773213d 100644 --- a/src/lerobot/policies/smolvla/configuration_smolvla.py +++ b/src/lerobot/policies/smolvla/configuration_smolvla.py @@ -20,7 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) -from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.policies.rtc.configuration_rtc import RTCConfig, RTCTrainingConfig from lerobot.utils.constants import OBS_IMAGES @@ -103,8 +103,9 @@ class SmolVLAConfig(PreTrainedConfig): min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding max_period: float = 4.0 - # Real-Time Chunking (RTC) configuration + # Real-Time Chunking (RTC) configurations rtc_config: RTCConfig | None = None + rtc_training_config: RTCTrainingConfig | None = None def __post_init__(self): super().__post_init__() diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index c611e9ba2..a1b66991d 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -63,6 +63,7 @@ from typing_extensions import Unpack from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.rtc.modeling_rtc import RTCProcessor +from lerobot.policies.rtc.training_time import apply_rtc_training_time, masked_mean, sample_rtc_delay from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel from lerobot.policies.utils import ( @@ -85,8 +86,8 @@ def create_sinusoidal_pos_embedding( if dimension % 2 != 0: raise ValueError(f"dimension ({dimension}) must be divisible by 2") - if time.ndim != 1: - raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + if time.ndim not in (1, 2): + raise ValueError("The time tensor is expected to be of shape `(batch_size,)` or `(batch_size, T)`.") dtype = get_safe_dtype(torch.float64, device.type) fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) @@ -94,9 +95,14 @@ def create_sinusoidal_pos_embedding( # Compute the outer product scaling_factor = 1.0 / period * 2 * math.pi - sin_input = scaling_factor[None, :] * time[:, None] + if time.ndim == 1: + sin_input = scaling_factor[None, :] * time[:, None] + return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + + time_flat = time.reshape(-1) + sin_input = scaling_factor[None, :] * time_flat[:, None] pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) - return pos_emb + return pos_emb.reshape(*time.shape, dimension) def make_att_2d_masks(pad_masks, att_masks): @@ -375,6 +381,16 @@ class SmolVLAPolicy(PreTrainedPolicy): lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] actions = self.prepare_action(batch) + postfix_mask = None + rtc_cfg = self.config.rtc_training_config + if rtc_cfg is not None and rtc_cfg.enabled and self.training: + batch_size = actions.shape[0] + if time is None: + time = self.model.sample_time(batch_size, actions.device) + if noise is None: + noise = self.model.sample_noise(actions.shape, actions.device) + delay = sample_rtc_delay(rtc_cfg, batch_size, actions.device) + time, postfix_mask = apply_rtc_training_time(time, delay, actions.shape[1]) actions_is_pad = batch.get("actions_id_pad") loss_dict = {} losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) @@ -384,6 +400,7 @@ class SmolVLAPolicy(PreTrainedPolicy): in_episode_bound = ~actions_is_pad losses = losses * in_episode_bound.unsqueeze(-1) loss_dict["losses_after_in_ep_bound"] = losses.clone() + postfix_mask = in_episode_bound if postfix_mask is None else (postfix_mask & in_episode_bound) # Remove padding losses = losses[:, :, : self.config.max_action_dim] @@ -391,12 +408,12 @@ class SmolVLAPolicy(PreTrainedPolicy): if reduction == "none": # Return per-sample losses (B,) by averaging over time and action dims - per_sample_loss = losses.mean(dim=(1, 2)) + per_sample_loss = masked_mean(losses, postfix_mask, reduce_dims=(1, 2)) loss_dict["loss"] = per_sample_loss.mean().item() return per_sample_loss, loss_dict else: # Default: return scalar mean loss - loss = losses.mean() + loss = masked_mean(losses, postfix_mask, reduce_dims=(0, 1, 2)) loss_dict["loss"] = loss.item() return loss, loss_dict @@ -731,7 +748,12 @@ class VLAFlowMatching(nn.Module): ) time_emb = time_emb.type(dtype=dtype) - time_emb = time_emb[:, None, :].expand_as(action_emb) + if time_emb.dim() == 2: + time_emb = time_emb[:, None, :].expand_as(action_emb) + elif time_emb.shape[:2] != action_emb.shape[:2]: + raise ValueError( + f"Expected time_emb shape {action_emb.shape[:2]}, got {time_emb.shape[:2]}" + ) action_time_emb = torch.cat([action_emb, time_emb], dim=2) action_time_emb = self.action_time_mlp_in(action_time_emb) @@ -763,7 +785,12 @@ class VLAFlowMatching(nn.Module): if time is None: time = self.sample_time(actions.shape[0], actions.device) - time_expanded = time[:, None, None] + if time.ndim == 1: + time_expanded = time[:, None, None] + elif time.ndim == 2: + time_expanded = time[:, :, None] + else: + raise ValueError(f"Expected time shape (B,) or (B, T), got {time.shape}") x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( diff --git a/tests/policies/rtc/test_training_time_rtc.py b/tests/policies/rtc/test_training_time_rtc.py new file mode 100644 index 000000000..22d208895 --- /dev/null +++ b/tests/policies/rtc/test_training_time_rtc.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for training-time RTC helpers.""" + +import torch + +from lerobot.configs.types import RTCTrainingDelayDistribution +from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig +from lerobot.policies.rtc.training_time import apply_rtc_training_time, sample_rtc_delay + + +def test_rtc_training_config_defaults(): + config = RTCTrainingConfig() + assert config.enabled is False + assert config.min_delay == 0 + assert config.max_delay == 0 + assert config.delay_distribution == RTCTrainingDelayDistribution.UNIFORM + assert config.exp_decay == 1.0 + + +def test_sample_rtc_delay_uniform_range(): + cfg = RTCTrainingConfig(enabled=True, min_delay=1, max_delay=4) + delays = sample_rtc_delay(cfg, batch_size=100, device=torch.device("cpu")) + assert delays.min().item() >= 1 + assert delays.max().item() <= 4 + + +def test_apply_rtc_training_time_prefix_mask(): + time = torch.tensor([0.5]) + delays = torch.tensor([2]) + time_tokens, postfix_mask = apply_rtc_training_time(time, delays, seq_len=4) + assert time_tokens.shape == (1, 4) + assert postfix_mask.shape == (1, 4) + # Delay=2 means the first two steps are prefix (time forced to 0.0) and only the last two are postfix. + assert torch.allclose(time_tokens[0], torch.tensor([0.0, 0.0, 0.5, 0.5])) + assert torch.equal(postfix_mask[0], torch.tensor([False, False, True, True])) +