mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
rename action_dim, state_dim to max_action_dim, max_state_dim
This commit is contained in:
@@ -21,8 +21,8 @@ def test_pi05_model_architecture():
|
|||||||
|
|
||||||
# Create config
|
# Create config
|
||||||
config = PI05OpenPIConfig(
|
config = PI05OpenPIConfig(
|
||||||
action_dim=7,
|
max_action_dim=7,
|
||||||
state_dim=14,
|
max_state_dim=14,
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -89,8 +89,8 @@ def test_pi05_forward_pass():
|
|||||||
|
|
||||||
# Create config
|
# Create config
|
||||||
config = PI05OpenPIConfig(
|
config = PI05OpenPIConfig(
|
||||||
action_dim=7,
|
max_action_dim=7,
|
||||||
state_dim=14,
|
max_state_dim=14,
|
||||||
dtype="float32",
|
dtype="float32",
|
||||||
chunk_size=16, # Shorter chunk_size for testing
|
chunk_size=16, # Shorter chunk_size for testing
|
||||||
n_action_steps=16, # Shorter action steps for testing
|
n_action_steps=16, # Shorter action steps for testing
|
||||||
@@ -152,8 +152,8 @@ def test_pi0_vs_pi05_differences():
|
|||||||
print("\nComparing PI0 vs PI0.5 architectures...")
|
print("\nComparing PI0 vs PI0.5 architectures...")
|
||||||
|
|
||||||
# Create both configurations
|
# Create both configurations
|
||||||
config_pi0 = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32")
|
config_pi0 = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||||
config_pi05 = PI05OpenPIConfig(action_dim=7, state_dim=14, dtype="float32")
|
config_pi05 = PI05OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||||
|
|
||||||
dataset_stats = {
|
dataset_stats = {
|
||||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ def test_policy_instantiation():
|
|||||||
print("Testing PI0OpenPI policy instantiation...")
|
print("Testing PI0OpenPI policy instantiation...")
|
||||||
|
|
||||||
# Create config
|
# Create config
|
||||||
config = PI0OpenPIConfig(action_dim=7, state_dim=14, dtype="float32")
|
config = PI0OpenPIConfig(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||||
|
|
||||||
# Create dummy dataset stats
|
# Create dummy dataset stats
|
||||||
dataset_stats = {
|
dataset_stats = {
|
||||||
@@ -75,8 +75,8 @@ def test_config_creation():
|
|||||||
try:
|
try:
|
||||||
config = make_policy_config(
|
config = make_policy_config(
|
||||||
policy_type="pi0_openpi",
|
policy_type="pi0_openpi",
|
||||||
action_dim=7,
|
max_action_dim=7,
|
||||||
state_dim=14,
|
max_state_dim=14,
|
||||||
)
|
)
|
||||||
print("✓ Config created successfully through factory")
|
print("✓ Config created successfully through factory")
|
||||||
print(f" Config type: {type(config).__name__}")
|
print(f" Config type: {type(config).__name__}")
|
||||||
|
|||||||
@@ -79,7 +79,9 @@ def instantiate_lerobot_pi0(from_pretrained: bool = False):
|
|||||||
policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
|
policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
config = PI0OpenPIConfig(action_dim=DUMMY_ACTION_DIM, state_dim=DUMMY_STATE_DIM, dtype="float32")
|
config = PI0OpenPIConfig(
|
||||||
|
max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32"
|
||||||
|
)
|
||||||
policy = PI0OpenPIPolicy(config, DUMMY_DATASET_STATS)
|
policy = PI0OpenPIPolicy(config, DUMMY_DATASET_STATS)
|
||||||
policy.to(DEVICE)
|
policy.to(DEVICE)
|
||||||
return policy
|
return policy
|
||||||
|
|||||||
@@ -17,12 +17,12 @@ def create_dummy_stats(config):
|
|||||||
"""Create dummy dataset statistics for testing."""
|
"""Create dummy dataset statistics for testing."""
|
||||||
dummy_stats = {
|
dummy_stats = {
|
||||||
"observation.state": {
|
"observation.state": {
|
||||||
"mean": torch.zeros(config.state_dim),
|
"mean": torch.zeros(config.max_state_dim),
|
||||||
"std": torch.ones(config.state_dim),
|
"std": torch.ones(config.max_state_dim),
|
||||||
},
|
},
|
||||||
"action": {
|
"action": {
|
||||||
"mean": torch.zeros(config.action_dim),
|
"mean": torch.zeros(config.max_action_dim),
|
||||||
"std": torch.ones(config.action_dim),
|
"std": torch.ones(config.max_action_dim),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,8 +97,8 @@ def _test_hub_loading(model_id, model_name):
|
|||||||
print(f" - Model type: {model_name}")
|
print(f" - Model type: {model_name}")
|
||||||
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
|
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
|
||||||
print(f" - Action expert variant: {policy.config.action_expert_variant}")
|
print(f" - Action expert variant: {policy.config.action_expert_variant}")
|
||||||
print(f" - Action dimension: {policy.config.action_dim}")
|
print(f" - Action dimension: {policy.config.max_action_dim}")
|
||||||
print(f" - State dimension: {policy.config.state_dim}")
|
print(f" - State dimension: {policy.config.max_state_dim}")
|
||||||
print(f" - Chunk_size: {policy.config.chunk_size}")
|
print(f" - Chunk_size: {policy.config.chunk_size}")
|
||||||
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
|
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
|
||||||
if model_name == "PI0.5":
|
if model_name == "PI0.5":
|
||||||
@@ -185,12 +185,12 @@ def _test_hub_loading(model_id, model_name):
|
|||||||
# Create test batch
|
# Create test batch
|
||||||
batch = {
|
batch = {
|
||||||
"observation.state": torch.randn(
|
"observation.state": torch.randn(
|
||||||
batch_size, policy.config.state_dim, dtype=torch.float32, device=device
|
batch_size, policy.config.max_state_dim, dtype=torch.float32, device=device
|
||||||
),
|
),
|
||||||
"action": torch.randn(
|
"action": torch.randn(
|
||||||
batch_size,
|
batch_size,
|
||||||
policy.config.chunk_size,
|
policy.config.chunk_size,
|
||||||
policy.config.action_dim,
|
policy.config.max_action_dim,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=device,
|
device=device,
|
||||||
),
|
),
|
||||||
@@ -282,8 +282,8 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
|||||||
print(f" - Model type: {model_type}")
|
print(f" - Model type: {model_type}")
|
||||||
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
|
print(f" - PaliGemma variant: {policy.config.paligemma_variant}")
|
||||||
print(f" - Action expert variant: {policy.config.action_expert_variant}")
|
print(f" - Action expert variant: {policy.config.action_expert_variant}")
|
||||||
print(f" - Action dimension: {policy.config.action_dim}")
|
print(f" - Action dimension: {policy.config.max_action_dim}")
|
||||||
print(f" - State dimension: {policy.config.state_dim}")
|
print(f" - State dimension: {policy.config.max_state_dim}")
|
||||||
print(f" - Chunk size: {policy.config.chunk_size}")
|
print(f" - Chunk size: {policy.config.chunk_size}")
|
||||||
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
|
print(f" - Tokenizer max length: {policy.config.tokenizer_max_length}")
|
||||||
print(f" - Device: {device}")
|
print(f" - Device: {device}")
|
||||||
@@ -339,10 +339,14 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
|||||||
batch_size = 1
|
batch_size = 1
|
||||||
batch = {
|
batch = {
|
||||||
"observation.state": torch.randn(
|
"observation.state": torch.randn(
|
||||||
batch_size, policy.config.state_dim, dtype=torch.float32, device=device
|
batch_size, policy.config.max_state_dim, dtype=torch.float32, device=device
|
||||||
),
|
),
|
||||||
"action": torch.randn(
|
"action": torch.randn(
|
||||||
batch_size, policy.config.chunk_size, policy.config.action_dim, dtype=torch.float32, device=device
|
batch_size,
|
||||||
|
policy.config.chunk_size,
|
||||||
|
policy.config.max_action_dim,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
),
|
),
|
||||||
"task": ["Pick up the object"] * batch_size,
|
"task": ["Pick up the object"] * batch_size,
|
||||||
}
|
}
|
||||||
@@ -369,7 +373,7 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
|
|||||||
policy.eval()
|
policy.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
action = policy.select_action(batch)
|
action = policy.select_action(batch)
|
||||||
expected_shape = (batch_size, policy.config.action_dim)
|
expected_shape = (batch_size, policy.config.max_action_dim)
|
||||||
assert action.shape == expected_shape, (
|
assert action.shape == expected_shape, (
|
||||||
f"{model_id}: Expected action shape {expected_shape}, got {action.shape}"
|
f"{model_id}: Expected action shape {expected_shape}, got {action.shape}"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user