mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-29 23:49:43 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 99c0d93b34 | |||
| c62784e14c | |||
| cc6a2cac43 |
@@ -248,7 +248,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
def generate_model_card(
|
def generate_model_card(
|
||||||
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
|
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
|
||||||
) -> ModelCard:
|
) -> ModelCard:
|
||||||
base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model
|
base_model_mapping = {
|
||||||
|
"smolvla": "lerobot/smolvla_base",
|
||||||
|
"pi0": "lerobot/pi0_base",
|
||||||
|
"pi05": "lerobot/pi05_base",
|
||||||
|
"pi0_fast": "lerobot/pi0fast-base",
|
||||||
|
"xvla": "lerobot/xvla-base",
|
||||||
|
}
|
||||||
|
|
||||||
card_data = ModelCardData(
|
card_data = ModelCardData(
|
||||||
license=license or "apache-2.0",
|
license=license or "apache-2.0",
|
||||||
@@ -257,7 +263,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
|
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
|
||||||
model_name=model_type,
|
model_name=model_type,
|
||||||
datasets=dataset_repo_id,
|
datasets=dataset_repo_id,
|
||||||
base_model=base_model,
|
base_model=base_model_mapping(model_type, None),
|
||||||
)
|
)
|
||||||
|
|
||||||
template_card = (
|
template_card = (
|
||||||
|
|||||||
@@ -51,10 +51,7 @@ class BiRebotB601Follower(Robot):
|
|||||||
max_relative_target=config.left_arm_config.max_relative_target,
|
max_relative_target=config.left_arm_config.max_relative_target,
|
||||||
cameras=config.left_arm_config.cameras,
|
cameras=config.left_arm_config.cameras,
|
||||||
motor_can_ids=config.left_arm_config.motor_can_ids,
|
motor_can_ids=config.left_arm_config.motor_can_ids,
|
||||||
control_mode=config.left_arm_config.control_mode,
|
|
||||||
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
|
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
|
||||||
mit_kp=config.left_arm_config.mit_kp,
|
|
||||||
mit_kd=config.left_arm_config.mit_kd,
|
|
||||||
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
|
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
|
||||||
joint_limits=config.left_arm_config.joint_limits,
|
joint_limits=config.left_arm_config.joint_limits,
|
||||||
)
|
)
|
||||||
@@ -69,10 +66,7 @@ class BiRebotB601Follower(Robot):
|
|||||||
max_relative_target=config.right_arm_config.max_relative_target,
|
max_relative_target=config.right_arm_config.max_relative_target,
|
||||||
cameras=config.right_arm_config.cameras,
|
cameras=config.right_arm_config.cameras,
|
||||||
motor_can_ids=config.right_arm_config.motor_can_ids,
|
motor_can_ids=config.right_arm_config.motor_can_ids,
|
||||||
control_mode=config.right_arm_config.control_mode,
|
|
||||||
pos_vel_velocity=config.right_arm_config.pos_vel_velocity,
|
pos_vel_velocity=config.right_arm_config.pos_vel_velocity,
|
||||||
mit_kp=config.right_arm_config.mit_kp,
|
|
||||||
mit_kd=config.right_arm_config.mit_kd,
|
|
||||||
gripper_torque_ratio=config.right_arm_config.gripper_torque_ratio,
|
gripper_torque_ratio=config.right_arm_config.gripper_torque_ratio,
|
||||||
joint_limits=config.right_arm_config.joint_limits,
|
joint_limits=config.right_arm_config.joint_limits,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -65,24 +65,10 @@ class RebotB601FollowerConfig:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Control mode for the arm joints (the gripper always runs in FORCE_POS):
|
# Target velocity for joints running in POS_VEL mode, in degrees/s. A scalar is
|
||||||
# "pos_vel" - position control with velocity limit (firmware PID gains)
|
# applied to every joint; a list provides one value per joint (in motor order).
|
||||||
# "mit" - full impedance control with caller-supplied kp/kd
|
|
||||||
control_mode: str = "pos_vel"
|
|
||||||
|
|
||||||
# Target velocity for joints in POS_VEL mode, or velocity feedforward for joints
|
|
||||||
# in MIT mode, in degrees/s. Scalar applies to every joint; a list gives one
|
|
||||||
# value per joint (in motor order).
|
|
||||||
pos_vel_velocity: float | list[float] = field(default_factory=lambda: [150.0] * 7)
|
pos_vel_velocity: float | list[float] = field(default_factory=lambda: [150.0] * 7)
|
||||||
|
|
||||||
# MIT-mode position stiffness (Nm/rad). Scalar applies to every arm joint; a
|
|
||||||
# list gives one value per joint (in motor order). Ignored when control_mode
|
|
||||||
# is "pos_vel". The gripper entry is unused (gripper stays in FORCE_POS).
|
|
||||||
mit_kp: float | list[float] = 100.0
|
|
||||||
|
|
||||||
# MIT-mode velocity damping (Nm·s/rad). Same shape conventions as ``mit_kp``.
|
|
||||||
mit_kd: float | list[float] = 3.0
|
|
||||||
|
|
||||||
# Torque/current ratio for the gripper's FORCE_POS mode, in range [0, 1].
|
# Torque/current ratio for the gripper's FORCE_POS mode, in range [0, 1].
|
||||||
gripper_torque_ratio: float = 0.1
|
gripper_torque_ratio: float = 0.1
|
||||||
|
|
||||||
|
|||||||
@@ -38,8 +38,7 @@ else:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Joint always controlled in FORCE_POS mode; every other joint runs in the mode
|
# Joint controlled in FORCE_POS mode; every other joint runs in POS_VEL mode.
|
||||||
# selected by ``RebotB601FollowerConfig.control_mode`` (POS_VEL or MIT).
|
|
||||||
GRIPPER_MOTOR = "gripper"
|
GRIPPER_MOTOR = "gripper"
|
||||||
# Per-joint Damiao motor models for the B601-DM (passed to motorbridge).
|
# Per-joint Damiao motor models for the B601-DM (passed to motorbridge).
|
||||||
MOTOR_MODELS = {
|
MOTOR_MODELS = {
|
||||||
@@ -169,22 +168,12 @@ class RebotB601Follower(Robot):
|
|||||||
self._save_calibration()
|
self._save_calibration()
|
||||||
print(f"Calibration saved to {self.calibration_fpath}")
|
print(f"Calibration saved to {self.calibration_fpath}")
|
||||||
|
|
||||||
def _arm_mode(self):
|
|
||||||
"""MotorBridge mode used for the arm joints (gripper always uses FORCE_POS)."""
|
|
||||||
mode = self.config.control_mode
|
|
||||||
if mode == "pos_vel":
|
|
||||||
return MotorBridgeMode.POS_VEL
|
|
||||||
if mode == "mit":
|
|
||||||
return MotorBridgeMode.MIT
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported control_mode '{mode}'. Use 'pos_vel' or 'mit'."
|
|
||||||
)
|
|
||||||
|
|
||||||
def configure(self) -> None:
|
def configure(self) -> None:
|
||||||
self.bus.enable_all()
|
self.bus.enable_all()
|
||||||
arm_mode = self._arm_mode()
|
|
||||||
for motor_name, motor in self.motors.items():
|
for motor_name, motor in self.motors.items():
|
||||||
target_mode = MotorBridgeMode.FORCE_POS if motor_name == GRIPPER_MOTOR else arm_mode
|
target_mode = (
|
||||||
|
MotorBridgeMode.FORCE_POS if motor_name == GRIPPER_MOTOR else MotorBridgeMode.POS_VEL
|
||||||
|
)
|
||||||
for attempt in range(_ENSURE_MODE_RETRIES + 1):
|
for attempt in range(_ENSURE_MODE_RETRIES + 1):
|
||||||
try:
|
try:
|
||||||
motor.ensure_mode(target_mode)
|
motor.ensure_mode(target_mode)
|
||||||
@@ -263,7 +252,6 @@ class RebotB601Follower(Robot):
|
|||||||
goal_present_pos = {key: (g, present_pos.get(key, g)) for key, g in goal_pos.items()}
|
goal_present_pos = {key: (g, present_pos.get(key, g)) for key, g in goal_pos.items()}
|
||||||
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
||||||
|
|
||||||
use_mit = self.config.control_mode == "mit"
|
|
||||||
for motor_name, position_deg in goal_pos.items():
|
for motor_name, position_deg in goal_pos.items():
|
||||||
motor = self.motors.get(motor_name)
|
motor = self.motors.get(motor_name)
|
||||||
if motor is None:
|
if motor is None:
|
||||||
@@ -278,10 +266,6 @@ class RebotB601Follower(Robot):
|
|||||||
vel_rad = math.radians(vel_deg_s)
|
vel_rad = math.radians(vel_deg_s)
|
||||||
if motor_name == GRIPPER_MOTOR:
|
if motor_name == GRIPPER_MOTOR:
|
||||||
motor.send_force_pos(pos_rad, vel_rad, self.config.gripper_torque_ratio)
|
motor.send_force_pos(pos_rad, vel_rad, self.config.gripper_torque_ratio)
|
||||||
elif use_mit:
|
|
||||||
kp = self.config.mit_kp[idx] if isinstance(self.config.mit_kp, list) else self.config.mit_kp
|
|
||||||
kd = self.config.mit_kd[idx] if isinstance(self.config.mit_kd, list) else self.config.mit_kd
|
|
||||||
motor.send_mit(pos_rad, vel_rad, kp, kd, 0.0)
|
|
||||||
else:
|
else:
|
||||||
motor.send_pos_vel(pos_rad, vel_rad)
|
motor.send_pos_vel(pos_rad, vel_rad)
|
||||||
|
|
||||||
|
|||||||
@@ -73,14 +73,17 @@ _Writes checkpoints to `outputs/train/<desired_policy_repo_id>/checkpoints/`._
|
|||||||
### Evaluate the policy/run inference
|
### Evaluate the policy/run inference
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-record \
|
lerobot-rollout \
|
||||||
--robot.type=so100_follower \
|
--strategy.type=base \
|
||||||
--dataset.repo_id=<hf_user>/eval_<dataset> \
|
--robot.type=so101_follower \
|
||||||
|
--robot.port=/dev/ttyACM0 \
|
||||||
|
--robot.cameras="{ up: {type: opencv, index_or_path: /dev/video1, width: 640, height: 480, fps: 30}, side: {type: opencv, index_or_path: /dev/video5, width: 640, height: 480, fps: 30}}" \
|
||||||
--policy.path=<hf_user>/<desired_policy_repo_id> \
|
--policy.path=<hf_user>/<desired_policy_repo_id> \
|
||||||
--episodes=10
|
--task="Put lego brick into the transparent box" \
|
||||||
|
--duration=60
|
||||||
```
|
```
|
||||||
|
|
||||||
Prefix the dataset repo with **eval\_** and supply `--policy.path` pointing to a local or hub checkpoint.
|
If you want to record a dataset while testing the policy use `--dataset.repo_id=<hf_user>/eval_dataset_name` it is important to use the prefix **eval\_**. For the policy path use the policy from the Hugging Face Hub or a local one. Skipping duration will make the policy run indefinitely.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user