mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
fix abolute
This commit is contained in:
@@ -335,24 +335,31 @@ def run_ee_inference_loop(
|
||||
batch[key] = batch[key].unsqueeze(0)
|
||||
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(batch)
|
||||
action_tensor = policy.select_action(batch)
|
||||
|
||||
# 5. Postprocess to get EE action
|
||||
ee_action = postprocessor(action)
|
||||
# 5. Postprocess and convert tensor to dict
|
||||
action_tensor = postprocessor(action_tensor)
|
||||
|
||||
# Flatten to 1D: take first timestep if chunks, squeeze batch dims
|
||||
while action_tensor.dim() > 1:
|
||||
action_tensor = action_tensor[0]
|
||||
|
||||
# Convert tensor to dict using action names from dataset
|
||||
action_names = dataset.features[ACTION]["names"]
|
||||
ee_action = {name: float(action_tensor[i]) for i, name in enumerate(action_names)}
|
||||
|
||||
# 6. Convert relative action back to absolute if needed
|
||||
if use_relative_actions:
|
||||
# Convert dict to tensor for relative->absolute conversion
|
||||
action_keys = sorted([k for k in ee_action.keys() if "ee." in k or k.endswith(".pos")])
|
||||
action_tensor = torch.tensor([ee_action.get(k, 0.0) for k in action_keys])
|
||||
action_keys = sorted(ee_action.keys())
|
||||
action_vals = torch.tensor([ee_action[k] for k in action_keys])
|
||||
|
||||
# Unnormalize if we have a normalizer
|
||||
if relative_normalizer is not None:
|
||||
action_tensor = relative_normalizer.unnormalize(action_tensor.unsqueeze(0).unsqueeze(0))
|
||||
action_tensor = action_tensor.squeeze(0).squeeze(0)
|
||||
action_vals = relative_normalizer.unnormalize(action_vals.unsqueeze(0).unsqueeze(0))
|
||||
action_vals = action_vals.squeeze(0).squeeze(0)
|
||||
|
||||
# Convert from relative to absolute
|
||||
absolute_action = convert_from_relative_actions(action_tensor.unsqueeze(0), current_ee_pos)
|
||||
absolute_action = convert_from_relative_actions(action_vals.unsqueeze(0), current_ee_pos)
|
||||
|
||||
# Convert back to dict
|
||||
ee_action = {k: float(absolute_action[0, i]) for i, k in enumerate(action_keys)}
|
||||
|
||||
Reference in New Issue
Block a user