mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 05:59:52 +00:00
TODO: Make test works
This commit is contained in:
@@ -14,13 +14,22 @@ pytestmark = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
from lerobot.policies.factory import make_policy_config # noqa: E402
|
||||
from lerobot.policies.pi0 import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402
|
||||
from lerobot.policies.pi0_openpi import ( # noqa: E402
|
||||
PI0OpenPIConfig,
|
||||
PI0OpenPIPolicy,
|
||||
make_pi0_openpi_pre_post_processors, # noqa: E402
|
||||
)
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
# Set seed
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_policy_instantiation():
|
||||
# Create config
|
||||
|
||||
set_seed(42)
|
||||
config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Set up input_features and output_features in the config
|
||||
@@ -61,11 +70,13 @@ def test_policy_instantiation():
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0OpenPIPolicy(config, dataset_stats)
|
||||
|
||||
policy = PI0OpenPIPolicy(config)
|
||||
preprocessor, postprocessor = make_pi0_openpi_pre_post_processors(
|
||||
config=config, dataset_stats=dataset_stats
|
||||
)
|
||||
# Test forward pass with dummy data
|
||||
batch_size = 1
|
||||
device = policy.device if hasattr(policy, "device") else "cpu"
|
||||
device = config.device
|
||||
batch = {
|
||||
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
|
||||
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
|
||||
@@ -74,7 +85,7 @@ def test_policy_instantiation():
|
||||
), # Use rand for [0,1] range
|
||||
"task": ["Pick up the object"] * batch_size,
|
||||
}
|
||||
|
||||
batch = preprocessor(batch)
|
||||
try:
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
|
||||
@@ -85,6 +96,8 @@ def test_policy_instantiation():
|
||||
try:
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
action = postprocessor(action)
|
||||
print(f"Action: {action}")
|
||||
print(f"Action prediction successful. Action shape: {action.shape}")
|
||||
except Exception as e:
|
||||
print(f"Action prediction failed: {e}")
|
||||
|
||||
Reference in New Issue
Block a user