mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-01 07:07:08 +00:00
test(depth stats): updating tests
This commit is contained in:
@@ -245,3 +245,44 @@ class TestFeatureFileRouting:
|
||||
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
|
||||
# ── 5. Depth stats unit canonicalization (millimetres) ────────────────
|
||||
|
||||
|
||||
class TestDepthStatsUnit:
|
||||
"""Depth stats are always stored in millimetres, regardless of raw frame dtype."""
|
||||
|
||||
NUM_FRAMES = 4
|
||||
|
||||
@pytest.mark.parametrize("use_videos", [False, True])
|
||||
def test_stats_canonicalized_to_mm(self, tmp_path, features_factory, use_videos):
|
||||
"""Float (metre) and integer (millimetre) depth over the same physical range
|
||||
yield identical millimetre-scale stats."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
def _record(depth_dtype, root):
|
||||
features = features_factory(
|
||||
camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH, use_videos=use_videos
|
||||
)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
fps=DEFAULT_FPS,
|
||||
features=features,
|
||||
root=root,
|
||||
use_videos=use_videos,
|
||||
streaming_encoding=use_videos,
|
||||
)
|
||||
add_frames(dataset, num_frames=self.NUM_FRAMES, depth_dtype=depth_dtype)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
return np.asarray(dataset.meta.stats[DEPTH_KEY]["mean"]).reshape(-1)
|
||||
|
||||
# add_frames ramps float depth over 0.1–10 m and integer depth over 100–10000 mm
|
||||
# (the same physical range), so canonicalized stats must match.
|
||||
mean_m = _record(np.float32, tmp_path / "ds_m")
|
||||
mean_mm = _record(np.uint16, tmp_path / "ds_mm")
|
||||
|
||||
# Float (metre) input is scaled to millimetres, not left in the single-digit metre range.
|
||||
assert mean_m.item() > 50.0
|
||||
np.testing.assert_allclose(mean_m, mean_mm, rtol=0.05)
|
||||
|
||||
Vendored
+12
-7
@@ -49,16 +49,18 @@ from tests.fixtures.constants import (
|
||||
)
|
||||
|
||||
|
||||
def add_frames(dataset: LeRobotDataset, num_frames: int) -> None:
|
||||
def add_frames(dataset: LeRobotDataset, num_frames: int, depth_dtype: np.dtype = np.uint16) -> None:
|
||||
"""Append ``num_frames`` synthetic frames to ``dataset``.
|
||||
|
||||
Generates per-feature payloads from ``dataset.meta``: uint16 depth ramps for
|
||||
keys in ``dataset.meta.depth_keys``, uint8 random noise for video/image keys,
|
||||
and float32 zeros for everything else. ``DEFAULT_FEATURES`` (timestamp,
|
||||
frame_index, ...) are auto-populated by ``add_frame`` and skipped here.
|
||||
Generates per-feature payloads from ``dataset.meta``: depth ramps (``depth_dtype``,
|
||||
default ``uint16`` millimetres; pass ``np.float32`` for metres) for keys in
|
||||
``dataset.meta.depth_keys``, uint8 random noise for video/image keys, and float32
|
||||
zeros for everything else. ``DEFAULT_FEATURES`` (timestamp, frame_index, ...) are
|
||||
auto-populated by ``add_frame`` and skipped here.
|
||||
"""
|
||||
video_keys = dataset.meta.video_keys
|
||||
depth_keys = dataset.meta.depth_keys
|
||||
depth_is_float = np.issubdtype(depth_dtype, np.floating)
|
||||
# Smooth gradient base reused per (H, W) to keep depth frames cheap to
|
||||
# encode (HEVC Main 12 hates white noise).
|
||||
_depth_base_cache: dict[tuple[int, int], np.ndarray] = {}
|
||||
@@ -70,11 +72,14 @@ def add_frames(dataset: LeRobotDataset, num_frames: int) -> None:
|
||||
shape = ft["shape"]
|
||||
if key in depth_keys:
|
||||
h, w, _ = shape
|
||||
# Float depth is expressed in metres, integer depth in millimetres.
|
||||
lo, hi = (0.1, 10.0) if depth_is_float else (100.0, 10_000.0)
|
||||
base = _depth_base_cache.setdefault(
|
||||
(h, w),
|
||||
np.linspace(100.0, 10_000.0, h * w, dtype=np.float32).reshape(h, w, 1),
|
||||
np.linspace(lo, hi, h * w, dtype=np.float32).reshape(h, w, 1),
|
||||
)
|
||||
frame[key] = (base + 50.0 * i).clip(0, 65535).astype(np.uint16)
|
||||
step = (0.05 if depth_is_float else 50.0) * i
|
||||
frame[key] = (base + step).clip(0, 65535).astype(depth_dtype)
|
||||
elif key in video_keys:
|
||||
frame[key] = np.random.randint(0, 256, shape, dtype=np.uint8)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user