mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
rename action_dim, state_dim to max_action_dim, max_state_dim
This commit is contained in:
@@ -21,8 +21,8 @@ def test_pi05_model_architecture():
|
||||
|
||||
# Create config
|
||||
config = PI05OpenPIConfig(
|
||||
action_dim=7,
|
||||
state_dim=14,
|
||||
max_action_dim=7,
|
||||
max_state_dim=14,
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
@@ -89,8 +89,8 @@ def test_pi05_forward_pass():
|
||||
|
||||
# Create config
|
||||
config = PI05OpenPIConfig(
|
||||
action_dim=7,
|
||||
state_dim=14,
|
||||
max_action_dim=7,
|
||||
max_state_dim=14,
|
||||
dtype="float32",
|
||||
chunk_size=16, # Shorter chunk_size for testing
|
||||
n_action_steps=16, # Shorter action steps for testing
|
||||
@@ -152,8 +152,8 @@ def test_pi0_vs_pi05_differences():
|
||||
print("\nComparing PI0 vs PI0.5 architectures...")
|
||||
|
||||
# Create both configurations
|
||||
config_pi0 = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32")
|
||||
config_pi05 = PI05OpenPIConfig(action_dim=7, state_dim=14, dtype="float32")
|
||||
config_pi0 = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
config_pi05 = PI05OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
|
||||
@@ -19,7 +19,7 @@ def test_policy_instantiation():
|
||||
print("Testing PI0OpenPI policy instantiation...")
|
||||
|
||||
# Create config
|
||||
config = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32")
|
||||
config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Create dummy dataset stats
|
||||
dataset_stats = {
|
||||
@@ -75,8 +75,8 @@ def test_config_creation():
|
||||
try:
|
||||
config = make_policy_config(
|
||||
policy_type="pi0_openpi",
|
||||
action_dim=7,
|
||||
state_dim=14,
|
||||
max_action_dim=7,
|
||||
max_state_dim=14,
|
||||
)
|
||||
print("✓ Config created successfully through factory")
|
||||
print(f" Config type: {type(config).__name__}")
|
||||
|
||||
@@ -79,7 +79,9 @@ def instantiate_lerobot_pi0(from_pretrained: bool = False):
|
||||
policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
|
||||
)
|
||||
else:
|
||||
config = PI0OpenPIConfig(action_dim=DUMMY_ACTION_DIM, state_dim=DUMMY_STATE_DIM, dtype="float32")
|
||||
config = PI0OpenPIConfig(
|
||||
max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32"
|
||||
)
|
||||
policy = PI0OpenPIPolicy(config, DUMMY_DATASET_STATS)
|
||||
policy.to(DEVICE)
|
||||
return policy
|
||||
|
||||
@@ -17,12 +17,12 @@ def create_dummy_stats(config):
|
||||
"""Create dummy dataset statistics for testing."""
|
||||
dummy_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(config.state_dim),
|
||||
"std": torch.ones(config.state_dim),
|
||||
"mean": torch.zeros(config.max_state_dim),
|
||||
"std": torch.ones(config.max_state_dim),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(config.action_dim),
|
||||
"std": torch.ones(config.action_dim),
|
||||
"mean": torch.zeros(config.max_action_dim),
|
||||
"std": torch.ones(config.max_action_dim),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -97,8 +97,8 @@ def _test_hub_loading(model_id, model_name):
|
||||
print(f" - Model type: {model_name}")
|
||||
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
|
||||
print(f" - Action expert variant: {policy.config.action_expert_variant}")
|
||||
print(f" - Action dimension: {policy.config.action_dim}")
|
||||
print(f" - State dimension: {policy.config.state_dim}")
|
||||
print(f" - Action dimension: {policy.config.max_action_dim}")
|
||||
print(f" - State dimension: {policy.config.max_state_dim}")
|
||||
print(f" - Chunk_size: {policy.config.chunk_size}")
|
||||
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
|
||||
if model_name == "PI0.5":
|
||||
@@ -185,12 +185,12 @@ def _test_hub_loading(model_id, model_name):
|
||||
# Create test batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(
|
||||
batch_size, policy.config.state_dim, dtype=torch.float32, device=device
|
||||
batch_size, policy.config.max_state_dim, dtype=torch.float32, device=device
|
||||
),
|
||||
"action": torch.randn(
|
||||
batch_size,
|
||||
policy.config.chunk_size,
|
||||
policy.config.action_dim,
|
||||
policy.config.max_action_dim,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
),
|
||||
@@ -282,8 +282,8 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
||||
print(f" - Model type: {model_type}")
|
||||
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
|
||||
print(f" - Action expert variant: {policy.config.action_expert_variant}")
|
||||
print(f" - Action dimension: {policy.config.action_dim}")
|
||||
print(f" - State dimension: {policy.config.state_dim}")
|
||||
print(f" - Action dimension: {policy.config.max_action_dim}")
|
||||
print(f" - State dimension: {policy.config.max_state_dim}")
|
||||
print(f" - Chunk size: {policy.config.chunk_size}")
|
||||
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
|
||||
print(f" - Device: {device}")
|
||||
@@ -339,10 +339,14 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
||||
batch_size = 1
|
||||
batch = {
|
||||
"observation.state": torch.randn(
|
||||
batch_size, policy.config.state_dim, dtype=torch.float32, device=device
|
||||
batch_size, policy.config.max_state_dim, dtype=torch.float32, device=device
|
||||
),
|
||||
"action": torch.randn(
|
||||
batch_size, policy.config.chunk_size, policy.config.action_dim, dtype=torch.float32, device=device
|
||||
batch_size,
|
||||
policy.config.chunk_size,
|
||||
policy.config.max_action_dim,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
),
|
||||
"task": ["Pick up the object"] * batch_size,
|
||||
}
|
||||
@@ -369,7 +373,7 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
action = policy.select_action(batch)
|
||||
expected_shape = (batch_size, policy.config.action_dim)
|
||||
expected_shape = (batch_size, policy.config.max_action_dim)
|
||||
assert action.shape == expected_shape, (
|
||||
f"{model_id}: Expected action shape {expected_shape}, got {action.shape}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user