fix abolute

This commit is contained in:
Pepijn
2026-01-08 17:02:00 +01:00
parent 037747da82
commit a9cf770b99
+16 -9
View File
@@ -335,24 +335,31 @@ def run_ee_inference_loop(
batch[key] = batch[key].unsqueeze(0)
with torch.inference_mode():
action = policy.select_action(batch)
action_tensor = policy.select_action(batch)
# 5. Postprocess to get EE action
ee_action = postprocessor(action)
# 5. Postprocess and convert tensor to dict
action_tensor = postprocessor(action_tensor)
# Flatten to 1D: take first timestep if chunks, squeeze batch dims
while action_tensor.dim() > 1:
action_tensor = action_tensor[0]
# Convert tensor to dict using action names from dataset
action_names = dataset.features[ACTION]["names"]
ee_action = {name: float(action_tensor[i]) for i, name in enumerate(action_names)}
# 6. Convert relative action back to absolute if needed
if use_relative_actions:
# Convert dict to tensor for relative->absolute conversion
action_keys = sorted([k for k in ee_action.keys() if "ee." in k or k.endswith(".pos")])
action_tensor = torch.tensor([ee_action.get(k, 0.0) for k in action_keys])
action_keys = sorted(ee_action.keys())
action_vals = torch.tensor([ee_action[k] for k in action_keys])
# Unnormalize if we have a normalizer
if relative_normalizer is not None:
action_tensor = relative_normalizer.unnormalize(action_tensor.unsqueeze(0).unsqueeze(0))
action_tensor = action_tensor.squeeze(0).squeeze(0)
action_vals = relative_normalizer.unnormalize(action_vals.unsqueeze(0).unsqueeze(0))
action_vals = action_vals.squeeze(0).squeeze(0)
# Convert from relative to absolute
absolute_action = convert_from_relative_actions(action_tensor.unsqueeze(0), current_ee_pos)
absolute_action = convert_from_relative_actions(action_vals.unsqueeze(0), current_ee_pos)
# Convert back to dict
ee_action = {k: float(absolute_action[0, i]) for i, k in enumerate(action_keys)}