From 455d347b490b8b6365005af6d0dbfb1e8c1dcf19 Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Mon, 3 Nov 2025 18:03:52 +0700 Subject: [PATCH] Add RTCConfig field to SmolVLAConfig MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add rtc_config as an optional field in SmolVLAConfig to properly support Real-Time Chunking configuration. This replaces the previous getattr() workarounds with direct attribute access, making the code cleaner and more maintainable. Changes: - Import RTCConfig in configuration_smolvla.py - Add rtc_config: RTCConfig | None = None field - Revert getattr() calls to direct attribute access in modeling_smolvla.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare Co-Authored-By: Claude --- .../policies/smolvla/configuration_smolvla.py | 4 +++ .../policies/smolvla/modeling_smolvla.py | 30 ++++++++++--------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/lerobot/policies/smolvla/configuration_smolvla.py b/src/lerobot/policies/smolvla/configuration_smolvla.py index eedf477a5..c32c8a60e 100644 --- a/src/lerobot/policies/smolvla/configuration_smolvla.py +++ b/src/lerobot/policies/smolvla/configuration_smolvla.py @@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) +from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.utils.constants import OBS_IMAGES @@ -102,6 +103,9 @@ class SmolVLAConfig(PreTrainedConfig): min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding max_period: float = 4.0 + # Real-Time Chunking (RTC) configuration + rtc_config: RTCConfig | None = None + def __post_init__(self): super().__post_init__() diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 2153f806a..dd49a45f7 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -257,9 +257,8 @@ class SmolVLAPolicy(PreTrainedPolicy): """ self.rtc_processor = None - rtc_config = getattr(self.config, "rtc_config", None) - if rtc_config is not None and rtc_config.enabled: - self.rtc_processor = RTCProcessor(rtc_config) + if self.config.rtc_config is not None and self.config.rtc_config.enabled: + self.rtc_processor = RTCProcessor(self.config.rtc_config) # In case of calling init_rtc_processor after the model is created # We need to set the rtc_processor to the model @@ -344,8 +343,7 @@ class SmolVLAPolicy(PreTrainedPolicy): return len(self._queues[ACTION]) == 0 def _rtc_enabled(self) -> bool: - rtc_config = getattr(self.config, "rtc_config", None) - return rtc_config is not None and rtc_config.enabled + return self.config.rtc_config is not None and self.config.rtc_config.enabled def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: """Do a full training forward pass to compute the loss""" @@ -808,11 +806,10 @@ class VLAFlowMatching(nn.Module): timestep=current_timestep, ) - rtc_config = getattr(self.config, "rtc_config", None) - if rtc_config is not None and rtc_config.enabled: + if self.config.rtc_config is not None and self.config.rtc_config.enabled: inference_delay = kwargs.get("inference_delay") prev_chunk_left_over = kwargs.get("prev_chunk_left_over") - execution_horizon = kwargs.get("execution_horizon", rtc_config.execution_horizon) + execution_horizon = kwargs.get("execution_horizon", self.config.rtc_config.execution_horizon) v_t = self.rtc_processor.denoise_step( x_t=x_t, @@ -830,8 +827,11 @@ class VLAFlowMatching(nn.Module): time += dt # Record x_t after Euler step (other params are recorded in rtc_processor.denoise_step) - rtc_config = getattr(self.config, "rtc_config", None) - if rtc_config is not None and rtc_config.enabled and correction is not None: + if ( + self.config.rtc_config is not None + and self.config.rtc_config.enabled + and correction is not None + ): self.rtc_processor.track_debug(time=time, x_t=x_t) # Visualize x_t using plot_waypoints - accumulate all denoise steps @@ -902,8 +902,7 @@ class VLAFlowMatching(nn.Module): xt_name = "smolvla_x_t_denoise_steps.png" v_name = "smolvla_v_denoise_steps.png" - rtc_config = getattr(self.config, "rtc_config", None) - if rtc_config is not None and rtc_config.enabled: + if self.config.rtc_config is not None and self.config.rtc_config.enabled: xt_name = "smolvla_x_t_with_rtc_denoise_steps.png" v_name = "smolvla_v_with_rtc_denoise_steps.png" @@ -932,8 +931,11 @@ class VLAFlowMatching(nn.Module): # Plot ground truth on provided axes if available if use_provided_axes: prev_chunk_left_over = kwargs.get("prev_chunk_left_over") - rtc_config = getattr(self.config, "rtc_config", None) - if prev_chunk_left_over is not None and rtc_config is not None and rtc_config.enabled: + if ( + prev_chunk_left_over is not None + and self.config.rtc_config is not None + and self.config.rtc_config.enabled + ): plot_waypoints( viz_xt_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth" )