mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
fix(smolvla2): use RobotConfig.max_relative_target, drop --max_action_norm
The hand-rolled action-norm safety clip duplicated what every ``RobotConfig`` already exposes — ``max_relative_target`` — and at the wrong layer (after postprocess but before send_action, instead of inside the robot driver where every other lerobot entry point puts it). The norm clip also rejected entire actions instead of clipping per-motor relative motion, so a single rogue joint would kill the whole tick. Replace with ``--robot.max_relative_target``: a string parsed as either a bare float (uniform per-motor cap) or a JSON object mapping motor name → cap. Passed through to ``RobotConfig(max_relative_target=...)`` at robot construction; the driver's ``send_action`` clips each commanded joint position relative to the current measured one before issuing it on the bus — same behaviour ``lerobot-record`` ships. Also bump ``--chunk_hz`` default from ``4.0`` to ``1.0``. One new chunk per second is what the trained checkpoint can comfortably keep up with on common hardware and gives smoother motion than sub-second chunk regenerations (no RTC interpolation between chunks yet). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -180,14 +180,19 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
p.add_argument(
|
p.add_argument(
|
||||||
"--max_action_norm",
|
"--robot.max_relative_target",
|
||||||
dest="max_action_norm",
|
dest="robot_max_relative_target",
|
||||||
type=float,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help=(
|
help=(
|
||||||
"Safety clip: reject any individual action whose L2 norm "
|
"Safety clip on per-motor relative motion, passed through to "
|
||||||
"exceeds this value. Default ``None`` = no clipping. Useful "
|
"``RobotConfig.max_relative_target``. Accepts either a float "
|
||||||
"as a kill-switch when bringing up a new robot/task pair."
|
"(applied to every motor — e.g. ``5.0`` degrees) or a JSON "
|
||||||
|
"object mapping motor names to caps "
|
||||||
|
"(e.g. ``'{\"shoulder_pan\": 5, \"gripper\": 30}'``). The "
|
||||||
|
"robot driver clips each commanded position relative to the "
|
||||||
|
"current measured position before sending — same kill-switch "
|
||||||
|
"``lerobot-record`` uses. Default ``None`` = no clipping."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
p.add_argument(
|
p.add_argument(
|
||||||
@@ -213,7 +218,16 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
|||||||
help="Pocket-tts voice name (or path to a .wav for cloning).",
|
help="Pocket-tts voice name (or path to a .wav for cloning).",
|
||||||
)
|
)
|
||||||
p.add_argument(
|
p.add_argument(
|
||||||
"--chunk_hz", type=float, default=4.0, help="Action-chunk generation rate."
|
"--chunk_hz",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help=(
|
||||||
|
"Action-chunk generation rate (Hz). Default ``1.0`` — one "
|
||||||
|
"new chunk per second. Lower = less inference cost / "
|
||||||
|
"smoother behaviour but longer reaction time to changes. "
|
||||||
|
"Higher = fresher actions / more inference cost; cap at "
|
||||||
|
"~1/(forward-pass latency)."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
p.add_argument(
|
p.add_argument(
|
||||||
"--ctrl_hz", type=float, default=50.0, help="Action dispatch rate."
|
"--ctrl_hz", type=float, default=50.0, help="Action dispatch rate."
|
||||||
@@ -427,12 +441,16 @@ def _build_robot(
|
|||||||
robot_port: str | None,
|
robot_port: str | None,
|
||||||
robot_id: str | None,
|
robot_id: str | None,
|
||||||
robot_cameras_json: str | None,
|
robot_cameras_json: str | None,
|
||||||
|
robot_max_relative_target: str | None,
|
||||||
):
|
):
|
||||||
"""Build and connect a robot from CLI args.
|
"""Build and connect a robot from CLI args.
|
||||||
|
|
||||||
Mirrors how ``lerobot-record`` builds a robot but takes the args
|
Mirrors how ``lerobot-record`` builds a robot but takes the args
|
||||||
flat from argparse instead of through draccus, so the runtime
|
flat from argparse instead of through draccus, so the runtime
|
||||||
keeps its plain ``--key=value`` CLI surface.
|
keeps its plain ``--key=value`` CLI surface. ``max_relative_target``
|
||||||
|
is passed through to the RobotConfig — the driver itself clips each
|
||||||
|
commanded joint position relative to the current measured one
|
||||||
|
before issuing it on the bus.
|
||||||
"""
|
"""
|
||||||
import json # noqa: PLC0415
|
import json # noqa: PLC0415
|
||||||
|
|
||||||
@@ -454,6 +472,21 @@ def _build_robot(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"--robot.cameras must be a JSON object, got {robot_cameras_json!r}: {exc}"
|
f"--robot.cameras must be a JSON object, got {robot_cameras_json!r}: {exc}"
|
||||||
) from exc
|
) from exc
|
||||||
|
if robot_max_relative_target:
|
||||||
|
# Accept either a bare float (uniform cap) or a JSON object
|
||||||
|
# (per-motor cap). Matches ``RobotConfig.max_relative_target``'s
|
||||||
|
# ``float | dict[str, float] | None`` shape.
|
||||||
|
s = robot_max_relative_target.strip()
|
||||||
|
try:
|
||||||
|
if s.startswith("{"):
|
||||||
|
kwargs["max_relative_target"] = json.loads(s)
|
||||||
|
else:
|
||||||
|
kwargs["max_relative_target"] = float(s)
|
||||||
|
except (json.JSONDecodeError, ValueError) as exc:
|
||||||
|
raise ValueError(
|
||||||
|
f"--robot.max_relative_target must be a float or JSON dict, "
|
||||||
|
f"got {robot_max_relative_target!r}: {exc}"
|
||||||
|
) from exc
|
||||||
cfg = cls(**kwargs)
|
cfg = cls(**kwargs)
|
||||||
robot = make_robot_from_config(cfg)
|
robot = make_robot_from_config(cfg)
|
||||||
robot.connect()
|
robot.connect()
|
||||||
@@ -518,15 +551,15 @@ def _build_robot_action_executor(
|
|||||||
robot,
|
robot,
|
||||||
postprocessor: Any,
|
postprocessor: Any,
|
||||||
ds_meta: Any,
|
ds_meta: Any,
|
||||||
max_action_norm: float | None,
|
|
||||||
) -> Callable[[Any], None]:
|
) -> Callable[[Any], None]:
|
||||||
"""Closure that postprocesses an action and dispatches to the robot.
|
"""Closure that postprocesses an action and dispatches to the robot.
|
||||||
|
|
||||||
Mirrors ``lerobot-record``'s ``predict_action`` tail: postprocess
|
Mirrors ``lerobot-record``'s ``predict_action`` tail: postprocess
|
||||||
(denormalise) → ``make_robot_action`` (tensor → ``{joint: value}``
|
(denormalise) → ``make_robot_action`` (tensor → ``{joint: value}``
|
||||||
dict) → ``robot.send_action(...)``. Optional safety clip on the
|
dict) → ``robot.send_action(...)``. Safety clipping happens *inside*
|
||||||
action's L2 norm acts as a kill switch when bringing up a new
|
``robot.send_action`` via the driver's ``max_relative_target``
|
||||||
robot/task pair.
|
cap (passed in at ``RobotConfig`` construction time) — same place
|
||||||
|
``lerobot-record`` enforces it.
|
||||||
"""
|
"""
|
||||||
import torch # noqa: PLC0415
|
import torch # noqa: PLC0415
|
||||||
|
|
||||||
@@ -537,15 +570,6 @@ def _build_robot_action_executor(
|
|||||||
if postprocessor is not None:
|
if postprocessor is not None:
|
||||||
action = postprocessor(action)
|
action = postprocessor(action)
|
||||||
if isinstance(action, torch.Tensor):
|
if isinstance(action, torch.Tensor):
|
||||||
if max_action_norm is not None:
|
|
||||||
norm = float(action.float().norm().item())
|
|
||||||
if norm > max_action_norm:
|
|
||||||
logger.warning(
|
|
||||||
"action norm %.3f > max_action_norm=%.3f — "
|
|
||||||
"rejecting tick",
|
|
||||||
norm, max_action_norm,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
if action.ndim > 1 and action.shape[0] == 1:
|
if action.ndim > 1 and action.shape[0] == 1:
|
||||||
action = action.squeeze(0)
|
action = action.squeeze(0)
|
||||||
action_dict = make_robot_action(action, ds_meta.features)
|
action_dict = make_robot_action(action, ds_meta.features)
|
||||||
@@ -745,6 +769,7 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
robot_port=args.robot_port,
|
robot_port=args.robot_port,
|
||||||
robot_id=args.robot_id,
|
robot_id=args.robot_id,
|
||||||
robot_cameras_json=args.robot_cameras,
|
robot_cameras_json=args.robot_cameras,
|
||||||
|
robot_max_relative_target=args.robot_max_relative_target,
|
||||||
)
|
)
|
||||||
observation_provider = _build_robot_observation_provider(
|
observation_provider = _build_robot_observation_provider(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
@@ -756,7 +781,6 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
robot=robot,
|
robot=robot,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
ds_meta=ds_meta,
|
ds_meta=ds_meta,
|
||||||
max_action_norm=args.max_action_norm,
|
|
||||||
)
|
)
|
||||||
elif args.dataset_repo_id is not None:
|
elif args.dataset_repo_id is not None:
|
||||||
print(
|
print(
|
||||||
|
|||||||
Reference in New Issue
Block a user