feat(is_depth): simplifying is_depth nested name + legacy support

This commit is contained in:
CarolinePascal
2026-05-21 14:26:16 +02:00
parent e87933302d
commit 15647f50a2
8 changed files with 37 additions and 38 deletions
+2 -2
View File
@@ -82,7 +82,7 @@ After the first episode of a video stream is encoded, the encoder configuration
"video.pix_fmt": "yuv420p",
"video.fps": 30,
"video.channels": 3,
"video.is_depth_map": false,
"is_depth_map": false,
"video.g": 2,
"video.crf": 30,
"video.preset": "fast",
@@ -97,7 +97,7 @@ After the first episode of a video stream is encoded, the encoder configuration
Two sources contribute to the `info` block:
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `video.is_depth_map`, plus `audio.*` if an audio stream is present.
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `is_depth_map`, plus `audio.*` if an audio stream is present.
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
<Tip>
+13 -7
View File
@@ -342,14 +342,20 @@ class LeRobotDatasetMetadata:
def depth_keys(self) -> list[str]:
"""Keys to access depth-map modalities stored as videos or images.
A depth key is a feature whose ``info`` dict carries ``"<dtype>.is_depth_map": True``.
A depth key is a feature whose ``info`` dict carries ``"is_depth_map": True``
(or the legacy ``"video.is_depth_map"`` inside ``info`` or ``video_info``).
"""
return [
key
for key, ft in self.features.items()
# TODO(CarolinePascal): Make sure the legacy video_info works here as well.
if (ft.get("info") or {}).get(ft["dtype"] + ".is_depth_map", False)
]
def _is_depth(ft: dict) -> bool:
info = ft.get("info") or {}
video_info = ft.get("video_info") or {}
return (
info.get("is_depth_map", False)
or info.get("video.is_depth_map", False)
or video_info.get("video.is_depth_map", False)
)
return [key for key, ft in self.features.items() if _is_depth(ft)]
@property
def camera_keys(self) -> list[str]:
+1 -1
View File
@@ -155,7 +155,7 @@ class DatasetWriter:
return ep_buffer
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
path_template = DEFAULT_DEPTH_PATH if self.image_key in self._meta.depth_keys else DEFAULT_IMAGE_PATH
path_template = DEFAULT_DEPTH_PATH if image_key in self._meta.depth_keys else DEFAULT_IMAGE_PATH
fpath = path_template.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
+1 -2
View File
@@ -955,7 +955,6 @@ def get_video_info(
video_info["video.width"] = video_stream.width
video_info["video.codec"] = video_stream.codec.canonical_name
video_info["video.pix_fmt"] = video_stream.pix_fmt
video_info["video.is_depth_map"] = False
# Calculate fps from r_frame_rate
video_info["video.fps"] = int(video_stream.base_rate)
@@ -976,7 +975,7 @@ def get_video_info(
if field_name == "vcodec":
continue
video_info.setdefault(f"video.{field_name}", field_value)
video_info["video.is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
video_info["is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
return video_info
+13 -19
View File
@@ -88,21 +88,18 @@ def hw_to_dataset_features(
for key, shape in cam_fts.items():
dtype = "video" if use_video else "image"
if len(shape) == 2 or shape[2] == 1:
if len(shape) == 2:
shape = (shape[0], shape[1], 1)
features[f"{prefix}.depth_maps.{key}"] = {
"dtype": dtype,
"shape": shape,
"names": ["height", "width", "channels"],
"info": {dtype + ".is_depth_map": True},
}
else:
if len(shape) == 3 and shape[2] in (1, 3):
features[f"{prefix}.images.{key}"] = {
"dtype": dtype,
"shape": shape,
"names": ["height", "width", "channels"],
"info": {"is_depth_map": shape[2] == 1},
}
else:
raise ValueError(
f"Camera feature '{key}' has shape {shape}. "
f"Expected a 3-tuple (H, W, C), e.g. (480, 640, 3) for RGB or (480, 640, 1) for depth."
)
_validate_feature_names(features)
return features
@@ -132,10 +129,7 @@ def build_dataset_frame(
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
elif ft["dtype"] in ["image", "video"]:
if ft["info"].get(ft["dtype"] + ".is_depth_map"):
frame[key] = values[key.removeprefix(f"{prefix}.depth_maps.")]
else:
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
return frame
@@ -164,11 +158,11 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
type = FeatureType.VISUAL
if len(shape) != 3:
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
names = ft["names"]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
else:
names = ft["names"]
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif key == OBS_ENV_STATE:
type = FeatureType.ENV
elif key.startswith(OBS_STR):
+4 -4
View File
@@ -59,7 +59,7 @@ DEPTH_FEATURES = {
"dtype": "video",
"shape": (64, 96, 1),
"names": ["height", "width", "channels"],
"info": {"video.is_depth_map": True},
"info": {"is_depth_map": True},
},
}
@@ -155,7 +155,7 @@ def test_create_without_videos_has_no_video_path(tmp_path):
def test_depth_keys_property_filters_by_marker(tmp_path):
"""``depth_keys`` selects only video features carrying ``video.is_depth_map=True``."""
"""``depth_keys`` selects only features carrying ``is_depth_map=True`` in info."""
features = {
**VIDEO_FEATURES,
**DEPTH_FEATURES,
@@ -164,8 +164,8 @@ def test_depth_keys_property_filters_by_marker(tmp_path):
repo_id="test/depth_keys", fps=DEFAULT_FPS, features=features, root=tmp_path / "depth_keys"
)
assert set(meta.video_keys) == {"observation.images.laptop", "observation.depth.laptop"}
assert meta.depth_keys == ["observation.depth.laptop"]
assert set(meta.video_keys) == {"observation.images.laptop", "observation.images.laptop_depth"}
assert meta.depth_keys == ["observation.images.laptop_depth"]
def test_depth_keys_empty_when_no_marker(tmp_path):
+2 -2
View File
@@ -368,7 +368,7 @@ class TestGetVideoInfo:
assert info["video.pix_fmt"] == "yuv420p"
assert info["video.fps"] == 30
assert info["video.channels"] == 3
assert info["video.is_depth_map"] is False
assert info["is_depth_map"] is False
assert info["has_audio"] is False
assert "video.g" not in info
assert "video.crf" not in info
@@ -463,7 +463,7 @@ class TestEncodeVideoFrames:
assert info["video.codec"] == "av1"
assert info["video.pix_fmt"] == "yuv420p"
assert info["video.fps"] == 30
assert info["video.is_depth_map"] is False
assert info["is_depth_map"] is False
assert info["has_audio"] is False
# Encoder config
assert info["video.g"] == 4
+1 -1
View File
@@ -39,7 +39,7 @@ DUMMY_VIDEO_INFO = {
"video.crf": 30,
"video.preset": 12,
"video.fast_decode": 0,
"video.is_depth_map": False,
"is_depth_map": False,
"has_audio": False,
}
DUMMY_CAMERA_FEATURES = {