match state and exoected obs dimensions

This commit is contained in:
Pepijn
2026-01-08 17:52:19 +01:00
parent cbeb9ce00a
commit 99cdb07dda
+24 -20
View File
@@ -42,7 +42,6 @@ from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_featur
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.utils import make_robot_action
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline, make_default_processors
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import predict_action
@@ -319,18 +318,17 @@ def run_ee_inference_loop(
ee_state = joints_to_ee(joint_state.copy())
# 3. Build observation frame with EE state for policy input
# Get expected state dimension from policy's input features
state_feature = policy.config.input_features.get("observation.state")
expected_dim = state_feature.shape[0] if state_feature else None
# Build state array from EE values (sorted keys)
ee_keys = sorted(ee_state.keys())
# Filter to only EE keys (FK may include other keys in output)
# Expected: left_ee.{x,y,z,wx,wy,wz,gripper_pos}, right_ee.{...} = 14 total
ee_keys = sorted([k for k in ee_state.keys() if "_ee." in k])
ee_values = [ee_state[k] for k in ee_keys]
# Truncate to match expected dimension (FK may output more than policy expects)
if expected_dim and len(ee_values) > expected_dim:
ee_values = ee_values[:expected_dim]
ee_keys = ee_keys[:expected_dim]
# Debug: print on first step
if step == 0:
print(f" FK output keys ({len(ee_keys)}): {ee_keys}")
state_feature = policy.config.input_features.get("observation.state")
if state_feature:
print(f" Policy expects state dim: {state_feature.shape[0]}")
# Store current EE position for relative action conversion (using same order)
current_ee_pos = torch.tensor(ee_values)
@@ -364,8 +362,12 @@ def run_ee_inference_loop(
robot_type=robot.robot_type,
)
# 5. Convert action tensor to dict
ee_action = make_robot_action(action_tensor, dataset.features)
# 5. Convert action tensor to dict using EE keys (not joint keys from eval dataset)
action_tensor = action_tensor.squeeze(0).cpu()
while action_tensor.dim() > 1:
action_tensor = action_tensor[0]
# Use the same EE keys we used for state (truncated to match policy's action dim)
ee_action = {ee_keys[i]: float(action_tensor[i]) for i in range(len(action_tensor))}
# 6. Convert relative action back to absolute if needed
if use_relative_actions:
@@ -515,17 +517,17 @@ def main():
)
print(" Dataset created")
# Load policy
# Load policy directly using from_pretrained to preserve original EE features
# (make_policy would overwrite output_features with joint features from eval dataset)
print(f"\n[4/5] Loading policy from {HF_MODEL_ID}...")
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
from lerobot.policies.factory import get_policy_class
# Pass dataset meta for policy creation (required by make_policy)
policy = make_policy(policy_config, ds_meta=dataset.meta)
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_cls = get_policy_class(policy_config.type)
policy = policy_cls.from_pretrained(HF_MODEL_ID)
# Load preprocessor/postprocessor from pretrained model
# DO NOT pass dataset_stats - let it load from pretrained model
# (evaluation dataset has different features than training dataset)
# (uses the trained EE features, not joint features from eval dataset)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy.config,
pretrained_path=HF_MODEL_ID,
@@ -534,6 +536,8 @@ def main():
},
)
print(" Policy loaded")
print(f" State dim: {policy.config.input_features['observation.state'].shape[0]}")
print(f" Action dim: {policy.config.output_features['action'].shape[0]}")
# Auto-detect relative action/state settings from checkpoint
relative_normalizer, use_relative_actions, use_relative_state = load_relative_config(HF_MODEL_ID)