diff --git a/src/lerobot/policies/xvla/configuration_xvla.py b/src/lerobot/policies/xvla/configuration_xvla.py index 08d232aff..60ebfe911 100644 --- a/src/lerobot/policies/xvla/configuration_xvla.py +++ b/src/lerobot/policies/xvla/configuration_xvla.py @@ -45,7 +45,6 @@ class XVLAConfig(PreTrainedConfig): n_obs_steps: int = 1 chunk_size: int = 32 n_action_steps: int = 32 - num_actions: int = 32 normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { @@ -99,10 +98,8 @@ class XVLAConfig(PreTrainedConfig): def __post_init__(self) -> None: super().__post_init__() - if self.num_actions <= 0: - raise ValueError("`num_actions` must be strictly positive.") - if self.chunk_size != self.num_actions: - self.chunk_size = self.num_actions + if self.chunk_size <= 0: + raise ValueError("`chunk_size` must be strictly positive.") if self.n_action_steps > self.chunk_size: raise ValueError( f"`n_action_steps` ({self.n_action_steps}) must be <= `chunk_size` ({self.chunk_size})." diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py index bad553fbe..6d1d0dbdf 100644 --- a/src/lerobot/policies/xvla/modeling_xvla.py +++ b/src/lerobot/policies/xvla/modeling_xvla.py @@ -50,7 +50,7 @@ class XVLAModel(nn.Module): ) -> None: super().__init__() self.config = config - self.num_actions: int = config.num_actions + self.chunk_size: int = config.chunk_size self.use_proprio: bool = config.use_proprio self.action_space = build_action_space(config.action_mode.lower()) self.dim_action = self.action_space.dim_action @@ -165,7 +165,7 @@ class XVLAModel(nn.Module): batch_size = input_ids.shape[0] action_dim = self.dim_action - x1 = torch.randn(batch_size, self.num_actions, action_dim, device=proprio.device, dtype=proprio.dtype) + x1 = torch.randn(batch_size, self.chunk_size, action_dim, device=proprio.device, dtype=proprio.dtype) action = torch.zeros_like(x1) steps = max(1, int(steps)) @@ -274,7 +274,7 @@ class XVLAPolicy(PreTrainedPolicy): actions = batch[ACTION] if actions.ndim == 2: actions = actions.unsqueeze(1) - actions = pad_tensor_along_dim(actions, self.config.num_actions, dim=1) + actions = pad_tensor_along_dim(actions, self.config.chunk_size, dim=1) if actions.shape[-1] != self.model.dim_action: actions = pad_vector(actions, self.model.dim_action) return actions