mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-01 07:07:08 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 30b79a56f7 | |||
| 036d99bf74 | |||
| eba1d1bd0c |
@@ -134,6 +134,9 @@ lerobot-train \
|
||||
> [!TIP]
|
||||
> This is purely a decode-time presentation choice — it does **not** alter the stored video or its metadata, so the same dataset can be read as `mm` or `m` without re-encoding. It has no effect on datasets without depth cameras.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Depth statistics in `meta/stats.json` are always computed in **millimetres**, regardless of the raw frame dtype.
|
||||
|
||||
---
|
||||
|
||||
## Persistence in dataset metadata
|
||||
|
||||
@@ -22,6 +22,7 @@ import numpy as np
|
||||
from lerobot.processor import RelativeActionsProcessorStep
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
from .depth_utils import MM_PER_METRE
|
||||
from .io_utils import load_image_as_numpy
|
||||
|
||||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||||
@@ -508,8 +509,8 @@ def compute_episode_stats(
|
||||
Note:
|
||||
For 'image'/'video' features, stats are computed per channel and kept with a
|
||||
leading channel axis (e.g. shape (3, 1, 1) for RGB). RGB stats are divided by
|
||||
255 to land in [0, 1]; depth maps (features flagged with ``is_depth_map``) skip
|
||||
this rescaling and remain in their stored units.
|
||||
255 to land in [0, 1]; depth maps (features flagged with ``is_depth_map``) are
|
||||
instead canonicalized to millimetres regardless of the raw frame unit.
|
||||
"""
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
@@ -533,9 +534,14 @@ def compute_episode_stats(
|
||||
)
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
normalization_factor = (
|
||||
255.0 if not (features[key].get("info") or {}).get("is_depth_map", False) else 1.0
|
||||
)
|
||||
if (features[key].get("info") or {}).get("is_depth_map", False):
|
||||
# Depth stats are canonically stored in millimetres; metre (float) depth is
|
||||
# scaled up, integer (millimetre) depth is left as-is.
|
||||
normalization_factor = (
|
||||
1.0 / MM_PER_METRE if np.issubdtype(ep_ft_array.dtype, np.floating) else 1.0
|
||||
)
|
||||
else:
|
||||
normalization_factor = 255.0
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / normalization_factor, axis=0)
|
||||
for k, v in ep_stats[key].items()
|
||||
|
||||
@@ -39,7 +39,7 @@ from lerobot.configs.video import (
|
||||
from .image_writer import squeeze_single_channel
|
||||
from .pyav_utils import write_u16_plane
|
||||
|
||||
_MM_PER_METRE = 1000.0
|
||||
MM_PER_METRE = 1000.0
|
||||
_UINT16_MAX = 65535
|
||||
|
||||
|
||||
@@ -126,12 +126,12 @@ def quantize_depth(
|
||||
|
||||
# Convert depth_min, depth_max, and shift to the resolved input unit.
|
||||
depth_min_u = (
|
||||
np.float32(depth_min) if resolved_unit == DEPTH_METER_UNIT else np.float32(depth_min * _MM_PER_METRE)
|
||||
np.float32(depth_min) if resolved_unit == DEPTH_METER_UNIT else np.float32(depth_min * MM_PER_METRE)
|
||||
)
|
||||
depth_max_u = (
|
||||
np.float32(depth_max) if resolved_unit == DEPTH_METER_UNIT else np.float32(depth_max * _MM_PER_METRE)
|
||||
np.float32(depth_max) if resolved_unit == DEPTH_METER_UNIT else np.float32(depth_max * MM_PER_METRE)
|
||||
)
|
||||
shift_u = np.float32(shift) if resolved_unit == DEPTH_METER_UNIT else np.float32(shift * _MM_PER_METRE)
|
||||
shift_u = np.float32(shift) if resolved_unit == DEPTH_METER_UNIT else np.float32(shift * MM_PER_METRE)
|
||||
|
||||
# Normalization and quantization is performed in the resolved input unit.
|
||||
if use_log:
|
||||
@@ -236,7 +236,7 @@ def dequantize_depth(
|
||||
|
||||
# mm path: round + clamp in float32, skipping the uint16 round-trip
|
||||
# when returning a tensor (torch.uint16 is poorly supported).
|
||||
buf.mul_(_MM_PER_METRE).round_().clamp_(0.0, _UINT16_MAX)
|
||||
buf.mul_(MM_PER_METRE).round_().clamp_(0.0, _UINT16_MAX)
|
||||
if output_tensor:
|
||||
return buf
|
||||
return buf.cpu().numpy().astype(np.uint16, copy=False)
|
||||
@@ -259,7 +259,7 @@ def dequantize_depth(
|
||||
if output_unit == DEPTH_METER_UNIT:
|
||||
return torch.from_numpy(buf) if output_tensor else buf
|
||||
|
||||
np.multiply(buf, _MM_PER_METRE, out=buf)
|
||||
np.multiply(buf, MM_PER_METRE, out=buf)
|
||||
np.rint(buf, out=buf)
|
||||
np.clip(buf, 0.0, _UINT16_MAX, out=buf)
|
||||
if output_tensor:
|
||||
|
||||
@@ -47,7 +47,7 @@ from lerobot.configs import (
|
||||
)
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
from .depth_utils import quantize_depth
|
||||
from .depth_utils import MM_PER_METRE, quantize_depth
|
||||
from .pyav_utils import get_pix_fmt_channels
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -848,6 +848,9 @@ class _CameraEncoderThread(threading.Thread):
|
||||
# Reshape CHW to (H*W, C) for per-channel stats
|
||||
channels = img_downsampled.shape[0]
|
||||
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
|
||||
# Depth stats are canonically stored in millimetres; metre (float) depth is scaled up.
|
||||
if self.is_depth and np.issubdtype(frame_data.dtype, np.floating):
|
||||
img_for_stats = img_for_stats * MM_PER_METRE
|
||||
stats_tracker.update(img_for_stats)
|
||||
|
||||
frame_count += 1
|
||||
|
||||
@@ -11,8 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import builtins
|
||||
import dataclasses
|
||||
@@ -21,7 +19,7 @@ import os
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import TYPE_CHECKING, TypedDict, TypeVar, Unpack
|
||||
from typing import TypedDict, TypeVar, Unpack
|
||||
|
||||
import packaging
|
||||
import safetensors
|
||||
@@ -40,13 +38,10 @@ from .utils import log_model_loading_keys
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedPolicy")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
|
||||
def _build_card_context(
|
||||
cfg: TrainPipelineConfig | None,
|
||||
dataset_meta: LeRobotDatasetMetadata | None,
|
||||
dataset_repo_id: str | None,
|
||||
input_features: dict | None,
|
||||
output_features: dict | None,
|
||||
) -> dict:
|
||||
@@ -77,16 +72,30 @@ def _build_card_context(
|
||||
"lerobot_version": __version__,
|
||||
}
|
||||
|
||||
if dataset_meta is not None:
|
||||
context["dataset"] = {
|
||||
"repo_id": dataset_meta.repo_id,
|
||||
"episodes": dataset_meta.total_episodes,
|
||||
"frames": dataset_meta.total_frames,
|
||||
"fps": dataset_meta.fps,
|
||||
"tasks": [str(task) for task in dataset_meta.tasks.index],
|
||||
}
|
||||
context["robot_type"] = dataset_meta.robot_type
|
||||
context["cameras"] = [key.split(".")[-1] for key in dataset_meta.camera_keys]
|
||||
if dataset_repo_id:
|
||||
dataset_cfg = getattr(cfg, "dataset", None)
|
||||
try:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
|
||||
meta = LeRobotDatasetMetadata(
|
||||
dataset_repo_id,
|
||||
root=getattr(dataset_cfg, "root", None),
|
||||
revision=getattr(dataset_cfg, "revision", None),
|
||||
)
|
||||
context["dataset"] = {
|
||||
"repo_id": dataset_repo_id,
|
||||
"episodes": meta.total_episodes,
|
||||
"frames": meta.total_frames,
|
||||
"fps": meta.fps,
|
||||
"tasks": [str(task) for task in meta.tasks.index],
|
||||
}
|
||||
context["robot_type"] = meta.robot_type
|
||||
context["cameras"] = [key.split(".")[-1] for key in meta.camera_keys]
|
||||
except Exception as e: # noqa: BLE001 — dataset details are optional, never fail the push
|
||||
logging.warning(
|
||||
f"Could not load dataset metadata for '{dataset_repo_id}'; those sections will be "
|
||||
f"omitted from the model card. ({e})"
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
@@ -295,7 +304,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
cfg: TrainPipelineConfig,
|
||||
peft_model=None,
|
||||
state_dict: dict[str, Tensor] | None = None,
|
||||
dataset_meta: LeRobotDatasetMetadata | None = None,
|
||||
):
|
||||
api = HfApi()
|
||||
repo_id = api.create_repo(
|
||||
@@ -317,12 +325,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
self.save_pretrained(saved_path, state_dict=state_dict)
|
||||
|
||||
card = self.generate_model_card(
|
||||
cfg.dataset.repo_id,
|
||||
self.config.type,
|
||||
self.config.license,
|
||||
self.config.tags,
|
||||
cfg=cfg,
|
||||
dataset_meta=dataset_meta,
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
|
||||
)
|
||||
card.save(str(saved_path / "README.md"))
|
||||
|
||||
@@ -349,7 +352,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
license: str | None,
|
||||
tags: list[str] | None,
|
||||
cfg: TrainPipelineConfig | None = None,
|
||||
dataset_meta: LeRobotDatasetMetadata | None = None,
|
||||
) -> ModelCard:
|
||||
base_model_mapping = {
|
||||
"smolvla": "lerobot/smolvla_base",
|
||||
@@ -370,7 +372,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
)
|
||||
|
||||
context = _build_card_context(
|
||||
cfg, dataset_meta, self.config.input_features, self.config.output_features
|
||||
cfg, dataset_repo_id, self.config.input_features, self.config.output_features
|
||||
)
|
||||
# Used by the template to pre-fill commands and the "Fine-tuned from" line.
|
||||
context["policy_repo_id"] = getattr(self.config, "repo_id", None)
|
||||
@@ -387,7 +389,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
self,
|
||||
peft_config=None,
|
||||
peft_cli_overrides: dict | None = None,
|
||||
) -> PreTrainedPolicy:
|
||||
) -> "PreTrainedPolicy":
|
||||
"""
|
||||
Wrap this policy with PEFT adapters for parameter-efficient fine-tuning.
|
||||
|
||||
|
||||
@@ -65,13 +65,7 @@ class BiRebotB601Follower(BimanualMixin, Robot):
|
||||
cameras=left_arm_cameras,
|
||||
motor_can_ids=config.left_arm_config.motor_can_ids,
|
||||
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
|
||||
control_mode=config.left_arm_config.control_mode,
|
||||
mit_kp=config.left_arm_config.mit_kp,
|
||||
mit_kd=config.left_arm_config.mit_kd,
|
||||
gripper_control_mode=config.left_arm_config.gripper_control_mode,
|
||||
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
|
||||
gripper_mit_kp=config.left_arm_config.gripper_mit_kp,
|
||||
gripper_mit_kd=config.left_arm_config.gripper_mit_kd,
|
||||
joint_limits=config.left_arm_config.joint_limits,
|
||||
)
|
||||
|
||||
@@ -86,13 +80,7 @@ class BiRebotB601Follower(BimanualMixin, Robot):
|
||||
cameras=config.right_arm_config.cameras,
|
||||
motor_can_ids=config.right_arm_config.motor_can_ids,
|
||||
pos_vel_velocity=config.right_arm_config.pos_vel_velocity,
|
||||
control_mode=config.right_arm_config.control_mode,
|
||||
mit_kp=config.right_arm_config.mit_kp,
|
||||
mit_kd=config.right_arm_config.mit_kd,
|
||||
gripper_control_mode=config.right_arm_config.gripper_control_mode,
|
||||
gripper_torque_ratio=config.right_arm_config.gripper_torque_ratio,
|
||||
gripper_mit_kp=config.right_arm_config.gripper_mit_kp,
|
||||
gripper_mit_kd=config.right_arm_config.gripper_mit_kd,
|
||||
joint_limits=config.right_arm_config.joint_limits,
|
||||
)
|
||||
|
||||
|
||||
@@ -65,33 +65,18 @@ class RebotB601FollowerConfig:
|
||||
}
|
||||
)
|
||||
|
||||
# Max speed (deg/s) per joint for POS_VEL arms and FORCE_POS gripper (motor order).
|
||||
pos_vel_velocity: float | list[float] = field(
|
||||
default_factory=lambda: [150.0, 150.0, 150.0, 150.0, 150.0, 150.0, 900.0]
|
||||
)
|
||||
# Target velocity for joints running in POS_VEL mode, in degrees/s. A scalar is
|
||||
# applied to every joint; a list provides one value per joint (in motor order).
|
||||
pos_vel_velocity: float | list[float] = field(default_factory=lambda: [150.0] * 7)
|
||||
|
||||
# Arm control: "mit" or "pos_vel".
|
||||
control_mode: str = "mit"
|
||||
|
||||
# MIT kp/kd per arm joint (motor order). Unused when control_mode="pos_vel".
|
||||
mit_kp: float | list[float] = field(default_factory=lambda: [45.0, 45.0, 45.0, 8.0, 9.0, 8.0, 8.0])
|
||||
mit_kd: float | list[float] = field(default_factory=lambda: [12.0, 12.0, 12.0, 1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
# Gripper control: "force_pos" or "mit".
|
||||
gripper_control_mode: str = "force_pos"
|
||||
|
||||
# FORCE_POS only: max grip force, in [0, 1].
|
||||
gripper_torque_ratio: float = 0.07
|
||||
|
||||
# MIT only.
|
||||
gripper_mit_kp: float = 8.0
|
||||
gripper_mit_kd: float = 0.3
|
||||
# Torque/current ratio for the gripper's FORCE_POS mode, in range [0, 1].
|
||||
gripper_torque_ratio: float = 0.1
|
||||
|
||||
# Soft joint limits (degrees). These are clipped against on every action.
|
||||
joint_limits: dict[str, tuple[float, float]] = field(
|
||||
default_factory=lambda: {
|
||||
"shoulder_pan": (-150.0, 150.0),
|
||||
"shoulder_lift": (-200.0, 1.0),
|
||||
"shoulder_pan": (-145.0, 145.0),
|
||||
"shoulder_lift": (-170.0, 1.0),
|
||||
"elbow_flex": (-200.0, 1.0),
|
||||
"wrist_flex": (-80.0, 90.0),
|
||||
"wrist_yaw": (-90.0, 90.0),
|
||||
|
||||
@@ -174,25 +174,11 @@ class RebotB601Follower(Robot):
|
||||
print(f"Calibration saved to {self.calibration_fpath}")
|
||||
|
||||
def configure(self) -> None:
|
||||
if self.config.control_mode not in ("pos_vel", "mit"):
|
||||
raise ValueError(
|
||||
f"Unsupported control_mode '{self.config.control_mode}'. Use 'pos_vel' or 'mit'."
|
||||
)
|
||||
if self.config.gripper_control_mode not in ("force_pos", "mit"):
|
||||
raise ValueError(
|
||||
f"Unsupported gripper_control_mode '{self.config.gripper_control_mode}'. "
|
||||
"Use 'force_pos' or 'mit'."
|
||||
)
|
||||
use_mit = self.config.control_mode == "mit"
|
||||
gripper_use_mit = self.config.gripper_control_mode == "mit"
|
||||
self.bus.enable_all()
|
||||
for motor_name, motor in self.motors.items():
|
||||
if motor_name == GRIPPER_MOTOR:
|
||||
target_mode = MotorBridgeMode.MIT if gripper_use_mit else MotorBridgeMode.FORCE_POS
|
||||
elif use_mit:
|
||||
target_mode = MotorBridgeMode.MIT
|
||||
else:
|
||||
target_mode = MotorBridgeMode.POS_VEL
|
||||
target_mode = (
|
||||
MotorBridgeMode.FORCE_POS if motor_name == GRIPPER_MOTOR else MotorBridgeMode.POS_VEL
|
||||
)
|
||||
for attempt in range(_ENSURE_MODE_RETRIES + 1):
|
||||
try:
|
||||
motor.ensure_mode(target_mode)
|
||||
@@ -278,34 +264,22 @@ class RebotB601Follower(Robot):
|
||||
goal_present_pos = {key: (g, present_pos.get(key, g)) for key, g in goal_pos.items()}
|
||||
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
||||
|
||||
use_mit = self.config.control_mode == "mit"
|
||||
for motor_name, position_deg in goal_pos.items():
|
||||
motor = self.motors.get(motor_name)
|
||||
if motor is None:
|
||||
continue
|
||||
idx = self.motor_names.index(motor_name)
|
||||
vel_deg_s = (
|
||||
self.config.pos_vel_velocity[idx]
|
||||
if isinstance(self.config.pos_vel_velocity, list)
|
||||
else self.config.pos_vel_velocity
|
||||
)
|
||||
pos_rad = math.radians(position_deg)
|
||||
vel_rad = math.radians(vel_deg_s)
|
||||
if motor_name == GRIPPER_MOTOR:
|
||||
if self.config.gripper_control_mode == "mit":
|
||||
motor.send_mit(pos_rad, 0.0, self.config.gripper_mit_kp, self.config.gripper_mit_kd, 0.0)
|
||||
else:
|
||||
vel_deg_s = (
|
||||
self.config.pos_vel_velocity[idx]
|
||||
if isinstance(self.config.pos_vel_velocity, list)
|
||||
else self.config.pos_vel_velocity
|
||||
)
|
||||
motor.send_force_pos(pos_rad, math.radians(vel_deg_s), self.config.gripper_torque_ratio)
|
||||
elif use_mit:
|
||||
kp = self.config.mit_kp[idx] if isinstance(self.config.mit_kp, list) else self.config.mit_kp
|
||||
kd = self.config.mit_kd[idx] if isinstance(self.config.mit_kd, list) else self.config.mit_kd
|
||||
motor.send_mit(pos_rad, 0.0, kp, kd, 0.0)
|
||||
motor.send_force_pos(pos_rad, vel_rad, self.config.gripper_torque_ratio)
|
||||
else:
|
||||
vel_deg_s = (
|
||||
self.config.pos_vel_velocity[idx]
|
||||
if isinstance(self.config.pos_vel_velocity, list)
|
||||
else self.config.pos_vel_velocity
|
||||
)
|
||||
motor.send_pos_vel(pos_rad, math.radians(vel_deg_s))
|
||||
motor.send_pos_vel(pos_rad, vel_rad)
|
||||
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
|
||||
@@ -736,9 +736,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
unwrapped_model = accelerator.unwrap_model(policy)
|
||||
# PEFT only applies when training a policy — reward models use the plain path.
|
||||
if not cfg.is_reward_model_training and cfg.policy.use_peft:
|
||||
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model, dataset_meta=dataset.meta)
|
||||
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
|
||||
else:
|
||||
unwrapped_model.push_model_to_hub(cfg, state_dict=model_state_dict, dataset_meta=dataset.meta)
|
||||
unwrapped_model.push_model_to_hub(cfg, state_dict=model_state_dict)
|
||||
preprocessor.push_to_hub(active_cfg.repo_id)
|
||||
postprocessor.push_to_hub(active_cfg.repo_id)
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ class RebotArm102LeaderConfig:
|
||||
joint_ranges: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"shoulder_pan": [-150, 150],
|
||||
"shoulder_lift": [-200, 1],
|
||||
"shoulder_lift": [-170, 1],
|
||||
"elbow_flex": [-200, 1],
|
||||
"wrist_flex": [-80, 90],
|
||||
"wrist_yaw": [-90, 90],
|
||||
|
||||
@@ -245,3 +245,44 @@ class TestFeatureFileRouting:
|
||||
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
|
||||
# ── 5. Depth stats unit canonicalization (millimetres) ────────────────
|
||||
|
||||
|
||||
class TestDepthStatsUnit:
|
||||
"""Depth stats are always stored in millimetres, regardless of raw frame dtype."""
|
||||
|
||||
NUM_FRAMES = 4
|
||||
|
||||
@pytest.mark.parametrize("use_videos", [False, True])
|
||||
def test_stats_canonicalized_to_mm(self, tmp_path, features_factory, use_videos):
|
||||
"""Float (metre) and integer (millimetre) depth over the same physical range
|
||||
yield identical millimetre-scale stats."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
def _record(depth_dtype, root):
|
||||
features = features_factory(
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH, use_videos=use_videos
|
||||
)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
fps=DEFAULT_FPS,
|
||||
features=features,
|
||||
root=root,
|
||||
use_videos=use_videos,
|
||||
streaming_encoding=use_videos,
|
||||
)
|
||||
add_frames(dataset, num_frames=self.NUM_FRAMES, depth_dtype=depth_dtype)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
return np.asarray(dataset.meta.stats[DEPTH_KEY]["mean"]).reshape(-1)
|
||||
|
||||
# add_frames ramps float depth over 0.1–10 m and integer depth over 100–10000 mm
|
||||
# (the same physical range), so canonicalized stats must match.
|
||||
mean_m = _record(np.float32, tmp_path / "ds_m")
|
||||
mean_mm = _record(np.uint16, tmp_path / "ds_mm")
|
||||
|
||||
# Float (metre) input is scaled to millimetres, not left in the single-digit metre range.
|
||||
assert mean_m.item() > 50.0
|
||||
np.testing.assert_allclose(mean_m, mean_mm, rtol=0.05)
|
||||
|
||||
Vendored
+12
-7
@@ -49,16 +49,18 @@ from tests.fixtures.constants import (
|
||||
)
|
||||
|
||||
|
||||
def add_frames(dataset: LeRobotDataset, num_frames: int) -> None:
|
||||
def add_frames(dataset: LeRobotDataset, num_frames: int, depth_dtype: np.dtype = np.uint16) -> None:
|
||||
"""Append ``num_frames`` synthetic frames to ``dataset``.
|
||||
|
||||
Generates per-feature payloads from ``dataset.meta``: uint16 depth ramps for
|
||||
keys in ``dataset.meta.depth_keys``, uint8 random noise for video/image keys,
|
||||
and float32 zeros for everything else. ``DEFAULT_FEATURES`` (timestamp,
|
||||
frame_index, ...) are auto-populated by ``add_frame`` and skipped here.
|
||||
Generates per-feature payloads from ``dataset.meta``: depth ramps (``depth_dtype``,
|
||||
default ``uint16`` millimetres; pass ``np.float32`` for metres) for keys in
|
||||
``dataset.meta.depth_keys``, uint8 random noise for video/image keys, and float32
|
||||
zeros for everything else. ``DEFAULT_FEATURES`` (timestamp, frame_index, ...) are
|
||||
auto-populated by ``add_frame`` and skipped here.
|
||||
"""
|
||||
video_keys = dataset.meta.video_keys
|
||||
depth_keys = dataset.meta.depth_keys
|
||||
depth_is_float = np.issubdtype(depth_dtype, np.floating)
|
||||
# Smooth gradient base reused per (H, W) to keep depth frames cheap to
|
||||
# encode (HEVC Main 12 hates white noise).
|
||||
_depth_base_cache: dict[tuple[int, int], np.ndarray] = {}
|
||||
@@ -70,11 +72,14 @@ def add_frames(dataset: LeRobotDataset, num_frames: int) -> None:
|
||||
shape = ft["shape"]
|
||||
if key in depth_keys:
|
||||
h, w, _ = shape
|
||||
# Float depth is expressed in metres, integer depth in millimetres.
|
||||
lo, hi = (0.1, 10.0) if depth_is_float else (100.0, 10_000.0)
|
||||
base = _depth_base_cache.setdefault(
|
||||
(h, w),
|
||||
np.linspace(100.0, 10_000.0, h * w, dtype=np.float32).reshape(h, w, 1),
|
||||
np.linspace(lo, hi, h * w, dtype=np.float32).reshape(h, w, 1),
|
||||
)
|
||||
frame[key] = (base + 50.0 * i).clip(0, 65535).astype(np.uint16)
|
||||
step = (0.05 if depth_is_float else 50.0) * i
|
||||
frame[key] = (base + step).clip(0, 65535).astype(depth_dtype)
|
||||
elif key in video_keys:
|
||||
frame[key] = np.random.randint(0, 256, shape, dtype=np.uint8)
|
||||
else:
|
||||
|
||||
@@ -91,11 +91,10 @@ def test_get_observation_converts_to_degrees(follower):
|
||||
|
||||
|
||||
def test_send_action_clips_to_joint_limits(follower):
|
||||
# shoulder_pan limit is (-150, 150); request beyond the upper bound.
|
||||
# shoulder_pan limit is (-145, 145); request beyond the upper bound.
|
||||
returned = follower.send_action({"shoulder_pan.pos": 999.0})
|
||||
assert returned["shoulder_pan.pos"] == 150.0
|
||||
# Default control_mode is "mit", so arm joints are driven via send_mit.
|
||||
follower.motors["shoulder_pan"].send_mit.assert_called_once()
|
||||
assert returned["shoulder_pan.pos"] == 145.0
|
||||
follower.motors["shoulder_pan"].send_pos_vel.assert_called_once()
|
||||
|
||||
|
||||
def test_send_action_routes_gripper_to_force_pos(follower):
|
||||
@@ -104,22 +103,6 @@ def test_send_action_routes_gripper_to_force_pos(follower):
|
||||
follower.motors["gripper"].send_pos_vel.assert_not_called()
|
||||
|
||||
|
||||
def test_gripper_mit_mode_routes_to_send_mit():
|
||||
bus_mock = _make_bus_mock()
|
||||
with (
|
||||
patch(f"{_MODULE}.require_package", lambda *a, **kw: None),
|
||||
patch(f"{_MODULE}.MotorBridgeController") as controller_cls,
|
||||
patch(f"{_MODULE}.MotorBridgeMode", MagicMock()),
|
||||
):
|
||||
controller_cls.from_dm_serial.return_value = bus_mock
|
||||
cfg = RebotB601FollowerRobotConfig(port="/dev/null", gripper_control_mode="mit")
|
||||
robot = RebotB601Follower(cfg)
|
||||
robot.connect(calibrate=False)
|
||||
robot.send_action({"gripper.pos": -10.0})
|
||||
robot.motors["gripper"].send_mit.assert_called_once()
|
||||
robot.motors["gripper"].send_force_pos.assert_not_called()
|
||||
|
||||
|
||||
def test_bimanual_prefixes_features():
|
||||
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
|
||||
cfg = BiRebotB601FollowerConfig(
|
||||
|
||||
Reference in New Issue
Block a user