diff --git a/examples/openarms/evaluate_relative.py b/examples/openarms/evaluate_relative.py index 7d89ab49c..3c708c0e8 100644 --- a/examples/openarms/evaluate_relative.py +++ b/examples/openarms/evaluate_relative.py @@ -26,8 +26,8 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts from lerobot.policies.factory import make_policy, make_pre_post_processors +from lerobot.policies.utils import make_robot_action as tensor_to_robot_action from lerobot.processor import make_default_processors -from lerobot.processor.core import RobotAction from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig from lerobot.robots.openarms.openarms_follower import OpenArmsFollower from lerobot.utils.constants import ACTION, OBS_STR @@ -62,16 +62,6 @@ CAMERA_CONFIG = { } -def make_robot_action(action_values: dict, features: dict) -> RobotAction: - robot_action = {} - for key in features: - if key.startswith(ACTION + "."): - action_key = key.removeprefix(ACTION + ".") - if action_key in action_values: - robot_action[action_key] = action_values[action_key] - return robot_action - - def load_relative_config(model_path: Path | str) -> tuple[PerTimestepNormalizer | None, bool]: """Load normalizer and relative_state setting from checkpoint.""" model_path = Path(model_path) if isinstance(model_path, str) else model_path @@ -146,8 +136,8 @@ def inference_loop_relative( if isinstance(state_tensor, torch.Tensor): observation_frame[state_key] = convert_state_to_relative(state_tensor) - # Policy inference (outputs normalized relative actions) - action_values = predict_action( + # Policy inference (outputs action tensor) + action_tensor = predict_action( observation=observation_frame, policy=policy, device=device, @@ -158,17 +148,20 @@ def inference_loop_relative( robot_type=robot.robot_type, ) - # Unnormalize actions + # Unnormalize relative actions if normalizer exists if relative_normalizer is not None: - action_keys = [k for k in action_values.keys() if not k.startswith("task")] - action_tensor = torch.tensor([[action_values[k] for k in action_keys]]) - action_tensor = action_tensor.unsqueeze(1) - action_unnorm = relative_normalizer.unnormalize(action_tensor) - for i, k in enumerate(action_keys): - action_values[k] = action_unnorm[0, 0, i].item() + # action_tensor shape: [1, action_dim] or [action_dim] + if action_tensor.dim() == 1: + action_tensor = action_tensor.unsqueeze(0).unsqueeze(0) # [1, 1, action_dim] + elif action_tensor.dim() == 2: + action_tensor = action_tensor.unsqueeze(1) # [batch, 1, action_dim] + action_tensor = relative_normalizer.unnormalize(action_tensor) + action_tensor = action_tensor.squeeze(1) # back to [batch, action_dim] - # Convert to absolute - relative_action = make_robot_action(action_values, dataset.features) + # Convert tensor to dict + relative_action = tensor_to_robot_action(action_tensor, dataset.features) + + # Convert relative to absolute absolute_action = convert_from_relative_actions_dict(relative_action, current_pos) robot.send_action(absolute_action)