mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-26 12:47:18 +00:00
test(dtype): fixing stats computation typing tests
This commit is contained in:
@@ -242,12 +242,12 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
images = None
|
||||
for i, idx in enumerate(sampled_indices):
|
||||
path = image_paths[idx]
|
||||
# we load RGB images as uint8 to reduce memory usage
|
||||
# we load RGB images as uint8 to reduce memory usage; depth keeps its native dtype
|
||||
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
|
||||
img = auto_downsample_height_width(img)
|
||||
|
||||
if images is None:
|
||||
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
||||
images = np.empty((len(sampled_indices), *img.shape), dtype=img.dtype)
|
||||
|
||||
images[i] = img
|
||||
|
||||
|
||||
@@ -35,9 +35,11 @@ from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
channels = 1 if "depth" in str(path) else 3
|
||||
shape = (channels, 32, 32) if channel_first else (32, 32, channels)
|
||||
return np.ones(shape, dtype=dtype)
|
||||
is_depth = "depth" in str(path)
|
||||
channels = 1 if is_depth else 3
|
||||
out_dtype = np.uint16 if is_depth else dtype
|
||||
arr = np.arange(channels * 32 * 32, dtype=out_dtype).reshape(channels, 32, 32)
|
||||
return arr if channel_first else arr.transpose(1, 2, 0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -190,10 +192,13 @@ def test_compute_episode_stats():
|
||||
assert stats[depth_key]["count"].item() == 100
|
||||
assert stats[OBS_STATE]["count"].item() == 100
|
||||
assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1)
|
||||
# Depth keeps a single channel and is not rescaled by 255 (raw stored units).
|
||||
assert stats[depth_key]["mean"].shape == (1, 1, 1)
|
||||
assert stats[depth_key]["mean"].item() == 1.0
|
||||
np.testing.assert_allclose(stats[OBS_IMAGE]["mean"], 1 / 255)
|
||||
# Depth keeps raw values: max far exceeds 255, proving no /255 and no uint8 downcast.
|
||||
assert stats[depth_key]["min"].item() == 0.0
|
||||
assert stats[depth_key]["max"].item() == 1023.0
|
||||
# RGB is normalized to [0, 1].
|
||||
np.testing.assert_allclose(stats[OBS_IMAGE]["min"], 0.0)
|
||||
np.testing.assert_allclose(stats[OBS_IMAGE]["max"], 1.0)
|
||||
|
||||
|
||||
def test_assert_type_and_shape_valid():
|
||||
@@ -650,8 +655,9 @@ def test_compute_episode_stats_with_image_data():
|
||||
# Depth quantiles are single-channel and kept in raw (un-normalized) units.
|
||||
for q in ("q01", "q50", "q99"):
|
||||
assert stats["observation.images.depth"][q].shape == (1, 1, 1)
|
||||
assert stats["observation.images.depth"]["q50"].item() == 1.0
|
||||
np.testing.assert_allclose(stats["observation.image"]["q50"], 1 / 255)
|
||||
# Depth max stays in raw units (not /255, not uint8-capped); RGB is normalized.
|
||||
assert stats["observation.images.depth"]["max"].item() == 1023.0
|
||||
np.testing.assert_allclose(stats["observation.image"]["max"], 1.0)
|
||||
|
||||
# Action quantiles should have correct shape
|
||||
assert stats["action"]["q01"].shape == (5,)
|
||||
|
||||
Reference in New Issue
Block a user