diff --git a/push_pi0_to_hub.py b/push_pi0_to_hub.py index bc238f0a5..af77d226b 100644 --- a/push_pi0_to_hub.py +++ b/push_pi0_to_hub.py @@ -38,7 +38,7 @@ def create_and_push_model( # Input/output dimensions action_dim=32, # see openpi `Pi0Config` state_dim=32, - action_horizon=50, + chunk_size=50, n_action_steps=50, # Image inputs, see openpi `model.py, IMAGE_KEYS` image_keys=( diff --git a/src/lerobot/policies/pi05_openpi/configuration_pi05openpi.py b/src/lerobot/policies/pi05_openpi/configuration_pi05openpi.py index c828d63bc..4566fe0e7 100644 --- a/src/lerobot/policies/pi05_openpi/configuration_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/configuration_pi05openpi.py @@ -34,7 +34,7 @@ class PI05OpenPIConfig(PreTrainedConfig): # Input / output structure n_obs_steps: int = 1 - action_horizon: int = 50 # Number of action steps to predict + chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon" n_action_steps: int = 50 # Number of action steps to execute action_dim: int = 32 # Action dimension (will be padded to 32) state_dim: int = 32 # State dimension (will be padded to 32) @@ -87,9 +87,9 @@ class PI05OpenPIConfig(PreTrainedConfig): super().__post_init__() # Validate configuration - if self.n_action_steps > self.action_horizon: + if self.n_action_steps > self.chunk_size: raise ValueError( - f"n_action_steps ({self.n_action_steps}) cannot be greater than action_horizon ({self.action_horizon})" + f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})" ) if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]: diff --git a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py index f33e2d5fd..a912deac2 100644 --- a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py @@ -630,7 +630,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) pad_masks.append(action_time_mask) - att_masks += [1] + ([0] * (self.config.action_horizon - 1)) + att_masks += [1] + ([0] * (self.config.chunk_size - 1)) embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) @@ -688,7 +688,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond ) - suffix_out = suffix_out[:, -self.config.action_horizon :] + suffix_out = suffix_out[:, -self.config.chunk_size :] suffix_out = suffix_out.to(dtype=torch.float32) def action_out_proj_func(suffix_out): @@ -711,7 +711,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` if noise is None: # Sample noise with padded dimension (32) as expected by action_in_proj - actions_shape = (bsize, self.config.action_horizon, 32) # Use 32 for internal processing + actions_shape = (bsize, self.config.chunk_size, 32) # Use 32 for internal processing noise = self.sample_noise(actions_shape, device) prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( @@ -789,7 +789,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` ) suffix_out = outputs_embeds[1] - suffix_out = suffix_out[:, -self.config.action_horizon :] + suffix_out = suffix_out[:, -self.config.chunk_size :] suffix_out = suffix_out.to(dtype=torch.float32) return self.action_out_proj(suffix_out) diff --git a/src/lerobot/policies/pi0_openpi/configuration_pi0openpi.py b/src/lerobot/policies/pi0_openpi/configuration_pi0openpi.py index 3750787bc..59d0085ae 100644 --- a/src/lerobot/policies/pi0_openpi/configuration_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/configuration_pi0openpi.py @@ -31,7 +31,7 @@ class PI0OpenPIConfig(PreTrainedConfig): # Input / output structure n_obs_steps: int = 1 - action_horizon: int = 50 # Number of action steps to predict + chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon" n_action_steps: int = 50 # Number of action steps to execute action_dim: int = 32 # Action dimension (will be padded to 32) state_dim: int = 32 # State dimension (will be padded to 32) @@ -84,9 +84,9 @@ class PI0OpenPIConfig(PreTrainedConfig): super().__post_init__() # Validate configuration - if self.n_action_steps > self.action_horizon: + if self.n_action_steps > self.chunk_size: raise ValueError( - f"n_action_steps ({self.n_action_steps}) cannot be greater than action_horizon ({self.action_horizon})" + f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})" ) if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]: diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index 94a2d7bf6..44d1d6a50 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -647,7 +647,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) pad_masks.append(action_time_mask) - att_masks += [1] + ([0] * (self.config.action_horizon - 1)) + att_masks += [1] + ([0] * (self.config.chunk_size - 1)) embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) @@ -705,7 +705,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond ) - suffix_out = suffix_out[:, -self.config.action_horizon :] + suffix_out = suffix_out[:, -self.config.chunk_size :] suffix_out = suffix_out.to(dtype=torch.float32) def action_out_proj_func(suffix_out): @@ -728,7 +728,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` if noise is None: # Sample noise with padded dimension (32) as expected by action_in_proj - actions_shape = (bsize, self.config.action_horizon, 32) # Use 32 for internal processing + actions_shape = (bsize, self.config.chunk_size, 32) # Use 32 for internal processing noise = self.sample_noise(actions_shape, device) prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( @@ -806,7 +806,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` ) suffix_out = outputs_embeds[1] - suffix_out = suffix_out[:, -self.config.action_horizon :] + suffix_out = suffix_out[:, -self.config.chunk_size :] suffix_out = suffix_out.to(dtype=torch.float32) return self.action_out_proj(suffix_out) diff --git a/test_pi05_openpi.py b/test_pi05_openpi.py index 05fe27732..8e51d829c 100644 --- a/test_pi05_openpi.py +++ b/test_pi05_openpi.py @@ -87,7 +87,7 @@ def test_pi05_forward_pass(): action_dim=7, state_dim=14, dtype="float32", - action_horizon=16, # Shorter horizon for testing + chunk_size=16, # Shorter chunk_size for testing n_action_steps=16, # Shorter action steps for testing ) @@ -111,7 +111,7 @@ def test_pi05_forward_pass(): device = next(policy.parameters()).device batch = { "observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device), - "action": torch.randn(batch_size, config.action_horizon, 7, dtype=torch.float32, device=device), + "action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device), "observation.images.base_0_rgb": torch.rand( batch_size, 3, 224, 224, dtype=torch.float32, device=device ), diff --git a/test_pi0_openpi.py b/test_pi0_openpi.py index ad8419f7a..cc09015d3 100644 --- a/test_pi0_openpi.py +++ b/test_pi0_openpi.py @@ -36,7 +36,7 @@ def test_policy_instantiation(): device = policy.device if hasattr(policy, "device") else "cpu" batch = { "observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device), - "action": torch.randn(batch_size, config.action_horizon, 7, dtype=torch.float32, device=device), + "action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device), "observation.images.base_0_rgb": torch.rand( batch_size, 3, 224, 224, dtype=torch.float32, device=device ), # Use rand for [0,1] range diff --git a/test_pi0_pi05_hub.py b/test_pi0_pi05_hub.py index d737bc561..a63654bbe 100644 --- a/test_pi0_pi05_hub.py +++ b/test_pi0_pi05_hub.py @@ -82,7 +82,7 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"): print(f" - Action expert variant: {policy.config.action_expert_variant}") print(f" - Action dimension: {policy.config.action_dim}") print(f" - State dimension: {policy.config.state_dim}") - print(f" - Action horizon: {policy.config.action_horizon}") + print(f" - Chunk_size: {policy.config.chunk_size}") print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}") if model_name == "PI0.5": print(f" - discrete_state_input: {policy.config.discrete_state_input}") @@ -172,7 +172,7 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"): ), "action": torch.randn( batch_size, - policy.config.action_horizon, + policy.config.chunk_size, policy.config.action_dim, dtype=torch.float32, device=device,