diff --git a/tests/policies/test_pi05_openpi.py b/tests/policies/test_pi05_openpi.py index cee7d3248..aa6e732e7 100644 --- a/tests/policies/test_pi05_openpi.py +++ b/tests/policies/test_pi05_openpi.py @@ -21,8 +21,8 @@ def test_pi05_model_architecture(): # Create config config = PI05OpenPIConfig( - action_dim=7, - state_dim=14, + max_action_dim=7, + max_state_dim=14, dtype="float32", ) @@ -89,8 +89,8 @@ def test_pi05_forward_pass(): # Create config config = PI05OpenPIConfig( - action_dim=7, - state_dim=14, + max_action_dim=7, + max_state_dim=14, dtype="float32", chunk_size=16, # Shorter chunk_size for testing n_action_steps=16, # Shorter action steps for testing @@ -152,8 +152,8 @@ def test_pi0_vs_pi05_differences(): print("\nComparing PI0 vs PI0.5 architectures...") # Create both configurations - config_pi0 = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32") - config_pi05 = PI05OpenPIConfig(action_dim=7, state_dim=14, dtype="float32") + config_pi0 = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32") + config_pi05 = PI05OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32") dataset_stats = { "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, diff --git a/tests/policies/test_pi0_openpi.py b/tests/policies/test_pi0_openpi.py index 966323aba..6254d4f50 100644 --- a/tests/policies/test_pi0_openpi.py +++ b/tests/policies/test_pi0_openpi.py @@ -19,7 +19,7 @@ def test_policy_instantiation(): print("Testing PI0OpenPI policy instantiation...") # Create config - config = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32") + config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32") # Create dummy dataset stats dataset_stats = { @@ -75,8 +75,8 @@ def test_config_creation(): try: config = make_policy_config( policy_type="pi0_openpi", - action_dim=7, - state_dim=14, + max_action_dim=7, + max_state_dim=14, ) print("✓ Config created successfully through factory") print(f" Config type: {type(config).__name__}") diff --git a/tests/policies/test_pi0_original_vs_lerobot.py b/tests/policies/test_pi0_original_vs_lerobot.py index 688d24749..77339708f 100644 --- a/tests/policies/test_pi0_original_vs_lerobot.py +++ b/tests/policies/test_pi0_original_vs_lerobot.py @@ -79,7 +79,9 @@ def instantiate_lerobot_pi0(from_pretrained: bool = False): policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS ) else: - config = PI0OpenPIConfig(action_dim=DUMMY_ACTION_DIM, state_dim=DUMMY_STATE_DIM, dtype="float32") + config = PI0OpenPIConfig( + max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32" + ) policy = PI0OpenPIPolicy(config, DUMMY_DATASET_STATS) policy.to(DEVICE) return policy diff --git a/tests/policies/test_pi0_pi05_hub.py b/tests/policies/test_pi0_pi05_hub.py index 8d123a1a4..350f1552a 100644 --- a/tests/policies/test_pi0_pi05_hub.py +++ b/tests/policies/test_pi0_pi05_hub.py @@ -17,12 +17,12 @@ def create_dummy_stats(config): """Create dummy dataset statistics for testing.""" dummy_stats = { "observation.state": { - "mean": torch.zeros(config.state_dim), - "std": torch.ones(config.state_dim), + "mean": torch.zeros(config.max_state_dim), + "std": torch.ones(config.max_state_dim), }, "action": { - "mean": torch.zeros(config.action_dim), - "std": torch.ones(config.action_dim), + "mean": torch.zeros(config.max_action_dim), + "std": torch.ones(config.max_action_dim), }, } @@ -97,8 +97,8 @@ def _test_hub_loading(model_id, model_name): print(f" - Model type: {model_name}") print(f" - PaliGemma variant: {policy.config.paligemma_variant}") print(f" - Action expert variant: {policy.config.action_expert_variant}") - print(f" - Action dimension: {policy.config.action_dim}") - print(f" - State dimension: {policy.config.state_dim}") + print(f" - Action dimension: {policy.config.max_action_dim}") + print(f" - State dimension: {policy.config.max_state_dim}") print(f" - Chunk_size: {policy.config.chunk_size}") print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}") if model_name == "PI0.5": @@ -185,12 +185,12 @@ def _test_hub_loading(model_id, model_name): # Create test batch batch = { "observation.state": torch.randn( - batch_size, policy.config.state_dim, dtype=torch.float32, device=device + batch_size, policy.config.max_state_dim, dtype=torch.float32, device=device ), "action": torch.randn( batch_size, policy.config.chunk_size, - policy.config.action_dim, + policy.config.max_action_dim, dtype=torch.float32, device=device, ), @@ -282,8 +282,8 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class): print(f" - Model type: {model_type}") print(f" - PaliGemma variant: {policy.config.paligemma_variant}") print(f" - Action expert variant: {policy.config.action_expert_variant}") - print(f" - Action dimension: {policy.config.action_dim}") - print(f" - State dimension: {policy.config.state_dim}") + print(f" - Action dimension: {policy.config.max_action_dim}") + print(f" - State dimension: {policy.config.max_state_dim}") print(f" - Chunk size: {policy.config.chunk_size}") print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}") print(f" - Device: {device}") @@ -339,10 +339,14 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class): batch_size = 1 batch = { "observation.state": torch.randn( - batch_size, policy.config.state_dim, dtype=torch.float32, device=device + batch_size, policy.config.max_state_dim, dtype=torch.float32, device=device ), "action": torch.randn( - batch_size, policy.config.chunk_size, policy.config.action_dim, dtype=torch.float32, device=device + batch_size, + policy.config.chunk_size, + policy.config.max_action_dim, + dtype=torch.float32, + device=device, ), "task": ["Pick up the object"] * batch_size, } @@ -369,7 +373,7 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class): policy.eval() with torch.no_grad(): action = policy.select_action(batch) - expected_shape = (batch_size, policy.config.action_dim) + expected_shape = (batch_size, policy.config.max_action_dim) assert action.shape == expected_shape, ( f"{model_id}: Expected action shape {expected_shape}, got {action.shape}" )