use observation instead of obs

This commit is contained in:
Pepijn
2025-09-12 12:18:00 +02:00
parent 2005a28a00
commit 6bdcd460e0
@@ -435,22 +435,23 @@ class ForwardKinematicsJointsToEE(ObservationProcessorStep):
kinematics: RobotKinematics
motor_names: list[str]
def observation(self, obs: dict) -> dict:
if not all(f"{OBS_STATE}.{n}.pos" in obs for n in self.motor_names):
def observation(self, observation: dict) -> dict:
print("observation in step", observation)
if not all(f"{n}.pos" in observation for n in self.motor_names):
raise ValueError(f"Missing required joint positions for motors: {self.motor_names}")
q = np.array([obs[f"{OBS_STATE}.{n}.pos"] for n in self.motor_names], dtype=float)
q = np.array([observation[f"{n}.pos"] for n in self.motor_names], dtype=float)
t = self.kinematics.forward_kinematics(q)
pos = t[:3, 3]
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
obs[f"{OBS_STATE}.ee.x"] = float(pos[0])
obs[f"{OBS_STATE}.ee.y"] = float(pos[1])
obs[f"{OBS_STATE}.ee.z"] = float(pos[2])
obs[f"{OBS_STATE}.ee.wx"] = float(tw[0])
obs[f"{OBS_STATE}.ee.wy"] = float(tw[1])
obs[f"{OBS_STATE}.ee.wz"] = float(tw[2])
return obs
observation[f"{OBS_STATE}.ee.x"] = float(pos[0])
observation[f"{OBS_STATE}.ee.y"] = float(pos[1])
observation[f"{OBS_STATE}.ee.z"] = float(pos[2])
observation[f"{OBS_STATE}.ee.wx"] = float(tw[0])
observation[f"{OBS_STATE}.ee.wy"] = float(tw[1])
observation[f"{OBS_STATE}.ee.wz"] = float(tw[2])
return observation
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]