use same name for action and state dim as lerobot pi0 and remove fixed image keys

This commit is contained in:
Pepijn
2025-09-13 13:08:41 +02:00
parent 5361346bec
commit b9df1a4ac5
4 changed files with 38 additions and 72 deletions
@@ -16,7 +16,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@@ -36,23 +36,20 @@ class PI05OpenPIConfig(PreTrainedConfig):
n_obs_steps: int = 1 n_obs_steps: int = 1
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon" chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
n_action_steps: int = 50 # Number of action steps to execute n_action_steps: int = 50 # Number of action steps to execute
action_dim: int = 32 # Action dimension (will be padded to 32)
state_dim: int = 32 # State dimension (will be padded to 32) # Shorter state and action vectors will be padded to these dimensions
max_state_dim: int = 32 # State dimension (will be padded to 32)
max_action_dim: int = 32 # Action dimension (will be padded to 32)
# Flow matching parameters: see openpi `PI0Pytorch` # Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10 # Number of denoising steps during inference num_inference_steps: int = 10 # Number of denoising steps during inference
time_sampling_beta_alpha: float = 1.5 # Beta distribution alpha parameter for time sampling time_sampling_beta_alpha: float = 1.5 # Beta distribution alpha parameter for time sampling
time_sampling_beta_beta: float = 1.0 # Beta distribution beta parameter for time sampling time_sampling_beta_beta: float = 1.0 # Beta distribution beta parameter for time sampling
min_period: float = 4e-3 # Min period for sinusoidal positional encoding min_period: float = 4e-3 # Min period for sinusoidal positional encoding
max_period: float = 4.0 # Max period for sinusoidal positional encodingis my max_period: float = 4.0 # Max period for sinusoidal positional encoding
# Image preprocessing # Image preprocessing
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py` image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
image_keys: tuple[str, ...] = (
"observation.images.base_0_rgb",
"observation.images.left_wrist_0_rgb",
"observation.images.right_wrist_0_rgb",
)
# Normalization # Normalization
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
@@ -103,26 +100,12 @@ class PI05OpenPIConfig(PreTrainedConfig):
def validate_features(self) -> None: def validate_features(self) -> None:
"""Validate and set up input/output features.""" """Validate and set up input/output features."""
# Add image features # Image features are now handled dynamically through dataset configuration
for key in self.image_keys: # No need to auto-add hardcoded image keys
if key not in self.input_features:
self.input_features[key] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224), # Default shape, will be resized
)
# Ensure state and action features exist # State and action features are also handled dynamically through dataset configuration
if "observation.state" not in self.input_features: # The actual dimensions come from the feature shapes, max dimensions are used for padding only
self.input_features["observation.state"] = PolicyFeature( pass
type=FeatureType.STATE,
shape=(self.state_dim,),
)
if "action" not in self.output_features:
self.output_features["action"] = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.action_dim,),
)
def get_optimizer_preset(self) -> AdamWConfig: def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig( return AdamWConfig(
@@ -503,8 +503,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
precision=config.dtype, precision=config.dtype,
) )
self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width) self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim) self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim)
self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width)
self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
@@ -739,8 +739,8 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
actions_shape = ( actions_shape = (
bsize, bsize,
self.config.chunk_size, self.config.chunk_size,
self.config.action_dim, self.config.max_action_dim,
) # Use config action_dim for internal processing ) # Use config max_action_dim for internal processing
noise = self.sample_noise(actions_shape, device) noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
@@ -1235,12 +1235,12 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
def prepare_state(self, batch): # see lerobot pi0 `prepare_state` (exact copy) def prepare_state(self, batch): # see lerobot pi0 `prepare_state` (exact copy)
"""Pad state""" """Pad state"""
state = pad_vector(batch[OBS_STATE], self.config.state_dim) state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
return state return state
def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy) def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy)
"""Pad action""" """Pad action"""
actions = pad_vector(batch[ACTION], self.config.action_dim) actions = pad_vector(batch[ACTION], self.config.max_action_dim)
return actions return actions
@torch.no_grad() @torch.no_grad()
@@ -1294,8 +1294,8 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions) losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
# Truncate losses to actual action dimensions # Truncate losses to actual action dimensions
if self.config.action_dim < 32: original_action_dim = self.config.output_features[ACTION].shape[0]
losses = losses[:, :, : self.config.action_dim] losses = losses[:, :, :original_action_dim]
loss = losses.mean() loss = losses.mean()
@@ -16,7 +16,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@@ -33,23 +33,20 @@ class PI0OpenPIConfig(PreTrainedConfig):
n_obs_steps: int = 1 n_obs_steps: int = 1
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon" chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
n_action_steps: int = 50 # Number of action steps to execute n_action_steps: int = 50 # Number of action steps to execute
action_dim: int = 32 # Action dimension (will be padded to 32)
state_dim: int = 32 # State dimension (will be padded to 32) # Shorter state and action vectors will be padded to these dimensions
max_state_dim: int = 32 # State dimension (will be padded to 32)
max_action_dim: int = 32 # Action dimension (will be padded to 32)
# Flow matching parameters: see openpi `PI0Pytorch` # Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10 # Number of denoising steps during inference num_inference_steps: int = 10 # Number of denoising steps during inference
time_sampling_beta_alpha: float = 1.5 # Beta distribution alpha parameter for time sampling time_sampling_beta_alpha: float = 1.5 # Beta distribution alpha parameter for time sampling
time_sampling_beta_beta: float = 1.0 # Beta distribution beta parameter for time sampling time_sampling_beta_beta: float = 1.0 # Beta distribution beta parameter for time sampling
min_period: float = 4e-3 # Min period for sinusoidal positional encoding min_period: float = 4e-3 # Min period for sinusoidal positional encoding
max_period: float = 4.0 # Max period for sinusoidal positional encodingis my max_period: float = 4.0 # Max period for sinusoidal positional encoding
# Image preprocessing # Image preprocessing
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py` image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
image_keys: tuple[str, ...] = (
"observation.images.base_0_rgb",
"observation.images.left_wrist_0_rgb",
"observation.images.right_wrist_0_rgb",
)
# Normalization # Normalization
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
@@ -100,26 +97,12 @@ class PI0OpenPIConfig(PreTrainedConfig):
def validate_features(self) -> None: def validate_features(self) -> None:
"""Validate and set up input/output features.""" """Validate and set up input/output features."""
# Add image features # Image features are now handled dynamically through dataset configuration
for key in self.image_keys: # No need to auto-add hardcoded image keys
if key not in self.input_features:
self.input_features[key] = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224), # Default shape, will be resized
)
# Ensure state and action features exist # State and action features are also handled dynamically through dataset configuration
if "observation.state" not in self.input_features: # The actual dimensions come from the feature shapes, max dimensions are used for padding only
self.input_features["observation.state"] = PolicyFeature( pass
type=FeatureType.STATE,
shape=(self.state_dim,),
)
if "action" not in self.output_features:
self.output_features["action"] = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.action_dim,),
)
def get_optimizer_preset(self) -> AdamWConfig: def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig( return AdamWConfig(
@@ -503,10 +503,10 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
precision=config.dtype, precision=config.dtype,
) )
self.action_in_proj = nn.Linear(config.action_dim, action_expert_config.width) self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
self.action_out_proj = nn.Linear(action_expert_config.width, config.action_dim) self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim)
self.state_proj = nn.Linear(config.state_dim, action_expert_config.width) self.state_proj = nn.Linear(config.max_state_dim, action_expert_config.width)
self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width) self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width)
self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width)
@@ -758,8 +758,8 @@ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file_
actions_shape = ( actions_shape = (
bsize, bsize,
self.config.chunk_size, self.config.chunk_size,
self.config.action_dim, self.config.max_action_dim,
) # Use config action_dim for internal processing ) # Use config max_action_dim for internal processing
noise = self.sample_noise(actions_shape, device) noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
@@ -1250,12 +1250,12 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
def prepare_state(self, batch): # see lerobot pi0 `prepare_state` (exact copy) def prepare_state(self, batch): # see lerobot pi0 `prepare_state` (exact copy)
"""Pad state""" """Pad state"""
state = pad_vector(batch[OBS_STATE], self.config.state_dim) state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
return state return state
def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy) def prepare_action(self, batch): # see lerobot pi0 `prepare_action` (exact copy)
"""Pad action""" """Pad action"""
actions = pad_vector(batch[ACTION], self.config.action_dim) actions = pad_vector(batch[ACTION], self.config.max_action_dim)
return actions return actions
@torch.no_grad() @torch.no_grad()
@@ -1314,7 +1314,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
loss = losses.mean() loss = losses.mean()
loss_dict = { loss_dict = {
"loss": loss.item(), "l2_loss": loss.item(),
"loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(),
} }