build obs frame with name from training

This commit is contained in:
Pepijn
2026-01-08 17:28:39 +01:00
parent 8039a76e77
commit 8025ab0594
+13 -6
View File
@@ -318,13 +318,20 @@ def run_ee_inference_loop(
ee_state = joints_to_ee(joint_state.copy())
# Store current EE position for relative action conversion
current_ee_pos = torch.tensor([ee_state.get(k, 0.0) for k in sorted(ee_state.keys())])
# 3. Build observation frame with EE state for policy input
# Build state array from EE values (sorted to match training order)
ee_keys = sorted(ee_state.keys())
ee_values = [ee_state[k] for k in ee_keys]
# Use state names from dataset features to match training order
state_names = dataset.features.get("observation.state", {}).get("names", [])
if state_names:
# Build state array using the exact names from training
ee_values = [ee_state.get(name, 0.0) for name in state_names]
else:
# Fallback: use sorted keys
state_names = sorted(ee_state.keys())
ee_values = [ee_state[k] for k in state_names]
# Store current EE position for relative action conversion (using same order)
current_ee_pos = torch.tensor(ee_values)
# Convert to relative state if enabled (UMI-style)
if use_relative_state: