diff --git a/test_pi0_pi05_hub.py b/test_pi0_pi05_hub.py index 522b32b2e..d737bc561 100644 --- a/test_pi0_pi05_hub.py +++ b/test_pi0_pi05_hub.py @@ -84,7 +84,8 @@ def test_hub_loading(model_id="pepijn223/pi0_base_fp32", model_name="PI0"): print(f" - State dimension: {policy.config.state_dim}") print(f" - Action horizon: {policy.config.action_horizon}") print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}") - print(f" - discrete_state_input: {policy.config.discrete_state_input}") + if model_name == "PI0.5": + print(f" - discrete_state_input: {policy.config.discrete_state_input}") print(f" - Device: {device}") print(f" - Dtype: {next(policy.parameters()).dtype}")