mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
fix return action
This commit is contained in:
@@ -26,8 +26,8 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.utils import make_robot_action as tensor_to_robot_action
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.processor.core import RobotAction
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
@@ -62,16 +62,6 @@ CAMERA_CONFIG = {
|
||||
}
|
||||
|
||||
|
||||
def make_robot_action(action_values: dict, features: dict) -> RobotAction:
|
||||
robot_action = {}
|
||||
for key in features:
|
||||
if key.startswith(ACTION + "."):
|
||||
action_key = key.removeprefix(ACTION + ".")
|
||||
if action_key in action_values:
|
||||
robot_action[action_key] = action_values[action_key]
|
||||
return robot_action
|
||||
|
||||
|
||||
def load_relative_config(model_path: Path | str) -> tuple[PerTimestepNormalizer | None, bool]:
|
||||
"""Load normalizer and relative_state setting from checkpoint."""
|
||||
model_path = Path(model_path) if isinstance(model_path, str) else model_path
|
||||
@@ -146,8 +136,8 @@ def inference_loop_relative(
|
||||
if isinstance(state_tensor, torch.Tensor):
|
||||
observation_frame[state_key] = convert_state_to_relative(state_tensor)
|
||||
|
||||
# Policy inference (outputs normalized relative actions)
|
||||
action_values = predict_action(
|
||||
# Policy inference (outputs action tensor)
|
||||
action_tensor = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=device,
|
||||
@@ -158,17 +148,20 @@ def inference_loop_relative(
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
|
||||
# Unnormalize actions
|
||||
# Unnormalize relative actions if normalizer exists
|
||||
if relative_normalizer is not None:
|
||||
action_keys = [k for k in action_values.keys() if not k.startswith("task")]
|
||||
action_tensor = torch.tensor([[action_values[k] for k in action_keys]])
|
||||
action_tensor = action_tensor.unsqueeze(1)
|
||||
action_unnorm = relative_normalizer.unnormalize(action_tensor)
|
||||
for i, k in enumerate(action_keys):
|
||||
action_values[k] = action_unnorm[0, 0, i].item()
|
||||
# action_tensor shape: [1, action_dim] or [action_dim]
|
||||
if action_tensor.dim() == 1:
|
||||
action_tensor = action_tensor.unsqueeze(0).unsqueeze(0) # [1, 1, action_dim]
|
||||
elif action_tensor.dim() == 2:
|
||||
action_tensor = action_tensor.unsqueeze(1) # [batch, 1, action_dim]
|
||||
action_tensor = relative_normalizer.unnormalize(action_tensor)
|
||||
action_tensor = action_tensor.squeeze(1) # back to [batch, action_dim]
|
||||
|
||||
# Convert to absolute
|
||||
relative_action = make_robot_action(action_values, dataset.features)
|
||||
# Convert tensor to dict
|
||||
relative_action = tensor_to_robot_action(action_tensor, dataset.features)
|
||||
|
||||
# Convert relative to absolute
|
||||
absolute_action = convert_from_relative_actions_dict(relative_action, current_pos)
|
||||
|
||||
robot.send_action(absolute_action)
|
||||
|
||||
Reference in New Issue
Block a user