From a9cf770b995638d32c70d4a88cb65c2050cf7bd8 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 8 Jan 2026 17:02:00 +0100 Subject: [PATCH] fix abolute --- examples/openarms/evaluate_ee.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/examples/openarms/evaluate_ee.py b/examples/openarms/evaluate_ee.py index 50be357dd..39505e6be 100644 --- a/examples/openarms/evaluate_ee.py +++ b/examples/openarms/evaluate_ee.py @@ -335,24 +335,31 @@ def run_ee_inference_loop( batch[key] = batch[key].unsqueeze(0) with torch.inference_mode(): - action = policy.select_action(batch) + action_tensor = policy.select_action(batch) - # 5. Postprocess to get EE action - ee_action = postprocessor(action) + # 5. Postprocess and convert tensor to dict + action_tensor = postprocessor(action_tensor) + + # Flatten to 1D: take first timestep if chunks, squeeze batch dims + while action_tensor.dim() > 1: + action_tensor = action_tensor[0] + + # Convert tensor to dict using action names from dataset + action_names = dataset.features[ACTION]["names"] + ee_action = {name: float(action_tensor[i]) for i, name in enumerate(action_names)} # 6. Convert relative action back to absolute if needed if use_relative_actions: - # Convert dict to tensor for relative->absolute conversion - action_keys = sorted([k for k in ee_action.keys() if "ee." in k or k.endswith(".pos")]) - action_tensor = torch.tensor([ee_action.get(k, 0.0) for k in action_keys]) + action_keys = sorted(ee_action.keys()) + action_vals = torch.tensor([ee_action[k] for k in action_keys]) # Unnormalize if we have a normalizer if relative_normalizer is not None: - action_tensor = relative_normalizer.unnormalize(action_tensor.unsqueeze(0).unsqueeze(0)) - action_tensor = action_tensor.squeeze(0).squeeze(0) + action_vals = relative_normalizer.unnormalize(action_vals.unsqueeze(0).unsqueeze(0)) + action_vals = action_vals.squeeze(0).squeeze(0) # Convert from relative to absolute - absolute_action = convert_from_relative_actions(action_tensor.unsqueeze(0), current_ee_pos) + absolute_action = convert_from_relative_actions(action_vals.unsqueeze(0), current_ee_pos) # Convert back to dict ee_action = {k: float(absolute_action[0, i]) for i, k in enumerate(action_keys)}