refactor(constants, processor): standardize action and observation keys across multiple files (#1808)

- Added new constants for truncated and done states in constants.py.
- Updated references to action and observation keys in pipeline_features.py, converters.py, hil_processor.py, tokenizer_processor.py, and robot_kinematic_processor.py to use the new constants for improved readability and maintainability.
This commit is contained in:
Adil Zouitine
2025-08-31 22:53:13 +02:00
committed by GitHub
parent 574a708950
commit 08fb310eaa
6 changed files with 123 additions and 114 deletions
+5
View File
@@ -24,6 +24,11 @@ OBS_IMAGES = "observation.images"
OBS_LANGUAGE = "observation.language" OBS_LANGUAGE = "observation.language"
ACTION = "action" ACTION = "action"
REWARD = "next.reward" REWARD = "next.reward"
TRUNCATED = "next.truncated"
DONE = "next.done"
OBS_LANGUAGE_TOKENS = "observation.language.tokens"
OBS_LANGUAGE_ATTENTION_MASK = "observation.language.attention_mask"
ROBOTS = "robots" ROBOTS = "robots"
ROBOT_TYPE = "robot_type" ROBOT_TYPE = "robot_type"
+10 -9
View File
@@ -15,6 +15,7 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any from typing import Any
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.datasets.utils import hw_to_dataset_features from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.processor.pipeline import RobotProcessor from lerobot.processor.pipeline import RobotProcessor
@@ -59,26 +60,26 @@ def aggregate_pipeline_dataset_features(
# Go over every feature from the pipeline and merge: # Go over every feature from the pipeline and merge:
for full_key, ty in all_features.items(): for full_key, ty in all_features.items():
if full_key.startswith("action."): if full_key.startswith(f"{ACTION}."):
# action.<feat> # action.<feat>
if not keep(full_key): if not keep(full_key):
continue continue
name = full_key[len("action.") :] name = full_key[len(f"{ACTION}.") :]
hw.setdefault("action", {})[name] = ty hw.setdefault(ACTION, {})[name] = ty
elif full_key.startswith("observation.state."): elif full_key.startswith(f"{OBS_STATE}."):
# observation.state.<feat> # observation.state.<feat>
if not keep(full_key): if not keep(full_key):
continue continue
name = full_key[len("observation.state.") :] name = full_key[len(f"{OBS_STATE}.") :]
hw.setdefault("observation", {})[name] = ty hw.setdefault("observation", {})[name] = ty
elif full_key.startswith("observation.images."): elif full_key.startswith(f"{OBS_IMAGES}."):
# observation.images.<cam> # observation.images.<cam>
# images obey ONLY the use_videos flag, not patterns # images obey ONLY the use_videos flag, not patterns
if not use_videos: if not use_videos:
continue continue
name = full_key[len("observation.images.") :] name = full_key[len(f"{OBS_IMAGES}.") :]
hw.setdefault("observation", {})[name] = ty hw.setdefault("observation", {})[name] = ty
else: else:
@@ -86,8 +87,8 @@ def aggregate_pipeline_dataset_features(
continue continue
out: dict[str, dict] = {} out: dict[str, dict] = {}
if "action" in hw: if ACTION in hw:
out.update(hw_to_dataset_features(hw["action"], "action", use_videos)) out.update(hw_to_dataset_features(hw[ACTION], ACTION, use_videos))
if "observation" in hw: if "observation" in hw:
out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos)) out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos))
+18 -17
View File
@@ -24,6 +24,8 @@ import numpy as np
import torch import torch
from scipy.spatial.transform import Rotation from scipy.spatial.transform import Rotation
from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
from .pipeline import EnvTransition, TransitionKey from .pipeline import EnvTransition, TransitionKey
@@ -82,11 +84,11 @@ def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition:
for k, v in action.items(): for k, v in action.items():
# Check if the value is a type that should not be converted to a tensor. # Check if the value is a type that should not be converted to a tensor.
if isinstance(v, (Rotation, dict)): if isinstance(v, (Rotation, dict)):
act_dict[f"action.{k}"] = v act_dict[f"{ACTION}.{k}"] = v
continue continue
arr = np.array(v) if np.isscalar(v) else v arr = np.array(v) if np.isscalar(v) else v
act_dict[f"action.{k}"] = _to_tensor(arr) act_dict[f"{ACTION}.{k}"] = _to_tensor(arr)
return make_obs_act_transition(act=act_dict) return make_obs_act_transition(act=act_dict)
@@ -101,10 +103,10 @@ def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransitio
obs_dict: dict[str, Any] = {} obs_dict: dict[str, Any] = {}
for k, v in state.items(): for k, v in state.items():
arr = np.array(v) if np.isscalar(v) else v arr = np.array(v) if np.isscalar(v) else v
obs_dict[f"observation.state.{k}"] = _to_tensor(arr) obs_dict[f"{OBS_STATE}.{k}"] = _to_tensor(arr)
for cam, img in images.items(): for cam, img in images.items():
obs_dict[f"observation.images.{cam}"] = img obs_dict[f"{OBS_IMAGES}.{cam}"] = img
return make_obs_act_transition(obs=obs_dict) return make_obs_act_transition(obs=obs_dict)
@@ -120,8 +122,8 @@ def to_output_robot_action(transition: EnvTransition) -> dict[str, Any]:
return out return out
for k, v in action_dict.items(): for k, v in action_dict.items():
if isinstance(k, str) and k.startswith("action.") and k.endswith((".pos", ".vel")): if isinstance(k, str) and k.startswith(f"{ACTION}.") and k.endswith((".pos", ".vel")):
out_key = k[len("action.") :] # Strip the 'action.' prefix. out_key = k[len(f"{ACTION}.") :] # Strip the 'action.' prefix.
out[out_key] = float(v) out[out_key] = float(v)
return out return out
@@ -152,9 +154,9 @@ def to_dataset_frame(
- info dict - info dict
- *_is_pad flags and task from complementary_data - *_is_pad flags and task from complementary_data
""" """
action_names = features.get("action", {}).get("names", []) action_names = features.get(ACTION, {}).get("names", [])
obs_state_names = features.get("observation.state", {}).get("names", []) obs_state_names = features.get(OBS_STATE, {}).get("names", [])
image_keys = [k for k in features if k.startswith("observation.images.")] image_keys = [k for k in features if k.startswith(OBS_IMAGES)]
def _merge(base: EnvTransition, other: EnvTransition) -> EnvTransition: def _merge(base: EnvTransition, other: EnvTransition) -> EnvTransition:
out = deepcopy(base) out = deepcopy(base)
@@ -198,21 +200,20 @@ def to_dataset_frame(
# Observation.state vector # Observation.state vector
if obs_state_names: if obs_state_names:
vals = [_from_tensor(obs.get(f"observation.state.{n}", 0.0)) for n in obs_state_names] vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
batch["observation.state"] = np.asarray(vals, dtype=np.float32) batch[OBS_STATE] = np.asarray(vals, dtype=np.float32)
# Action vector # Action vector
if action_names: if action_names:
vals = [_from_tensor(act.get(f"action.{n}", 0.0)) for n in action_names] vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
batch["action"] = np.asarray(vals, dtype=np.float32) batch[ACTION] = np.asarray(vals, dtype=np.float32)
# Next.* fields
if tr.get(TransitionKey.REWARD) is not None: if tr.get(TransitionKey.REWARD) is not None:
batch["next.reward"] = _from_tensor(tr[TransitionKey.REWARD]) batch[REWARD] = _from_tensor(tr[TransitionKey.REWARD])
if tr.get(TransitionKey.DONE) is not None: if tr.get(TransitionKey.DONE) is not None:
batch["next.done"] = _from_tensor(tr[TransitionKey.DONE]) batch[DONE] = _from_tensor(tr[TransitionKey.DONE])
if tr.get(TransitionKey.TRUNCATED) is not None: if tr.get(TransitionKey.TRUNCATED) is not None:
batch["next.truncated"] = _from_tensor(tr[TransitionKey.TRUNCATED]) batch[TRUNCATED] = _from_tensor(tr[TransitionKey.TRUNCATED])
# Complementary data flags and task # Complementary data flags and task
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {} comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
+12 -9
View File
@@ -8,6 +8,7 @@ import torch
import torchvision.transforms.functional as F # noqa: N812 import torchvision.transforms.functional as F # noqa: N812
from lerobot.configs.types import PolicyFeature from lerobot.configs.types import PolicyFeature
from lerobot.constants import ACTION
from lerobot.processor.pipeline import ( from lerobot.processor.pipeline import (
ComplementaryDataProcessor, ComplementaryDataProcessor,
EnvTransition, EnvTransition,
@@ -22,6 +23,8 @@ from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents from lerobot.teleoperators.utils import TeleopEvents
GRIPPER_KEY = "gripper" GRIPPER_KEY = "gripper"
DISCRETE_PENALTY_KEY = "discrete_penalty"
TELEOP_ACTION_KEY = "teleop_action"
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data") @ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
@@ -33,7 +36,7 @@ class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor):
def complementary_data(self, complementary_data: dict) -> dict: def complementary_data(self, complementary_data: dict) -> dict:
new_complementary_data = dict(complementary_data) new_complementary_data = dict(complementary_data)
new_complementary_data["teleop_action"] = self.teleop_device.get_action() new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
return new_complementary_data return new_complementary_data
@@ -141,7 +144,7 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
if current_gripper_pos is None: if current_gripper_pos is None:
return complementary_data return complementary_data
gripper_action = action[f"action.{GRIPPER_KEY}.pos"] gripper_action = action[f"{ACTION}.{GRIPPER_KEY}.pos"]
gripper_action_normalized = gripper_action / self.max_gripper_pos gripper_action_normalized = gripper_action / self.max_gripper_pos
# Normalize gripper state and action # Normalize gripper state and action
@@ -156,7 +159,7 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
# Create new complementary data with penalty info # Create new complementary data with penalty info
new_complementary_data = dict(complementary_data) new_complementary_data = dict(complementary_data)
new_complementary_data["discrete_penalty"] = gripper_penalty new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
return new_complementary_data return new_complementary_data
@@ -187,7 +190,7 @@ class InterventionActionProcessor(ProcessorStep):
# Get intervention signals from complementary data # Get intervention signals from complementary data
info = transition.get(TransitionKey.INFO, {}) info = transition.get(TransitionKey.INFO, {})
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
teleop_action = complementary_data.get("teleop_action", {}) teleop_action = complementary_data.get(TELEOP_ACTION_KEY, {})
is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False) is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False)
terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False) terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False)
success = info.get(TeleopEvents.SUCCESS, False) success = info.get(TeleopEvents.SUCCESS, False)
@@ -200,12 +203,12 @@ class InterventionActionProcessor(ProcessorStep):
if isinstance(teleop_action, dict): if isinstance(teleop_action, dict):
# Convert teleop_action dict to tensor format # Convert teleop_action dict to tensor format
action_list = [ action_list = [
teleop_action.get("action.delta_x", 0.0), teleop_action.get(f"{ACTION}.delta_x", 0.0),
teleop_action.get("action.delta_y", 0.0), teleop_action.get(f"{ACTION}.delta_y", 0.0),
teleop_action.get("action.delta_z", 0.0), teleop_action.get(f"{ACTION}.delta_z", 0.0),
] ]
if self.use_gripper: if self.use_gripper:
action_list.append(teleop_action.get("gripper", 1.0)) action_list.append(teleop_action.get(GRIPPER_KEY, 1.0))
elif isinstance(teleop_action, np.ndarray): elif isinstance(teleop_action, np.ndarray):
action_list = teleop_action.tolist() action_list = teleop_action.tolist()
else: else:
@@ -229,7 +232,7 @@ class InterventionActionProcessor(ProcessorStep):
# Update complementary data with teleop action # Update complementary data with teleop action
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
complementary_data["teleop_action"] = new_transition.get(TransitionKey.ACTION) complementary_data[TELEOP_ACTION_KEY] = new_transition.get(TransitionKey.ACTION)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
return new_transition return new_transition
+9 -11
View File
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any
import torch import torch
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import OBS_LANGUAGE from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
from lerobot.processor.pipeline import ( from lerobot.processor.pipeline import (
EnvTransition, EnvTransition,
ObservationProcessor, ObservationProcessor,
@@ -156,10 +156,8 @@ class TokenizerProcessor(ObservationProcessor):
new_observation = dict(observation) new_observation = dict(observation)
# Add tokenized data to observation # Add tokenized data to observation
new_observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"] new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
new_observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to( new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
dtype=torch.bool
)
return new_observation return new_observation
@@ -239,13 +237,13 @@ class TokenizerProcessor(ObservationProcessor):
""" """
# Add features for tokenized output if they don't exist # Add features for tokenized output if they don't exist
# Standard tokenizer output includes tokens and attention_mask # Standard tokenizer output includes tokens and attention_mask
tokens_key = f"{OBS_LANGUAGE}.tokens"
attention_mask_key = f"{OBS_LANGUAGE}.attention_mask"
if tokens_key not in features: if OBS_LANGUAGE_TOKENS not in features:
features[tokens_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,)) features[OBS_LANGUAGE_TOKENS] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
if attention_mask_key not in features: if OBS_LANGUAGE_ATTENTION_MASK not in features:
features[attention_mask_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,)) features[OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
type=FeatureType.LANGUAGE, shape=(self.max_length,)
)
return features return features
@@ -20,6 +20,7 @@ import numpy as np
from scipy.spatial.transform import Rotation from scipy.spatial.transform import Rotation
from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.model.kinematics import RobotKinematics from lerobot.model.kinematics import RobotKinematics
from lerobot.processor.pipeline import ( from lerobot.processor.pipeline import (
ActionProcessor, ActionProcessor,
@@ -81,13 +82,13 @@ class EEReferenceAndDelta(ActionProcessor):
# Current pose from FK on measured joints # Current pose from FK on measured joints
t_curr = self.kinematics.forward_kinematics(q) t_curr = self.kinematics.forward_kinematics(q)
enabled = bool(new_action.pop("action.enabled", 0)) enabled = bool(new_action.pop(f"{ACTION}.enabled", 0))
tx = float(new_action.pop("action.target_x", 0.0)) tx = float(new_action.pop(f"{ACTION}.target_x", 0.0))
ty = float(new_action.pop("action.target_y", 0.0)) ty = float(new_action.pop(f"{ACTION}.target_y", 0.0))
tz = float(new_action.pop("action.target_z", 0.0)) tz = float(new_action.pop(f"{ACTION}.target_z", 0.0))
wx = float(new_action.pop("action.target_wx", 0.0)) wx = float(new_action.pop(f"{ACTION}.target_wx", 0.0))
wy = float(new_action.pop("action.target_wy", 0.0)) wy = float(new_action.pop(f"{ACTION}.target_wy", 0.0))
wz = float(new_action.pop("action.target_wz", 0.0)) wz = float(new_action.pop(f"{ACTION}.target_wz", 0.0))
desired = None desired = None
@@ -123,12 +124,12 @@ class EEReferenceAndDelta(ActionProcessor):
# Write action fields # Write action fields
pos = desired[:3, 3] pos = desired[:3, 3]
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec() tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
new_action["action.ee.x"] = float(pos[0]) new_action[f"{ACTION}.ee.x"] = float(pos[0])
new_action["action.ee.y"] = float(pos[1]) new_action[f"{ACTION}.ee.y"] = float(pos[1])
new_action["action.ee.z"] = float(pos[2]) new_action[f"{ACTION}.ee.z"] = float(pos[2])
new_action["action.ee.wx"] = float(tw[0]) new_action[f"{ACTION}.ee.wx"] = float(tw[0])
new_action["action.ee.wy"] = float(tw[1]) new_action[f"{ACTION}.ee.wy"] = float(tw[1])
new_action["action.ee.wz"] = float(tw[2]) new_action[f"{ACTION}.ee.wz"] = float(tw[2])
self._prev_enabled = enabled self._prev_enabled = enabled
return new_action return new_action
@@ -139,20 +140,20 @@ class EEReferenceAndDelta(ActionProcessor):
self._command_when_disabled = None self._command_when_disabled = None
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop("action.enabled", None) features.pop(f"{ACTION}.enabled", None)
features.pop("action.target_x", None) features.pop(f"{ACTION}.target_x", None)
features.pop("action.target_y", None) features.pop(f"{ACTION}.target_y", None)
features.pop("action.target_z", None) features.pop(f"{ACTION}.target_z", None)
features.pop("action.target_wx", None) features.pop(f"{ACTION}.target_wx", None)
features.pop("action.target_wy", None) features.pop(f"{ACTION}.target_wy", None)
features.pop("action.target_wz", None) features.pop(f"{ACTION}.target_wz", None)
features["action.ee.x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) features[f"{ACTION}.ee.x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.ee.y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) features[f"{ACTION}.ee.y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.ee.z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) features[f"{ACTION}.ee.z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.ee.wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) features[f"{ACTION}.ee.wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.ee.wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) features[f"{ACTION}.ee.wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.ee.wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) features[f"{ACTION}.ee.wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
return features return features
@@ -180,12 +181,12 @@ class EEBoundsAndSafety(ActionProcessor):
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False) _last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
def action(self, act: dict) -> dict: def action(self, act: dict) -> dict:
x = act.get("action.ee.x", None) x = act.get(f"{ACTION}.ee.x", None)
y = act.get("action.ee.y", None) y = act.get(f"{ACTION}.ee.y", None)
z = act.get("action.ee.z", None) z = act.get(f"{ACTION}.ee.z", None)
wx = act.get("action.ee.wx", None) wx = act.get(f"{ACTION}.ee.wx", None)
wy = act.get("action.ee.wy", None) wy = act.get(f"{ACTION}.ee.wy", None)
wz = act.get("action.ee.wz", None) wz = act.get(f"{ACTION}.ee.wz", None)
if None in (x, y, z, wx, wy, wz): if None in (x, y, z, wx, wy, wz):
return act return act
@@ -207,12 +208,12 @@ class EEBoundsAndSafety(ActionProcessor):
self._last_pos = pos self._last_pos = pos
self._last_twist = twist self._last_twist = twist
act["action.ee.x"] = float(pos[0]) act[f"{ACTION}.ee.x"] = float(pos[0])
act["action.ee.y"] = float(pos[1]) act[f"{ACTION}.ee.y"] = float(pos[1])
act["action.ee.z"] = float(pos[2]) act[f"{ACTION}.ee.z"] = float(pos[2])
act["action.ee.wx"] = float(twist[0]) act[f"{ACTION}.ee.wx"] = float(twist[0])
act["action.ee.wy"] = float(twist[1]) act[f"{ACTION}.ee.wy"] = float(twist[1])
act["action.ee.wz"] = float(twist[2]) act[f"{ACTION}.ee.wz"] = float(twist[2])
return act return act
def reset(self): def reset(self):
@@ -250,12 +251,12 @@ class InverseKinematicsEEToJoints(ProcessorStep):
act = transition.get(TransitionKey.ACTION) or {} act = transition.get(TransitionKey.ACTION) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
x = act.get("action.ee.x", None) x = act.get(f"{ACTION}.ee.x", None)
y = act.get("action.ee.y", None) y = act.get(f"{ACTION}.ee.y", None)
z = act.get("action.ee.z", None) z = act.get(f"{ACTION}.ee.z", None)
wx = act.get("action.ee.wx", None) wx = act.get(f"{ACTION}.ee.wx", None)
wy = act.get("action.ee.wy", None) wy = act.get(f"{ACTION}.ee.wy", None)
wz = act.get("action.ee.wz", None) wz = act.get(f"{ACTION}.ee.wz", None)
if None in (x, y, z, wx, wy, wz): if None in (x, y, z, wx, wy, wz):
return transition return transition
@@ -285,19 +286,19 @@ class InverseKinematicsEEToJoints(ProcessorStep):
new_act = dict(act) new_act = dict(act)
for i, name in enumerate(self.motor_names): for i, name in enumerate(self.motor_names):
if name == "gripper": if name == "gripper":
new_act["observation.state.gripper.pos"] = float(raw["gripper"]) new_act[f"{OBS_STATE}.gripper.pos"] = float(raw["gripper"])
else: else:
new_act[f"action.{name}.pos"] = float(q_target[i]) new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
transition[TransitionKey.ACTION] = new_act transition[TransitionKey.ACTION] = new_act
if not self.initial_guess_current_joints: if not self.initial_guess_current_joints:
transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
return transition return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features["observation.state.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) features[f"{OBS_STATE}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
features["action.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
for name in self.motor_names: for name in self.motor_names:
features[f"action.{name}.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) features[f"{ACTION}.{name}.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
return features return features
@@ -333,12 +334,12 @@ class GripperVelocityToJoint(ProcessorStep):
act = transition.get(TransitionKey.ACTION) or {} act = transition.get(TransitionKey.ACTION) or {}
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {} comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
if "action.gripper" not in act: if f"{ACTION}.gripper" not in act:
return transition return transition
if "gripper" not in self.motor_names: if "gripper" not in self.motor_names:
new_act = dict(act) new_act = dict(act)
new_act.pop("action.gripper", None) new_act.pop(f"{ACTION}.gripper", None)
transition[TransitionKey.ACTION] = new_act transition[TransitionKey.ACTION] = new_act
return transition return transition
@@ -346,32 +347,32 @@ class GripperVelocityToJoint(ProcessorStep):
# Discrete gripper actions are in [0, 1, 2] # Discrete gripper actions are in [0, 1, 2]
# 0: open, 1: close, 2: stay # 0: open, 1: close, 2: stay
# We need to shift them to [-1, 0, 1] and then scale them to clip_max # We need to shift them to [-1, 0, 1] and then scale them to clip_max
gripper_action = act.get("action.gripper", 1.0) gripper_action = act.get(f"{ACTION}.gripper", 1.0)
gripper_action = gripper_action - 1.0 gripper_action = gripper_action - 1.0
gripper_action *= self.clip_max gripper_action *= self.clip_max
act["action.gripper"] = gripper_action act[f"{ACTION}.gripper"] = gripper_action
# Get current gripper position from complementary data # Get current gripper position from complementary data
raw = comp.get("raw_joint_positions") or {} raw = comp.get("raw_joint_positions") or {}
curr_pos = float(raw.get("gripper")) curr_pos = float(raw.get("gripper"))
# Compute desired gripper velocity # Compute desired gripper velocity
u = float(act.get("action.gripper", 0.0)) u = float(act.get(f"{ACTION}.gripper", 0.0))
delta = u * float(self.speed_factor) delta = u * float(self.speed_factor)
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max)) gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
new_act = dict(act) new_act = dict(act)
new_act["action.gripper.pos"] = gripper_pos new_act[f"{ACTION}.gripper.pos"] = gripper_pos
new_act.pop("action.gripper", None) new_act.pop(f"{ACTION}.gripper", None)
transition[TransitionKey.ACTION] = new_act transition[TransitionKey.ACTION] = new_act
obs["observation.state.gripper.pos"] = curr_pos obs[f"{OBS_STATE}.gripper.pos"] = curr_pos
transition[TransitionKey.OBSERVATION] = obs transition[TransitionKey.OBSERVATION] = obs
return transition return transition
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
features.pop("action.gripper", None) features.pop(f"{ACTION}.gripper", None)
features["action.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
return features return features
@@ -396,26 +397,26 @@ class ForwardKinematicsJointsToEE(ObservationProcessor):
motor_names: list[str] motor_names: list[str]
def observation(self, obs: dict) -> dict: def observation(self, obs: dict) -> dict:
if not all(f"observation.state.{n}.pos" in obs for n in self.motor_names): if not all(f"{OBS_STATE}.{n}.pos" in obs for n in self.motor_names):
return obs return obs
q = np.array([obs[f"observation.state.{n}.pos"] for n in self.motor_names], dtype=float) q = np.array([obs[f"{OBS_STATE}.{n}.pos"] for n in self.motor_names], dtype=float)
t = self.kinematics.forward_kinematics(q) t = self.kinematics.forward_kinematics(q)
pos = t[:3, 3] pos = t[:3, 3]
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec() tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
obs["observation.state.ee.x"] = float(pos[0]) obs[f"{OBS_STATE}.ee.x"] = float(pos[0])
obs["observation.state.ee.y"] = float(pos[1]) obs[f"{OBS_STATE}.ee.y"] = float(pos[1])
obs["observation.state.ee.z"] = float(pos[2]) obs[f"{OBS_STATE}.ee.z"] = float(pos[2])
obs["observation.state.ee.wx"] = float(tw[0]) obs[f"{OBS_STATE}.ee.wx"] = float(tw[0])
obs["observation.state.ee.wy"] = float(tw[1]) obs[f"{OBS_STATE}.ee.wy"] = float(tw[1])
obs["observation.state.ee.wz"] = float(tw[2]) obs[f"{OBS_STATE}.ee.wz"] = float(tw[2])
return obs return obs
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
# We specify the dataset features of this step that we want to be stored in the dataset # We specify the dataset features of this step that we want to be stored in the dataset
for k in ["x", "y", "z", "wx", "wy", "wz"]: for k in ["x", "y", "z", "wx", "wy", "wz"]:
features[f"observation.state.ee.{k}"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),) features[f"{OBS_STATE}.ee.{k}"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
return features return features