mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-24 03:37:12 +00:00
fix(robocasa): align env state/action order to openpi/robocasa convention
LeRobot's RoboCasaEnv used a divergent flat state/action layout vs the robocasa package (robocasa.utils.env_utils.convert_action) and the openpi robocasa pipeline. This scrambles I/O when using openpi-convention checkpoints (e.g. the JAX->PyTorch->LeRobot converted pi05 robocasa model: CloseFridge 20% -> 60% once both orders match openpi). - convert_action: ee_pos(3)+ee_rot(3)+gripper(1)+base_motion(4)+control_mode(1) - observation.state: ee_pos_rel(3)+ee_rot_rel(4)+base_pos(3)+base_rot(4)+gripper(2) Matches openpi examples/robocasa/main.py + RobocasaInputs ordering. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -33,8 +33,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Dimensions for the flat action/state vectors used by the LeRobot wrapper.
|
||||
# These correspond to the PandaOmron robot in RoboCasa365.
|
||||
OBS_STATE_DIM = 16 # base_pos(3) + base_quat(4) + ee_pos_rel(3) + ee_quat_rel(4) + gripper_qpos(2)
|
||||
ACTION_DIM = 12 # base_motion(4) + control_mode(1) + ee_pos(3) + ee_rot(3) + gripper(1)
|
||||
OBS_STATE_DIM = 16 # ee_pos_rel(3) + ee_quat_rel(4) + base_pos(3) + base_quat(4) + gripper_qpos(2)
|
||||
ACTION_DIM = 12 # ee_pos(3) + ee_rot(3) + gripper(1) + base_motion(4) + control_mode(1)
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
|
||||
@@ -101,14 +101,15 @@ def _resolve_tasks(task: str) -> tuple[list[str], str | None]:
|
||||
def convert_action(flat_action: np.ndarray) -> dict[str, Any]:
|
||||
"""Split a flat (12,) action vector into a RoboCasa action dict.
|
||||
|
||||
Layout: base_motion(4) + control_mode(1) + ee_pos(3) + ee_rot(3) + gripper(1)
|
||||
Layout (openpi / robocasa.utils.env_utils.convert_action order):
|
||||
ee_pos(3) + ee_rot(3) + gripper(1) + base_motion(4) + control_mode(1)
|
||||
"""
|
||||
return {
|
||||
"action.base_motion": flat_action[0:4],
|
||||
"action.control_mode": flat_action[4:5],
|
||||
"action.end_effector_position": flat_action[5:8],
|
||||
"action.end_effector_rotation": flat_action[8:11],
|
||||
"action.gripper_close": flat_action[11:12],
|
||||
"action.end_effector_position": flat_action[0:3],
|
||||
"action.end_effector_rotation": flat_action[3:6],
|
||||
"action.gripper_close": flat_action[6:7],
|
||||
"action.base_motion": flat_action[7:11],
|
||||
"action.control_mode": flat_action[11:12],
|
||||
}
|
||||
|
||||
|
||||
@@ -230,12 +231,14 @@ class RoboCasaEnv(gym.Env):
|
||||
return {"pixels": images}
|
||||
|
||||
# `state.*` keys come from PandaOmronKeyConverter inside the wrapper.
|
||||
# openpi state order: ee first, then base, then gripper (matches the
|
||||
# openpi robocasa pipeline / examples/robocasa/main.py state layout).
|
||||
agent_pos = np.concatenate(
|
||||
[
|
||||
raw_obs.get("state.base_position", np.zeros(3)),
|
||||
raw_obs.get("state.base_rotation", np.zeros(4)),
|
||||
raw_obs.get("state.end_effector_position_relative", np.zeros(3)),
|
||||
raw_obs.get("state.end_effector_rotation_relative", np.zeros(4)),
|
||||
raw_obs.get("state.base_position", np.zeros(3)),
|
||||
raw_obs.get("state.base_rotation", np.zeros(4)),
|
||||
raw_obs.get("state.gripper_qpos", np.zeros(2)),
|
||||
],
|
||||
axis=-1,
|
||||
|
||||
Reference in New Issue
Block a user