mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
Merge branch 'feat/add_pi' into feat/validate_pi_libero
This commit is contained in:
+1
-1
@@ -38,7 +38,7 @@ def create_and_push_model(
|
|||||||
# Input/output dimensions
|
# Input/output dimensions
|
||||||
action_dim=32, # see openpi `Pi0Config`
|
action_dim=32, # see openpi `Pi0Config`
|
||||||
state_dim=32,
|
state_dim=32,
|
||||||
action_horizon=50,
|
chunk_size=50,
|
||||||
n_action_steps=50,
|
n_action_steps=50,
|
||||||
# Image inputs, see openpi `model.py, IMAGE_KEYS`
|
# Image inputs, see openpi `model.py, IMAGE_KEYS`
|
||||||
image_keys=(
|
image_keys=(
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ class PI05OpenPIConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
# Input / output structure
|
# Input / output structure
|
||||||
n_obs_steps: int = 1
|
n_obs_steps: int = 1
|
||||||
action_horizon: int = 50 # Number of action steps to predict
|
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
|
||||||
n_action_steps: int = 50 # Number of action steps to execute
|
n_action_steps: int = 50 # Number of action steps to execute
|
||||||
action_dim: int = 32 # Action dimension (will be padded to 32)
|
action_dim: int = 32 # Action dimension (will be padded to 32)
|
||||||
state_dim: int = 32 # State dimension (will be padded to 32)
|
state_dim: int = 32 # State dimension (will be padded to 32)
|
||||||
@@ -87,9 +87,9 @@ class PI05OpenPIConfig(PreTrainedConfig):
|
|||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
# Validate configuration
|
# Validate configuration
|
||||||
if self.n_action_steps > self.action_horizon:
|
if self.n_action_steps > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than action_horizon ({self.action_horizon})"
|
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
|
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
|
||||||
|
|||||||
@@ -630,7 +630,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
|
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
|
||||||
pad_masks.append(action_time_mask)
|
pad_masks.append(action_time_mask)
|
||||||
|
|
||||||
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
|
att_masks += [1] + ([0] * (self.config.chunk_size - 1))
|
||||||
|
|
||||||
embs = torch.cat(embs, dim=1)
|
embs = torch.cat(embs, dim=1)
|
||||||
pad_masks = torch.cat(pad_masks, dim=1)
|
pad_masks = torch.cat(pad_masks, dim=1)
|
||||||
@@ -688,7 +688,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||||
)
|
)
|
||||||
|
|
||||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||||
|
|
||||||
def action_out_proj_func(suffix_out):
|
def action_out_proj_func(suffix_out):
|
||||||
@@ -711,7 +711,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
if noise is None:
|
if noise is None:
|
||||||
# Sample noise with padded dimension (32) as expected by action_in_proj
|
# Sample noise with padded dimension (32) as expected by action_in_proj
|
||||||
actions_shape = (bsize, self.config.action_horizon, 32) # Use 32 for internal processing
|
actions_shape = (bsize, self.config.chunk_size, 32) # Use 32 for internal processing
|
||||||
noise = self.sample_noise(actions_shape, device)
|
noise = self.sample_noise(actions_shape, device)
|
||||||
|
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||||
@@ -789,7 +789,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
)
|
)
|
||||||
|
|
||||||
suffix_out = outputs_embeds[1]
|
suffix_out = outputs_embeds[1]
|
||||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||||
return self.action_out_proj(suffix_out)
|
return self.action_out_proj(suffix_out)
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class PI0OpenPIConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
# Input / output structure
|
# Input / output structure
|
||||||
n_obs_steps: int = 1
|
n_obs_steps: int = 1
|
||||||
action_horizon: int = 50 # Number of action steps to predict
|
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
|
||||||
n_action_steps: int = 50 # Number of action steps to execute
|
n_action_steps: int = 50 # Number of action steps to execute
|
||||||
action_dim: int = 32 # Action dimension (will be padded to 32)
|
action_dim: int = 32 # Action dimension (will be padded to 32)
|
||||||
state_dim: int = 32 # State dimension (will be padded to 32)
|
state_dim: int = 32 # State dimension (will be padded to 32)
|
||||||
@@ -84,9 +84,9 @@ class PI0OpenPIConfig(PreTrainedConfig):
|
|||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
# Validate configuration
|
# Validate configuration
|
||||||
if self.n_action_steps > self.action_horizon:
|
if self.n_action_steps > self.chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than action_horizon ({self.action_horizon})"
|
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
|
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
|
||||||
|
|||||||
@@ -647,7 +647,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
|
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device)
|
||||||
pad_masks.append(action_time_mask)
|
pad_masks.append(action_time_mask)
|
||||||
|
|
||||||
att_masks += [1] + ([0] * (self.config.action_horizon - 1))
|
att_masks += [1] + ([0] * (self.config.chunk_size - 1))
|
||||||
|
|
||||||
embs = torch.cat(embs, dim=1)
|
embs = torch.cat(embs, dim=1)
|
||||||
pad_masks = torch.cat(pad_masks, dim=1)
|
pad_masks = torch.cat(pad_masks, dim=1)
|
||||||
@@ -705,7 +705,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond
|
||||||
)
|
)
|
||||||
|
|
||||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||||
|
|
||||||
def action_out_proj_func(suffix_out):
|
def action_out_proj_func(suffix_out):
|
||||||
@@ -728,7 +728,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
if noise is None:
|
if noise is None:
|
||||||
# Sample noise with padded dimension (32) as expected by action_in_proj
|
# Sample noise with padded dimension (32) as expected by action_in_proj
|
||||||
actions_shape = (bsize, self.config.action_horizon, 32) # Use 32 for internal processing
|
actions_shape = (bsize, self.config.chunk_size, 32) # Use 32 for internal processing
|
||||||
noise = self.sample_noise(actions_shape, device)
|
noise = self.sample_noise(actions_shape, device)
|
||||||
|
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||||
@@ -806,7 +806,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
)
|
)
|
||||||
|
|
||||||
suffix_out = outputs_embeds[1]
|
suffix_out = outputs_embeds[1]
|
||||||
suffix_out = suffix_out[:, -self.config.action_horizon :]
|
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||||
return self.action_out_proj(suffix_out)
|
return self.action_out_proj(suffix_out)
|
||||||
|
|
||||||
|
|||||||
+2
-2
@@ -87,7 +87,7 @@ def test_pi05_forward_pass():
|
|||||||
action_dim=7,
|
action_dim=7,
|
||||||
state_dim=14,
|
state_dim=14,
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
action_horizon=16, # Shorter horizon for testing
|
chunk_size=16, # Shorter chunk_size for testing
|
||||||
n_action_steps=16, # Shorter action steps for testing
|
n_action_steps=16, # Shorter action steps for testing
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -111,7 +111,7 @@ def test_pi05_forward_pass():
|
|||||||
device = next(policy.parameters()).device
|
device = next(policy.parameters()).device
|
||||||
batch = {
|
batch = {
|
||||||
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
|
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
|
||||||
"action": torch.randn(batch_size, config.action_horizon, 7, dtype=torch.float32, device=device),
|
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
|
||||||
"observation.images.base_0_rgb": torch.rand(
|
"observation.images.base_0_rgb": torch.rand(
|
||||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||||
),
|
),
|
||||||
|
|||||||
+1
-1
@@ -36,7 +36,7 @@ def test_policy_instantiation():
|
|||||||
device = policy.device if hasattr(policy, "device") else "cpu"
|
device = policy.device if hasattr(policy, "device") else "cpu"
|
||||||
batch = {
|
batch = {
|
||||||
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
|
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
|
||||||
"action": torch.randn(batch_size, config.action_horizon, 7, dtype=torch.float32, device=device),
|
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
|
||||||
"observation.images.base_0_rgb": torch.rand(
|
"observation.images.base_0_rgb": torch.rand(
|
||||||
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
batch_size, 3, 224, 224, dtype=torch.float32, device=device
|
||||||
), # Use rand for [0,1] range
|
), # Use rand for [0,1] range
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
|
|||||||
print(f" - Action expert variant: {policy.config.action_expert_variant}")
|
print(f" - Action expert variant: {policy.config.action_expert_variant}")
|
||||||
print(f" - Action dimension: {policy.config.action_dim}")
|
print(f" - Action dimension: {policy.config.action_dim}")
|
||||||
print(f" - State dimension: {policy.config.state_dim}")
|
print(f" - State dimension: {policy.config.state_dim}")
|
||||||
print(f" - Action horizon: {policy.config.action_horizon}")
|
print(f" - Chunk_size: {policy.config.chunk_size}")
|
||||||
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
|
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
|
||||||
if model_name == "PI0.5":
|
if model_name == "PI0.5":
|
||||||
print(f" - discrete_state_input: {policy.config.discrete_state_input}")
|
print(f" - discrete_state_input: {policy.config.discrete_state_input}")
|
||||||
@@ -172,7 +172,7 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"):
|
|||||||
),
|
),
|
||||||
"action": torch.randn(
|
"action": torch.randn(
|
||||||
batch_size,
|
batch_size,
|
||||||
policy.config.action_horizon,
|
policy.config.chunk_size,
|
||||||
policy.config.action_dim,
|
policy.config.action_dim,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=device,
|
device=device,
|
||||||
|
|||||||
Reference in New Issue
Block a user