From cbeb9ce00a9eba9e60ae655363a87177cef76c67 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 8 Jan 2026 17:36:53 +0100 Subject: [PATCH] Update evaluate_ee.py --- examples/openarms/evaluate_ee.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/examples/openarms/evaluate_ee.py b/examples/openarms/evaluate_ee.py index afd1c1201..4593a38f3 100644 --- a/examples/openarms/evaluate_ee.py +++ b/examples/openarms/evaluate_ee.py @@ -319,17 +319,18 @@ def run_ee_inference_loop( ee_state = joints_to_ee(joint_state.copy()) # 3. Build observation frame with EE state for policy input - # Use state names from policy's input features (from training) to match expected order + # Get expected state dimension from policy's input features state_feature = policy.config.input_features.get("observation.state") - state_names = getattr(state_feature, "names", None) if state_feature else None + expected_dim = state_feature.shape[0] if state_feature else None - if state_names: - # Build state array using the exact names from training - ee_values = [ee_state.get(name, 0.0) for name in state_names] - else: - # Fallback: use sorted keys - state_names = sorted(ee_state.keys()) - ee_values = [ee_state[k] for k in state_names] + # Build state array from EE values (sorted keys) + ee_keys = sorted(ee_state.keys()) + ee_values = [ee_state[k] for k in ee_keys] + + # Truncate to match expected dimension (FK may output more than policy expects) + if expected_dim and len(ee_values) > expected_dim: + ee_values = ee_values[:expected_dim] + ee_keys = ee_keys[:expected_dim] # Store current EE position for relative action conversion (using same order) current_ee_pos = torch.tensor(ee_values) @@ -519,10 +520,11 @@ def main(): policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID) policy_config.pretrained_path = HF_MODEL_ID - # Create policy without new dataset meta (use pretrained config) - policy = make_policy(policy_config, ds_meta=None) + # Pass dataset meta for policy creation (required by make_policy) + policy = make_policy(policy_config, ds_meta=dataset.meta) - # Use pretrained stats from model, NOT from new evaluation dataset + # Load preprocessor/postprocessor from pretrained model + # DO NOT pass dataset_stats - let it load from pretrained model # (evaluation dataset has different features than training dataset) preprocessor, postprocessor = make_pre_post_processors( policy_cfg=policy.config,