mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
feat(is_depth): simplifying is_depth nested name + legacy support
This commit is contained in:
@@ -82,7 +82,7 @@ After the first episode of a video stream is encoded, the encoder configuration
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.fps": 30,
|
||||
"video.channels": 3,
|
||||
"video.is_depth_map": false,
|
||||
"is_depth_map": false,
|
||||
"video.g": 2,
|
||||
"video.crf": 30,
|
||||
"video.preset": "fast",
|
||||
@@ -97,7 +97,7 @@ After the first episode of a video stream is encoded, the encoder configuration
|
||||
|
||||
Two sources contribute to the `info` block:
|
||||
|
||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `video.is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -342,14 +342,20 @@ class LeRobotDatasetMetadata:
|
||||
def depth_keys(self) -> list[str]:
|
||||
"""Keys to access depth-map modalities stored as videos or images.
|
||||
|
||||
A depth key is a feature whose ``info`` dict carries ``"<dtype>.is_depth_map": True``.
|
||||
A depth key is a feature whose ``info`` dict carries ``"is_depth_map": True``
|
||||
(or the legacy ``"video.is_depth_map"`` inside ``info`` or ``video_info``).
|
||||
"""
|
||||
return [
|
||||
key
|
||||
for key, ft in self.features.items()
|
||||
# TODO(CarolinePascal): Make sure the legacy video_info works here as well.
|
||||
if (ft.get("info") or {}).get(ft["dtype"] + ".is_depth_map", False)
|
||||
]
|
||||
|
||||
def _is_depth(ft: dict) -> bool:
|
||||
info = ft.get("info") or {}
|
||||
video_info = ft.get("video_info") or {}
|
||||
return (
|
||||
info.get("is_depth_map", False)
|
||||
or info.get("video.is_depth_map", False)
|
||||
or video_info.get("video.is_depth_map", False)
|
||||
)
|
||||
|
||||
return [key for key, ft in self.features.items() if _is_depth(ft)]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
|
||||
@@ -155,7 +155,7 @@ class DatasetWriter:
|
||||
return ep_buffer
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
path_template = DEFAULT_DEPTH_PATH if self.image_key in self._meta.depth_keys else DEFAULT_IMAGE_PATH
|
||||
path_template = DEFAULT_DEPTH_PATH if image_key in self._meta.depth_keys else DEFAULT_IMAGE_PATH
|
||||
fpath = path_template.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
|
||||
@@ -955,7 +955,6 @@ def get_video_info(
|
||||
video_info["video.width"] = video_stream.width
|
||||
video_info["video.codec"] = video_stream.codec.canonical_name
|
||||
video_info["video.pix_fmt"] = video_stream.pix_fmt
|
||||
video_info["video.is_depth_map"] = False
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
video_info["video.fps"] = int(video_stream.base_rate)
|
||||
@@ -976,7 +975,7 @@ def get_video_info(
|
||||
if field_name == "vcodec":
|
||||
continue
|
||||
video_info.setdefault(f"video.{field_name}", field_value)
|
||||
video_info["video.is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
|
||||
video_info["is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
@@ -88,21 +88,18 @@ def hw_to_dataset_features(
|
||||
|
||||
for key, shape in cam_fts.items():
|
||||
dtype = "video" if use_video else "image"
|
||||
if len(shape) == 2 or shape[2] == 1:
|
||||
if len(shape) == 2:
|
||||
shape = (shape[0], shape[1], 1)
|
||||
features[f"{prefix}.depth_maps.{key}"] = {
|
||||
"dtype": dtype,
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": {dtype + ".is_depth_map": True},
|
||||
}
|
||||
else:
|
||||
if len(shape) == 3 and shape[2] in (1, 3):
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": dtype,
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": {"is_depth_map": shape[2] == 1},
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Camera feature '{key}' has shape {shape}. "
|
||||
f"Expected a 3-tuple (H, W, C), e.g. (480, 640, 3) for RGB or (480, 640, 1) for depth."
|
||||
)
|
||||
|
||||
_validate_feature_names(features)
|
||||
return features
|
||||
@@ -132,10 +129,7 @@ def build_dataset_frame(
|
||||
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
if ft["info"].get(ft["dtype"] + ".is_depth_map"):
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.depth_maps.")]
|
||||
else:
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
|
||||
return frame
|
||||
|
||||
@@ -164,11 +158,11 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
type = FeatureType.VISUAL
|
||||
if len(shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||
|
||||
names = ft["names"]
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
else:
|
||||
names = ft["names"]
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == OBS_ENV_STATE:
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith(OBS_STR):
|
||||
|
||||
@@ -59,7 +59,7 @@ DEPTH_FEATURES = {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 1),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": {"video.is_depth_map": True},
|
||||
"info": {"is_depth_map": True},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -155,7 +155,7 @@ def test_create_without_videos_has_no_video_path(tmp_path):
|
||||
|
||||
|
||||
def test_depth_keys_property_filters_by_marker(tmp_path):
|
||||
"""``depth_keys`` selects only video features carrying ``video.is_depth_map=True``."""
|
||||
"""``depth_keys`` selects only features carrying ``is_depth_map=True`` in info."""
|
||||
features = {
|
||||
**VIDEO_FEATURES,
|
||||
**DEPTH_FEATURES,
|
||||
@@ -164,8 +164,8 @@ def test_depth_keys_property_filters_by_marker(tmp_path):
|
||||
repo_id="test/depth_keys", fps=DEFAULT_FPS, features=features, root=tmp_path / "depth_keys"
|
||||
)
|
||||
|
||||
assert set(meta.video_keys) == {"observation.images.laptop", "observation.depth.laptop"}
|
||||
assert meta.depth_keys == ["observation.depth.laptop"]
|
||||
assert set(meta.video_keys) == {"observation.images.laptop", "observation.images.laptop_depth"}
|
||||
assert meta.depth_keys == ["observation.images.laptop_depth"]
|
||||
|
||||
|
||||
def test_depth_keys_empty_when_no_marker(tmp_path):
|
||||
|
||||
@@ -368,7 +368,7 @@ class TestGetVideoInfo:
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.channels"] == 3
|
||||
assert info["video.is_depth_map"] is False
|
||||
assert info["is_depth_map"] is False
|
||||
assert info["has_audio"] is False
|
||||
assert "video.g" not in info
|
||||
assert "video.crf" not in info
|
||||
@@ -463,7 +463,7 @@ class TestEncodeVideoFrames:
|
||||
assert info["video.codec"] == "av1"
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.is_depth_map"] is False
|
||||
assert info["is_depth_map"] is False
|
||||
assert info["has_audio"] is False
|
||||
# Encoder config
|
||||
assert info["video.g"] == 4
|
||||
|
||||
Vendored
+1
-1
@@ -39,7 +39,7 @@ DUMMY_VIDEO_INFO = {
|
||||
"video.crf": 30,
|
||||
"video.preset": 12,
|
||||
"video.fast_decode": 0,
|
||||
"video.is_depth_map": False,
|
||||
"is_depth_map": False,
|
||||
"has_audio": False,
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES = {
|
||||
|
||||
Reference in New Issue
Block a user