feat(depth shape): ensuring depth maps shape is always including the channel

This commit is contained in:
CarolinePascal
2026-05-21 14:25:42 +02:00
parent 3cf5e3c8cb
commit e87933302d
5 changed files with 30 additions and 26 deletions
+9 -7
View File
@@ -55,10 +55,10 @@ IMAGE_FEATURES = {
DEPTH_FEATURES = {
**SIMPLE_FEATURES,
"observation.depth.laptop": {
"observation.images.laptop_depth": {
"dtype": "video",
"shape": (64, 96),
"names": ["height", "width"],
"shape": (64, 96, 1),
"names": ["height", "width", "channels"],
"info": {"video.is_depth_map": True},
},
}
@@ -69,11 +69,13 @@ def _make_dummy_stats(features: dict) -> dict:
stats = {}
for key, ft in features.items():
if ft["dtype"] in ("image", "video"):
channels = ft["shape"][-1]
stat_shape = (channels, 1, 1)
stats[key] = {
"max": np.ones((3, 1, 1), dtype=np.float32),
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32),
"min": np.zeros((3, 1, 1), dtype=np.float32),
"std": np.full((3, 1, 1), 0.25, dtype=np.float32),
"max": np.ones(stat_shape, dtype=np.float32),
"mean": np.full(stat_shape, 0.5, dtype=np.float32),
"min": np.zeros(stat_shape, dtype=np.float32),
"std": np.full(stat_shape, 0.25, dtype=np.float32),
"count": np.array([5]),
}
elif ft["dtype"] in ("float32", "float64", "int64"):