mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 15:09:51 +00:00
refactor(processors): update transition handling in RewardClassifierProcessor and InverseKinematicsEEToJoints (#1844)
This commit is contained in:
@@ -251,8 +251,9 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
initial_guess_current_joints: bool = True
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
act = transition.get(TransitionKey.ACTION) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
new_transition = transition.copy()
|
||||
act = new_transition.get(TransitionKey.ACTION) or {}
|
||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
x = act.get(f"{ACTION}.ee.x", None)
|
||||
y = act.get(f"{ACTION}.ee.y", None)
|
||||
@@ -262,7 +263,7 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
wz = act.get(f"{ACTION}.ee.wz", None)
|
||||
|
||||
if None in (x, y, z, wx, wy, wz):
|
||||
return transition
|
||||
return new_transition
|
||||
|
||||
# Get joint positions from complimentary data
|
||||
raw = comp.get("raw_joint_positions", None)
|
||||
@@ -292,10 +293,10 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
||||
new_act[f"{OBS_STATE}.gripper.pos"] = float(raw["gripper"])
|
||||
else:
|
||||
new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
new_transition[TransitionKey.ACTION] = new_act
|
||||
if not self.initial_guess_current_joints:
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
|
||||
return transition
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features[f"{OBS_STATE}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||
@@ -333,18 +334,19 @@ class GripperVelocityToJoint(ProcessorStep):
|
||||
discrete_gripper: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
obs = transition.get(TransitionKey.OBSERVATION) or {}
|
||||
act = transition.get(TransitionKey.ACTION) or {}
|
||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
new_transition = transition.copy()
|
||||
obs = new_transition.get(TransitionKey.OBSERVATION) or {}
|
||||
act = new_transition.get(TransitionKey.ACTION) or {}
|
||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
if f"{ACTION}.gripper" not in act:
|
||||
return transition
|
||||
return new_transition
|
||||
|
||||
if "gripper" not in self.motor_names:
|
||||
new_act = dict(act)
|
||||
new_act.pop(f"{ACTION}.gripper", None)
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
return transition
|
||||
new_transition[TransitionKey.ACTION] = new_act
|
||||
return new_transition
|
||||
|
||||
if self.discrete_gripper:
|
||||
# Discrete gripper actions are in [0, 1, 2]
|
||||
@@ -367,11 +369,11 @@ class GripperVelocityToJoint(ProcessorStep):
|
||||
new_act = dict(act)
|
||||
new_act[f"{ACTION}.gripper.pos"] = gripper_pos
|
||||
new_act.pop(f"{ACTION}.gripper", None)
|
||||
transition[TransitionKey.ACTION] = new_act
|
||||
new_transition[TransitionKey.ACTION] = new_act
|
||||
|
||||
obs[f"{OBS_STATE}.gripper.pos"] = curr_pos
|
||||
transition[TransitionKey.OBSERVATION] = obs
|
||||
return transition
|
||||
new_transition[TransitionKey.OBSERVATION] = obs
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
features.pop(f"{ACTION}.gripper", None)
|
||||
|
||||
Reference in New Issue
Block a user