feat(umi): simplify to derive_state_from_action and cam0-only

- Remove fix_dataset.py (user fixes dataset at source)
- evaluate.py: replace observation.pose/joints with observation.state
  (8D, derived from action during training, from FK at inference)
- evaluate.py: remove cam1 — training uses only cam0
- docs: rewrite workflow around derive_state_from_action=true,
  updated recompute-stats and training commands with
  relative_exclude_joints for gripper dims

Made-with: Cursor
This commit is contained in:
Pepijn
2026-04-02 15:02:20 +02:00
parent e627d6442e
commit 8455efc474
2 changed files with 303 additions and 212 deletions
+179 -116
View File
@@ -15,29 +15,32 @@
# limitations under the License.
"""
Inference script for a pi0 model trained with **relative EE actions** on an OpenArm robot.
Inference script for a pi0 model trained with UMI-style relative EE actions
on an OpenArm robot (single right arm, one wrist camera).
Single right OpenArm follower with one wrist camera.
Training dataset layout:
observation.images.cam0 [3, 720, 960]
action [x, y, z, ax, ay, az, proximal, distal] (shape 8)
This uses the built-in ``DeriveStateFromActionStep`` (no-op at inference),
``RelativeActionsProcessorStep``, ``AbsoluteActionsProcessorStep``, and
``RelativeStateProcessorStep`` that are already wired into pi0's processor
pipeline.
The model uses ``derive_state_from_action=true``, so observation.state is
derived from the action column during training. At inference the state must
be provided by the robot — this script uses FK to compute the current EE
pose and gripper position, which it exposes as ``observation.state``.
The inference loop:
1. Reads joint positions from the robot (7-DOF arm + gripper).
2. Converts to EE pose via forward kinematics (FK).
This produces ``observation.state`` with the current EE pose.
3. The pi0 preprocessor:
a) ``DeriveStateFromActionStep`` — no-op (state comes from robot).
b) ``RelativeActionsProcessorStep`` caches the raw state.
c) ``RelativeStateProcessorStep`` buffers prev state, stacks, subtracts.
d) ``NormalizerProcessorStep`` normalizes state and actions.
4. pi0 predicts relative action chunk.
5. The pi0 postprocessor:
a) ``UnnormalizerProcessorStep`` unnormalizes.
b) ``AbsoluteActionsProcessorStep`` adds cached stateabsolute EE.
6. IK converts absolute EE → joint targets → robot.
Pipeline:
1. Read arm joints from robot → FK → observation.state [x,y,z,ax,ay,az,prox,dist]
2. Read camera image → observation.images.cam0
3. pi0 preprocessor (loaded from checkpoint):
- DeriveStateFromActionStep: no-op at inference (state from robot)
- RelativeActionsProcessorStep: caches current state
- RelativeStateProcessorStep: buffers prev state, stacks [prev,cur],
subtracts current → velocity info, flattens
- NormalizerProcessorStep: normalizes
4. pi0 predicts relative action chunk (30 steps)
5. pi0 postprocessor: unnormalize, add cached state → absolute EE
6. IK: absolute EE [x,y,z,ax,ay,az] → arm joint targets
7. Gripper [proximal, distal]gripper motor targets
8. Send to robot
Usage:
python evaluate.py
@@ -45,57 +48,177 @@ Usage:
from __future__ import annotations
import numpy as np
from scipy.spatial.transform import Rotation
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.feature_utils import combine_feature_dicts
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
from lerobot.processor import (
RelativeStateProcessorStep,
RobotProcessorPipeline,
make_default_teleop_action_processor,
)
from lerobot.processor.converters import (
observation_to_transition,
robot_action_observation_to_transition,
transition_to_observation,
transition_to_robot_action,
)
from lerobot.processor import RelativeStateProcessorStep
from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig
from lerobot.robots.so_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.scripts.lerobot_record import record_loop
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
FPS = 30
# ---------------------------------------------------------------------------
# Configuration — adapt these to your setup
# ---------------------------------------------------------------------------
FPS = 46
EPISODE_TIME_SEC = 60
TASK_DESCRIPTION = "red cube"
HF_MODEL_ID = "pepijn223/grabette-umi-pi0"
# Latency compensation: skip this many steps from the start of each predicted
# action chunk. Formula: ceil(total_latency_ms / (1000 / FPS)).
# E.g. at 10Hz with ~200ms total system latency: ceil(200 / 100) = 2.
# Latency compensation: skip this many predicted action steps to account for
# camera + inference + execution latency. Formula: ceil(total_ms / (1000/FPS)).
# At 46 FPS (~22ms/step) with ~150ms total latency: ceil(150/22) ≈ 7.
# Start with 0 for a safe first test, then increase to match measured latency.
LATENCY_SKIP_STEPS = 0
# EE feature keys produced by ForwardKinematicsJointsToEE (arm pose only).
# Gripper joints use absolute position control, not EE-relative.
EE_KEYS = ["x", "y", "z", "wx", "wy", "wz"]
URDF_PATH = "src/lerobot/robots/openarm_follower/urdf/openarm_bimanual_pybullet.urdf"
URDF_EE_FRAME = "openarm_right_link7"
URDF_EE_FRAME = "openarm_right_ee_target"
IK_POSITION_WEIGHT = 1.0
IK_ORIENTATION_WEIGHT = 1.0
# ---------------------------------------------------------------------------
# Dataset features for inference
#
# The training dataset has only observation.images.cam0 and action.
# observation.state is derived from action during training
# (derive_state_from_action=true) but must be supplied by the robot at
# inference. We define it here so build_dataset_frame can map FK output
# to the right feature.
# ---------------------------------------------------------------------------
DATASET_FEATURES: dict = {
"observation.state": {
"dtype": "float32",
"shape": [8],
"names": ["x", "y", "z", "ax", "ay", "az", "proximal", "distal"],
},
"observation.images.cam0": {
"dtype": "video",
"shape": [3, 720, 960],
"names": ["channels", "height", "width"],
"info": {
"video.height": 720,
"video.width": 960,
"video.codec": "h264",
"video.pix_fmt": "yuv420p",
"video.is_depth_map": False,
"video.fps": FPS,
"video.channels": 3,
"has_audio": False,
},
},
"action": {
"dtype": "float32",
"shape": [8],
"names": ["x", "y", "z", "ax", "ay", "az", "proximal", "distal"],
},
"timestamp": {"dtype": "float32", "shape": [1], "names": None},
"frame_index": {"dtype": "int64", "shape": [1], "names": None},
"episode_index": {"dtype": "int64", "shape": [1], "names": None},
"index": {"dtype": "int64", "shape": [1], "names": None},
"task_index": {"dtype": "int64", "shape": [1], "names": None},
}
# ---------------------------------------------------------------------------
# FK / IK callables
# ---------------------------------------------------------------------------
class JointsToEE:
"""FK: raw robot observation → flat dict matching observation.state names.
Arm joint positions → EE pose [x,y,z,ax,ay,az] via forward kinematics.
Gripper motor positions → [proximal, distal].
Camera images pass through unchanged.
"""
def __init__(self, kinematics: RobotKinematics, arm_motor_names: list[str]):
self.kin = kinematics
self.arm = arm_motor_names
def __call__(self, obs: RobotObservation) -> RobotObservation:
q = np.array([float(obs[f"{m}.pos"]) for m in self.arm])
t = self.kin.forward_kinematics(q)
rot = Rotation.from_matrix(t[:3, :3]).as_rotvec()
out: dict = {
"x": float(t[0, 3]),
"y": float(t[1, 3]),
"z": float(t[2, 3]),
"ax": float(rot[0]),
"ay": float(rot[1]),
"az": float(rot[2]),
"proximal": float(obs["proximal.pos"]),
"distal": float(obs["distal.pos"]),
}
for k, v in obs.items():
if not k.endswith((".pos", ".vel", ".torque")):
out[k] = v
return out
class EEToJoints:
"""IK: policy action dict → motor position dict for the robot.
Reads [x,y,z,ax,ay,az] from the action, runs IK for arm joint targets.
Passes [proximal, distal] as direct gripper position commands.
"""
def __init__(
self,
kinematics: RobotKinematics,
arm_motor_names: list[str],
position_weight: float = 1.0,
orientation_weight: float = 1.0,
):
self.kin = kinematics
self.arm = arm_motor_names
self.pw = position_weight
self.ow = orientation_weight
self.q_curr: np.ndarray | None = None
def __call__(self, args: tuple[RobotAction, RobotObservation]) -> RobotAction:
action, obs = args
q_raw = np.array([float(obs[f"{m}.pos"]) for m in self.arm])
if self.q_curr is None:
self.q_curr = q_raw
t_des = np.eye(4)
t_des[:3, :3] = Rotation.from_rotvec([action["ax"], action["ay"], action["az"]]).as_matrix()
t_des[:3, 3] = [action["x"], action["y"], action["z"]]
q_target = self.kin.inverse_kinematics(
self.q_curr, t_des, position_weight=self.pw, orientation_weight=self.ow
)
self.q_curr = q_target
out: dict = {f"{m}.pos": float(q_target[i]) for i, m in enumerate(self.arm)}
out["proximal.pos"] = float(action["proximal"])
out["distal.pos"] = float(action["distal"])
return out
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
camera_config = {"cam0": OpenCVCameraConfig(index_or_path=0, width=960, height=720, fps=FPS)}
camera_config = {
"cam0": OpenCVCameraConfig(index_or_path=0, width=960, height=720, fps=FPS),
}
robot_config = OpenArmFollowerConfig(
port="can0",
id="right_openarm",
@@ -110,79 +233,20 @@ def main():
policy.config.latency_skip_steps = LATENCY_SKIP_STEPS
arm_motor_names = list(robot.bus.motors.keys())
gripper_motor_names = list(robot.gripper_bus.motors.keys())
kinematics_solver = RobotKinematics(
kinematics = RobotKinematics(
urdf_path=URDF_PATH,
target_frame_name=URDF_EE_FRAME,
joint_names=arm_motor_names,
)
# The policy starts from the robot's current EE pose (via FK below).
# Relative actions are predicted as deltas from that pose, so no manual
# re-centering is needed — the starting point is always the live EE tip.
fk = JointsToEE(kinematics, arm_motor_names)
ik = EEToJoints(kinematics, arm_motor_names, IK_POSITION_WEIGHT, IK_ORIENTATION_WEIGHT)
# FK: joint observation → EE observation (produces observation.state).
# gripper_names=[] means proximal/distal pass through as absolute positions.
robot_joints_to_ee_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
steps=[
ForwardKinematicsJointsToEE(
kinematics=kinematics_solver,
motor_names=arm_motor_names,
gripper_names=[],
)
],
to_transition=observation_to_transition,
to_output=transition_to_observation,
)
# IK: EE action → joint targets. Gripper actions are absolute and pass through.
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
steps=[
InverseKinematicsEEToJoints(
kinematics=kinematics_solver,
motor_names=arm_motor_names,
gripper_names=[],
initial_guess_current_joints=True,
),
],
to_transition=robot_action_observation_to_transition,
to_output=transition_to_robot_action,
)
# OpenArm observations include .vel and .torque per motor; the EE policy
# pipeline only needs .pos (converted to EE by FK) and camera features.
obs_features = {
k: v
for k, v in robot.observation_features.items()
if not (k.endswith(".vel") or k.endswith(".torque"))
}
# A dataset object is needed for its .features and .meta.stats even when
# not recording — record_loop uses them for building observation/action frames.
dataset = LeRobotDataset.create(
repo_id="tmp/openarm_eval_scratch",
fps=FPS,
features=combine_feature_dicts(
aggregate_pipeline_dataset_features(
pipeline=robot_joints_to_ee_processor,
initial_features=create_initial_features(observation=obs_features),
use_videos=True,
),
aggregate_pipeline_dataset_features(
pipeline=make_default_teleop_action_processor(),
initial_features=create_initial_features(
action={
**{f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,)) for k in EE_KEYS},
**{
f"{g}.pos": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
for g in gripper_motor_names
},
}
),
use_videos=True,
),
),
features=DATASET_FEATURES,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
@@ -195,7 +259,7 @@ def main():
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
)
_relative_state_steps = [s for s in preprocessor.steps if isinstance(s, RelativeStateProcessorStep)]
relative_state_steps = [s for s in preprocessor.steps if isinstance(s, RelativeStateProcessorStep)]
robot.connect()
@@ -207,7 +271,7 @@ def main():
raise ValueError("Robot is not connected!")
log_say("Starting policy execution")
for step in _relative_state_steps:
for step in relative_state_steps:
step.reset()
record_loop(
@@ -221,9 +285,8 @@ def main():
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
teleop_action_processor=make_default_teleop_action_processor(),
robot_action_processor=robot_ee_to_joints_processor,
robot_observation_processor=robot_joints_to_ee_processor,
robot_action_processor=ik,
robot_observation_processor=fk,
)
finally:
robot.disconnect()