feat(features): route 2D camera shapes to observation.depth.<key>

This commit is contained in:
CarolinePascal
2026-05-20 15:50:46 +02:00
parent f15348e769
commit 085f574301
+21 -6
View File
@@ -69,6 +69,7 @@ def hw_to_dataset_features(
for key, ftype in hw_features.items()
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
}
#TODO(CarolinePascal): we should not rely on the shape to determine if a feature is a camera !
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
if joint_fts and prefix == ACTION:
@@ -86,11 +87,22 @@ def hw_to_dataset_features(
}
for key, shape in cam_fts.items():
features[f"{prefix}.images.{key}"] = {
"dtype": "video" if use_video else "image",
"shape": shape,
"names": ["height", "width", "channels"],
}
dtype = "video" if use_video else "image"
if len(shape) == 2 or shape[2] == 1:
if len(shape) == 2:
shape = (shape[0], shape[1], 1)
features[f"{prefix}.depth_maps.{key}"] = {
"dtype": dtype,
"shape": shape,
"names": ["height", "width", "channels"],
"info": {dtype + ".is_depth_map": True},
}
else:
features[f"{prefix}.images.{key}"] = {
"dtype": dtype,
"shape": shape,
"names": ["height", "width", "channels"],
}
_validate_feature_names(features)
return features
@@ -120,7 +132,10 @@ def build_dataset_frame(
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
elif ft["dtype"] in ["image", "video"]:
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
if ft["info"].get(ft["dtype"] + ".is_depth_map"):
frame[key] = values[key.removeprefix(f"{prefix}.depth_maps.")]
else:
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
return frame