diff --git a/tests/policies/test_pi05_openpi.py b/tests/policies/test_pi05_openpi.py index ce1eb0e62..bddf0d5ea 100644 --- a/tests/policies/test_pi05_openpi.py +++ b/tests/policies/test_pi05_openpi.py @@ -23,6 +23,27 @@ def test_pi05_model_architecture(): dtype="float32", ) + # Set up input_features and output_features in the config + from lerobot.configs.types import FeatureType, PolicyFeature + + config.input_features = { + "observation.state": PolicyFeature( + type=FeatureType.STATE, + shape=(14,), + ), + "observation.images.base_0_rgb": PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 224, 224), + ), + } + + config.output_features = { + "action": PolicyFeature( + type=FeatureType.ACTION, + shape=(7,), + ), + } + assert config.tokenizer_max_length == 200, ( f"Expected tokenizer_max_length=200 for pi05, got {config.tokenizer_max_length}" ) @@ -40,6 +61,10 @@ def test_pi05_model_architecture(): "mean": torch.zeros(7), "std": torch.ones(7), }, + "observation.images.base_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + }, } # Instantiate policy @@ -82,6 +107,27 @@ def test_pi05_forward_pass(): n_action_steps=16, # Shorter action steps for testing ) + # Set up input_features and output_features in the config + from lerobot.configs.types import FeatureType, PolicyFeature + + config.input_features = { + "observation.state": PolicyFeature( + type=FeatureType.STATE, + shape=(14,), + ), + "observation.images.base_0_rgb": PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 224, 224), + ), + } + + config.output_features = { + "action": PolicyFeature( + type=FeatureType.ACTION, + shape=(7,), + ), + } + # Create dummy dataset stats dataset_stats = { "observation.state": { @@ -92,6 +138,10 @@ def test_pi05_forward_pass(): "mean": torch.zeros(7), "std": torch.ones(7), }, + "observation.images.base_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + }, } # Instantiate policy diff --git a/tests/policies/test_pi0_openpi.py b/tests/policies/test_pi0_openpi.py index 1da016bf0..57f6f5625 100644 --- a/tests/policies/test_pi0_openpi.py +++ b/tests/policies/test_pi0_openpi.py @@ -18,6 +18,27 @@ def test_policy_instantiation(): # Create config config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32") + # Set up input_features and output_features in the config + from lerobot.configs.types import FeatureType, PolicyFeature + + config.input_features = { + "observation.state": PolicyFeature( + type=FeatureType.STATE, + shape=(14,), + ), + "observation.images.base_0_rgb": PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 224, 224), + ), + } + + config.output_features = { + "action": PolicyFeature( + type=FeatureType.ACTION, + shape=(7,), + ), + } + # Create dummy dataset stats dataset_stats = { "observation.state": { @@ -28,6 +49,10 @@ def test_policy_instantiation(): "mean": torch.zeros(7), "std": torch.ones(7), }, + "observation.images.base_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + }, } # Instantiate policy diff --git a/tests/policies/test_pi0_pi05_hub.py b/tests/policies/test_pi0_pi05_hub.py index 94c9b4798..4cc2ad7c5 100644 --- a/tests/policies/test_pi0_pi05_hub.py +++ b/tests/policies/test_pi0_pi05_hub.py @@ -36,7 +36,7 @@ def create_dummy_stats(config): } # Add stats for image keys if they exist - for key in config.image_keys: + for key in config.image_features.keys(): dummy_stats[key] = { "mean": torch.zeros(3, config.image_resolution[0], config.image_resolution[1]), "std": torch.ones(3, config.image_resolution[0], config.image_resolution[1]), @@ -205,7 +205,7 @@ def _test_hub_loading(model_id, model_name): } # Add images if they're in the config - for key in policy.config.image_keys: + for key in policy.config.image_features.keys(): batch[key] = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device) try: @@ -358,7 +358,7 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class): } # Add images based on config - for key in policy.config.image_keys: + for key in policy.config.image_features.keys(): batch[key] = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device) # Test forward pass