mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
refactor
This commit is contained in:
@@ -45,7 +45,6 @@ class XVLAConfig(PreTrainedConfig):
|
|||||||
n_obs_steps: int = 1
|
n_obs_steps: int = 1
|
||||||
chunk_size: int = 32
|
chunk_size: int = 32
|
||||||
n_action_steps: int = 32
|
n_action_steps: int = 32
|
||||||
num_actions: int = 32
|
|
||||||
|
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
@@ -99,10 +98,8 @@ class XVLAConfig(PreTrainedConfig):
|
|||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
if self.num_actions <= 0:
|
if self.chunk_size <= 0:
|
||||||
raise ValueError("`num_actions` must be strictly positive.")
|
raise ValueError("`chunk_size` must be strictly positive.")
|
||||||
if self.chunk_size != self.num_actions:
|
|
||||||
self.chunk_size = self.num_actions
|
|
||||||
if self.n_action_steps > self.chunk_size:
|
if self.n_action_steps > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`n_action_steps` ({self.n_action_steps}) must be <= `chunk_size` ({self.chunk_size})."
|
f"`n_action_steps` ({self.n_action_steps}) must be <= `chunk_size` ({self.chunk_size})."
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class XVLAModel(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.num_actions: int = config.num_actions
|
self.chunk_size: int = config.chunk_size
|
||||||
self.use_proprio: bool = config.use_proprio
|
self.use_proprio: bool = config.use_proprio
|
||||||
self.action_space = build_action_space(config.action_mode.lower())
|
self.action_space = build_action_space(config.action_mode.lower())
|
||||||
self.dim_action = self.action_space.dim_action
|
self.dim_action = self.action_space.dim_action
|
||||||
@@ -165,7 +165,7 @@ class XVLAModel(nn.Module):
|
|||||||
batch_size = input_ids.shape[0]
|
batch_size = input_ids.shape[0]
|
||||||
action_dim = self.dim_action
|
action_dim = self.dim_action
|
||||||
|
|
||||||
x1 = torch.randn(batch_size, self.num_actions, action_dim, device=proprio.device, dtype=proprio.dtype)
|
x1 = torch.randn(batch_size, self.chunk_size, action_dim, device=proprio.device, dtype=proprio.dtype)
|
||||||
action = torch.zeros_like(x1)
|
action = torch.zeros_like(x1)
|
||||||
|
|
||||||
steps = max(1, int(steps))
|
steps = max(1, int(steps))
|
||||||
@@ -274,7 +274,7 @@ class XVLAPolicy(PreTrainedPolicy):
|
|||||||
actions = batch[ACTION]
|
actions = batch[ACTION]
|
||||||
if actions.ndim == 2:
|
if actions.ndim == 2:
|
||||||
actions = actions.unsqueeze(1)
|
actions = actions.unsqueeze(1)
|
||||||
actions = pad_tensor_along_dim(actions, self.config.num_actions, dim=1)
|
actions = pad_tensor_along_dim(actions, self.config.chunk_size, dim=1)
|
||||||
if actions.shape[-1] != self.model.dim_action:
|
if actions.shape[-1] != self.model.dim_action:
|
||||||
actions = pad_vector(actions, self.model.dim_action)
|
actions = pad_vector(actions, self.model.dim_action)
|
||||||
return actions
|
return actions
|
||||||
|
|||||||
Reference in New Issue
Block a user