mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
refactor
This commit is contained in:
@@ -45,7 +45,6 @@ class XVLAConfig(PreTrainedConfig):
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 32
|
||||
n_action_steps: int = 32
|
||||
num_actions: int = 32
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
@@ -99,10 +98,8 @@ class XVLAConfig(PreTrainedConfig):
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
|
||||
if self.num_actions <= 0:
|
||||
raise ValueError("`num_actions` must be strictly positive.")
|
||||
if self.chunk_size != self.num_actions:
|
||||
self.chunk_size = self.num_actions
|
||||
if self.chunk_size <= 0:
|
||||
raise ValueError("`chunk_size` must be strictly positive.")
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"`n_action_steps` ({self.n_action_steps}) must be <= `chunk_size` ({self.chunk_size})."
|
||||
|
||||
@@ -50,7 +50,7 @@ class XVLAModel(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_actions: int = config.num_actions
|
||||
self.chunk_size: int = config.chunk_size
|
||||
self.use_proprio: bool = config.use_proprio
|
||||
self.action_space = build_action_space(config.action_mode.lower())
|
||||
self.dim_action = self.action_space.dim_action
|
||||
@@ -165,7 +165,7 @@ class XVLAModel(nn.Module):
|
||||
batch_size = input_ids.shape[0]
|
||||
action_dim = self.dim_action
|
||||
|
||||
x1 = torch.randn(batch_size, self.num_actions, action_dim, device=proprio.device, dtype=proprio.dtype)
|
||||
x1 = torch.randn(batch_size, self.chunk_size, action_dim, device=proprio.device, dtype=proprio.dtype)
|
||||
action = torch.zeros_like(x1)
|
||||
|
||||
steps = max(1, int(steps))
|
||||
@@ -274,7 +274,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
actions = batch[ACTION]
|
||||
if actions.ndim == 2:
|
||||
actions = actions.unsqueeze(1)
|
||||
actions = pad_tensor_along_dim(actions, self.config.num_actions, dim=1)
|
||||
actions = pad_tensor_along_dim(actions, self.config.chunk_size, dim=1)
|
||||
if actions.shape[-1] != self.model.dim_action:
|
||||
actions = pad_vector(actions, self.model.dim_action)
|
||||
return actions
|
||||
|
||||
Reference in New Issue
Block a user