mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
feat(depth shape): ensuring depth maps shape is always including the channel
This commit is contained in:
@@ -55,10 +55,10 @@ IMAGE_FEATURES = {
|
||||
|
||||
DEPTH_FEATURES = {
|
||||
**SIMPLE_FEATURES,
|
||||
"observation.depth.laptop": {
|
||||
"observation.images.laptop_depth": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96),
|
||||
"names": ["height", "width"],
|
||||
"shape": (64, 96, 1),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": {"video.is_depth_map": True},
|
||||
},
|
||||
}
|
||||
@@ -69,11 +69,13 @@ def _make_dummy_stats(features: dict) -> dict:
|
||||
stats = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] in ("image", "video"):
|
||||
channels = ft["shape"][-1]
|
||||
stat_shape = (channels, 1, 1)
|
||||
stats[key] = {
|
||||
"max": np.ones((3, 1, 1), dtype=np.float32),
|
||||
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32),
|
||||
"min": np.zeros((3, 1, 1), dtype=np.float32),
|
||||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32),
|
||||
"max": np.ones(stat_shape, dtype=np.float32),
|
||||
"mean": np.full(stat_shape, 0.5, dtype=np.float32),
|
||||
"min": np.zeros(stat_shape, dtype=np.float32),
|
||||
"std": np.full(stat_shape, 0.25, dtype=np.float32),
|
||||
"count": np.array([5]),
|
||||
}
|
||||
elif ft["dtype"] in ("float32", "float64", "int64"):
|
||||
|
||||
Reference in New Issue
Block a user