rename action_dim, state_dim to max_action_dim, max_state_dim

This commit is contained in:
Pepijn
2025-09-17 16:34:07 +02:00
parent 8c0cdb00a6
commit 53577f5f1a
4 changed files with 29 additions and 23 deletions
+6 -6
View File
@@ -21,8 +21,8 @@ def test_pi05_model_architecture():
# Create config
config = PI05OpenPIConfig(
action_dim=7,
state_dim=14,
max_action_dim=7,
max_state_dim=14,
dtype="float32",
)
@@ -89,8 +89,8 @@ def test_pi05_forward_pass():
# Create config
config = PI05OpenPIConfig(
action_dim=7,
state_dim=14,
max_action_dim=7,
max_state_dim=14,
dtype="float32",
chunk_size=16, # Shorter chunk_size for testing
n_action_steps=16, # Shorter action steps for testing
@@ -152,8 +152,8 @@ def test_pi0_vs_pi05_differences():
print("\nComparing PI0 vs PI0.5 architectures...")
# Create both configurations
config_pi0 = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32")
config_pi05 = PI05OpenPIConfig(action_dim=7, state_dim=14, dtype="float32")
config_pi0 = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
config_pi05 = PI05OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
dataset_stats = {
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
+3 -3
View File
@@ -19,7 +19,7 @@ def test_policy_instantiation():
print("Testing PI0OpenPI policy instantiation...")
# Create config
config = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32")
config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
# Create dummy dataset stats
dataset_stats = {
@@ -75,8 +75,8 @@ def test_config_creation():
try:
config = make_policy_config(
policy_type="pi0_openpi",
action_dim=7,
state_dim=14,
max_action_dim=7,
max_state_dim=14,
)
print("✓ Config created successfully through factory")
print(f" Config type: {type(config).__name__}")
@@ -79,7 +79,9 @@ def instantiate_lerobot_pi0(from_pretrained: bool = False):
policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
)
else:
config = PI0OpenPIConfig(action_dim=DUMMY_ACTION_DIM, state_dim=DUMMY_STATE_DIM, dtype="float32")
config = PI0OpenPIConfig(
max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32"
)
policy = PI0OpenPIPolicy(config, DUMMY_DATASET_STATS)
policy.to(DEVICE)
return policy
+17 -13
View File
@@ -17,12 +17,12 @@ def create_dummy_stats(config):
"""Create dummy dataset statistics for testing."""
dummy_stats = {
"observation.state": {
"mean": torch.zeros(config.state_dim),
"std": torch.ones(config.state_dim),
"mean": torch.zeros(config.max_state_dim),
"std": torch.ones(config.max_state_dim),
},
"action": {
"mean": torch.zeros(config.action_dim),
"std": torch.ones(config.action_dim),
"mean": torch.zeros(config.max_action_dim),
"std": torch.ones(config.max_action_dim),
},
}
@@ -97,8 +97,8 @@ def _test_hub_loading(model_id, model_name):
print(f" - Model type: {model_name}")
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
print(f" - Action expert variant: {policy.config.action_expert_variant}")
print(f" - Action dimension: {policy.config.action_dim}")
print(f" - State dimension: {policy.config.state_dim}")
print(f" - Action dimension: {policy.config.max_action_dim}")
print(f" - State dimension: {policy.config.max_state_dim}")
print(f" - Chunk_size: {policy.config.chunk_size}")
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
if model_name == "PI0.5":
@@ -185,12 +185,12 @@ def _test_hub_loading(model_id, model_name):
# Create test batch
batch = {
"observation.state": torch.randn(
batch_size, policy.config.state_dim, dtype=torch.float32, device=device
batch_size, policy.config.max_state_dim, dtype=torch.float32, device=device
),
"action": torch.randn(
batch_size,
policy.config.chunk_size,
policy.config.action_dim,
policy.config.max_action_dim,
dtype=torch.float32,
device=device,
),
@@ -282,8 +282,8 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
print(f" - Model type: {model_type}")
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
print(f" - Action expert variant: {policy.config.action_expert_variant}")
print(f" - Action dimension: {policy.config.action_dim}")
print(f" - State dimension: {policy.config.state_dim}")
print(f" - Action dimension: {policy.config.max_action_dim}")
print(f" - State dimension: {policy.config.max_state_dim}")
print(f" - Chunk size: {policy.config.chunk_size}")
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
print(f" - Device: {device}")
@@ -339,10 +339,14 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
batch_size = 1
batch = {
"observation.state": torch.randn(
batch_size, policy.config.state_dim, dtype=torch.float32, device=device
batch_size, policy.config.max_state_dim, dtype=torch.float32, device=device
),
"action": torch.randn(
batch_size, policy.config.chunk_size, policy.config.action_dim, dtype=torch.float32, device=device
batch_size,
policy.config.chunk_size,
policy.config.max_action_dim,
dtype=torch.float32,
device=device,
),
"task": ["Pick up the object"] * batch_size,
}
@@ -369,7 +373,7 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
policy.eval()
with torch.no_grad():
action = policy.select_action(batch)
expected_shape = (batch_size, policy.config.action_dim)
expected_shape = (batch_size, policy.config.max_action_dim)
assert action.shape == expected_shape, (
f"{model_id}: Expected action shape {expected_shape}, got {action.shape}"
)