fix return action

This commit is contained in:
Pepijn
2026-01-08 14:39:23 +01:00
parent cafb956e15
commit 84f06a86af
+15 -22
View File
@@ -26,8 +26,8 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.utils import make_robot_action as tensor_to_robot_action
from lerobot.processor import make_default_processors
from lerobot.processor.core import RobotAction
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.utils.constants import ACTION, OBS_STR
@@ -62,16 +62,6 @@ CAMERA_CONFIG = {
}
def make_robot_action(action_values: dict, features: dict) -> RobotAction:
robot_action = {}
for key in features:
if key.startswith(ACTION + "."):
action_key = key.removeprefix(ACTION + ".")
if action_key in action_values:
robot_action[action_key] = action_values[action_key]
return robot_action
def load_relative_config(model_path: Path | str) -> tuple[PerTimestepNormalizer | None, bool]:
"""Load normalizer and relative_state setting from checkpoint."""
model_path = Path(model_path) if isinstance(model_path, str) else model_path
@@ -146,8 +136,8 @@ def inference_loop_relative(
if isinstance(state_tensor, torch.Tensor):
observation_frame[state_key] = convert_state_to_relative(state_tensor)
# Policy inference (outputs normalized relative actions)
action_values = predict_action(
# Policy inference (outputs action tensor)
action_tensor = predict_action(
observation=observation_frame,
policy=policy,
device=device,
@@ -158,17 +148,20 @@ def inference_loop_relative(
robot_type=robot.robot_type,
)
# Unnormalize actions
# Unnormalize relative actions if normalizer exists
if relative_normalizer is not None:
action_keys = [k for k in action_values.keys() if not k.startswith("task")]
action_tensor = torch.tensor([[action_values[k] for k in action_keys]])
action_tensor = action_tensor.unsqueeze(1)
action_unnorm = relative_normalizer.unnormalize(action_tensor)
for i, k in enumerate(action_keys):
action_values[k] = action_unnorm[0, 0, i].item()
# action_tensor shape: [1, action_dim] or [action_dim]
if action_tensor.dim() == 1:
action_tensor = action_tensor.unsqueeze(0).unsqueeze(0) # [1, 1, action_dim]
elif action_tensor.dim() == 2:
action_tensor = action_tensor.unsqueeze(1) # [batch, 1, action_dim]
action_tensor = relative_normalizer.unnormalize(action_tensor)
action_tensor = action_tensor.squeeze(1) # back to [batch, action_dim]
# Convert to absolute
relative_action = make_robot_action(action_values, dataset.features)
# Convert tensor to dict
relative_action = tensor_to_robot_action(action_tensor, dataset.features)
# Convert relative to absolute
absolute_action = convert_from_relative_actions_dict(relative_action, current_pos)
robot.send_action(absolute_action)