mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 14:09:47 +00:00
add policy from local path
This commit is contained in:
@@ -143,6 +143,7 @@ DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
|
|||||||
def load_holosoma_policy(
|
def load_holosoma_policy(
|
||||||
repo_id: str = DEFAULT_HOLOSOMA_REPO_ID,
|
repo_id: str = DEFAULT_HOLOSOMA_REPO_ID,
|
||||||
policy_name: str = "fastsac",
|
policy_name: str = "fastsac",
|
||||||
|
local_path: str | None = None
|
||||||
) -> ort.InferenceSession:
|
) -> ort.InferenceSession:
|
||||||
"""Load Holosoma 29-DOF locomotion policy from Hugging Face Hub.
|
"""Load Holosoma 29-DOF locomotion policy from Hugging Face Hub.
|
||||||
|
|
||||||
@@ -150,6 +151,9 @@ def load_holosoma_policy(
|
|||||||
repo_id: Hugging Face Hub repository ID containing the ONNX policies.
|
repo_id: Hugging Face Hub repository ID containing the ONNX policies.
|
||||||
policy_name: Policy variant to load ("fastsac" or "ppo").
|
policy_name: Policy variant to load ("fastsac" or "ppo").
|
||||||
"""
|
"""
|
||||||
|
if local_path is not None:
|
||||||
|
logger.info(f"Loading policy from local path: {local_path}")
|
||||||
|
policy_path = local_path
|
||||||
filename_map = {
|
filename_map = {
|
||||||
"fastsac": "fastsac_g1_29dof.onnx",
|
"fastsac": "fastsac_g1_29dof.onnx",
|
||||||
"ppo": "ppo_g1_29dof.onnx",
|
"ppo": "ppo_g1_29dof.onnx",
|
||||||
@@ -220,14 +224,6 @@ class HolosomaLocomotionController:
|
|||||||
logger.info(f" Observation dim: 100, Action dim: 29")
|
logger.info(f" Observation dim: 100, Action dim: 29")
|
||||||
logger.info(f" Missing joints (G1 23-DOF): {MISSING_JOINTS}")
|
logger.info(f" Missing joints (G1 23-DOF): {MISSING_JOINTS}")
|
||||||
|
|
||||||
def _transform_imu_data(self, waist_yaw, waist_yaw_omega, imu_quat, imu_omega):
|
|
||||||
"""Transform IMU data from torso to pelvis frame."""
|
|
||||||
RzWaist = R.from_euler("z", waist_yaw).as_matrix()
|
|
||||||
R_torso = R.from_quat([imu_quat[1], imu_quat[2], imu_quat[3], imu_quat[0]]).as_matrix()
|
|
||||||
R_pelvis = np.dot(R_torso, RzWaist.T)
|
|
||||||
w = np.dot(RzWaist, imu_omega) - np.array([0, 0, waist_yaw_omega])
|
|
||||||
return R.from_matrix(R_pelvis).as_quat()[[3, 0, 1, 2]], w
|
|
||||||
|
|
||||||
def holosoma_locomotion_run(self):
|
def holosoma_locomotion_run(self):
|
||||||
"""29-DOF whole-body locomotion policy loop - controls ALL 29 joints."""
|
"""29-DOF whole-body locomotion policy loop - controls ALL 29 joints."""
|
||||||
self.counter += 1
|
self.counter += 1
|
||||||
@@ -467,6 +463,12 @@ if __name__ == "__main__":
|
|||||||
choices=["fastsac", "ppo"],
|
choices=["fastsac", "ppo"],
|
||||||
help="Policy variant to load (default: fastsac)",
|
help="Policy variant to load (default: fastsac)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--local-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to local ONNX file (overrides --repo-id and --policy)",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Load policy
|
# Load policy
|
||||||
|
|||||||
@@ -148,14 +148,6 @@ class UnitreeRLLocomotionController:
|
|||||||
logger.info("UnitreeRLLocomotionController initialized")
|
logger.info("UnitreeRLLocomotionController initialized")
|
||||||
logger.info(" Observation dim: 47, Action dim: 12 (legs only)")
|
logger.info(" Observation dim: 47, Action dim: 12 (legs only)")
|
||||||
|
|
||||||
def _transform_imu_data(self, waist_yaw, waist_yaw_omega, imu_quat, imu_omega):
|
|
||||||
"""Transform IMU data from torso to pelvis frame."""
|
|
||||||
RzWaist = R.from_euler("z", waist_yaw).as_matrix()
|
|
||||||
R_torso = R.from_quat([imu_quat[1], imu_quat[2], imu_quat[3], imu_quat[0]]).as_matrix()
|
|
||||||
R_pelvis = np.dot(R_torso, RzWaist.T)
|
|
||||||
w = np.dot(RzWaist, imu_omega) - np.array([0, 0, waist_yaw_omega])
|
|
||||||
return R.from_matrix(R_pelvis).as_quat()[[3, 0, 1, 2]], w
|
|
||||||
|
|
||||||
def locomotion_run(self):
|
def locomotion_run(self):
|
||||||
"""12-DOF legs-only locomotion policy loop."""
|
"""12-DOF legs-only locomotion policy loop."""
|
||||||
self.counter += 1
|
self.counter += 1
|
||||||
@@ -194,11 +186,6 @@ class UnitreeRLLocomotionController:
|
|||||||
quat = robot_state.imu_state.quaternion
|
quat = robot_state.imu_state.quaternion
|
||||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||||
|
|
||||||
# Transform IMU from torso to pelvis frame
|
|
||||||
waist_yaw = robot_state.motor_state[12].q
|
|
||||||
waist_yaw_omega = robot_state.motor_state[12].dq
|
|
||||||
quat, ang_vel = self._transform_imu_data(waist_yaw, waist_yaw_omega, quat, ang_vel)
|
|
||||||
|
|
||||||
# Scale observations
|
# Scale observations
|
||||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||||
qj_obs = (self.qj - DEFAULT_LEG_ANGLES) * DOF_POS_SCALE
|
qj_obs = (self.qj - DEFAULT_LEG_ANGLES) * DOF_POS_SCALE
|
||||||
|
|||||||
Reference in New Issue
Block a user