mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 08:09:45 +00:00
rename action_horizon to chunk_size
This commit is contained in:
+1
-1
@@ -38,7 +38,7 @@ def create_and_push_model(
|
||||
# Input/output dimensions
|
||||
action_dim=32, # see openpi `Pi0Config`
|
||||
state_dim=32,
|
||||
action_horizon=50,
|
||||
chunk_size=50,
|
||||
n_action_steps=50,
|
||||
# Image inputs, see openpi `model.py, IMAGE_KEYS`
|
||||
image_keys=(
|
||||
|
||||
@@ -34,7 +34,7 @@ class PI05OpenPIConfig(PreTrainedConfig):
|
||||
|
||||
# Input / output structure
|
||||
n_obs_steps: int = 1
|
||||
action_horizon: int = 50 # Number of action steps to predict
|
||||
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
|
||||
n_action_steps: int = 50 # Number of action steps to execute
|
||||
action_dim: int = 32 # Action dimension (will be padded to 32)
|
||||
state_dim: int = 32 # State dimension (will be padded to 32)
|
||||
@@ -87,9 +87,9 @@ class PI05OpenPIConfig(PreTrainedConfig):
|
||||
super().__post_init__()
|
||||
|
||||
# Validate configuration
|
||||
if self.n_action_steps > self.action_horizon:
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than action_horizon ({self.action_horizon})"
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
|
||||
@@ -630,7 +630,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
|
||||
att_masks += [1] + ([0] * (self.config.chunk_size - 1))
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
@@ -688,7 +688,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||
)
|
||||
|
||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
|
||||
def action_out_proj_func(suffix_out):
|
||||
@@ -711,7 +711,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
if noise is None:
|
||||
# Sample noise with padded dimension (32) as expected by action_in_proj
|
||||
actions_shape = (bsize, self.config.action_horizon, 32) # Use 32 for internal processing
|
||||
actions_shape = (bsize, self.config.chunk_size, 32) # Use 32 for internal processing
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
@@ -789,7 +789,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
return self.action_out_proj(suffix_out)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class PI0OpenPIConfig(PreTrainedConfig):
|
||||
|
||||
# Input / output structure
|
||||
n_obs_steps: int = 1
|
||||
action_horizon: int = 50 # Number of action steps to predict
|
||||
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
|
||||
n_action_steps: int = 50 # Number of action steps to execute
|
||||
action_dim: int = 32 # Action dimension (will be padded to 32)
|
||||
state_dim: int = 32 # State dimension (will be padded to 32)
|
||||
@@ -84,9 +84,9 @@ class PI0OpenPIConfig(PreTrainedConfig):
|
||||
super().__post_init__()
|
||||
|
||||
# Validate configuration
|
||||
if self.n_action_steps > self.action_horizon:
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than action_horizon ({self.action_horizon})"
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
|
||||
@@ -647,7 +647,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
|
||||
att_masks += [1] + ([0] * (self.config.chunk_size - 1))
|
||||
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
@@ -705,7 +705,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||
)
|
||||
|
||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
|
||||
def action_out_proj_func(suffix_out):
|
||||
@@ -728,7 +728,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
if noise is None:
|
||||
# Sample noise with padded dimension (32) as expected by action_in_proj
|
||||
actions_shape = (bsize, self.config.action_horizon, 32) # Use 32 for internal processing
|
||||
actions_shape = (bsize, self.config.chunk_size, 32) # Use 32 for internal processing
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
@@ -806,7 +806,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
return self.action_out_proj(suffix_out)
|
||||
|
||||
|
||||
+2
-2
@@ -87,7 +87,7 @@ def test_pi05_forward_pass():
|
||||
action_dim=7,
|
||||
state_dim=14,
|
||||
dtype="float32",
|
||||
action_horizon=16, # Shorter horizon for testing
|
||||
chunk_size=16, # Shorter chunk_size for testing
|
||||
n_action_steps=16, # Shorter action steps for testing
|
||||
)
|
||||
|
||||
@@ -111,7 +111,7 @@ def test_pi05_forward_pass():
|
||||
device = next(policy.parameters()).device
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(batch_size, config.action_horizon, 7, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
),
|
||||
|
||||
+1
-1
@@ -36,7 +36,7 @@ def test_policy_instantiation():
|
||||
device = policy.device if hasattr(policy, "device") else "cpu"
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(batch_size, config.action_horizon, 7, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(
|
||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||
), # Use rand for [0,1] range
|
||||
|
||||
@@ -82,7 +82,7 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
|
||||
print(f" - Action expert variant: {policy.config.action_expert_variant}")
|
||||
print(f" - Action dimension: {policy.config.action_dim}")
|
||||
print(f" - State dimension: {policy.config.state_dim}")
|
||||
print(f" - Action horizon: {policy.config.action_horizon}")
|
||||
print(f" - Chunk_size: {policy.config.chunk_size}")
|
||||
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
|
||||
if model_name == "PI0.5":
|
||||
print(f" - discrete_state_input: {policy.config.discrete_state_input}")
|
||||
@@ -172,7 +172,7 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
|
||||
),
|
||||
"action": torch.randn(
|
||||
batch_size,
|
||||
policy.config.action_horizon,
|
||||
policy.config.chunk_size,
|
||||
policy.config.action_dim,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
|
||||
Reference in New Issue
Block a user