diff --git a/src/lerobot/utils/feature_utils.py b/src/lerobot/utils/feature_utils.py index 2a4886234..3ac32ae5f 100644 --- a/src/lerobot/utils/feature_utils.py +++ b/src/lerobot/utils/feature_utils.py @@ -69,6 +69,7 @@ def hw_to_dataset_features( for key, ftype in hw_features.items() if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL) } + #TODO(CarolinePascal): we should not rely on the shape to determine if a feature is a camera ! cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} if joint_fts and prefix == ACTION: @@ -86,11 +87,22 @@ def hw_to_dataset_features( } for key, shape in cam_fts.items(): - features[f"{prefix}.images.{key}"] = { - "dtype": "video" if use_video else "image", - "shape": shape, - "names": ["height", "width", "channels"], - } + dtype = "video" if use_video else "image" + if len(shape) == 2 or shape[2] == 1: + if len(shape) == 2: + shape = (shape[0], shape[1], 1) + features[f"{prefix}.depth_maps.{key}"] = { + "dtype": dtype, + "shape": shape, + "names": ["height", "width", "channels"], + "info": {dtype + ".is_depth_map": True}, + } + else: + features[f"{prefix}.images.{key}"] = { + "dtype": dtype, + "shape": shape, + "names": ["height", "width", "channels"], + } _validate_feature_names(features) return features @@ -120,7 +132,10 @@ def build_dataset_frame( elif ft["dtype"] == "float32" and len(ft["shape"]) == 1: frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32) elif ft["dtype"] in ["image", "video"]: - frame[key] = values[key.removeprefix(f"{prefix}.images.")] + if ft["info"].get(ft["dtype"] + ".is_depth_map"): + frame[key] = values[key.removeprefix(f"{prefix}.depth_maps.")] + else: + frame[key] = values[key.removeprefix(f"{prefix}.images.")] return frame