Right kwargs for the policy

This commit is contained in:
Eugene Mironov
2025-11-06 02:43:44 +07:00
parent b27570039c
commit 11b35dfa11
@@ -54,10 +54,12 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
import math
from collections import deque
from typing import TypedDict
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
@@ -70,6 +72,12 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LAN
from lerobot.utils.utils import get_safe_dtype
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
@@ -261,7 +269,9 @@ class SmolVLAPolicy(PreTrainedPolicy):
def get_optim_params(self) -> dict:
return self.parameters()
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
def _get_action_chunk(
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor:
# TODO: Check if this for loop is needed.
# Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch
# In the case of offline inference, we have the action in the batch
@@ -296,7 +306,9 @@ class SmolVLAPolicy(PreTrainedPolicy):
return batch
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
def predict_action_chunk(
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor:
self.eval()
batch = self._prepare_batch(batch)
@@ -306,7 +318,9 @@ class SmolVLAPolicy(PreTrainedPolicy):
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
def select_action(
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
@@ -737,7 +751,14 @@ class VLAFlowMatching(nn.Module):
return losses
def sample_actions(
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, **kwargs
self,
images,
img_masks,
lang_tokens,
lang_masks,
state,
noise=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
@@ -783,7 +804,7 @@ class VLAFlowMatching(nn.Module):
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon", self.config.rtc_config.execution_horizon)
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,