mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-02 07:37:10 +00:00
tests(unit): adapting and extending depth tests to units manipulations
This commit is contained in:
@@ -32,6 +32,7 @@ from lerobot.configs.video import (
|
||||
)
|
||||
from lerobot.datasets.depth_utils import dequantize_depth, quantize_depth
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image, write_image
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
@@ -245,3 +246,83 @@ class TestFeatureFileRouting:
|
||||
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
|
||||
class TestDepthUnitMetadata:
|
||||
"""The depth unit is inferred once from dtype, stored in ``info``, and drives stats + reads."""
|
||||
|
||||
NUM_FRAMES = 4
|
||||
|
||||
def _record(self, root, features_factory, depth_dtype, value, use_videos):
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = features_factory(camera_features=DUMMY_CAMERA_FEATURES_WITH_DEPTH, use_videos=use_videos)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
fps=DEFAULT_FPS,
|
||||
features=features,
|
||||
root=root,
|
||||
use_videos=use_videos,
|
||||
streaming_encoding=use_videos,
|
||||
)
|
||||
for _ in range(self.NUM_FRAMES):
|
||||
frame: dict = {"task": "test"}
|
||||
for key, ft in dataset.meta.features.items():
|
||||
if key in DEFAULT_FEATURES:
|
||||
continue
|
||||
if key in dataset.meta.depth_keys:
|
||||
frame[key] = np.full(ft["shape"], value, dtype=depth_dtype)
|
||||
elif key in dataset.meta.camera_keys:
|
||||
frame[key] = np.random.randint(0, 256, ft["shape"], dtype=np.uint8)
|
||||
else:
|
||||
frame[key] = np.zeros(ft["shape"], dtype=np.float32)
|
||||
dataset.add_frame(frame)
|
||||
return dataset
|
||||
|
||||
@pytest.mark.parametrize("use_videos", [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
("depth_dtype", "value", "expected_unit"),
|
||||
[(np.float32, 2.0, DEPTH_METER_UNIT), (np.uint16, 2000, DEPTH_MILLIMETER_UNIT)],
|
||||
)
|
||||
def test_recorded_unit_inferred_persisted_and_kept_in_stats(
|
||||
self, tmp_path, features_factory, use_videos, depth_dtype, value, expected_unit
|
||||
):
|
||||
"""Unit is inferred from the first frame's dtype, drives stats (raw, never canonicalized), and survives a reload."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset = self._record(tmp_path / "ds", features_factory, depth_dtype, value, use_videos)
|
||||
assert dataset.meta.features[DEPTH_KEY]["info"]["depth_unit"] == expected_unit
|
||||
dataset.save_episode()
|
||||
mean = float(np.asarray(dataset.meta.stats[DEPTH_KEY]["mean"]).reshape(-1)[0])
|
||||
np.testing.assert_allclose(mean, value, rtol=0.05)
|
||||
dataset.finalize()
|
||||
|
||||
reloaded = LeRobotDataset(repo_id=DUMMY_REPO_ID, root=tmp_path / "ds")
|
||||
assert reloaded.meta.features[DEPTH_KEY]["info"]["depth_unit"] == expected_unit
|
||||
|
||||
@pytest.mark.parametrize("use_videos", [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
("output_unit", "expected"),
|
||||
[(DEPTH_MILLIMETER_UNIT, 2000.0), (DEPTH_METER_UNIT, 2.0)],
|
||||
)
|
||||
def test_read_honors_output_unit_for_frames_and_stats(
|
||||
self, tmp_path, features_factory, use_videos, output_unit, expected
|
||||
):
|
||||
"""Reloading with a ``depth_output_unit`` converts metre frames (image mode) and rescales stats while preserving count."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset = self._record(tmp_path / "ds", features_factory, np.float32, 2.0, use_videos=use_videos)
|
||||
dataset.save_episode()
|
||||
count = float(np.asarray(dataset.meta.stats[DEPTH_KEY]["count"]).reshape(-1)[0])
|
||||
dataset.finalize()
|
||||
|
||||
read_dataset = LeRobotDataset(
|
||||
repo_id=DUMMY_REPO_ID, root=tmp_path / "ds", depth_output_unit=output_unit
|
||||
)
|
||||
stats = read_dataset.meta.stats[DEPTH_KEY]
|
||||
np.testing.assert_allclose(float(np.asarray(stats["mean"]).reshape(-1)[0]), expected, rtol=0.05)
|
||||
np.testing.assert_allclose(float(np.asarray(stats["count"]).reshape(-1)[0]), count)
|
||||
|
||||
if not use_videos:
|
||||
depth = read_dataset[0][DEPTH_KEY]
|
||||
assert torch.allclose(depth, torch.full_like(depth, expected))
|
||||
|
||||
Vendored
+8
@@ -27,6 +27,7 @@ import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.depth_utils import infer_depth_unit
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||
from lerobot.datasets.io_utils import flatten_dict, hf_transform_to_torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
@@ -535,6 +536,13 @@ def lerobot_dataset_factory(
|
||||
chunks_size=chunks_size,
|
||||
**info_kwargs,
|
||||
)
|
||||
# This synthetic path skips add_frame, so record the depth unit the writer would
|
||||
# have stored (dummy depth is uint16) to keep ``depth_unit`` present in info.json.
|
||||
# Reassign a fresh info dict to avoid mutating the shared feature constants.
|
||||
for ft in info.features.values():
|
||||
ft_info = ft.get("info")
|
||||
if ft_info is not None and ft_info.get("is_depth_map") and "depth_unit" not in ft_info:
|
||||
ft["info"] = {**ft_info, "depth_unit": infer_depth_unit(np.uint16)}
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info.features)
|
||||
if tasks is None:
|
||||
|
||||
Reference in New Issue
Block a user