diff --git a/examples/openarms/evaluate_ee.py b/examples/openarms/evaluate_ee.py index 4593a38f3..d6d68de0f 100644 --- a/examples/openarms/evaluate_ee.py +++ b/examples/openarms/evaluate_ee.py @@ -42,7 +42,6 @@ from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_featur from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts from lerobot.model.kinematics import RobotKinematics from lerobot.policies.factory import make_policy, make_pre_post_processors -from lerobot.policies.utils import make_robot_action from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline, make_default_processors from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.control_utils import predict_action @@ -319,18 +318,17 @@ def run_ee_inference_loop( ee_state = joints_to_ee(joint_state.copy()) # 3. Build observation frame with EE state for policy input - # Get expected state dimension from policy's input features - state_feature = policy.config.input_features.get("observation.state") - expected_dim = state_feature.shape[0] if state_feature else None - - # Build state array from EE values (sorted keys) - ee_keys = sorted(ee_state.keys()) + # Filter to only EE keys (FK may include other keys in output) + # Expected: left_ee.{x,y,z,wx,wy,wz,gripper_pos}, right_ee.{...} = 14 total + ee_keys = sorted([k for k in ee_state.keys() if "_ee." in k]) ee_values = [ee_state[k] for k in ee_keys] - # Truncate to match expected dimension (FK may output more than policy expects) - if expected_dim and len(ee_values) > expected_dim: - ee_values = ee_values[:expected_dim] - ee_keys = ee_keys[:expected_dim] + # Debug: print on first step + if step == 0: + print(f" FK output keys ({len(ee_keys)}): {ee_keys}") + state_feature = policy.config.input_features.get("observation.state") + if state_feature: + print(f" Policy expects state dim: {state_feature.shape[0]}") # Store current EE position for relative action conversion (using same order) current_ee_pos = torch.tensor(ee_values) @@ -364,8 +362,12 @@ def run_ee_inference_loop( robot_type=robot.robot_type, ) - # 5. Convert action tensor to dict - ee_action = make_robot_action(action_tensor, dataset.features) + # 5. Convert action tensor to dict using EE keys (not joint keys from eval dataset) + action_tensor = action_tensor.squeeze(0).cpu() + while action_tensor.dim() > 1: + action_tensor = action_tensor[0] + # Use the same EE keys we used for state (truncated to match policy's action dim) + ee_action = {ee_keys[i]: float(action_tensor[i]) for i in range(len(action_tensor))} # 6. Convert relative action back to absolute if needed if use_relative_actions: @@ -515,17 +517,17 @@ def main(): ) print(" Dataset created") - # Load policy + # Load policy directly using from_pretrained to preserve original EE features + # (make_policy would overwrite output_features with joint features from eval dataset) print(f"\n[4/5] Loading policy from {HF_MODEL_ID}...") - policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID) - policy_config.pretrained_path = HF_MODEL_ID + from lerobot.policies.factory import get_policy_class - # Pass dataset meta for policy creation (required by make_policy) - policy = make_policy(policy_config, ds_meta=dataset.meta) + policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID) + policy_cls = get_policy_class(policy_config.type) + policy = policy_cls.from_pretrained(HF_MODEL_ID) # Load preprocessor/postprocessor from pretrained model - # DO NOT pass dataset_stats - let it load from pretrained model - # (evaluation dataset has different features than training dataset) + # (uses the trained EE features, not joint features from eval dataset) preprocessor, postprocessor = make_pre_post_processors( policy_cfg=policy.config, pretrained_path=HF_MODEL_ID, @@ -534,6 +536,8 @@ def main(): }, ) print(" Policy loaded") + print(f" State dim: {policy.config.input_features['observation.state'].shape[0]}") + print(f" Action dim: {policy.config.output_features['action'].shape[0]}") # Auto-detect relative action/state settings from checkpoint relative_normalizer, use_relative_actions, use_relative_state = load_relative_config(HF_MODEL_ID)