mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 15:09:51 +00:00
refactor(constants, processor): standardize action and observation keys across multiple files (#1808)
- Added new constants for truncated and done states in constants.py. - Updated references to action and observation keys in pipeline_features.py, converters.py, hil_processor.py, tokenizer_processor.py, and robot_kinematic_processor.py to use the new constants for improved readability and maintainability.
This commit is contained in:
@@ -24,6 +24,11 @@ OBS_IMAGES = "observation.images"
|
|||||||
OBS_LANGUAGE = "observation.language"
|
OBS_LANGUAGE = "observation.language"
|
||||||
ACTION = "action"
|
ACTION = "action"
|
||||||
REWARD = "next.reward"
|
REWARD = "next.reward"
|
||||||
|
TRUNCATED = "next.truncated"
|
||||||
|
DONE = "next.done"
|
||||||
|
|
||||||
|
OBS_LANGUAGE_TOKENS = "observation.language.tokens"
|
||||||
|
OBS_LANGUAGE_ATTENTION_MASK = "observation.language.attention_mask"
|
||||||
|
|
||||||
ROBOTS = "robots"
|
ROBOTS = "robots"
|
||||||
ROBOT_TYPE = "robot_type"
|
ROBOT_TYPE = "robot_type"
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||||
from lerobot.datasets.utils import hw_to_dataset_features
|
from lerobot.datasets.utils import hw_to_dataset_features
|
||||||
from lerobot.processor.pipeline import RobotProcessor
|
from lerobot.processor.pipeline import RobotProcessor
|
||||||
|
|
||||||
@@ -59,26 +60,26 @@ def aggregate_pipeline_dataset_features(
|
|||||||
|
|
||||||
# Go over every feature from the pipeline and merge:
|
# Go over every feature from the pipeline and merge:
|
||||||
for full_key, ty in all_features.items():
|
for full_key, ty in all_features.items():
|
||||||
if full_key.startswith("action."):
|
if full_key.startswith(f"{ACTION}."):
|
||||||
# action.<feat>
|
# action.<feat>
|
||||||
if not keep(full_key):
|
if not keep(full_key):
|
||||||
continue
|
continue
|
||||||
name = full_key[len("action.") :]
|
name = full_key[len(f"{ACTION}.") :]
|
||||||
hw.setdefault("action", {})[name] = ty
|
hw.setdefault(ACTION, {})[name] = ty
|
||||||
|
|
||||||
elif full_key.startswith("observation.state."):
|
elif full_key.startswith(f"{OBS_STATE}."):
|
||||||
# observation.state.<feat>
|
# observation.state.<feat>
|
||||||
if not keep(full_key):
|
if not keep(full_key):
|
||||||
continue
|
continue
|
||||||
name = full_key[len("observation.state.") :]
|
name = full_key[len(f"{OBS_STATE}.") :]
|
||||||
hw.setdefault("observation", {})[name] = ty
|
hw.setdefault("observation", {})[name] = ty
|
||||||
|
|
||||||
elif full_key.startswith("observation.images."):
|
elif full_key.startswith(f"{OBS_IMAGES}."):
|
||||||
# observation.images.<cam>
|
# observation.images.<cam>
|
||||||
# images obey ONLY the use_videos flag, not patterns
|
# images obey ONLY the use_videos flag, not patterns
|
||||||
if not use_videos:
|
if not use_videos:
|
||||||
continue
|
continue
|
||||||
name = full_key[len("observation.images.") :]
|
name = full_key[len(f"{OBS_IMAGES}.") :]
|
||||||
hw.setdefault("observation", {})[name] = ty
|
hw.setdefault("observation", {})[name] = ty
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -86,8 +87,8 @@ def aggregate_pipeline_dataset_features(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
out: dict[str, dict] = {}
|
out: dict[str, dict] = {}
|
||||||
if "action" in hw:
|
if ACTION in hw:
|
||||||
out.update(hw_to_dataset_features(hw["action"], "action", use_videos))
|
out.update(hw_to_dataset_features(hw[ACTION], ACTION, use_videos))
|
||||||
if "observation" in hw:
|
if "observation" in hw:
|
||||||
out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos))
|
out.update(hw_to_dataset_features(hw["observation"], "observation", use_videos))
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from scipy.spatial.transform import Rotation
|
from scipy.spatial.transform import Rotation
|
||||||
|
|
||||||
|
from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED
|
||||||
|
|
||||||
from .pipeline import EnvTransition, TransitionKey
|
from .pipeline import EnvTransition, TransitionKey
|
||||||
|
|
||||||
|
|
||||||
@@ -82,11 +84,11 @@ def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition:
|
|||||||
for k, v in action.items():
|
for k, v in action.items():
|
||||||
# Check if the value is a type that should not be converted to a tensor.
|
# Check if the value is a type that should not be converted to a tensor.
|
||||||
if isinstance(v, (Rotation, dict)):
|
if isinstance(v, (Rotation, dict)):
|
||||||
act_dict[f"action.{k}"] = v
|
act_dict[f"{ACTION}.{k}"] = v
|
||||||
continue
|
continue
|
||||||
|
|
||||||
arr = np.array(v) if np.isscalar(v) else v
|
arr = np.array(v) if np.isscalar(v) else v
|
||||||
act_dict[f"action.{k}"] = _to_tensor(arr)
|
act_dict[f"{ACTION}.{k}"] = _to_tensor(arr)
|
||||||
|
|
||||||
return make_obs_act_transition(act=act_dict)
|
return make_obs_act_transition(act=act_dict)
|
||||||
|
|
||||||
@@ -101,10 +103,10 @@ def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransitio
|
|||||||
obs_dict: dict[str, Any] = {}
|
obs_dict: dict[str, Any] = {}
|
||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
arr = np.array(v) if np.isscalar(v) else v
|
arr = np.array(v) if np.isscalar(v) else v
|
||||||
obs_dict[f"observation.state.{k}"] = _to_tensor(arr)
|
obs_dict[f"{OBS_STATE}.{k}"] = _to_tensor(arr)
|
||||||
|
|
||||||
for cam, img in images.items():
|
for cam, img in images.items():
|
||||||
obs_dict[f"observation.images.{cam}"] = img
|
obs_dict[f"{OBS_IMAGES}.{cam}"] = img
|
||||||
|
|
||||||
return make_obs_act_transition(obs=obs_dict)
|
return make_obs_act_transition(obs=obs_dict)
|
||||||
|
|
||||||
@@ -120,8 +122,8 @@ def to_output_robot_action(transition: EnvTransition) -> dict[str, Any]:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
for k, v in action_dict.items():
|
for k, v in action_dict.items():
|
||||||
if isinstance(k, str) and k.startswith("action.") and k.endswith((".pos", ".vel")):
|
if isinstance(k, str) and k.startswith(f"{ACTION}.") and k.endswith((".pos", ".vel")):
|
||||||
out_key = k[len("action.") :] # Strip the 'action.' prefix.
|
out_key = k[len(f"{ACTION}.") :] # Strip the 'action.' prefix.
|
||||||
out[out_key] = float(v)
|
out[out_key] = float(v)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
@@ -152,9 +154,9 @@ def to_dataset_frame(
|
|||||||
- info dict
|
- info dict
|
||||||
- *_is_pad flags and task from complementary_data
|
- *_is_pad flags and task from complementary_data
|
||||||
"""
|
"""
|
||||||
action_names = features.get("action", {}).get("names", [])
|
action_names = features.get(ACTION, {}).get("names", [])
|
||||||
obs_state_names = features.get("observation.state", {}).get("names", [])
|
obs_state_names = features.get(OBS_STATE, {}).get("names", [])
|
||||||
image_keys = [k for k in features if k.startswith("observation.images.")]
|
image_keys = [k for k in features if k.startswith(OBS_IMAGES)]
|
||||||
|
|
||||||
def _merge(base: EnvTransition, other: EnvTransition) -> EnvTransition:
|
def _merge(base: EnvTransition, other: EnvTransition) -> EnvTransition:
|
||||||
out = deepcopy(base)
|
out = deepcopy(base)
|
||||||
@@ -198,21 +200,20 @@ def to_dataset_frame(
|
|||||||
|
|
||||||
# Observation.state vector
|
# Observation.state vector
|
||||||
if obs_state_names:
|
if obs_state_names:
|
||||||
vals = [_from_tensor(obs.get(f"observation.state.{n}", 0.0)) for n in obs_state_names]
|
vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
|
||||||
batch["observation.state"] = np.asarray(vals, dtype=np.float32)
|
batch[OBS_STATE] = np.asarray(vals, dtype=np.float32)
|
||||||
|
|
||||||
# Action vector
|
# Action vector
|
||||||
if action_names:
|
if action_names:
|
||||||
vals = [_from_tensor(act.get(f"action.{n}", 0.0)) for n in action_names]
|
vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
|
||||||
batch["action"] = np.asarray(vals, dtype=np.float32)
|
batch[ACTION] = np.asarray(vals, dtype=np.float32)
|
||||||
|
|
||||||
# Next.* fields
|
|
||||||
if tr.get(TransitionKey.REWARD) is not None:
|
if tr.get(TransitionKey.REWARD) is not None:
|
||||||
batch["next.reward"] = _from_tensor(tr[TransitionKey.REWARD])
|
batch[REWARD] = _from_tensor(tr[TransitionKey.REWARD])
|
||||||
if tr.get(TransitionKey.DONE) is not None:
|
if tr.get(TransitionKey.DONE) is not None:
|
||||||
batch["next.done"] = _from_tensor(tr[TransitionKey.DONE])
|
batch[DONE] = _from_tensor(tr[TransitionKey.DONE])
|
||||||
if tr.get(TransitionKey.TRUNCATED) is not None:
|
if tr.get(TransitionKey.TRUNCATED) is not None:
|
||||||
batch["next.truncated"] = _from_tensor(tr[TransitionKey.TRUNCATED])
|
batch[TRUNCATED] = _from_tensor(tr[TransitionKey.TRUNCATED])
|
||||||
|
|
||||||
# Complementary data flags and task
|
# Complementary data flags and task
|
||||||
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import torch
|
|||||||
import torchvision.transforms.functional as F # noqa: N812
|
import torchvision.transforms.functional as F # noqa: N812
|
||||||
|
|
||||||
from lerobot.configs.types import PolicyFeature
|
from lerobot.configs.types import PolicyFeature
|
||||||
|
from lerobot.constants import ACTION
|
||||||
from lerobot.processor.pipeline import (
|
from lerobot.processor.pipeline import (
|
||||||
ComplementaryDataProcessor,
|
ComplementaryDataProcessor,
|
||||||
EnvTransition,
|
EnvTransition,
|
||||||
@@ -22,6 +23,8 @@ from lerobot.teleoperators.teleoperator import Teleoperator
|
|||||||
from lerobot.teleoperators.utils import TeleopEvents
|
from lerobot.teleoperators.utils import TeleopEvents
|
||||||
|
|
||||||
GRIPPER_KEY = "gripper"
|
GRIPPER_KEY = "gripper"
|
||||||
|
DISCRETE_PENALTY_KEY = "discrete_penalty"
|
||||||
|
TELEOP_ACTION_KEY = "teleop_action"
|
||||||
|
|
||||||
|
|
||||||
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
|
@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data")
|
||||||
@@ -33,7 +36,7 @@ class AddTeleopActionAsComplimentaryData(ComplementaryDataProcessor):
|
|||||||
|
|
||||||
def complementary_data(self, complementary_data: dict) -> dict:
|
def complementary_data(self, complementary_data: dict) -> dict:
|
||||||
new_complementary_data = dict(complementary_data)
|
new_complementary_data = dict(complementary_data)
|
||||||
new_complementary_data["teleop_action"] = self.teleop_device.get_action()
|
new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action()
|
||||||
return new_complementary_data
|
return new_complementary_data
|
||||||
|
|
||||||
|
|
||||||
@@ -141,7 +144,7 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
|
|||||||
if current_gripper_pos is None:
|
if current_gripper_pos is None:
|
||||||
return complementary_data
|
return complementary_data
|
||||||
|
|
||||||
gripper_action = action[f"action.{GRIPPER_KEY}.pos"]
|
gripper_action = action[f"{ACTION}.{GRIPPER_KEY}.pos"]
|
||||||
gripper_action_normalized = gripper_action / self.max_gripper_pos
|
gripper_action_normalized = gripper_action / self.max_gripper_pos
|
||||||
|
|
||||||
# Normalize gripper state and action
|
# Normalize gripper state and action
|
||||||
@@ -156,7 +159,7 @@ class GripperPenaltyProcessor(ComplementaryDataProcessor):
|
|||||||
|
|
||||||
# Create new complementary data with penalty info
|
# Create new complementary data with penalty info
|
||||||
new_complementary_data = dict(complementary_data)
|
new_complementary_data = dict(complementary_data)
|
||||||
new_complementary_data["discrete_penalty"] = gripper_penalty
|
new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty
|
||||||
|
|
||||||
return new_complementary_data
|
return new_complementary_data
|
||||||
|
|
||||||
@@ -187,7 +190,7 @@ class InterventionActionProcessor(ProcessorStep):
|
|||||||
# Get intervention signals from complementary data
|
# Get intervention signals from complementary data
|
||||||
info = transition.get(TransitionKey.INFO, {})
|
info = transition.get(TransitionKey.INFO, {})
|
||||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||||
teleop_action = complementary_data.get("teleop_action", {})
|
teleop_action = complementary_data.get(TELEOP_ACTION_KEY, {})
|
||||||
is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False)
|
is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False)
|
||||||
terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False)
|
terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False)
|
||||||
success = info.get(TeleopEvents.SUCCESS, False)
|
success = info.get(TeleopEvents.SUCCESS, False)
|
||||||
@@ -200,12 +203,12 @@ class InterventionActionProcessor(ProcessorStep):
|
|||||||
if isinstance(teleop_action, dict):
|
if isinstance(teleop_action, dict):
|
||||||
# Convert teleop_action dict to tensor format
|
# Convert teleop_action dict to tensor format
|
||||||
action_list = [
|
action_list = [
|
||||||
teleop_action.get("action.delta_x", 0.0),
|
teleop_action.get(f"{ACTION}.delta_x", 0.0),
|
||||||
teleop_action.get("action.delta_y", 0.0),
|
teleop_action.get(f"{ACTION}.delta_y", 0.0),
|
||||||
teleop_action.get("action.delta_z", 0.0),
|
teleop_action.get(f"{ACTION}.delta_z", 0.0),
|
||||||
]
|
]
|
||||||
if self.use_gripper:
|
if self.use_gripper:
|
||||||
action_list.append(teleop_action.get("gripper", 1.0))
|
action_list.append(teleop_action.get(GRIPPER_KEY, 1.0))
|
||||||
elif isinstance(teleop_action, np.ndarray):
|
elif isinstance(teleop_action, np.ndarray):
|
||||||
action_list = teleop_action.tolist()
|
action_list = teleop_action.tolist()
|
||||||
else:
|
else:
|
||||||
@@ -229,7 +232,7 @@ class InterventionActionProcessor(ProcessorStep):
|
|||||||
|
|
||||||
# Update complementary data with teleop action
|
# Update complementary data with teleop action
|
||||||
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||||
complementary_data["teleop_action"] = new_transition.get(TransitionKey.ACTION)
|
complementary_data[TELEOP_ACTION_KEY] = new_transition.get(TransitionKey.ACTION)
|
||||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data
|
||||||
|
|
||||||
return new_transition
|
return new_transition
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
from lerobot.constants import OBS_LANGUAGE
|
from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||||
from lerobot.processor.pipeline import (
|
from lerobot.processor.pipeline import (
|
||||||
EnvTransition,
|
EnvTransition,
|
||||||
ObservationProcessor,
|
ObservationProcessor,
|
||||||
@@ -156,10 +156,8 @@ class TokenizerProcessor(ObservationProcessor):
|
|||||||
new_observation = dict(observation)
|
new_observation = dict(observation)
|
||||||
|
|
||||||
# Add tokenized data to observation
|
# Add tokenized data to observation
|
||||||
new_observation[f"{OBS_LANGUAGE}.tokens"] = tokenized_prompt["input_ids"]
|
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||||
new_observation[f"{OBS_LANGUAGE}.attention_mask"] = tokenized_prompt["attention_mask"].to(
|
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||||
dtype=torch.bool
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_observation
|
return new_observation
|
||||||
|
|
||||||
@@ -239,13 +237,13 @@ class TokenizerProcessor(ObservationProcessor):
|
|||||||
"""
|
"""
|
||||||
# Add features for tokenized output if they don't exist
|
# Add features for tokenized output if they don't exist
|
||||||
# Standard tokenizer output includes tokens and attention_mask
|
# Standard tokenizer output includes tokens and attention_mask
|
||||||
tokens_key = f"{OBS_LANGUAGE}.tokens"
|
|
||||||
attention_mask_key = f"{OBS_LANGUAGE}.attention_mask"
|
|
||||||
|
|
||||||
if tokens_key not in features:
|
if OBS_LANGUAGE_TOKENS not in features:
|
||||||
features[tokens_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
|
features[OBS_LANGUAGE_TOKENS] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
|
||||||
|
|
||||||
if attention_mask_key not in features:
|
if OBS_LANGUAGE_ATTENTION_MASK not in features:
|
||||||
features[attention_mask_key] = PolicyFeature(type=FeatureType.LANGUAGE, shape=(self.max_length,))
|
features[OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature(
|
||||||
|
type=FeatureType.LANGUAGE, shape=(self.max_length,)
|
||||||
|
)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import numpy as np
|
|||||||
from scipy.spatial.transform import Rotation
|
from scipy.spatial.transform import Rotation
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
|
from lerobot.constants import ACTION, OBS_STATE
|
||||||
from lerobot.model.kinematics import RobotKinematics
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
from lerobot.processor.pipeline import (
|
from lerobot.processor.pipeline import (
|
||||||
ActionProcessor,
|
ActionProcessor,
|
||||||
@@ -81,13 +82,13 @@ class EEReferenceAndDelta(ActionProcessor):
|
|||||||
# Current pose from FK on measured joints
|
# Current pose from FK on measured joints
|
||||||
t_curr = self.kinematics.forward_kinematics(q)
|
t_curr = self.kinematics.forward_kinematics(q)
|
||||||
|
|
||||||
enabled = bool(new_action.pop("action.enabled", 0))
|
enabled = bool(new_action.pop(f"{ACTION}.enabled", 0))
|
||||||
tx = float(new_action.pop("action.target_x", 0.0))
|
tx = float(new_action.pop(f"{ACTION}.target_x", 0.0))
|
||||||
ty = float(new_action.pop("action.target_y", 0.0))
|
ty = float(new_action.pop(f"{ACTION}.target_y", 0.0))
|
||||||
tz = float(new_action.pop("action.target_z", 0.0))
|
tz = float(new_action.pop(f"{ACTION}.target_z", 0.0))
|
||||||
wx = float(new_action.pop("action.target_wx", 0.0))
|
wx = float(new_action.pop(f"{ACTION}.target_wx", 0.0))
|
||||||
wy = float(new_action.pop("action.target_wy", 0.0))
|
wy = float(new_action.pop(f"{ACTION}.target_wy", 0.0))
|
||||||
wz = float(new_action.pop("action.target_wz", 0.0))
|
wz = float(new_action.pop(f"{ACTION}.target_wz", 0.0))
|
||||||
|
|
||||||
desired = None
|
desired = None
|
||||||
|
|
||||||
@@ -123,12 +124,12 @@ class EEReferenceAndDelta(ActionProcessor):
|
|||||||
# Write action fields
|
# Write action fields
|
||||||
pos = desired[:3, 3]
|
pos = desired[:3, 3]
|
||||||
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
|
tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec()
|
||||||
new_action["action.ee.x"] = float(pos[0])
|
new_action[f"{ACTION}.ee.x"] = float(pos[0])
|
||||||
new_action["action.ee.y"] = float(pos[1])
|
new_action[f"{ACTION}.ee.y"] = float(pos[1])
|
||||||
new_action["action.ee.z"] = float(pos[2])
|
new_action[f"{ACTION}.ee.z"] = float(pos[2])
|
||||||
new_action["action.ee.wx"] = float(tw[0])
|
new_action[f"{ACTION}.ee.wx"] = float(tw[0])
|
||||||
new_action["action.ee.wy"] = float(tw[1])
|
new_action[f"{ACTION}.ee.wy"] = float(tw[1])
|
||||||
new_action["action.ee.wz"] = float(tw[2])
|
new_action[f"{ACTION}.ee.wz"] = float(tw[2])
|
||||||
|
|
||||||
self._prev_enabled = enabled
|
self._prev_enabled = enabled
|
||||||
return new_action
|
return new_action
|
||||||
@@ -139,20 +140,20 @@ class EEReferenceAndDelta(ActionProcessor):
|
|||||||
self._command_when_disabled = None
|
self._command_when_disabled = None
|
||||||
|
|
||||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
features.pop("action.enabled", None)
|
features.pop(f"{ACTION}.enabled", None)
|
||||||
features.pop("action.target_x", None)
|
features.pop(f"{ACTION}.target_x", None)
|
||||||
features.pop("action.target_y", None)
|
features.pop(f"{ACTION}.target_y", None)
|
||||||
features.pop("action.target_z", None)
|
features.pop(f"{ACTION}.target_z", None)
|
||||||
features.pop("action.target_wx", None)
|
features.pop(f"{ACTION}.target_wx", None)
|
||||||
features.pop("action.target_wy", None)
|
features.pop(f"{ACTION}.target_wy", None)
|
||||||
features.pop("action.target_wz", None)
|
features.pop(f"{ACTION}.target_wz", None)
|
||||||
|
|
||||||
features["action.ee.x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
features[f"{ACTION}.ee.x"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||||
features["action.ee.y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
features[f"{ACTION}.ee.y"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||||
features["action.ee.z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
features[f"{ACTION}.ee.z"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||||
features["action.ee.wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
features[f"{ACTION}.ee.wx"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||||
features["action.ee.wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
features[f"{ACTION}.ee.wy"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||||
features["action.ee.wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
features[f"{ACTION}.ee.wz"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
@@ -180,12 +181,12 @@ class EEBoundsAndSafety(ActionProcessor):
|
|||||||
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
|
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||||
|
|
||||||
def action(self, act: dict) -> dict:
|
def action(self, act: dict) -> dict:
|
||||||
x = act.get("action.ee.x", None)
|
x = act.get(f"{ACTION}.ee.x", None)
|
||||||
y = act.get("action.ee.y", None)
|
y = act.get(f"{ACTION}.ee.y", None)
|
||||||
z = act.get("action.ee.z", None)
|
z = act.get(f"{ACTION}.ee.z", None)
|
||||||
wx = act.get("action.ee.wx", None)
|
wx = act.get(f"{ACTION}.ee.wx", None)
|
||||||
wy = act.get("action.ee.wy", None)
|
wy = act.get(f"{ACTION}.ee.wy", None)
|
||||||
wz = act.get("action.ee.wz", None)
|
wz = act.get(f"{ACTION}.ee.wz", None)
|
||||||
|
|
||||||
if None in (x, y, z, wx, wy, wz):
|
if None in (x, y, z, wx, wy, wz):
|
||||||
return act
|
return act
|
||||||
@@ -207,12 +208,12 @@ class EEBoundsAndSafety(ActionProcessor):
|
|||||||
self._last_pos = pos
|
self._last_pos = pos
|
||||||
self._last_twist = twist
|
self._last_twist = twist
|
||||||
|
|
||||||
act["action.ee.x"] = float(pos[0])
|
act[f"{ACTION}.ee.x"] = float(pos[0])
|
||||||
act["action.ee.y"] = float(pos[1])
|
act[f"{ACTION}.ee.y"] = float(pos[1])
|
||||||
act["action.ee.z"] = float(pos[2])
|
act[f"{ACTION}.ee.z"] = float(pos[2])
|
||||||
act["action.ee.wx"] = float(twist[0])
|
act[f"{ACTION}.ee.wx"] = float(twist[0])
|
||||||
act["action.ee.wy"] = float(twist[1])
|
act[f"{ACTION}.ee.wy"] = float(twist[1])
|
||||||
act["action.ee.wz"] = float(twist[2])
|
act[f"{ACTION}.ee.wz"] = float(twist[2])
|
||||||
return act
|
return act
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@@ -250,12 +251,12 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
|||||||
act = transition.get(TransitionKey.ACTION) or {}
|
act = transition.get(TransitionKey.ACTION) or {}
|
||||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||||
|
|
||||||
x = act.get("action.ee.x", None)
|
x = act.get(f"{ACTION}.ee.x", None)
|
||||||
y = act.get("action.ee.y", None)
|
y = act.get(f"{ACTION}.ee.y", None)
|
||||||
z = act.get("action.ee.z", None)
|
z = act.get(f"{ACTION}.ee.z", None)
|
||||||
wx = act.get("action.ee.wx", None)
|
wx = act.get(f"{ACTION}.ee.wx", None)
|
||||||
wy = act.get("action.ee.wy", None)
|
wy = act.get(f"{ACTION}.ee.wy", None)
|
||||||
wz = act.get("action.ee.wz", None)
|
wz = act.get(f"{ACTION}.ee.wz", None)
|
||||||
|
|
||||||
if None in (x, y, z, wx, wy, wz):
|
if None in (x, y, z, wx, wy, wz):
|
||||||
return transition
|
return transition
|
||||||
@@ -285,19 +286,19 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
|||||||
new_act = dict(act)
|
new_act = dict(act)
|
||||||
for i, name in enumerate(self.motor_names):
|
for i, name in enumerate(self.motor_names):
|
||||||
if name == "gripper":
|
if name == "gripper":
|
||||||
new_act["observation.state.gripper.pos"] = float(raw["gripper"])
|
new_act[f"{OBS_STATE}.gripper.pos"] = float(raw["gripper"])
|
||||||
else:
|
else:
|
||||||
new_act[f"action.{name}.pos"] = float(q_target[i])
|
new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
|
||||||
transition[TransitionKey.ACTION] = new_act
|
transition[TransitionKey.ACTION] = new_act
|
||||||
if not self.initial_guess_current_joints:
|
if not self.initial_guess_current_joints:
|
||||||
transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
|
transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
features["observation.state.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
features[f"{OBS_STATE}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||||
features["action.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||||
for name in self.motor_names:
|
for name in self.motor_names:
|
||||||
features[f"action.{name}.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
features[f"{ACTION}.{name}.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
@@ -333,12 +334,12 @@ class GripperVelocityToJoint(ProcessorStep):
|
|||||||
act = transition.get(TransitionKey.ACTION) or {}
|
act = transition.get(TransitionKey.ACTION) or {}
|
||||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||||
|
|
||||||
if "action.gripper" not in act:
|
if f"{ACTION}.gripper" not in act:
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
if "gripper" not in self.motor_names:
|
if "gripper" not in self.motor_names:
|
||||||
new_act = dict(act)
|
new_act = dict(act)
|
||||||
new_act.pop("action.gripper", None)
|
new_act.pop(f"{ACTION}.gripper", None)
|
||||||
transition[TransitionKey.ACTION] = new_act
|
transition[TransitionKey.ACTION] = new_act
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
@@ -346,32 +347,32 @@ class GripperVelocityToJoint(ProcessorStep):
|
|||||||
# Discrete gripper actions are in [0, 1, 2]
|
# Discrete gripper actions are in [0, 1, 2]
|
||||||
# 0: open, 1: close, 2: stay
|
# 0: open, 1: close, 2: stay
|
||||||
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
|
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
|
||||||
gripper_action = act.get("action.gripper", 1.0)
|
gripper_action = act.get(f"{ACTION}.gripper", 1.0)
|
||||||
gripper_action = gripper_action - 1.0
|
gripper_action = gripper_action - 1.0
|
||||||
gripper_action *= self.clip_max
|
gripper_action *= self.clip_max
|
||||||
act["action.gripper"] = gripper_action
|
act[f"{ACTION}.gripper"] = gripper_action
|
||||||
|
|
||||||
# Get current gripper position from complementary data
|
# Get current gripper position from complementary data
|
||||||
raw = comp.get("raw_joint_positions") or {}
|
raw = comp.get("raw_joint_positions") or {}
|
||||||
curr_pos = float(raw.get("gripper"))
|
curr_pos = float(raw.get("gripper"))
|
||||||
|
|
||||||
# Compute desired gripper velocity
|
# Compute desired gripper velocity
|
||||||
u = float(act.get("action.gripper", 0.0))
|
u = float(act.get(f"{ACTION}.gripper", 0.0))
|
||||||
delta = u * float(self.speed_factor)
|
delta = u * float(self.speed_factor)
|
||||||
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
|
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
|
||||||
|
|
||||||
new_act = dict(act)
|
new_act = dict(act)
|
||||||
new_act["action.gripper.pos"] = gripper_pos
|
new_act[f"{ACTION}.gripper.pos"] = gripper_pos
|
||||||
new_act.pop("action.gripper", None)
|
new_act.pop(f"{ACTION}.gripper", None)
|
||||||
transition[TransitionKey.ACTION] = new_act
|
transition[TransitionKey.ACTION] = new_act
|
||||||
|
|
||||||
obs["observation.state.gripper.pos"] = curr_pos
|
obs[f"{OBS_STATE}.gripper.pos"] = curr_pos
|
||||||
transition[TransitionKey.OBSERVATION] = obs
|
transition[TransitionKey.OBSERVATION] = obs
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
features.pop("action.gripper", None)
|
features.pop(f"{ACTION}.gripper", None)
|
||||||
features["action.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
features[f"{ACTION}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
@@ -396,26 +397,26 @@ class ForwardKinematicsJointsToEE(ObservationProcessor):
|
|||||||
motor_names: list[str]
|
motor_names: list[str]
|
||||||
|
|
||||||
def observation(self, obs: dict) -> dict:
|
def observation(self, obs: dict) -> dict:
|
||||||
if not all(f"observation.state.{n}.pos" in obs for n in self.motor_names):
|
if not all(f"{OBS_STATE}.{n}.pos" in obs for n in self.motor_names):
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
q = np.array([obs[f"observation.state.{n}.pos"] for n in self.motor_names], dtype=float)
|
q = np.array([obs[f"{OBS_STATE}.{n}.pos"] for n in self.motor_names], dtype=float)
|
||||||
t = self.kinematics.forward_kinematics(q)
|
t = self.kinematics.forward_kinematics(q)
|
||||||
pos = t[:3, 3]
|
pos = t[:3, 3]
|
||||||
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
|
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
|
||||||
|
|
||||||
obs["observation.state.ee.x"] = float(pos[0])
|
obs[f"{OBS_STATE}.ee.x"] = float(pos[0])
|
||||||
obs["observation.state.ee.y"] = float(pos[1])
|
obs[f"{OBS_STATE}.ee.y"] = float(pos[1])
|
||||||
obs["observation.state.ee.z"] = float(pos[2])
|
obs[f"{OBS_STATE}.ee.z"] = float(pos[2])
|
||||||
obs["observation.state.ee.wx"] = float(tw[0])
|
obs[f"{OBS_STATE}.ee.wx"] = float(tw[0])
|
||||||
obs["observation.state.ee.wy"] = float(tw[1])
|
obs[f"{OBS_STATE}.ee.wy"] = float(tw[1])
|
||||||
obs["observation.state.ee.wz"] = float(tw[2])
|
obs[f"{OBS_STATE}.ee.wz"] = float(tw[2])
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||||
for k in ["x", "y", "z", "wx", "wy", "wz"]:
|
for k in ["x", "y", "z", "wx", "wy", "wz"]:
|
||||||
features[f"observation.state.ee.{k}"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
features[f"{OBS_STATE}.ee.{k}"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user