fix(stats): fixing stats computation for depth frames

This commit is contained in:
CarolinePascal
2026-06-25 16:39:37 +02:00
parent a42596b917
commit 0e39bae335
2 changed files with 36 additions and 15 deletions
+12 -6
View File
@@ -226,19 +226,25 @@ def load_image_as_numpy(
Args:
fpath (str | Path): Path to the image file.
dtype (np.dtype): The desired data type of the output array. If floating,
pixels are scaled to [0, 1].
pixels are scaled to [0, 1]. Only used for RGB images.
channel_first (bool): If True, converts the image to (C, H, W) format.
Otherwise, it remains in (H, W, C) format.
Returns:
np.ndarray: The image as a numpy array.
"""
img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype)
is_depth = fpath.endswith(".tiff") or fpath.endswith(".tif")
if is_depth:
# Preserve the native depth dtype (uint16 -> "I;16", float32 -> "F").
img = PILImage.open(fpath)
img_array = np.array(img)
else:
img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype)
if np.issubdtype(dtype, np.floating):
img_array /= 255.0
if channel_first: # (H, W, C) -> (C, H, W)
img_array = np.transpose(img_array, (2, 0, 1))
if np.issubdtype(dtype, np.floating):
img_array /= 255.0
img_array = img_array[np.newaxis, ...] if img_array.ndim == 2 else np.transpose(img_array, (2, 0, 1))
return img_array
+24 -9
View File
@@ -35,7 +35,9 @@ from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
def mock_load_image_as_numpy(path, dtype, channel_first):
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
channels = 1 if "depth" in str(path) else 3
shape = (channels, 32, 32) if channel_first else (32, 32, channels)
return np.ones(shape, dtype=dtype)
@pytest.fixture
@@ -168,22 +170,30 @@ def test_get_feature_stats_single_value():
def test_compute_episode_stats():
depth_key = "observation.images.depth"
episode_data = {
OBS_IMAGE: [f"image_{i}.jpg" for i in range(100)],
depth_key: [f"depth_{i}.tiff" for i in range(100)],
OBS_STATE: np.random.rand(100, 10),
}
features = {
OBS_IMAGE: {"dtype": "image"},
depth_key: {"dtype": "image", "info": {"is_depth_map": True}},
OBS_STATE: {"dtype": "numeric"},
}
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
stats = compute_episode_stats(episode_data, features)
assert OBS_IMAGE in stats and OBS_STATE in stats
assert OBS_IMAGE in stats and depth_key in stats and OBS_STATE in stats
assert stats[OBS_IMAGE]["count"].item() == 100
assert stats[depth_key]["count"].item() == 100
assert stats[OBS_STATE]["count"].item() == 100
assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1)
# Depth keeps a single channel and is not rescaled by 255 (raw stored units).
assert stats[depth_key]["mean"].shape == (1, 1, 1)
assert stats[depth_key]["mean"].item() == 1.0
np.testing.assert_allclose(stats[OBS_IMAGE]["mean"], 1 / 255)
def test_assert_type_and_shape_valid():
@@ -618,25 +628,30 @@ def test_compute_episode_stats_with_custom_quantiles():
def test_compute_episode_stats_with_image_data():
"""Test quantile computation with image features."""
image_paths = [f"image_{i}.jpg" for i in range(50)]
depth_paths = [f"depth_{i}.tiff" for i in range(50)]
episode_data = {
"observation.image": image_paths,
"observation.images.depth": depth_paths,
"action": np.random.normal(0, 1, (50, 5)),
}
features = {
"observation.image": {"dtype": "image"},
"observation.images.depth": {"dtype": "image", "info": {"is_depth_map": True}},
"action": {"dtype": "float32", "shape": (5,)},
}
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
stats = compute_episode_stats(episode_data, features)
# Image quantiles should be normalized and have correct shape
assert "q01" in stats["observation.image"]
assert "q50" in stats["observation.image"]
assert "q99" in stats["observation.image"]
assert stats["observation.image"]["q01"].shape == (3, 1, 1)
assert stats["observation.image"]["q50"].shape == (3, 1, 1)
assert stats["observation.image"]["q99"].shape == (3, 1, 1)
# RGB image quantiles should be normalized and per-channel.
for q in ("q01", "q50", "q99"):
assert stats["observation.image"][q].shape == (3, 1, 1)
# Depth quantiles are single-channel and kept in raw (un-normalized) units.
for q in ("q01", "q50", "q99"):
assert stats["observation.images.depth"][q].shape == (1, 1, 1)
assert stats["observation.images.depth"]["q50"].item() == 1.0
np.testing.assert_allclose(stats["observation.image"]["q50"], 1 / 255)
# Action quantiles should have correct shape
assert stats["action"]["q01"].shape == (5,)