rename action_horizon to chunk_size

This commit is contained in:
Pepijn
2025-09-11 19:42:25 +02:00
parent b044f3104b
commit 2234b851c0
8 changed files with 20 additions and 20 deletions
+1 -1
View File
@@ -38,7 +38,7 @@ def create_and_push_model(
# Input/output dimensions
action_dim=32, # see openpi `Pi0Config`
state_dim=32,
action_horizon=50,
chunk_size=50,
n_action_steps=50,
# Image inputs, see openpi `model.py, IMAGE_KEYS`
image_keys=(
@@ -34,7 +34,7 @@ class PI05OpenPIConfig(PreTrainedConfig):
# Input / output structure
n_obs_steps: int = 1
action_horizon: int = 50 # Number of action steps to predict
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
n_action_steps: int = 50 # Number of action steps to execute
action_dim: int = 32 # Action dimension (will be padded to 32)
state_dim: int = 32 # State dimension (will be padded to 32)
@@ -87,9 +87,9 @@ class PI05OpenPIConfig(PreTrainedConfig):
super().__post_init__()
# Validate configuration
if self.n_action_steps > self.action_horizon:
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot be greater than action_horizon ({self.action_horizon})"
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
)
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
@@ -630,7 +630,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
pad_masks.append(action_time_mask)
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
att_masks += [1] + ([0] * (self.config.chunk_size - 1))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
@@ -688,7 +688,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
)
suffix_out = suffix_out[:, -self.config.action_horizon :]
suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
def action_out_proj_func(suffix_out):
@@ -711,7 +711,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
if noise is None:
# Sample noise with padded dimension (32) as expected by action_in_proj
actions_shape = (bsize, self.config.action_horizon, 32) # Use 32 for internal processing
actions_shape = (bsize, self.config.chunk_size, 32) # Use 32 for internal processing
noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
@@ -789,7 +789,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.action_horizon :]
suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
return self.action_out_proj(suffix_out)
@@ -31,7 +31,7 @@ class PI0OpenPIConfig(PreTrainedConfig):
# Input / output structure
n_obs_steps: int = 1
action_horizon: int = 50 # Number of action steps to predict
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
n_action_steps: int = 50 # Number of action steps to execute
action_dim: int = 32 # Action dimension (will be padded to 32)
state_dim: int = 32 # State dimension (will be padded to 32)
@@ -84,9 +84,9 @@ class PI0OpenPIConfig(PreTrainedConfig):
super().__post_init__()
# Validate configuration
if self.n_action_steps > self.action_horizon:
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot be greater than action_horizon ({self.action_horizon})"
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
)
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
@@ -647,7 +647,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
pad_masks.append(action_time_mask)
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
att_masks += [1] + ([0] * (self.config.chunk_size - 1))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
@@ -705,7 +705,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
)
suffix_out = suffix_out[:, -self.config.action_horizon :]
suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
def action_out_proj_func(suffix_out):
@@ -728,7 +728,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
if noise is None:
# Sample noise with padded dimension (32) as expected by action_in_proj
actions_shape = (bsize, self.config.action_horizon, 32) # Use 32 for internal processing
actions_shape = (bsize, self.config.chunk_size, 32) # Use 32 for internal processing
noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
@@ -806,7 +806,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.action_horizon :]
suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
return self.action_out_proj(suffix_out)
+2 -2
View File
@@ -87,7 +87,7 @@ def test_pi05_forward_pass():
action_dim=7,
state_dim=14,
dtype="float32",
action_horizon=16, # Shorter horizon for testing
chunk_size=16, # Shorter chunk_size for testing
n_action_steps=16, # Shorter action steps for testing
)
@@ -111,7 +111,7 @@ def test_pi05_forward_pass():
device = next(policy.parameters()).device
batch = {
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
"action": torch.randn(batch_size, config.action_horizon, 7, dtype=torch.float32, device=device),
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
+1 -1
View File
@@ -36,7 +36,7 @@ def test_policy_instantiation():
device = policy.device if hasattr(policy, "device") else "cpu"
batch = {
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
"action": torch.randn(batch_size, config.action_horizon, 7, dtype=torch.float32, device=device),
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
), # Use rand for [0,1] range
+2 -2
View File
@@ -82,7 +82,7 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
print(f" - Action expert variant: {policy.config.action_expert_variant}")
print(f" - Action dimension: {policy.config.action_dim}")
print(f" - State dimension: {policy.config.state_dim}")
print(f" - Action horizon: {policy.config.action_horizon}")
print(f" - Chunk_size: {policy.config.chunk_size}")
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
if model_name == "PI0.5":
print(f" - discrete_state_input: {policy.config.discrete_state_input}")
@@ -172,7 +172,7 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
),
"action": torch.randn(
batch_size,
policy.config.action_horizon,
policy.config.chunk_size,
policy.config.action_dim,
dtype=torch.float32,
device=device,