mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
Update evaluate_ee.py
This commit is contained in:
@@ -319,17 +319,18 @@ def run_ee_inference_loop(
|
|||||||
ee_state = joints_to_ee(joint_state.copy())
|
ee_state = joints_to_ee(joint_state.copy())
|
||||||
|
|
||||||
# 3. Build observation frame with EE state for policy input
|
# 3. Build observation frame with EE state for policy input
|
||||||
# Use state names from policy's input features (from training) to match expected order
|
# Get expected state dimension from policy's input features
|
||||||
state_feature = policy.config.input_features.get("observation.state")
|
state_feature = policy.config.input_features.get("observation.state")
|
||||||
state_names = getattr(state_feature, "names", None) if state_feature else None
|
expected_dim = state_feature.shape[0] if state_feature else None
|
||||||
|
|
||||||
if state_names:
|
# Build state array from EE values (sorted keys)
|
||||||
# Build state array using the exact names from training
|
ee_keys = sorted(ee_state.keys())
|
||||||
ee_values = [ee_state.get(name, 0.0) for name in state_names]
|
ee_values = [ee_state[k] for k in ee_keys]
|
||||||
else:
|
|
||||||
# Fallback: use sorted keys
|
# Truncate to match expected dimension (FK may output more than policy expects)
|
||||||
state_names = sorted(ee_state.keys())
|
if expected_dim and len(ee_values) > expected_dim:
|
||||||
ee_values = [ee_state[k] for k in state_names]
|
ee_values = ee_values[:expected_dim]
|
||||||
|
ee_keys = ee_keys[:expected_dim]
|
||||||
|
|
||||||
# Store current EE position for relative action conversion (using same order)
|
# Store current EE position for relative action conversion (using same order)
|
||||||
current_ee_pos = torch.tensor(ee_values)
|
current_ee_pos = torch.tensor(ee_values)
|
||||||
@@ -519,10 +520,11 @@ def main():
|
|||||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||||
policy_config.pretrained_path = HF_MODEL_ID
|
policy_config.pretrained_path = HF_MODEL_ID
|
||||||
|
|
||||||
# Create policy without new dataset meta (use pretrained config)
|
# Pass dataset meta for policy creation (required by make_policy)
|
||||||
policy = make_policy(policy_config, ds_meta=None)
|
policy = make_policy(policy_config, ds_meta=dataset.meta)
|
||||||
|
|
||||||
# Use pretrained stats from model, NOT from new evaluation dataset
|
# Load preprocessor/postprocessor from pretrained model
|
||||||
|
# DO NOT pass dataset_stats - let it load from pretrained model
|
||||||
# (evaluation dataset has different features than training dataset)
|
# (evaluation dataset has different features than training dataset)
|
||||||
preprocessor, postprocessor = make_pre_post_processors(
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
policy_cfg=policy.config,
|
policy_cfg=policy.config,
|
||||||
|
|||||||
Reference in New Issue
Block a user