From 11b35dfa11faddfda77d0640b3e9a74577818d45 Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Thu, 6 Nov 2025 02:43:44 +0700 Subject: [PATCH] Right kwargs for the policy --- .../policies/smolvla/modeling_smolvla.py | 31 ++++++++++++++++--- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index a7cf150ea..61e1c1f6c 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -54,10 +54,12 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base") import math from collections import deque +from typing import TypedDict import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn +from typing_extensions import Unpack from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.rtc.modeling_rtc import RTCProcessor @@ -70,6 +72,12 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LAN from lerobot.utils.utils import get_safe_dtype +class ActionSelectKwargs(TypedDict, total=False): + inference_delay: int | None + prev_chunk_left_over: Tensor | None + execution_horizon: int | None + + def create_sinusoidal_pos_embedding( time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" ) -> Tensor: @@ -261,7 +269,9 @@ class SmolVLAPolicy(PreTrainedPolicy): def get_optim_params(self) -> dict: return self.parameters() - def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor: + def _get_action_chunk( + self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs] + ) -> Tensor: # TODO: Check if this for loop is needed. # Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch # In the case of offline inference, we have the action in the batch @@ -296,7 +306,9 @@ class SmolVLAPolicy(PreTrainedPolicy): return batch @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor: + def predict_action_chunk( + self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs] + ) -> Tensor: self.eval() batch = self._prepare_batch(batch) @@ -306,7 +318,9 @@ class SmolVLAPolicy(PreTrainedPolicy): return actions @torch.no_grad() - def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor: + def select_action( + self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs] + ) -> Tensor: """Select a single action given environment observations. This method wraps `select_actions` in order to return one action at a time for execution in the @@ -737,7 +751,14 @@ class VLAFlowMatching(nn.Module): return losses def sample_actions( - self, images, img_masks, lang_tokens, lang_masks, state, noise=None, **kwargs + self, + images, + img_masks, + lang_tokens, + lang_masks, + state, + noise=None, + **kwargs: Unpack[ActionSelectKwargs], ) -> Tensor: """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" bsize = state.shape[0] @@ -783,7 +804,7 @@ class VLAFlowMatching(nn.Module): if self._rtc_enabled(): inference_delay = kwargs.get("inference_delay") prev_chunk_left_over = kwargs.get("prev_chunk_left_over") - execution_horizon = kwargs.get("execution_horizon", self.config.rtc_config.execution_horizon) + execution_horizon = kwargs.get("execution_horizon") v_t = self.rtc_processor.denoise_step( x_t=x_t,