feat(depth maps writer): adding support for raw depth maps recording with image writer

This commit is contained in:
CarolinePascal
2026-05-20 16:42:16 +02:00
parent 1dafb4acf6
commit 33a3b5a982
3 changed files with 57 additions and 6 deletions
+3 -1
View File
@@ -55,6 +55,7 @@ from .io_utils import (
from .utils import (
DEFAULT_EPISODES_PATH,
DEFAULT_IMAGE_PATH,
DEFAULT_DEPTH_PATH,
update_chunk_file_indices,
)
from .video_utils import (
@@ -154,7 +155,8 @@ class DatasetWriter:
return ep_buffer
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = DEFAULT_IMAGE_PATH.format(
path_template = DEFAULT_DEPTH_PATH if self.image_key in self._meta.depth_keys else DEFAULT_IMAGE_PATH
fpath = path_template.format(
image_key=image_key, episode_index=episode_index, frame_index=frame_index
)
return self._root / fpath
+53 -5
View File
@@ -42,10 +42,43 @@ def safe_stop_image_writer(func):
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim != 3:
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
"""Convert a NumPy array to a PIL Image, preserving precision for grayscale.
Behaviour by shape:
- ``(H, W)`` or ``(1, H, W)`` / ``(H, W, 1)``: single-channel grayscale.
The native dtype is preserved using the matching PIL mode
(``I;16`` / ``F``). This is the path used for raw depth maps (no rescaling, clamping, or downcasting)
- ``(3, H, W)`` / ``(H, W, 3)``: RGB. Channels-first inputs are transposed
to channels-last. Float inputs in ``[0, 1]`` are scaled to ``uint8``
(existing behaviour, gated by ``range_check``).
Other shapes / channel counts raise ``NotImplementedError`` or
``ValueError``.
"""
#TODO(CarolinePascal): 4 dimensions RGB-D images
if image_array.ndim not in (2, 3):
raise ValueError(
f"The array has {image_array.ndim} dimensions, but 2 or 3 is expected for an image."
)
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
# caller emits (H, W), (1, H, W), or (H, W, 1).
if image_array.ndim == 3:
if image_array.shape[0] == 1:
image_array = image_array[0]
elif image_array.shape[-1] == 1:
image_array = image_array[..., 0]
if image_array.ndim == 2:
if image_array.dtype not in [np.uint16, np.float32]:
raise ValueError(
f"Unsupported single-channel image dtype: {image_array.dtype}. "
f"Supported dtypes: {sorted(str(d) for d in [np.uint16, np.float32])}."
)
return PIL.Image.fromarray(np.ascontiguousarray(image_array))
# 3D path: must be RGB (3 channels), channels-first or channels-last.
if image_array.shape[0] == 3:
# Transpose from pytorch convention (C, H, W) to (H, W, C)
image_array = image_array.transpose(1, 2, 0)
@@ -71,13 +104,28 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
return PIL.Image.fromarray(image_array)
def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
"""Pick the right format-specific kwargs for :meth:`PIL.Image.Image.save`.
PNG uses ``compress_level`` (0-9, zlib). TIFF uses ``compression`` (raw) for lossless raw depth maps.
"""
suffix = Path(fpath).suffix.lower()
if suffix == ".png":
return {"compress_level": compress_level}
if suffix in (".tif", ".tiff"):
return {"compression": "raw"}
return {}
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
"""
Saves a NumPy array or PIL Image to a file.
This function handles both NumPy arrays and PIL Image objects, converting
the former to a PIL Image before saving. It includes error handling for
the save operation.
the save operation. The output format is inferred from the *fpath*
extension: ``.png`` PNG with ``compress_level``, ``.tiff`` / ``.tif``
lossless raw depth maps (TIFF).
Args:
image (np.ndarray | PIL.Image.Image): The image data to save.
@@ -101,7 +149,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
img = image
else:
raise TypeError(f"Unsupported image type: {type(image)}")
img.save(fpath, compress_level=compress_level)
img.save(fpath, **save_kwargs_for_path(fpath, compress_level))
except Exception as e:
logger.error("Error writing image %s: %s", fpath, e)
+1
View File
@@ -93,6 +93,7 @@ DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.tiff"
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"