diff --git a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py index 67e9f9e1e..3874d711c 100644 --- a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py +++ b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py @@ -435,22 +435,23 @@ class ForwardKinematicsJointsToEE(ObservationProcessorStep): kinematics: RobotKinematics motor_names: list[str] - def observation(self, obs: dict) -> dict: - if not all(f"{OBS_STATE}.{n}.pos" in obs for n in self.motor_names): + def observation(self, observation: dict) -> dict: + print("observation in step", observation) + if not all(f"{n}.pos" in observation for n in self.motor_names): raise ValueError(f"Missing required joint positions for motors: {self.motor_names}") - q = np.array([obs[f"{OBS_STATE}.{n}.pos"] for n in self.motor_names], dtype=float) + q = np.array([observation[f"{n}.pos"] for n in self.motor_names], dtype=float) t = self.kinematics.forward_kinematics(q) pos = t[:3, 3] tw = Rotation.from_matrix(t[:3, :3]).as_rotvec() - obs[f"{OBS_STATE}.ee.x"] = float(pos[0]) - obs[f"{OBS_STATE}.ee.y"] = float(pos[1]) - obs[f"{OBS_STATE}.ee.z"] = float(pos[2]) - obs[f"{OBS_STATE}.ee.wx"] = float(tw[0]) - obs[f"{OBS_STATE}.ee.wy"] = float(tw[1]) - obs[f"{OBS_STATE}.ee.wz"] = float(tw[2]) - return obs + observation[f"{OBS_STATE}.ee.x"] = float(pos[0]) + observation[f"{OBS_STATE}.ee.y"] = float(pos[1]) + observation[f"{OBS_STATE}.ee.z"] = float(pos[2]) + observation[f"{OBS_STATE}.ee.wx"] = float(tw[0]) + observation[f"{OBS_STATE}.ee.wy"] = float(tw[1]) + observation[f"{OBS_STATE}.ee.wz"] = float(tw[2]) + return observation def transform_features( self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]