mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
build obs frame with name from training
This commit is contained in:
@@ -318,13 +318,20 @@ def run_ee_inference_loop(
|
|||||||
|
|
||||||
ee_state = joints_to_ee(joint_state.copy())
|
ee_state = joints_to_ee(joint_state.copy())
|
||||||
|
|
||||||
# Store current EE position for relative action conversion
|
|
||||||
current_ee_pos = torch.tensor([ee_state.get(k, 0.0) for k in sorted(ee_state.keys())])
|
|
||||||
|
|
||||||
# 3. Build observation frame with EE state for policy input
|
# 3. Build observation frame with EE state for policy input
|
||||||
# Build state array from EE values (sorted to match training order)
|
# Use state names from dataset features to match training order
|
||||||
ee_keys = sorted(ee_state.keys())
|
state_names = dataset.features.get("observation.state", {}).get("names", [])
|
||||||
ee_values = [ee_state[k] for k in ee_keys]
|
|
||||||
|
if state_names:
|
||||||
|
# Build state array using the exact names from training
|
||||||
|
ee_values = [ee_state.get(name, 0.0) for name in state_names]
|
||||||
|
else:
|
||||||
|
# Fallback: use sorted keys
|
||||||
|
state_names = sorted(ee_state.keys())
|
||||||
|
ee_values = [ee_state[k] for k in state_names]
|
||||||
|
|
||||||
|
# Store current EE position for relative action conversion (using same order)
|
||||||
|
current_ee_pos = torch.tensor(ee_values)
|
||||||
|
|
||||||
# Convert to relative state if enabled (UMI-style)
|
# Convert to relative state if enabled (UMI-style)
|
||||||
if use_relative_state:
|
if use_relative_state:
|
||||||
|
|||||||
Reference in New Issue
Block a user