build obs with policy names

This commit is contained in:
Pepijn
2026-01-08 17:31:48 +01:00
parent 8025ab0594
commit 09904e7797
+7 -4
View File
@@ -319,8 +319,9 @@ def run_ee_inference_loop(
ee_state = joints_to_ee(joint_state.copy())
# 3. Build observation frame with EE state for policy input
# Use state names from dataset features to match training order
state_names = dataset.features.get("observation.state", {}).get("names", [])
# Use state names from policy's input features (from training) to match expected order
state_feature = policy.config.input_features.get("observation.state")
state_names = getattr(state_feature, "names", None) if state_feature else None
if state_names:
# Build state array using the exact names from training
@@ -518,12 +519,14 @@ def main():
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID
policy = make_policy(policy_config, ds_meta=dataset.meta)
# Create policy without new dataset meta (use pretrained config)
policy = make_policy(policy_config, ds_meta=None)
# Use pretrained stats from model, NOT from new evaluation dataset
# (evaluation dataset has different features than training dataset)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy.config,
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats,
preprocessor_overrides={
"device_processor": {"device": str(policy.config.device)}
},