mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
feat(output unit): adding support for output unit specification at dataset reading/training time
Co-authored-by: Wensi (Vince) Ai <59036629+wensi-ai@users.noreply.github.com>
This commit is contained in:
@@ -35,12 +35,17 @@ class DatasetConfig:
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_video_backend)
|
||||
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# When True, RGB video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
# Physical unit depth maps are dequantized to at load time: "mm" (millimetres) or "m" (metres).
|
||||
# Has no effect on datasets without depth cameras.
|
||||
depth_output_unit: str = "mm"
|
||||
streaming: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.depth_output_unit not in ("m", "mm"):
|
||||
raise ValueError(f"depth_output_unit must be 'm' or 'mm', got {self.depth_output_unit!r}")
|
||||
if self.episodes is not None:
|
||||
if any(ep < 0 for ep in self.episodes):
|
||||
raise ValueError(
|
||||
|
||||
@@ -54,6 +54,7 @@ class DatasetReader:
|
||||
delta_timestamps: dict[str, list[float]] | None,
|
||||
image_transforms: Callable | None,
|
||||
return_uint8: bool = False,
|
||||
depth_output_unit: str = "mm",
|
||||
):
|
||||
"""Initialize the reader with metadata, filtering, and transform config.
|
||||
|
||||
@@ -71,6 +72,10 @@ class DatasetReader:
|
||||
relative timestamp offsets for temporal context windows.
|
||||
image_transforms: Optional torchvision v2 transform applied to
|
||||
visual features.
|
||||
return_uint8: If True, return RGB video frames as raw uint8 tensors
|
||||
instead of normalized float32.
|
||||
depth_output_unit: Physical unit depth maps are dequantized to
|
||||
(``"m"`` or ``"mm"``). Defaults to ``"mm"``.
|
||||
"""
|
||||
self._meta = meta
|
||||
self.root = root
|
||||
@@ -79,6 +84,7 @@ class DatasetReader:
|
||||
self._video_backend = video_backend
|
||||
self._image_transforms = image_transforms
|
||||
self._return_uint8 = return_uint8
|
||||
self._depth_output_unit = depth_output_unit
|
||||
|
||||
self.hf_dataset: datasets.Dataset | None = None
|
||||
self._absolute_to_relative_idx: dict[int, int] | None = None
|
||||
@@ -266,6 +272,7 @@ class DatasetReader:
|
||||
depth_max=depth_encoder.depth_max,
|
||||
shift=depth_encoder.shift,
|
||||
use_log=depth_encoder.use_log,
|
||||
output_unit=self._depth_output_unit,
|
||||
)
|
||||
return vid_key, frames.squeeze(0)
|
||||
|
||||
|
||||
@@ -96,6 +96,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
return_uint8=True,
|
||||
depth_output_unit=cfg.dataset.depth_output_unit,
|
||||
tolerance_s=cfg.tolerance_s,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -58,6 +58,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
depth_output_unit: str = "mm",
|
||||
batch_encoding_size: int = 1,
|
||||
camera_encoder: VideoEncoderConfig | None = None,
|
||||
depth_encoder: DepthEncoderConfig | None = None,
|
||||
@@ -212,6 +213,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
self._return_uint8 = return_uint8
|
||||
self._depth_output_unit = depth_output_unit
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._encoder_threads = encoder_threads
|
||||
|
||||
@@ -252,6 +254,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
return_uint8=self._return_uint8,
|
||||
depth_output_unit=self._depth_output_unit,
|
||||
)
|
||||
|
||||
# Load actual data
|
||||
@@ -321,6 +324,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
delta_timestamps=self.delta_timestamps,
|
||||
image_transforms=self.image_transforms,
|
||||
return_uint8=self._return_uint8,
|
||||
depth_output_unit=self._depth_output_unit,
|
||||
)
|
||||
return self.reader
|
||||
|
||||
@@ -722,6 +726,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
|
||||
obj._return_uint8 = False
|
||||
obj._depth_output_unit = "mm"
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
@@ -820,6 +825,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
obj._return_uint8 = False
|
||||
obj._depth_output_unit = "mm"
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
|
||||
if obj._requested_root is not None:
|
||||
|
||||
Reference in New Issue
Block a user