diff --git a/src/lerobot/constants.py b/src/lerobot/constants.py index 98e1813c4..683c5ff0e 100644 --- a/src/lerobot/constants.py +++ b/src/lerobot/constants.py @@ -24,6 +24,11 @@ OBS_IMAGES = "observation.images" OBS_LANGUAGE = "observation.language" ACTION = "action" REWARD = "next.reward" +TRUNCATED = "next.truncated" +DONE = "next.done" + +OBS_LANGUAGE_TOKENS = "observation.language.tokens" +OBS_LANGUAGE_ATTENTION_MASK = "observation.language.attention_mask" ROBOTS = "robots" ROBOT_TYPE = "robot_type" diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index fef75b407..c22c6264f 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -15,6 +15,7 @@ from collections.abc import Sequence from typing import Any +from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.datasets.utils import hw_to_dataset_features from lerobot.processor.pipeline import RobotProcessor @@ -59,26 +60,26 @@ def aggregate_pipeline_dataset_features( # Go over every feature from the pipeline and merge: for full_key, ty in all_features.items(): - if full_key.startswith("action."): + if full_key.startswith(f"{ACTION}."): # action. if not keep(full_key): continue - name = full_key[len("action.") :] - hw.setdefault("action", {})[name] = ty + name = full_key[len(f"{ACTION}.") :] + hw.setdefault(ACTION, {})[name] = ty - elif full_key.startswith("observation.state."): + elif full_key.startswith(f"{OBS_STATE}."): # observation.state. if not keep(full_key): continue - name = full_key[len("observation.state.") :] + name = full_key[len(f"{OBS_STATE}.") :] hw.setdefault("observation", {})[name] = ty - elif full_key.startswith("observation.images."): + elif full_key.startswith(f"{OBS_IMAGES}."): # observation.images. # images obey ONLY the use_videos flag, not patterns if not use_videos: continue - name = full_key[len("observation.images.") :] + name = full_key[len(f"{OBS_IMAGES}.") :] hw.setdefault("observation", {})[name] = ty else: @@ -86,8 +87,8 @@ def aggregate_pipeline_dataset_features( continue out: dict[str, dict] = {} - if "action" in hw: - out.update(hw_to_dataset_features(hw["action"], "action", use_videos)) + if ACTION in hw: + out.update(hw_to_dataset_features(hw[ACTION], ACTION, use_videos)) if "observation" in hw: out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos)) diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 3a8f8b109..c9218a650 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -24,6 +24,8 @@ import numpy as np import torch from scipy.spatial.transform import Rotation +from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED + from .pipeline import EnvTransition, TransitionKey @@ -82,11 +84,11 @@ def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition: for k, v in action.items(): # Check if the value is a type that should not be converted to a tensor. if isinstance(v, (Rotation, dict)): - act_dict[f"action.{k}"] = v + act_dict[f"{ACTION}.{k}"] = v continue arr = np.array(v) if np.isscalar(v) else v - act_dict[f"action.{k}"] = _to_tensor(arr) + act_dict[f"{ACTION}.{k}"] = _to_tensor(arr) return make_obs_act_transition(act=act_dict) @@ -101,10 +103,10 @@ def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransitio obs_dict: dict[str, Any] = {} for k, v in state.items(): arr = np.array(v) if np.isscalar(v) else v - obs_dict[f"observation.state.{k}"] = _to_tensor(arr) + obs_dict[f"{OBS_STATE}.{k}"] = _to_tensor(arr) for cam, img in images.items(): - obs_dict[f"observation.images.{cam}"] = img + obs_dict[f"{OBS_IMAGES}.{cam}"] = img return make_obs_act_transition(obs=obs_dict) @@ -120,8 +122,8 @@ def to_output_robot_action(transition: EnvTransition) -> dict[str, Any]: return out for k, v in action_dict.items(): - if isinstance(k, str) and k.startswith("action.") and k.endswith((".pos", ".vel")): - out_key = k[len("action.") :] # Strip the 'action.' prefix. + if isinstance(k, str) and k.startswith(f"{ACTION}.") and k.endswith((".pos", ".vel")): + out_key = k[len(f"{ACTION}.") :] # Strip the 'action.' prefix. out[out_key] = float(v) return out @@ -152,9 +154,9 @@ def to_dataset_frame( - info dict - *_is_pad flags and task from complementary_data """ - action_names = features.get("action", {}).get("names", []) - obs_state_names = features.get("observation.state", {}).get("names", []) - image_keys = [k for k in features if k.startswith("observation.images.")] + action_names = features.get(ACTION, {}).get("names", []) + obs_state_names = features.get(OBS_STATE, {}).get("names", []) + image_keys = [k for k in features if k.startswith(OBS_IMAGES)] def _merge(base: EnvTransition, other: EnvTransition) -> EnvTransition: out = deepcopy(base) @@ -198,21 +200,20 @@ def to_dataset_frame( # Observation.state vector if obs_state_names: - vals = [_from_tensor(obs.get(f"observation.state.{n}", 0.0)) for n in obs_state_names] - batch["observation.state"] = np.asarray(vals, dtype=np.float32) + vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names] + batch[OBS_STATE] = np.asarray(vals, dtype=np.float32) # Action vector if action_names: - vals = [_from_tensor(act.get(f"action.{n}", 0.0)) for n in action_names] - batch["action"] = np.asarray(vals, dtype=np.float32) + vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names] + batch[ACTION] = np.asarray(vals, dtype=np.float32) - # Next.* fields if tr.get(TransitionKey.REWARD) is not None: - batch["next.reward"] = _from_tensor(tr[TransitionKey.REWARD]) + batch[REWARD] = _from_tensor(tr[TransitionKey.REWARD]) if tr.get(TransitionKey.DONE) is not None: - batch["next.done"] = _from_tensor(tr[TransitionKey.DONE]) + batch[DONE] = _from_tensor(tr[TransitionKey.DONE]) if tr.get(TransitionKey.TRUNCATED) is not None: - batch["next.truncated"] = _from_tensor(tr[TransitionKey.TRUNCATED]) + batch[TRUNCATED] = _from_tensor(tr[TransitionKey.TRUNCATED]) # Complementary data flags and task comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {} diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index c1f4569ed..c75e40fff 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -8,6 +8,7 @@ import torch import torchvision.transforms.functional as F # noqa: N812 from lerobot.configs.types import PolicyFeature +from lerobot.constants import ACTION from lerobot.processor.pipeline import ( ComplementaryDataProcessor, EnvTransition, @@ -22,6 +23,8 @@ from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents GRIPPER_KEY = "gripper" +DISCRETE_PENALTY_KEY = "discrete_penalty" +TELEOP_ACTION_KEY = "teleop_action" @ProcessorStepRegistry.register("add_teleop_action_as_complementary_data") @@ -33,7 +36,7 @@ class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor): def complementary_data(self, complementary_data: dict) -> dict: new_complementary_data = dict(complementary_data) - new_complementary_data["teleop_action"] = self.teleop_device.get_action() + new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action() return new_complementary_data @@ -141,7 +144,7 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor): if current_gripper_pos is None: return complementary_data - gripper_action = action[f"action.{GRIPPER_KEY}.pos"] + gripper_action = action[f"{ACTION}.{GRIPPER_KEY}.pos"] gripper_action_normalized = gripper_action / self.max_gripper_pos # Normalize gripper state and action @@ -156,7 +159,7 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor): # Create new complementary data with penalty info new_complementary_data = dict(complementary_data) - new_complementary_data["discrete_penalty"] = gripper_penalty + new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty return new_complementary_data @@ -187,7 +190,7 @@ class InterventionActionProcessor(ProcessorStep): # Get intervention signals from complementary data info = transition.get(TransitionKey.INFO, {}) complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - teleop_action = complementary_data.get("teleop_action", {}) + teleop_action = complementary_data.get(TELEOP_ACTION_KEY, {}) is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False) terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False) success = info.get(TeleopEvents.SUCCESS, False) @@ -200,12 +203,12 @@ class InterventionActionProcessor(ProcessorStep): if isinstance(teleop_action, dict): # Convert teleop_action dict to tensor format action_list = [ - teleop_action.get("action.delta_x", 0.0), - teleop_action.get("action.delta_y", 0.0), - teleop_action.get("action.delta_z", 0.0), + teleop_action.get(f"{ACTION}.delta_x", 0.0), + teleop_action.get(f"{ACTION}.delta_y", 0.0), + teleop_action.get(f"{ACTION}.delta_z", 0.0), ] if self.use_gripper: - action_list.append(teleop_action.get("gripper", 1.0)) + action_list.append(teleop_action.get(GRIPPER_KEY, 1.0)) elif isinstance(teleop_action, np.ndarray): action_list = teleop_action.tolist() else: @@ -229,7 +232,7 @@ class InterventionActionProcessor(ProcessorStep): # Update complementary data with teleop action complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) - complementary_data["teleop_action"] = new_transition.get(TransitionKey.ACTION) + complementary_data[TELEOP_ACTION_KEY] = new_transition.get(TransitionKey.ACTION) new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data return new_transition diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index d2c04e44c..6a6698a38 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any import torch from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.constants import OBS_LANGUAGE +from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS from lerobot.processor.pipeline import ( EnvTransition, ObservationProcessor, @@ -156,10 +156,8 @@ class TokenizerProcessor(ObservationProcessor): new_observation = dict(observation) # Add tokenized data to observation - new_observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"] - new_observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to( - dtype=torch.bool - ) + new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"] + new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool) return new_observation @@ -239,13 +237,13 @@ class TokenizerProcessor(ObservationProcessor): """ # Add features for tokenized output if they don't exist # Standard tokenizer output includes tokens and attention_mask - tokens_key = f"{OBS_LANGUAGE}.tokens" - attention_mask_key = f"{OBS_LANGUAGE}.attention_mask" - if tokens_key not in features: - features[tokens_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,)) + if OBS_LANGUAGE_TOKENS not in features: + features[OBS_LANGUAGE_TOKENS] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,)) - if attention_mask_key not in features: - features[attention_mask_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,)) + if OBS_LANGUAGE_ATTENTION_MASK not in features: + features[OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_length,) + ) return features diff --git a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py index 7c6c73a4d..39bab604f 100644 --- a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py +++ b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py @@ -20,6 +20,7 @@ import numpy as np from scipy.spatial.transform import Rotation from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.constants import ACTION, OBS_STATE from lerobot.model.kinematics import RobotKinematics from lerobot.processor.pipeline import ( ActionProcessor, @@ -81,13 +82,13 @@ class EEReferenceAndDelta(ActionProcessor): # Current pose from FK on measured joints t_curr = self.kinematics.forward_kinematics(q) - enabled = bool(new_action.pop("action.enabled", 0)) - tx = float(new_action.pop("action.target_x", 0.0)) - ty = float(new_action.pop("action.target_y", 0.0)) - tz = float(new_action.pop("action.target_z", 0.0)) - wx = float(new_action.pop("action.target_wx", 0.0)) - wy = float(new_action.pop("action.target_wy", 0.0)) - wz = float(new_action.pop("action.target_wz", 0.0)) + enabled = bool(new_action.pop(f"{ACTION}.enabled", 0)) + tx = float(new_action.pop(f"{ACTION}.target_x", 0.0)) + ty = float(new_action.pop(f"{ACTION}.target_y", 0.0)) + tz = float(new_action.pop(f"{ACTION}.target_z", 0.0)) + wx = float(new_action.pop(f"{ACTION}.target_wx", 0.0)) + wy = float(new_action.pop(f"{ACTION}.target_wy", 0.0)) + wz = float(new_action.pop(f"{ACTION}.target_wz", 0.0)) desired = None @@ -123,12 +124,12 @@ class EEReferenceAndDelta(ActionProcessor): # Write action fields pos = desired[:3, 3] tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec() - new_action["action.ee.x"] = float(pos[0]) - new_action["action.ee.y"] = float(pos[1]) - new_action["action.ee.z"] = float(pos[2]) - new_action["action.ee.wx"] = float(tw[0]) - new_action["action.ee.wy"] = float(tw[1]) - new_action["action.ee.wz"] = float(tw[2]) + new_action[f"{ACTION}.ee.x"] = float(pos[0]) + new_action[f"{ACTION}.ee.y"] = float(pos[1]) + new_action[f"{ACTION}.ee.z"] = float(pos[2]) + new_action[f"{ACTION}.ee.wx"] = float(tw[0]) + new_action[f"{ACTION}.ee.wy"] = float(tw[1]) + new_action[f"{ACTION}.ee.wz"] = float(tw[2]) self._prev_enabled = enabled return new_action @@ -139,20 +140,20 @@ class EEReferenceAndDelta(ActionProcessor): self._command_when_disabled = None def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - features.pop("action.enabled", None) - features.pop("action.target_x", None) - features.pop("action.target_y", None) - features.pop("action.target_z", None) - features.pop("action.target_wx", None) - features.pop("action.target_wy", None) - features.pop("action.target_wz", None) + features.pop(f"{ACTION}.enabled", None) + features.pop(f"{ACTION}.target_x", None) + features.pop(f"{ACTION}.target_y", None) + features.pop(f"{ACTION}.target_z", None) + features.pop(f"{ACTION}.target_wx", None) + features.pop(f"{ACTION}.target_wy", None) + features.pop(f"{ACTION}.target_wz", None) - features["action.ee.x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) - features["action.ee.y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) - features["action.ee.z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) - features["action.ee.wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) - features["action.ee.wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) - features["action.ee.wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features[f"{ACTION}.ee.x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features[f"{ACTION}.ee.y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features[f"{ACTION}.ee.z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features[f"{ACTION}.ee.wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features[f"{ACTION}.ee.wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features[f"{ACTION}.ee.wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) return features @@ -180,12 +181,12 @@ class EEBoundsAndSafety(ActionProcessor): _last_twist: np.ndarray | None = field(default=None, init=False, repr=False) def action(self, act: dict) -> dict: - x = act.get("action.ee.x", None) - y = act.get("action.ee.y", None) - z = act.get("action.ee.z", None) - wx = act.get("action.ee.wx", None) - wy = act.get("action.ee.wy", None) - wz = act.get("action.ee.wz", None) + x = act.get(f"{ACTION}.ee.x", None) + y = act.get(f"{ACTION}.ee.y", None) + z = act.get(f"{ACTION}.ee.z", None) + wx = act.get(f"{ACTION}.ee.wx", None) + wy = act.get(f"{ACTION}.ee.wy", None) + wz = act.get(f"{ACTION}.ee.wz", None) if None in (x, y, z, wx, wy, wz): return act @@ -207,12 +208,12 @@ class EEBoundsAndSafety(ActionProcessor): self._last_pos = pos self._last_twist = twist - act["action.ee.x"] = float(pos[0]) - act["action.ee.y"] = float(pos[1]) - act["action.ee.z"] = float(pos[2]) - act["action.ee.wx"] = float(twist[0]) - act["action.ee.wy"] = float(twist[1]) - act["action.ee.wz"] = float(twist[2]) + act[f"{ACTION}.ee.x"] = float(pos[0]) + act[f"{ACTION}.ee.y"] = float(pos[1]) + act[f"{ACTION}.ee.z"] = float(pos[2]) + act[f"{ACTION}.ee.wx"] = float(twist[0]) + act[f"{ACTION}.ee.wy"] = float(twist[1]) + act[f"{ACTION}.ee.wz"] = float(twist[2]) return act def reset(self): @@ -250,12 +251,12 @@ class InverseKinematicsEEToJoints(ProcessorStep): act = transition.get(TransitionKey.ACTION) or {} comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} - x = act.get("action.ee.x", None) - y = act.get("action.ee.y", None) - z = act.get("action.ee.z", None) - wx = act.get("action.ee.wx", None) - wy = act.get("action.ee.wy", None) - wz = act.get("action.ee.wz", None) + x = act.get(f"{ACTION}.ee.x", None) + y = act.get(f"{ACTION}.ee.y", None) + z = act.get(f"{ACTION}.ee.z", None) + wx = act.get(f"{ACTION}.ee.wx", None) + wy = act.get(f"{ACTION}.ee.wy", None) + wz = act.get(f"{ACTION}.ee.wz", None) if None in (x, y, z, wx, wy, wz): return transition @@ -285,19 +286,19 @@ class InverseKinematicsEEToJoints(ProcessorStep): new_act = dict(act) for i, name in enumerate(self.motor_names): if name == "gripper": - new_act["observation.state.gripper.pos"] = float(raw["gripper"]) + new_act[f"{OBS_STATE}.gripper.pos"] = float(raw["gripper"]) else: - new_act[f"action.{name}.pos"] = float(q_target[i]) + new_act[f"{ACTION}.{name}.pos"] = float(q_target[i]) transition[TransitionKey.ACTION] = new_act if not self.initial_guess_current_joints: transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target return transition def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - features["observation.state.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) - features["action.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features[f"{OBS_STATE}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) for name in self.motor_names: - features[f"action.{name}.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features[f"{ACTION}.{name}.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) return features @@ -333,12 +334,12 @@ class GripperVelocityToJoint(ProcessorStep): act = transition.get(TransitionKey.ACTION) or {} comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} - if "action.gripper" not in act: + if f"{ACTION}.gripper" not in act: return transition if "gripper" not in self.motor_names: new_act = dict(act) - new_act.pop("action.gripper", None) + new_act.pop(f"{ACTION}.gripper", None) transition[TransitionKey.ACTION] = new_act return transition @@ -346,32 +347,32 @@ class GripperVelocityToJoint(ProcessorStep): # Discrete gripper actions are in [0, 1, 2] # 0: open, 1: close, 2: stay # We need to shift them to [-1, 0, 1] and then scale them to clip_max - gripper_action = act.get("action.gripper", 1.0) + gripper_action = act.get(f"{ACTION}.gripper", 1.0) gripper_action = gripper_action - 1.0 gripper_action *= self.clip_max - act["action.gripper"] = gripper_action + act[f"{ACTION}.gripper"] = gripper_action # Get current gripper position from complementary data raw = comp.get("raw_joint_positions") or {} curr_pos = float(raw.get("gripper")) # Compute desired gripper velocity - u = float(act.get("action.gripper", 0.0)) + u = float(act.get(f"{ACTION}.gripper", 0.0)) delta = u * float(self.speed_factor) gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max)) new_act = dict(act) - new_act["action.gripper.pos"] = gripper_pos - new_act.pop("action.gripper", None) + new_act[f"{ACTION}.gripper.pos"] = gripper_pos + new_act.pop(f"{ACTION}.gripper", None) transition[TransitionKey.ACTION] = new_act - obs["observation.state.gripper.pos"] = curr_pos + obs[f"{OBS_STATE}.gripper.pos"] = curr_pos transition[TransitionKey.OBSERVATION] = obs return transition def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - features.pop("action.gripper", None) - features["action.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features.pop(f"{ACTION}.gripper", None) + features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) return features @@ -396,26 +397,26 @@ class ForwardKinematicsJointsToEE(ObservationProcessor): motor_names: list[str] def observation(self, obs: dict) -> dict: - if not all(f"observation.state.{n}.pos" in obs for n in self.motor_names): + if not all(f"{OBS_STATE}.{n}.pos" in obs for n in self.motor_names): return obs - q = np.array([obs[f"observation.state.{n}.pos"] for n in self.motor_names], dtype=float) + q = np.array([obs[f"{OBS_STATE}.{n}.pos"] for n in self.motor_names], dtype=float) t = self.kinematics.forward_kinematics(q) pos = t[:3, 3] tw = Rotation.from_matrix(t[:3, :3]).as_rotvec() - obs["observation.state.ee.x"] = float(pos[0]) - obs["observation.state.ee.y"] = float(pos[1]) - obs["observation.state.ee.z"] = float(pos[2]) - obs["observation.state.ee.wx"] = float(tw[0]) - obs["observation.state.ee.wy"] = float(tw[1]) - obs["observation.state.ee.wz"] = float(tw[2]) + obs[f"{OBS_STATE}.ee.x"] = float(pos[0]) + obs[f"{OBS_STATE}.ee.y"] = float(pos[1]) + obs[f"{OBS_STATE}.ee.z"] = float(pos[2]) + obs[f"{OBS_STATE}.ee.wx"] = float(tw[0]) + obs[f"{OBS_STATE}.ee.wy"] = float(tw[1]) + obs[f"{OBS_STATE}.ee.wz"] = float(tw[2]) return obs def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: # We specify the dataset features of this step that we want to be stored in the dataset for k in ["x", "y", "z", "wx", "wy", "wz"]: - features[f"observation.state.ee.{k}"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) + features[f"{OBS_STATE}.ee.{k}"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) return features