test(dtype): fixing stats computation typing tests

This commit is contained in:
CarolinePascal
2026-06-25 17:16:42 +02:00
parent 125819681a
commit 622e63abc5
2 changed files with 16 additions and 10 deletions
+2 -2
View File
@@ -242,12 +242,12 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
images = None
for i, idx in enumerate(sampled_indices):
path = image_paths[idx]
# we load RGB images as uint8 to reduce memory usage
# we load RGB images as uint8 to reduce memory usage; depth keeps its native dtype
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
img = auto_downsample_height_width(img)
if images is None:
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
images = np.empty((len(sampled_indices), *img.shape), dtype=img.dtype)
images[i] = img
+14 -8
View File
@@ -35,9 +35,11 @@ from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
def mock_load_image_as_numpy(path, dtype, channel_first):
channels = 1 if "depth" in str(path) else 3
shape = (channels, 32, 32) if channel_first else (32, 32, channels)
return np.ones(shape, dtype=dtype)
is_depth = "depth" in str(path)
channels = 1 if is_depth else 3
out_dtype = np.uint16 if is_depth else dtype
arr = np.arange(channels * 32 * 32, dtype=out_dtype).reshape(channels, 32, 32)
return arr if channel_first else arr.transpose(1, 2, 0)
@pytest.fixture
@@ -190,10 +192,13 @@ def test_compute_episode_stats():
assert stats[depth_key]["count"].item() == 100
assert stats[OBS_STATE]["count"].item() == 100
assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1)
# Depth keeps a single channel and is not rescaled by 255 (raw stored units).
assert stats[depth_key]["mean"].shape == (1, 1, 1)
assert stats[depth_key]["mean"].item() == 1.0
np.testing.assert_allclose(stats[OBS_IMAGE]["mean"], 1 / 255)
# Depth keeps raw values: max far exceeds 255, proving no /255 and no uint8 downcast.
assert stats[depth_key]["min"].item() == 0.0
assert stats[depth_key]["max"].item() == 1023.0
# RGB is normalized to [0, 1].
np.testing.assert_allclose(stats[OBS_IMAGE]["min"], 0.0)
np.testing.assert_allclose(stats[OBS_IMAGE]["max"], 1.0)
def test_assert_type_and_shape_valid():
@@ -650,8 +655,9 @@ def test_compute_episode_stats_with_image_data():
# Depth quantiles are single-channel and kept in raw (un-normalized) units.
for q in ("q01", "q50", "q99"):
assert stats["observation.images.depth"][q].shape == (1, 1, 1)
assert stats["observation.images.depth"]["q50"].item() == 1.0
np.testing.assert_allclose(stats["observation.image"]["q50"], 1 / 255)
# Depth max stays in raw units (not /255, not uint8-capped); RGB is normalized.
assert stats["observation.images.depth"]["max"].item() == 1023.0
np.testing.assert_allclose(stats["observation.image"]["max"], 1.0)
# Action quantiles should have correct shape
assert stats["action"]["q01"].shape == (5,)