mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
feat(features): route 2D camera shapes to observation.depth.<key>
This commit is contained in:
@@ -69,6 +69,7 @@ def hw_to_dataset_features(
|
|||||||
for key, ftype in hw_features.items()
|
for key, ftype in hw_features.items()
|
||||||
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
||||||
}
|
}
|
||||||
|
#TODO(CarolinePascal): we should not rely on the shape to determine if a feature is a camera !
|
||||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||||
|
|
||||||
if joint_fts and prefix == ACTION:
|
if joint_fts and prefix == ACTION:
|
||||||
@@ -86,11 +87,22 @@ def hw_to_dataset_features(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for key, shape in cam_fts.items():
|
for key, shape in cam_fts.items():
|
||||||
features[f"{prefix}.images.{key}"] = {
|
dtype = "video" if use_video else "image"
|
||||||
"dtype": "video" if use_video else "image",
|
if len(shape) == 2 or shape[2] == 1:
|
||||||
"shape": shape,
|
if len(shape) == 2:
|
||||||
"names": ["height", "width", "channels"],
|
shape = (shape[0], shape[1], 1)
|
||||||
}
|
features[f"{prefix}.depth_maps.{key}"] = {
|
||||||
|
"dtype": dtype,
|
||||||
|
"shape": shape,
|
||||||
|
"names": ["height", "width", "channels"],
|
||||||
|
"info": {dtype + ".is_depth_map": True},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
features[f"{prefix}.images.{key}"] = {
|
||||||
|
"dtype": dtype,
|
||||||
|
"shape": shape,
|
||||||
|
"names": ["height", "width", "channels"],
|
||||||
|
}
|
||||||
|
|
||||||
_validate_feature_names(features)
|
_validate_feature_names(features)
|
||||||
return features
|
return features
|
||||||
@@ -120,7 +132,10 @@ def build_dataset_frame(
|
|||||||
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
||||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||||
elif ft["dtype"] in ["image", "video"]:
|
elif ft["dtype"] in ["image", "video"]:
|
||||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
if ft["info"].get(ft["dtype"] + ".is_depth_map"):
|
||||||
|
frame[key] = values[key.removeprefix(f"{prefix}.depth_maps.")]
|
||||||
|
else:
|
||||||
|
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||||
|
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user