From 622e63abc59ec2a973d60a37f2a5b7e356f8d21f Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Thu, 25 Jun 2026 17:16:42 +0200 Subject: [PATCH] test(dtype): fixing stats computation typing tests --- src/lerobot/datasets/compute_stats.py | 4 ++-- tests/datasets/test_compute_stats.py | 22 ++++++++++++++-------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/lerobot/datasets/compute_stats.py b/src/lerobot/datasets/compute_stats.py index 4b55750a6..88f7ea226 100644 --- a/src/lerobot/datasets/compute_stats.py +++ b/src/lerobot/datasets/compute_stats.py @@ -242,12 +242,12 @@ def sample_images(image_paths: list[str]) -> np.ndarray: images = None for i, idx in enumerate(sampled_indices): path = image_paths[idx] - # we load RGB images as uint8 to reduce memory usage + # we load RGB images as uint8 to reduce memory usage; depth keeps its native dtype img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True) img = auto_downsample_height_width(img) if images is None: - images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8) + images = np.empty((len(sampled_indices), *img.shape), dtype=img.dtype) images[i] = img diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index 01af381d2..9f399b85c 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -35,9 +35,11 @@ from lerobot.utils.constants import OBS_IMAGE, OBS_STATE def mock_load_image_as_numpy(path, dtype, channel_first): - channels = 1 if "depth" in str(path) else 3 - shape = (channels, 32, 32) if channel_first else (32, 32, channels) - return np.ones(shape, dtype=dtype) + is_depth = "depth" in str(path) + channels = 1 if is_depth else 3 + out_dtype = np.uint16 if is_depth else dtype + arr = np.arange(channels * 32 * 32, dtype=out_dtype).reshape(channels, 32, 32) + return arr if channel_first else arr.transpose(1, 2, 0) @pytest.fixture @@ -190,10 +192,13 @@ def test_compute_episode_stats(): assert stats[depth_key]["count"].item() == 100 assert stats[OBS_STATE]["count"].item() == 100 assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1) - # Depth keeps a single channel and is not rescaled by 255 (raw stored units). assert stats[depth_key]["mean"].shape == (1, 1, 1) - assert stats[depth_key]["mean"].item() == 1.0 - np.testing.assert_allclose(stats[OBS_IMAGE]["mean"], 1 / 255) + # Depth keeps raw values: max far exceeds 255, proving no /255 and no uint8 downcast. + assert stats[depth_key]["min"].item() == 0.0 + assert stats[depth_key]["max"].item() == 1023.0 + # RGB is normalized to [0, 1]. + np.testing.assert_allclose(stats[OBS_IMAGE]["min"], 0.0) + np.testing.assert_allclose(stats[OBS_IMAGE]["max"], 1.0) def test_assert_type_and_shape_valid(): @@ -650,8 +655,9 @@ def test_compute_episode_stats_with_image_data(): # Depth quantiles are single-channel and kept in raw (un-normalized) units. for q in ("q01", "q50", "q99"): assert stats["observation.images.depth"][q].shape == (1, 1, 1) - assert stats["observation.images.depth"]["q50"].item() == 1.0 - np.testing.assert_allclose(stats["observation.image"]["q50"], 1 / 255) + # Depth max stays in raw units (not /255, not uint8-capped); RGB is normalized. + assert stats["observation.images.depth"]["max"].item() == 1023.0 + np.testing.assert_allclose(stats["observation.image"]["max"], 1.0) # Action quantiles should have correct shape assert stats["action"]["q01"].shape == (5,)