mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
fix tests
This commit is contained in:
@@ -23,6 +23,27 @@ def test_pi05_model_architecture():
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
# Set up input_features and output_features in the config
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(14,),
|
||||
),
|
||||
"observation.images.base_0_rgb": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224),
|
||||
),
|
||||
}
|
||||
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(7,),
|
||||
),
|
||||
}
|
||||
|
||||
assert config.tokenizer_max_length == 200, (
|
||||
f"Expected tokenizer_max_length=200 for pi05, got {config.tokenizer_max_length}"
|
||||
)
|
||||
@@ -40,6 +61,10 @@ def test_pi05_model_architecture():
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
},
|
||||
"observation.images.base_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
},
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
@@ -82,6 +107,27 @@ def test_pi05_forward_pass():
|
||||
n_action_steps=16, # Shorter action steps for testing
|
||||
)
|
||||
|
||||
# Set up input_features and output_features in the config
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(14,),
|
||||
),
|
||||
"observation.images.base_0_rgb": PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224),
|
||||
),
|
||||
}
|
||||
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(7,),
|
||||
),
|
||||
}
|
||||
|
||||
# Create dummy dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
@@ -92,6 +138,10 @@ def test_pi05_forward_pass():
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
},
|
||||
"observation.images.base_0_rgb": {
|
||||
"mean": torch.zeros(3, 224, 224),
|
||||
"std": torch.ones(3, 224, 224),
|
||||
},
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
|
||||
Reference in New Issue
Block a user