feat(features): route 2D camera shapes to observation.depth.<key>

This commit is contained in:
CarolinePascal
2026-05-20 15:50:46 +02:00
parent f15348e769
commit 085f574301
+21 -6
View File
@@ -69,6 +69,7 @@ def hw_to_dataset_features(
for key, ftype in hw_features.items() for key, ftype in hw_features.items()
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL) if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
} }
#TODO(CarolinePascal): we should not rely on the shape to determine if a feature is a camera !
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
if joint_fts and prefix == ACTION: if joint_fts and prefix == ACTION:
@@ -86,11 +87,22 @@ def hw_to_dataset_features(
} }
for key, shape in cam_fts.items(): for key, shape in cam_fts.items():
features[f"{prefix}.images.{key}"] = { dtype = "video" if use_video else "image"
"dtype": "video" if use_video else "image", if len(shape) == 2 or shape[2] == 1:
"shape": shape, if len(shape) == 2:
"names": ["height", "width", "channels"], shape = (shape[0], shape[1], 1)
} features[f"{prefix}.depth_maps.{key}"] = {
"dtype": dtype,
"shape": shape,
"names": ["height", "width", "channels"],
"info": {dtype + ".is_depth_map": True},
}
else:
features[f"{prefix}.images.{key}"] = {
"dtype": dtype,
"shape": shape,
"names": ["height", "width", "channels"],
}
_validate_feature_names(features) _validate_feature_names(features)
return features return features
@@ -120,7 +132,10 @@ def build_dataset_frame(
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1: elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32) frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
elif ft["dtype"] in ["image", "video"]: elif ft["dtype"] in ["image", "video"]:
frame[key] = values[key.removeprefix(f"{prefix}.images.")] if ft["info"].get(ft["dtype"] + ".is_depth_map"):
frame[key] = values[key.removeprefix(f"{prefix}.depth_maps.")]
else:
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
return frame return frame