This commit is contained in:
Jade Choghari
2025-11-17 16:08:51 +01:00
parent 9896ba4ee4
commit a6404f61e1
2 changed files with 5 additions and 8 deletions
@@ -45,7 +45,6 @@ class XVLAConfig(PreTrainedConfig):
n_obs_steps: int = 1
chunk_size: int = 32
n_action_steps: int = 32
num_actions: int = 32
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
@@ -99,10 +98,8 @@ class XVLAConfig(PreTrainedConfig):
def __post_init__(self) -> None:
super().__post_init__()
if self.num_actions <= 0:
raise ValueError("`num_actions` must be strictly positive.")
if self.chunk_size != self.num_actions:
self.chunk_size = self.num_actions
if self.chunk_size <= 0:
raise ValueError("`chunk_size` must be strictly positive.")
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"`n_action_steps` ({self.n_action_steps}) must be <= `chunk_size` ({self.chunk_size})."
+3 -3
View File
@@ -50,7 +50,7 @@ class XVLAModel(nn.Module):
) -> None:
super().__init__()
self.config = config
self.num_actions: int = config.num_actions
self.chunk_size: int = config.chunk_size
self.use_proprio: bool = config.use_proprio
self.action_space = build_action_space(config.action_mode.lower())
self.dim_action = self.action_space.dim_action
@@ -165,7 +165,7 @@ class XVLAModel(nn.Module):
batch_size = input_ids.shape[0]
action_dim = self.dim_action
x1 = torch.randn(batch_size, self.num_actions, action_dim, device=proprio.device, dtype=proprio.dtype)
x1 = torch.randn(batch_size, self.chunk_size, action_dim, device=proprio.device, dtype=proprio.dtype)
action = torch.zeros_like(x1)
steps = max(1, int(steps))
@@ -274,7 +274,7 @@ class XVLAPolicy(PreTrainedPolicy):
actions = batch[ACTION]
if actions.ndim == 2:
actions = actions.unsqueeze(1)
actions = pad_tensor_along_dim(actions, self.config.num_actions, dim=1)
actions = pad_tensor_along_dim(actions, self.config.chunk_size, dim=1)
if actions.shape[-1] != self.model.dim_action:
actions = pad_vector(actions, self.model.dim_action)
return actions