Add RTCConfig field to SmolVLAConfig

Add rtc_config as an optional field in SmolVLAConfig to properly
support Real-Time Chunking configuration. This replaces the previous
getattr() workarounds with direct attribute access, making the code
cleaner and more maintainable.

Changes:
- Import RTCConfig in configuration_smolvla.py
- Add rtc_config: RTCConfig | None = None field
- Revert getattr() calls to direct attribute access in modeling_smolvla.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Eugene Mironov
2025-11-03 18:03:52 +07:00
parent c835f03478
commit 455d347b49
2 changed files with 20 additions and 14 deletions
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.utils.constants import OBS_IMAGES
@@ -102,6 +103,9 @@ class SmolVLAConfig(PreTrainedConfig):
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
max_period: float = 4.0
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
def __post_init__(self):
super().__post_init__()
@@ -257,9 +257,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
"""
self.rtc_processor = None
rtc_config = getattr(self.config, "rtc_config", None)
if rtc_config is not None and rtc_config.enabled:
self.rtc_processor = RTCProcessor(rtc_config)
if self.config.rtc_config is not None and self.config.rtc_config.enabled:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
# In case of calling init_rtc_processor after the model is created
# We need to set the rtc_processor to the model
@@ -344,8 +343,7 @@ class SmolVLAPolicy(PreTrainedPolicy):
return len(self._queues[ACTION]) == 0
def _rtc_enabled(self) -> bool:
rtc_config = getattr(self.config, "rtc_config", None)
return rtc_config is not None and rtc_config.enabled
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
"""Do a full training forward pass to compute the loss"""
@@ -808,11 +806,10 @@ class VLAFlowMatching(nn.Module):
timestep=current_timestep,
)
rtc_config = getattr(self.config, "rtc_config", None)
if rtc_config is not None and rtc_config.enabled:
if self.config.rtc_config is not None and self.config.rtc_config.enabled:
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon", rtc_config.execution_horizon)
execution_horizon = kwargs.get("execution_horizon", self.config.rtc_config.execution_horizon)
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
@@ -830,8 +827,11 @@ class VLAFlowMatching(nn.Module):
time += dt
# Record x_t after Euler step (other params are recorded in rtc_processor.denoise_step)
rtc_config = getattr(self.config, "rtc_config", None)
if rtc_config is not None and rtc_config.enabled and correction is not None:
if (
self.config.rtc_config is not None
and self.config.rtc_config.enabled
and correction is not None
):
self.rtc_processor.track_debug(time=time, x_t=x_t)
# Visualize x_t using plot_waypoints - accumulate all denoise steps
@@ -902,8 +902,7 @@ class VLAFlowMatching(nn.Module):
xt_name = "smolvla_x_t_denoise_steps.png"
v_name = "smolvla_v_denoise_steps.png"
rtc_config = getattr(self.config, "rtc_config", None)
if rtc_config is not None and rtc_config.enabled:
if self.config.rtc_config is not None and self.config.rtc_config.enabled:
xt_name = "smolvla_x_t_with_rtc_denoise_steps.png"
v_name = "smolvla_v_with_rtc_denoise_steps.png"
@@ -932,8 +931,11 @@ class VLAFlowMatching(nn.Module):
# Plot ground truth on provided axes if available
if use_provided_axes:
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
rtc_config = getattr(self.config, "rtc_config", None)
if prev_chunk_left_over is not None and rtc_config is not None and rtc_config.enabled:
if (
prev_chunk_left_over is not None
and self.config.rtc_config is not None
and self.config.rtc_config.enabled
):
plot_waypoints(
viz_xt_axs, prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)