mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
fix tests
This commit is contained in:
@@ -23,6 +23,27 @@ def test_pi05_model_architecture():
|
|||||||
dtype="float32",
|
dtype="float32",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set up input_features and output_features in the config
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
|
||||||
|
config.input_features = {
|
||||||
|
"observation.state": PolicyFeature(
|
||||||
|
type=FeatureType.STATE,
|
||||||
|
shape=(14,),
|
||||||
|
),
|
||||||
|
"observation.images.base_0_rgb": PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL,
|
||||||
|
shape=(3, 224, 224),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
config.output_features = {
|
||||||
|
"action": PolicyFeature(
|
||||||
|
type=FeatureType.ACTION,
|
||||||
|
shape=(7,),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
assert config.tokenizer_max_length == 200, (
|
assert config.tokenizer_max_length == 200, (
|
||||||
f"Expected tokenizer_max_length=200 for pi05, got {config.tokenizer_max_length}"
|
f"Expected tokenizer_max_length=200 for pi05, got {config.tokenizer_max_length}"
|
||||||
)
|
)
|
||||||
@@ -40,6 +61,10 @@ def test_pi05_model_architecture():
|
|||||||
"mean": torch.zeros(7),
|
"mean": torch.zeros(7),
|
||||||
"std": torch.ones(7),
|
"std": torch.ones(7),
|
||||||
},
|
},
|
||||||
|
"observation.images.base_0_rgb": {
|
||||||
|
"mean": torch.zeros(3, 224, 224),
|
||||||
|
"std": torch.ones(3, 224, 224),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Instantiate policy
|
# Instantiate policy
|
||||||
@@ -82,6 +107,27 @@ def test_pi05_forward_pass():
|
|||||||
n_action_steps=16, # Shorter action steps for testing
|
n_action_steps=16, # Shorter action steps for testing
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set up input_features and output_features in the config
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
|
||||||
|
config.input_features = {
|
||||||
|
"observation.state": PolicyFeature(
|
||||||
|
type=FeatureType.STATE,
|
||||||
|
shape=(14,),
|
||||||
|
),
|
||||||
|
"observation.images.base_0_rgb": PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL,
|
||||||
|
shape=(3, 224, 224),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
config.output_features = {
|
||||||
|
"action": PolicyFeature(
|
||||||
|
type=FeatureType.ACTION,
|
||||||
|
shape=(7,),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
# Create dummy dataset stats
|
# Create dummy dataset stats
|
||||||
dataset_stats = {
|
dataset_stats = {
|
||||||
"observation.state": {
|
"observation.state": {
|
||||||
@@ -92,6 +138,10 @@ def test_pi05_forward_pass():
|
|||||||
"mean": torch.zeros(7),
|
"mean": torch.zeros(7),
|
||||||
"std": torch.ones(7),
|
"std": torch.ones(7),
|
||||||
},
|
},
|
||||||
|
"observation.images.base_0_rgb": {
|
||||||
|
"mean": torch.zeros(3, 224, 224),
|
||||||
|
"std": torch.ones(3, 224, 224),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Instantiate policy
|
# Instantiate policy
|
||||||
|
|||||||
@@ -18,6 +18,27 @@ def test_policy_instantiation():
|
|||||||
# Create config
|
# Create config
|
||||||
config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
|
config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||||
|
|
||||||
|
# Set up input_features and output_features in the config
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
|
||||||
|
config.input_features = {
|
||||||
|
"observation.state": PolicyFeature(
|
||||||
|
type=FeatureType.STATE,
|
||||||
|
shape=(14,),
|
||||||
|
),
|
||||||
|
"observation.images.base_0_rgb": PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL,
|
||||||
|
shape=(3, 224, 224),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
config.output_features = {
|
||||||
|
"action": PolicyFeature(
|
||||||
|
type=FeatureType.ACTION,
|
||||||
|
shape=(7,),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
# Create dummy dataset stats
|
# Create dummy dataset stats
|
||||||
dataset_stats = {
|
dataset_stats = {
|
||||||
"observation.state": {
|
"observation.state": {
|
||||||
@@ -28,6 +49,10 @@ def test_policy_instantiation():
|
|||||||
"mean": torch.zeros(7),
|
"mean": torch.zeros(7),
|
||||||
"std": torch.ones(7),
|
"std": torch.ones(7),
|
||||||
},
|
},
|
||||||
|
"observation.images.base_0_rgb": {
|
||||||
|
"mean": torch.zeros(3, 224, 224),
|
||||||
|
"std": torch.ones(3, 224, 224),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Instantiate policy
|
# Instantiate policy
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ def create_dummy_stats(config):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Add stats for image keys if they exist
|
# Add stats for image keys if they exist
|
||||||
for key in config.image_keys:
|
for key in config.image_features.keys():
|
||||||
dummy_stats[key] = {
|
dummy_stats[key] = {
|
||||||
"mean": torch.zeros(3, config.image_resolution[0], config.image_resolution[1]),
|
"mean": torch.zeros(3, config.image_resolution[0], config.image_resolution[1]),
|
||||||
"std": torch.ones(3, config.image_resolution[0], config.image_resolution[1]),
|
"std": torch.ones(3, config.image_resolution[0], config.image_resolution[1]),
|
||||||
@@ -205,7 +205,7 @@ def _test_hub_loading(model_id, model_name):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Add images if they're in the config
|
# Add images if they're in the config
|
||||||
for key in policy.config.image_keys:
|
for key in policy.config.image_features.keys():
|
||||||
batch[key] = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device)
|
batch[key] = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -358,7 +358,7 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Add images based on config
|
# Add images based on config
|
||||||
for key in policy.config.image_keys:
|
for key in policy.config.image_features.keys():
|
||||||
batch[key] = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device)
|
batch[key] = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
# Test forward pass
|
# Test forward pass
|
||||||
|
|||||||
Reference in New Issue
Block a user