mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
feat(is_depth): simplifying is_depth nested name + legacy support
This commit is contained in:
@@ -82,7 +82,7 @@ After the first episode of a video stream is encoded, the encoder configuration
|
|||||||
"video.pix_fmt": "yuv420p",
|
"video.pix_fmt": "yuv420p",
|
||||||
"video.fps": 30,
|
"video.fps": 30,
|
||||||
"video.channels": 3,
|
"video.channels": 3,
|
||||||
"video.is_depth_map": false,
|
"is_depth_map": false,
|
||||||
"video.g": 2,
|
"video.g": 2,
|
||||||
"video.crf": 30,
|
"video.crf": 30,
|
||||||
"video.preset": "fast",
|
"video.preset": "fast",
|
||||||
@@ -97,7 +97,7 @@ After the first episode of a video stream is encoded, the encoder configuration
|
|||||||
|
|
||||||
Two sources contribute to the `info` block:
|
Two sources contribute to the `info` block:
|
||||||
|
|
||||||
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `video.is_depth_map`, plus `audio.*` if an audio stream is present.
|
- **Stream-derived** (read back from the encoded MP4 with PyAV): `video.height`, `video.width`, `video.codec`, `video.pix_fmt`, `video.fps`, `video.channels`, `is_depth_map`, plus `audio.*` if an audio stream is present.
|
||||||
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
|
- **Encoder-derived** (taken from `VideoEncoderConfig`): `video.g`, `video.crf`, `video.preset`, `video.fast_decode`, `video.video_backend`, `video.extra_options`.
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|||||||
@@ -342,14 +342,20 @@ class LeRobotDatasetMetadata:
|
|||||||
def depth_keys(self) -> list[str]:
|
def depth_keys(self) -> list[str]:
|
||||||
"""Keys to access depth-map modalities stored as videos or images.
|
"""Keys to access depth-map modalities stored as videos or images.
|
||||||
|
|
||||||
A depth key is a feature whose ``info`` dict carries ``"<dtype>.is_depth_map": True``.
|
A depth key is a feature whose ``info`` dict carries ``"is_depth_map": True``
|
||||||
|
(or the legacy ``"video.is_depth_map"`` inside ``info`` or ``video_info``).
|
||||||
"""
|
"""
|
||||||
return [
|
|
||||||
key
|
def _is_depth(ft: dict) -> bool:
|
||||||
for key, ft in self.features.items()
|
info = ft.get("info") or {}
|
||||||
# TODO(CarolinePascal): Make sure the legacy video_info works here as well.
|
video_info = ft.get("video_info") or {}
|
||||||
if (ft.get("info") or {}).get(ft["dtype"] + ".is_depth_map", False)
|
return (
|
||||||
]
|
info.get("is_depth_map", False)
|
||||||
|
or info.get("video.is_depth_map", False)
|
||||||
|
or video_info.get("video.is_depth_map", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
return [key for key, ft in self.features.items() if _is_depth(ft)]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def camera_keys(self) -> list[str]:
|
def camera_keys(self) -> list[str]:
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ class DatasetWriter:
|
|||||||
return ep_buffer
|
return ep_buffer
|
||||||
|
|
||||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||||
path_template = DEFAULT_DEPTH_PATH if self.image_key in self._meta.depth_keys else DEFAULT_IMAGE_PATH
|
path_template = DEFAULT_DEPTH_PATH if image_key in self._meta.depth_keys else DEFAULT_IMAGE_PATH
|
||||||
fpath = path_template.format(
|
fpath = path_template.format(
|
||||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -955,7 +955,6 @@ def get_video_info(
|
|||||||
video_info["video.width"] = video_stream.width
|
video_info["video.width"] = video_stream.width
|
||||||
video_info["video.codec"] = video_stream.codec.canonical_name
|
video_info["video.codec"] = video_stream.codec.canonical_name
|
||||||
video_info["video.pix_fmt"] = video_stream.pix_fmt
|
video_info["video.pix_fmt"] = video_stream.pix_fmt
|
||||||
video_info["video.is_depth_map"] = False
|
|
||||||
|
|
||||||
# Calculate fps from r_frame_rate
|
# Calculate fps from r_frame_rate
|
||||||
video_info["video.fps"] = int(video_stream.base_rate)
|
video_info["video.fps"] = int(video_stream.base_rate)
|
||||||
@@ -976,7 +975,7 @@ def get_video_info(
|
|||||||
if field_name == "vcodec":
|
if field_name == "vcodec":
|
||||||
continue
|
continue
|
||||||
video_info.setdefault(f"video.{field_name}", field_value)
|
video_info.setdefault(f"video.{field_name}", field_value)
|
||||||
video_info["video.is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
|
video_info["is_depth_map"] = isinstance(video_encoder, DepthEncoderConfig)
|
||||||
|
|
||||||
return video_info
|
return video_info
|
||||||
|
|
||||||
|
|||||||
@@ -88,21 +88,18 @@ def hw_to_dataset_features(
|
|||||||
|
|
||||||
for key, shape in cam_fts.items():
|
for key, shape in cam_fts.items():
|
||||||
dtype = "video" if use_video else "image"
|
dtype = "video" if use_video else "image"
|
||||||
if len(shape) == 2 or shape[2] == 1:
|
if len(shape) == 3 and shape[2] in (1, 3):
|
||||||
if len(shape) == 2:
|
|
||||||
shape = (shape[0], shape[1], 1)
|
|
||||||
features[f"{prefix}.depth_maps.{key}"] = {
|
|
||||||
"dtype": dtype,
|
|
||||||
"shape": shape,
|
|
||||||
"names": ["height", "width", "channels"],
|
|
||||||
"info": {dtype + ".is_depth_map": True},
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
features[f"{prefix}.images.{key}"] = {
|
features[f"{prefix}.images.{key}"] = {
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
"shape": shape,
|
"shape": shape,
|
||||||
"names": ["height", "width", "channels"],
|
"names": ["height", "width", "channels"],
|
||||||
|
"info": {"is_depth_map": shape[2] == 1},
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Camera feature '{key}' has shape {shape}. "
|
||||||
|
f"Expected a 3-tuple (H, W, C), e.g. (480, 640, 3) for RGB or (480, 640, 1) for depth."
|
||||||
|
)
|
||||||
|
|
||||||
_validate_feature_names(features)
|
_validate_feature_names(features)
|
||||||
return features
|
return features
|
||||||
@@ -132,10 +129,7 @@ def build_dataset_frame(
|
|||||||
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
||||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||||
elif ft["dtype"] in ["image", "video"]:
|
elif ft["dtype"] in ["image", "video"]:
|
||||||
if ft["info"].get(ft["dtype"] + ".is_depth_map"):
|
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||||
frame[key] = values[key.removeprefix(f"{prefix}.depth_maps.")]
|
|
||||||
else:
|
|
||||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
|
||||||
|
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
@@ -164,11 +158,11 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
|||||||
type = FeatureType.VISUAL
|
type = FeatureType.VISUAL
|
||||||
if len(shape) != 3:
|
if len(shape) != 3:
|
||||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||||
|
else:
|
||||||
names = ft["names"]
|
names = ft["names"]
|
||||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||||
shape = (shape[2], shape[0], shape[1])
|
shape = (shape[2], shape[0], shape[1])
|
||||||
elif key == OBS_ENV_STATE:
|
elif key == OBS_ENV_STATE:
|
||||||
type = FeatureType.ENV
|
type = FeatureType.ENV
|
||||||
elif key.startswith(OBS_STR):
|
elif key.startswith(OBS_STR):
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ DEPTH_FEATURES = {
|
|||||||
"dtype": "video",
|
"dtype": "video",
|
||||||
"shape": (64, 96, 1),
|
"shape": (64, 96, 1),
|
||||||
"names": ["height", "width", "channels"],
|
"names": ["height", "width", "channels"],
|
||||||
"info": {"video.is_depth_map": True},
|
"info": {"is_depth_map": True},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,7 +155,7 @@ def test_create_without_videos_has_no_video_path(tmp_path):
|
|||||||
|
|
||||||
|
|
||||||
def test_depth_keys_property_filters_by_marker(tmp_path):
|
def test_depth_keys_property_filters_by_marker(tmp_path):
|
||||||
"""``depth_keys`` selects only video features carrying ``video.is_depth_map=True``."""
|
"""``depth_keys`` selects only features carrying ``is_depth_map=True`` in info."""
|
||||||
features = {
|
features = {
|
||||||
**VIDEO_FEATURES,
|
**VIDEO_FEATURES,
|
||||||
**DEPTH_FEATURES,
|
**DEPTH_FEATURES,
|
||||||
@@ -164,8 +164,8 @@ def test_depth_keys_property_filters_by_marker(tmp_path):
|
|||||||
repo_id="test/depth_keys", fps=DEFAULT_FPS, features=features, root=tmp_path / "depth_keys"
|
repo_id="test/depth_keys", fps=DEFAULT_FPS, features=features, root=tmp_path / "depth_keys"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert set(meta.video_keys) == {"observation.images.laptop", "observation.depth.laptop"}
|
assert set(meta.video_keys) == {"observation.images.laptop", "observation.images.laptop_depth"}
|
||||||
assert meta.depth_keys == ["observation.depth.laptop"]
|
assert meta.depth_keys == ["observation.images.laptop_depth"]
|
||||||
|
|
||||||
|
|
||||||
def test_depth_keys_empty_when_no_marker(tmp_path):
|
def test_depth_keys_empty_when_no_marker(tmp_path):
|
||||||
|
|||||||
@@ -368,7 +368,7 @@ class TestGetVideoInfo:
|
|||||||
assert info["video.pix_fmt"] == "yuv420p"
|
assert info["video.pix_fmt"] == "yuv420p"
|
||||||
assert info["video.fps"] == 30
|
assert info["video.fps"] == 30
|
||||||
assert info["video.channels"] == 3
|
assert info["video.channels"] == 3
|
||||||
assert info["video.is_depth_map"] is False
|
assert info["is_depth_map"] is False
|
||||||
assert info["has_audio"] is False
|
assert info["has_audio"] is False
|
||||||
assert "video.g" not in info
|
assert "video.g" not in info
|
||||||
assert "video.crf" not in info
|
assert "video.crf" not in info
|
||||||
@@ -463,7 +463,7 @@ class TestEncodeVideoFrames:
|
|||||||
assert info["video.codec"] == "av1"
|
assert info["video.codec"] == "av1"
|
||||||
assert info["video.pix_fmt"] == "yuv420p"
|
assert info["video.pix_fmt"] == "yuv420p"
|
||||||
assert info["video.fps"] == 30
|
assert info["video.fps"] == 30
|
||||||
assert info["video.is_depth_map"] is False
|
assert info["is_depth_map"] is False
|
||||||
assert info["has_audio"] is False
|
assert info["has_audio"] is False
|
||||||
# Encoder config
|
# Encoder config
|
||||||
assert info["video.g"] == 4
|
assert info["video.g"] == 4
|
||||||
|
|||||||
Vendored
+1
-1
@@ -39,7 +39,7 @@ DUMMY_VIDEO_INFO = {
|
|||||||
"video.crf": 30,
|
"video.crf": 30,
|
||||||
"video.preset": 12,
|
"video.preset": 12,
|
||||||
"video.fast_decode": 0,
|
"video.fast_decode": 0,
|
||||||
"video.is_depth_map": False,
|
"is_depth_map": False,
|
||||||
"has_audio": False,
|
"has_audio": False,
|
||||||
}
|
}
|
||||||
DUMMY_CAMERA_FEATURES = {
|
DUMMY_CAMERA_FEATURES = {
|
||||||
|
|||||||
Reference in New Issue
Block a user