mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 15:09:51 +00:00
refactor(processors): update transition handling in RewardClassifierProcessor and InverseKinematicsEEToJoints (#1844)
This commit is contained in:
@@ -282,15 +282,16 @@ class RewardClassifierProcessor(ProcessorStep):
|
|||||||
self.reward_classifier.eval()
|
self.reward_classifier.eval()
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
observation = transition.get(TransitionKey.OBSERVATION)
|
new_transition = transition.copy()
|
||||||
|
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||||
if observation is None or self.reward_classifier is None:
|
if observation is None or self.reward_classifier is None:
|
||||||
return transition
|
return new_transition
|
||||||
|
|
||||||
# Extract images from observation
|
# Extract images from observation
|
||||||
images = {key: value for key, value in observation.items() if "image" in key}
|
images = {key: value for key, value in observation.items() if "image" in key}
|
||||||
|
|
||||||
if not images:
|
if not images:
|
||||||
return transition
|
return new_transition
|
||||||
|
|
||||||
# Run reward classifier
|
# Run reward classifier
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
@@ -300,8 +301,8 @@ class RewardClassifierProcessor(ProcessorStep):
|
|||||||
classifier_frequency = 1 / (time.perf_counter() - start_time)
|
classifier_frequency = 1 / (time.perf_counter() - start_time)
|
||||||
|
|
||||||
# Calculate reward and termination
|
# Calculate reward and termination
|
||||||
reward = transition.get(TransitionKey.REWARD, 0.0)
|
reward = new_transition.get(TransitionKey.REWARD, 0.0)
|
||||||
terminated = transition.get(TransitionKey.DONE, False)
|
terminated = new_transition.get(TransitionKey.DONE, False)
|
||||||
|
|
||||||
if math.isclose(success, 1, abs_tol=1e-2):
|
if math.isclose(success, 1, abs_tol=1e-2):
|
||||||
reward = self.success_reward
|
reward = self.success_reward
|
||||||
@@ -309,7 +310,6 @@ class RewardClassifierProcessor(ProcessorStep):
|
|||||||
terminated = True
|
terminated = True
|
||||||
|
|
||||||
# Update transition
|
# Update transition
|
||||||
new_transition = transition.copy()
|
|
||||||
new_transition[TransitionKey.REWARD] = reward
|
new_transition[TransitionKey.REWARD] = reward
|
||||||
new_transition[TransitionKey.DONE] = terminated
|
new_transition[TransitionKey.DONE] = terminated
|
||||||
|
|
||||||
|
|||||||
@@ -251,8 +251,9 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
|||||||
initial_guess_current_joints: bool = True
|
initial_guess_current_joints: bool = True
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
act = transition.get(TransitionKey.ACTION) or {}
|
new_transition = transition.copy()
|
||||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
act = new_transition.get(TransitionKey.ACTION) or {}
|
||||||
|
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||||
|
|
||||||
x = act.get(f"{ACTION}.ee.x", None)
|
x = act.get(f"{ACTION}.ee.x", None)
|
||||||
y = act.get(f"{ACTION}.ee.y", None)
|
y = act.get(f"{ACTION}.ee.y", None)
|
||||||
@@ -262,7 +263,7 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
|||||||
wz = act.get(f"{ACTION}.ee.wz", None)
|
wz = act.get(f"{ACTION}.ee.wz", None)
|
||||||
|
|
||||||
if None in (x, y, z, wx, wy, wz):
|
if None in (x, y, z, wx, wy, wz):
|
||||||
return transition
|
return new_transition
|
||||||
|
|
||||||
# Get joint positions from complimentary data
|
# Get joint positions from complimentary data
|
||||||
raw = comp.get("raw_joint_positions", None)
|
raw = comp.get("raw_joint_positions", None)
|
||||||
@@ -292,10 +293,10 @@ class InverseKinematicsEEToJoints(ProcessorStep):
|
|||||||
new_act[f"{OBS_STATE}.gripper.pos"] = float(raw["gripper"])
|
new_act[f"{OBS_STATE}.gripper.pos"] = float(raw["gripper"])
|
||||||
else:
|
else:
|
||||||
new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
|
new_act[f"{ACTION}.{name}.pos"] = float(q_target[i])
|
||||||
transition[TransitionKey.ACTION] = new_act
|
new_transition[TransitionKey.ACTION] = new_act
|
||||||
if not self.initial_guess_current_joints:
|
if not self.initial_guess_current_joints:
|
||||||
transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
|
new_transition[TransitionKey.COMPLEMENTARY_DATA]["reference_joint_positions"] = q_target
|
||||||
return transition
|
return new_transition
|
||||||
|
|
||||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
features[f"{OBS_STATE}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
features[f"{OBS_STATE}.gripper.pos"] = (PolicyFeature(type=FeatureType.ACTION, shape=(1,)),)
|
||||||
@@ -333,18 +334,19 @@ class GripperVelocityToJoint(ProcessorStep):
|
|||||||
discrete_gripper: bool = False
|
discrete_gripper: bool = False
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
obs = transition.get(TransitionKey.OBSERVATION) or {}
|
new_transition = transition.copy()
|
||||||
act = transition.get(TransitionKey.ACTION) or {}
|
obs = new_transition.get(TransitionKey.OBSERVATION) or {}
|
||||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
act = new_transition.get(TransitionKey.ACTION) or {}
|
||||||
|
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||||
|
|
||||||
if f"{ACTION}.gripper" not in act:
|
if f"{ACTION}.gripper" not in act:
|
||||||
return transition
|
return new_transition
|
||||||
|
|
||||||
if "gripper" not in self.motor_names:
|
if "gripper" not in self.motor_names:
|
||||||
new_act = dict(act)
|
new_act = dict(act)
|
||||||
new_act.pop(f"{ACTION}.gripper", None)
|
new_act.pop(f"{ACTION}.gripper", None)
|
||||||
transition[TransitionKey.ACTION] = new_act
|
new_transition[TransitionKey.ACTION] = new_act
|
||||||
return transition
|
return new_transition
|
||||||
|
|
||||||
if self.discrete_gripper:
|
if self.discrete_gripper:
|
||||||
# Discrete gripper actions are in [0, 1, 2]
|
# Discrete gripper actions are in [0, 1, 2]
|
||||||
@@ -367,11 +369,11 @@ class GripperVelocityToJoint(ProcessorStep):
|
|||||||
new_act = dict(act)
|
new_act = dict(act)
|
||||||
new_act[f"{ACTION}.gripper.pos"] = gripper_pos
|
new_act[f"{ACTION}.gripper.pos"] = gripper_pos
|
||||||
new_act.pop(f"{ACTION}.gripper", None)
|
new_act.pop(f"{ACTION}.gripper", None)
|
||||||
transition[TransitionKey.ACTION] = new_act
|
new_transition[TransitionKey.ACTION] = new_act
|
||||||
|
|
||||||
obs[f"{OBS_STATE}.gripper.pos"] = curr_pos
|
obs[f"{OBS_STATE}.gripper.pos"] = curr_pos
|
||||||
transition[TransitionKey.OBSERVATION] = obs
|
new_transition[TransitionKey.OBSERVATION] = obs
|
||||||
return transition
|
return new_transition
|
||||||
|
|
||||||
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||||
features.pop(f"{ACTION}.gripper", None)
|
features.pop(f"{ACTION}.gripper", None)
|
||||||
|
|||||||
Reference in New Issue
Block a user