fix(robocasa): align env state/action order to openpi/robocasa convention

LeRobot's RoboCasaEnv used a divergent flat state/action layout vs the
robocasa package (robocasa.utils.env_utils.convert_action) and the openpi
robocasa pipeline. This scrambles I/O when using openpi-convention checkpoints
(e.g. the JAX->PyTorch->LeRobot converted pi05 robocasa model: CloseFridge
20% -> 60% once both orders match openpi).

- convert_action: ee_pos(3)+ee_rot(3)+gripper(1)+base_motion(4)+control_mode(1)
- observation.state: ee_pos_rel(3)+ee_rot_rel(4)+base_pos(3)+base_rot(4)+gripper(2)

Matches openpi examples/robocasa/main.py + RobocasaInputs ordering.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn223
2026-06-05 13:47:43 +02:00
parent de7ba67556
commit aca02ff24c
+13 -10
View File
@@ -33,8 +33,8 @@ logger = logging.getLogger(__name__)
# Dimensions for the flat action/state vectors used by the LeRobot wrapper.
# These correspond to the PandaOmron robot in RoboCasa365.
OBS_STATE_DIM = 16 # base_pos(3) + base_quat(4) + ee_pos_rel(3) + ee_quat_rel(4) + gripper_qpos(2)
ACTION_DIM = 12 # base_motion(4) + control_mode(1) + ee_pos(3) + ee_rot(3) + gripper(1)
OBS_STATE_DIM = 16 # ee_pos_rel(3) + ee_quat_rel(4) + base_pos(3) + base_quat(4) + gripper_qpos(2)
ACTION_DIM = 12 # ee_pos(3) + ee_rot(3) + gripper(1) + base_motion(4) + control_mode(1)
ACTION_LOW = -1.0
ACTION_HIGH = 1.0
@@ -101,14 +101,15 @@ def _resolve_tasks(task: str) -> tuple[list[str], str | None]:
def convert_action(flat_action: np.ndarray) -> dict[str, Any]:
"""Split a flat (12,) action vector into a RoboCasa action dict.
Layout: base_motion(4) + control_mode(1) + ee_pos(3) + ee_rot(3) + gripper(1)
Layout (openpi / robocasa.utils.env_utils.convert_action order):
ee_pos(3) + ee_rot(3) + gripper(1) + base_motion(4) + control_mode(1)
"""
return {
"action.base_motion": flat_action[0:4],
"action.control_mode": flat_action[4:5],
"action.end_effector_position": flat_action[5:8],
"action.end_effector_rotation": flat_action[8:11],
"action.gripper_close": flat_action[11:12],
"action.end_effector_position": flat_action[0:3],
"action.end_effector_rotation": flat_action[3:6],
"action.gripper_close": flat_action[6:7],
"action.base_motion": flat_action[7:11],
"action.control_mode": flat_action[11:12],
}
@@ -230,12 +231,14 @@ class RoboCasaEnv(gym.Env):
return {"pixels": images}
# `state.*` keys come from PandaOmronKeyConverter inside the wrapper.
# openpi state order: ee first, then base, then gripper (matches the
# openpi robocasa pipeline / examples/robocasa/main.py state layout).
agent_pos = np.concatenate(
[
raw_obs.get("state.base_position", np.zeros(3)),
raw_obs.get("state.base_rotation", np.zeros(4)),
raw_obs.get("state.end_effector_position_relative", np.zeros(3)),
raw_obs.get("state.end_effector_rotation_relative", np.zeros(4)),
raw_obs.get("state.base_position", np.zeros(3)),
raw_obs.get("state.base_rotation", np.zeros(4)),
raw_obs.get("state.gripper_qpos", np.zeros(2)),
],
axis=-1,