fix tests

This commit is contained in:
Pepijn
2025-09-17 19:06:23 +02:00
parent 9461b9f8d5
commit 6467ce10d4
3 changed files with 78 additions and 3 deletions
+50
View File
@@ -23,6 +23,27 @@ def test_pi05_model_architecture():
dtype="float32",
)
# Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature
config.input_features = {
"observation.state": PolicyFeature(
type=FeatureType.STATE,
shape=(14,),
),
"observation.images.base_0_rgb": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224),
),
}
config.output_features = {
"action": PolicyFeature(
type=FeatureType.ACTION,
shape=(7,),
),
}
assert config.tokenizer_max_length == 200, (
f"Expected tokenizer_max_length=200 for pi05, got {config.tokenizer_max_length}"
)
@@ -40,6 +61,10 @@ def test_pi05_model_architecture():
"mean": torch.zeros(7),
"std": torch.ones(7),
},
"observation.images.base_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
},
}
# Instantiate policy
@@ -82,6 +107,27 @@ def test_pi05_forward_pass():
n_action_steps=16, # Shorter action steps for testing
)
# Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature
config.input_features = {
"observation.state": PolicyFeature(
type=FeatureType.STATE,
shape=(14,),
),
"observation.images.base_0_rgb": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224),
),
}
config.output_features = {
"action": PolicyFeature(
type=FeatureType.ACTION,
shape=(7,),
),
}
# Create dummy dataset stats
dataset_stats = {
"observation.state": {
@@ -92,6 +138,10 @@ def test_pi05_forward_pass():
"mean": torch.zeros(7),
"std": torch.ones(7),
},
"observation.images.base_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
},
}
# Instantiate policy