mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
fix(processor): phone examples (#1921)
* fix(processor): phone examples * chore(processor): simplify gripper in phone example kinematic chain --------- Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -89,6 +89,7 @@ phone_to_robot_ee_pose_processor = RobotProcessorPipeline[RobotAction, EnvTransi
|
||||
max_ee_step_m=0.20,
|
||||
max_ee_twist_step_rad=0.50,
|
||||
),
|
||||
GripperVelocityToJoint(),
|
||||
],
|
||||
to_transition=robot_action_to_transition,
|
||||
to_output=identity_transition,
|
||||
@@ -102,10 +103,6 @@ robot_ee_to_joints_processor = RobotProcessorPipeline[EnvTransition, RobotAction
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
speed_factor=20.0,
|
||||
),
|
||||
],
|
||||
to_transition=identity_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
@@ -127,13 +124,6 @@ action_ee = aggregate_pipeline_dataset_features(
|
||||
use_videos=True,
|
||||
)
|
||||
|
||||
# Get gripper pos action features
|
||||
gripper = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_ee_to_joints_processor,
|
||||
initial_features=create_initial_features(action=robot.action_features, observation={}),
|
||||
use_videos=True,
|
||||
)
|
||||
|
||||
# Build dataset ee observation features
|
||||
observation_ee = aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
@@ -141,7 +131,7 @@ observation_ee = aggregate_pipeline_dataset_features(
|
||||
use_videos=True,
|
||||
)
|
||||
|
||||
dataset_features = combine_feature_dicts(action_ee, gripper, observation_ee)
|
||||
dataset_features = combine_feature_dicts(action_ee, observation_ee)
|
||||
|
||||
print("All dataset features: ", dataset_features)
|
||||
|
||||
|
||||
@@ -684,7 +684,11 @@ def hw_to_dataset_features(
|
||||
dict: A LeRobot features dictionary.
|
||||
"""
|
||||
features = {}
|
||||
joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float}
|
||||
joint_fts = {
|
||||
key: ftype
|
||||
for key, ftype in hw_features.items()
|
||||
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
||||
}
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
|
||||
if joint_fts and prefix == "action":
|
||||
@@ -736,7 +740,7 @@ def build_dataset_frame(
|
||||
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
frame[key] = values[key]
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
@@ -827,7 +827,7 @@ class ObservationProcessorStep(ProcessorStep, ABC):
|
||||
"""An abstract `ProcessorStep` that specifically targets the observation in a transition."""
|
||||
|
||||
@abstractmethod
|
||||
def observation(self, observation) -> dict[str, Any]:
|
||||
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Processes an observation dictionary. Subclasses must implement this method.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -307,7 +307,6 @@ def record_loop(
|
||||
obs_transition = robot_observation_processor(obs)
|
||||
|
||||
if policy is not None or dataset is not None:
|
||||
# TODO(Steven): We might be able to get rid of this
|
||||
observation_frame = build_dataset_frame(
|
||||
dataset.features, obs_transition[TransitionKey.OBSERVATION], prefix="observation"
|
||||
)
|
||||
@@ -357,26 +356,26 @@ def record_loop(
|
||||
# Applies a pipeline to the action, default is IdentityProcessor
|
||||
# IMPORTANT: action_pipeline.to_output must return a dict suitable for robot.send_action()
|
||||
if policy is not None and policy_transition is not None:
|
||||
action_values = policy_transition[TransitionKey.ACTION]
|
||||
robot_action_to_send = robot_action_processor(policy_transition)
|
||||
else:
|
||||
action_values = teleop_transition[TransitionKey.ACTION]
|
||||
robot_action_to_send = robot_action_processor(teleop_transition)
|
||||
|
||||
# Send action to robot
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
||||
# TODO(pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
||||
sent_action = robot.send_action(robot_action_to_send)
|
||||
# TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot.
|
||||
_sent_action = robot.send_action(robot_action_to_send)
|
||||
|
||||
# Write to dataset
|
||||
if dataset is not None:
|
||||
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
||||
action_frame = build_dataset_frame(dataset.features, action_values, prefix="action")
|
||||
frame = {**observation_frame, **action_frame}
|
||||
dataset.add_frame(frame, task=single_task)
|
||||
|
||||
if display_data:
|
||||
log_rerun_data(
|
||||
observation=obs_transition.get(TransitionKey.OBSERVATION), action=robot_action_to_send
|
||||
)
|
||||
log_rerun_data(observation=obs_transition[TransitionKey.OBSERVATION], action=action_values)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
@@ -15,11 +15,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_STATE
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
ComplementaryDataProcessorStep,
|
||||
@@ -92,13 +92,14 @@ class EEReferenceAndDelta(RobotActionProcessorStep):
|
||||
# Current pose from FK on measured joints
|
||||
t_curr = self.kinematics.forward_kinematics(q)
|
||||
|
||||
enabled = bool(new_action.pop("enabled", 0))
|
||||
tx = float(new_action.pop("target_x", 0.0))
|
||||
ty = float(new_action.pop("target_y", 0.0))
|
||||
tz = float(new_action.pop("target_z", 0.0))
|
||||
wx = float(new_action.pop("target_wx", 0.0))
|
||||
wy = float(new_action.pop("target_wy", 0.0))
|
||||
wz = float(new_action.pop("target_wz", 0.0))
|
||||
enabled = bool(new_action.pop("enabled"))
|
||||
tx = float(new_action.pop("target_x"))
|
||||
ty = float(new_action.pop("target_y"))
|
||||
tz = float(new_action.pop("target_z"))
|
||||
wx = float(new_action.pop("target_wx"))
|
||||
wy = float(new_action.pop("target_wy"))
|
||||
wz = float(new_action.pop("target_wz"))
|
||||
gripper_vel = float(new_action.pop("gripper_vel"))
|
||||
|
||||
desired = None
|
||||
|
||||
@@ -140,6 +141,7 @@ class EEReferenceAndDelta(RobotActionProcessorStep):
|
||||
new_action["ee.wx"] = float(tw[0])
|
||||
new_action["ee.wy"] = float(tw[1])
|
||||
new_action["ee.wz"] = float(tw[2])
|
||||
new_action["ee.gripper_vel"] = gripper_vel
|
||||
|
||||
self._prev_enabled = enabled
|
||||
return new_action
|
||||
@@ -160,6 +162,7 @@ class EEReferenceAndDelta(RobotActionProcessorStep):
|
||||
features[PipelineFeatureType.ACTION].pop("target_wx", None)
|
||||
features[PipelineFeatureType.ACTION].pop("target_wy", None)
|
||||
features[PipelineFeatureType.ACTION].pop("target_wz", None)
|
||||
features[PipelineFeatureType.ACTION].pop("gripper_vel", None)
|
||||
|
||||
features[PipelineFeatureType.ACTION]["ee.x"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["ee.y"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
@@ -167,6 +170,9 @@ class EEReferenceAndDelta(RobotActionProcessorStep):
|
||||
features[PipelineFeatureType.ACTION]["ee.wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["ee.wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["ee.wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["ee.gripper_vel"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
@@ -194,12 +200,13 @@ class EEBoundsAndSafety(RobotActionProcessorStep):
|
||||
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
|
||||
|
||||
def action(self, action: RobotAction) -> RobotAction:
|
||||
x = action.get("ee.x", None)
|
||||
y = action.get("ee.y", None)
|
||||
z = action.get("ee.z", None)
|
||||
wx = action.get("ee.wx", None)
|
||||
wy = action.get("ee.wy", None)
|
||||
wz = action.get("ee.wz", None)
|
||||
x = action.get("ee.x")
|
||||
y = action.get("ee.y")
|
||||
z = action.get("ee.z")
|
||||
wx = action.get("ee.wx")
|
||||
wy = action.get("ee.wy")
|
||||
wz = action.get("ee.wz")
|
||||
# TODO(Steven): ee.gripper_vel does not need to be bounded
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
raise ValueError(
|
||||
@@ -273,15 +280,18 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
|
||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
x = act.get("ee.x", None)
|
||||
y = act.get("ee.y", None)
|
||||
z = act.get("ee.z", None)
|
||||
wx = act.get("ee.wx", None)
|
||||
wy = act.get("ee.wy", None)
|
||||
wz = act.get("ee.wz", None)
|
||||
x = act.pop("ee.x")
|
||||
y = act.pop("ee.y")
|
||||
z = act.pop("ee.z")
|
||||
wx = act.pop("ee.wx")
|
||||
wy = act.pop("ee.wy")
|
||||
wz = act.pop("ee.wz")
|
||||
gripper_pos = act.pop("ee.gripper_pos")
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
return new_transition
|
||||
if None in (x, y, z, wx, wy, wz, gripper_pos):
|
||||
raise ValueError(
|
||||
"Missing required end-effector pose components: ee.x, ee.y, ee.z, ee.wx, ee.wy, ee.wz, ee.gripper_pos must all be present in action"
|
||||
)
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
raw = comp.get("raw_joint_positions", None)
|
||||
@@ -306,13 +316,12 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
self.q_curr = q_target
|
||||
|
||||
new_act = dict(act)
|
||||
# TODO: This is sentitive to order of motor_names = q_target mapping
|
||||
for i, name in enumerate(self.motor_names):
|
||||
if name == "gripper":
|
||||
# TODO(pepijn): Investigate if this is correct
|
||||
# Do we want an observation key in the action field?
|
||||
new_act["gripper.pos"] = float(raw["gripper"])
|
||||
else:
|
||||
if name != "gripper":
|
||||
new_act[f"{name}.pos"] = float(q_target[i])
|
||||
else:
|
||||
new_act["gripper.pos"] = float(gripper_pos)
|
||||
new_transition[TransitionKey.ACTION] = new_act
|
||||
if not self.initial_guess_current_joints:
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
|
||||
@@ -321,9 +330,13 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
features[PipelineFeatureType.ACTION]["gripper.pos"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
features[PipelineFeatureType.ACTION].pop("ee.x", None)
|
||||
features[PipelineFeatureType.ACTION].pop("ee.y", None)
|
||||
features[PipelineFeatureType.ACTION].pop("ee.z", None)
|
||||
features[PipelineFeatureType.ACTION].pop("ee.wx", None)
|
||||
features[PipelineFeatureType.ACTION].pop("ee.wy", None)
|
||||
features[PipelineFeatureType.ACTION].pop("ee.wz", None)
|
||||
features[PipelineFeatureType.ACTION].pop("ee.gripper_pos", None)
|
||||
for name in self.motor_names:
|
||||
features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
@@ -338,7 +351,7 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
|
||||
@ProcessorStepRegistry.register("gripper_velocity_to_joint")
|
||||
@dataclass
|
||||
class GripperVelocityToJoint(ProcessorStep):
|
||||
class GripperVelocityToJoint(RobotActionProcessorStep):
|
||||
"""
|
||||
Converts a gripper velocity command into a target gripper joint position.
|
||||
|
||||
@@ -354,66 +367,46 @@ class GripperVelocityToJoint(ProcessorStep):
|
||||
discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay).
|
||||
"""
|
||||
|
||||
motor_names: list[str]
|
||||
speed_factor: float = 20.0
|
||||
clip_min: float = 0.0
|
||||
clip_max: float = 100.0
|
||||
discrete_gripper: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
new_transition = transition.copy()
|
||||
obs = new_transition.get(TransitionKey.OBSERVATION) or {}
|
||||
act = new_transition.get(TransitionKey.ACTION) or {}
|
||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
def action(self, action: RobotAction) -> RobotAction:
|
||||
complementary_data = self.transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
if not isinstance(act, dict):
|
||||
raise ValueError(f"Action should be a RobotAction type got {type(act)}")
|
||||
gripper_vel = action.pop("ee.gripper_vel")
|
||||
|
||||
if "gripper" not in act:
|
||||
raise ValueError("Required action key 'gripper' not found in transition")
|
||||
|
||||
if "gripper" not in self.motor_names:
|
||||
if "raw_joint_positions" not in complementary_data:
|
||||
raise ValueError(
|
||||
f"Required motor name 'gripper' not found in self.motor_names={self.motor_names}"
|
||||
"raw_joint_positions is not in complementary data and is required for GripperVelocityToJoint"
|
||||
)
|
||||
|
||||
if self.discrete_gripper:
|
||||
# Discrete gripper actions are in [0, 1, 2]
|
||||
# 0: open, 1: close, 2: stay
|
||||
# We need to shift them to [-1, 0, 1] and then scale them to clip_max
|
||||
gripper_action = act.get("gripper", 1.0)
|
||||
gripper_action = gripper_action - 1.0
|
||||
gripper_action *= self.clip_max
|
||||
act["gripper"] = gripper_action
|
||||
curr_gripper_pos = complementary_data.get("raw_joint_positions").get("gripper")
|
||||
|
||||
# Get current gripper position from complementary data
|
||||
raw = comp.get("raw_joint_positions") or {}
|
||||
curr_pos = float(raw.get("gripper"))
|
||||
# TODO(Michel,Adil): Fix this logic
|
||||
# if self.discrete_gripper:
|
||||
# # Discrete gripper actions are in [0, 1, 2]
|
||||
# # 0: open, 1: close, 2: stay
|
||||
# # We need to shift them to [-1, 0, 1] and then scale them to clip_max
|
||||
# gripper_action = gripper_vel
|
||||
# gripper_action *= self.clip_max
|
||||
# action["ee.gripper_pos"] = gripper_action
|
||||
|
||||
# Compute desired gripper velocity
|
||||
u = float(act.get("gripper", 0.0))
|
||||
delta = u * float(self.speed_factor)
|
||||
gripper_pos = float(np.clip(curr_pos + delta, self.clip_min, self.clip_max))
|
||||
# Compute desired gripper position
|
||||
delta = gripper_vel * float(self.speed_factor)
|
||||
gripper_pos = float(np.clip(curr_gripper_pos + delta, self.clip_min, self.clip_max))
|
||||
action["ee.gripper_pos"] = gripper_pos
|
||||
|
||||
new_act = dict(act)
|
||||
new_act["gripper.pos"] = gripper_pos
|
||||
new_act.pop("gripper", None)
|
||||
new_transition[TransitionKey.ACTION] = new_act
|
||||
|
||||
obs[f"{OBS_STATE}.gripper.pos"] = curr_pos
|
||||
new_transition[TransitionKey.OBSERVATION] = obs
|
||||
return new_transition
|
||||
return action
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
features[PipelineFeatureType.ACTION].pop("gripper", None)
|
||||
features[PipelineFeatureType.ACTION]["gripper.pos"] = PolicyFeature(
|
||||
features[PipelineFeatureType.ACTION].pop("ee.gripper_vel")
|
||||
features[PipelineFeatureType.ACTION]["ee.gripper_pos"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.gripper.pos"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(1,)
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
@@ -429,36 +422,42 @@ class ForwardKinematicsJointsToEE(ObservationProcessorStep):
|
||||
|
||||
Attributes:
|
||||
kinematics: The robot's kinematic model.
|
||||
motor_names: A list of motor names whose joint positions are used for FK.
|
||||
"""
|
||||
|
||||
kinematics: RobotKinematics
|
||||
motor_names: list[str]
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
print("observation in step", observation)
|
||||
if not all(f"{n}.pos" in observation for n in self.motor_names):
|
||||
raise ValueError(f"Missing required joint positions for motors: {self.motor_names}")
|
||||
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
motor_joint_values = [observation.get(f"{n}.pos") for n in self.motor_names]
|
||||
|
||||
q = np.array([observation[f"{n}.pos"] for n in self.motor_names], dtype=float)
|
||||
q = np.array(motor_joint_values, dtype=float)
|
||||
t = self.kinematics.forward_kinematics(q)
|
||||
pos = t[:3, 3]
|
||||
tw = Rotation.from_matrix(t[:3, :3]).as_rotvec()
|
||||
|
||||
observation[f"{OBS_STATE}.ee.x"] = float(pos[0])
|
||||
observation[f"{OBS_STATE}.ee.y"] = float(pos[1])
|
||||
observation[f"{OBS_STATE}.ee.z"] = float(pos[2])
|
||||
observation[f"{OBS_STATE}.ee.wx"] = float(tw[0])
|
||||
observation[f"{OBS_STATE}.ee.wy"] = float(tw[1])
|
||||
observation[f"{OBS_STATE}.ee.wz"] = float(tw[2])
|
||||
gripper_pos = observation.get("gripper.pos")
|
||||
|
||||
for n in self.motor_names:
|
||||
observation.pop(f"{n}.pos")
|
||||
|
||||
observation["ee.x"] = float(pos[0])
|
||||
observation["ee.y"] = float(pos[1])
|
||||
observation["ee.z"] = float(pos[2])
|
||||
observation["ee.wx"] = float(tw[0])
|
||||
observation["ee.wy"] = float(tw[1])
|
||||
observation["ee.wz"] = float(tw[2])
|
||||
observation["ee.gripper_pos"] = float(gripper_pos)
|
||||
return observation
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
# We only use the ee pose in the dataset, so we don't need the joint positions
|
||||
for n in self.motor_names:
|
||||
features[PipelineFeatureType.OBSERVATION].pop(f"{n}.pos", None)
|
||||
# We specify the dataset features of this step that we want to be stored in the dataset
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz"]:
|
||||
features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.ee.{k}"] = PolicyFeature(
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]:
|
||||
features[PipelineFeatureType.OBSERVATION][f"ee.{k}"] = PolicyFeature(
|
||||
type=FeatureType.STATE, shape=(1,)
|
||||
)
|
||||
return features
|
||||
|
||||
@@ -38,6 +38,7 @@ class MapPhoneActionToRobotAction(RobotActionProcessorStep):
|
||||
to determine the correct button mappings for the gripper.
|
||||
"""
|
||||
|
||||
# TODO(Steven): Gripper vel could be output of phone_teleop directly
|
||||
platform: PhoneOS
|
||||
_enabled_prev: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
@@ -67,11 +68,11 @@ class MapPhoneActionToRobotAction(RobotActionProcessorStep):
|
||||
|
||||
# Map certain inputs to certain actions
|
||||
if self.platform == PhoneOS.IOS:
|
||||
gripper = float(inputs.get("a3", 0.0))
|
||||
gripper_vel = float(inputs.get("a3", 0.0))
|
||||
else:
|
||||
a = float(inputs.get("reservedButtonA", 0.0))
|
||||
b = float(inputs.get("reservedButtonB", 0.0))
|
||||
gripper = (
|
||||
gripper_vel = (
|
||||
a - b
|
||||
) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed
|
||||
|
||||
@@ -83,7 +84,7 @@ class MapPhoneActionToRobotAction(RobotActionProcessorStep):
|
||||
action["target_wx"] = rotvec[1] if enabled else 0.0
|
||||
action["target_wy"] = rotvec[0] if enabled else 0.0
|
||||
action["target_wz"] = -rotvec[2] if enabled else 0.0
|
||||
action["gripper"] = gripper # Still send gripper action when disabled
|
||||
action["gripper_vel"] = gripper_vel # Still send gripper action when disabled
|
||||
return action
|
||||
|
||||
def transform_features(
|
||||
@@ -101,5 +102,7 @@ class MapPhoneActionToRobotAction(RobotActionProcessorStep):
|
||||
features[PipelineFeatureType.ACTION]["target_wx"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_wy"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["target_wz"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
features[PipelineFeatureType.ACTION]["gripper_vel"] = PolicyFeature(
|
||||
type=FeatureType.ACTION, shape=(1,)
|
||||
)
|
||||
return features
|
||||
|
||||
Reference in New Issue
Block a user