mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
Bugfix: Add tests for image deletion and fix mixed image-video deletion (#2592)
* Add tests for image deletion and fix mixed-image-video deletion * Fix docstring whitespace * Remove debug print Signed-off-by: Alex Tyshka <atyshka15@gmail.com> * Remove inaccurate comment * Remove batched video test --------- Signed-off-by: Alex Tyshka <atyshka15@gmail.com>
This commit is contained in:
@@ -1498,7 +1498,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
episode_index = self.episode_buffer["episode_index"]
|
episode_index = self.episode_buffer["episode_index"]
|
||||||
if isinstance(episode_index, np.ndarray):
|
if isinstance(episode_index, np.ndarray):
|
||||||
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
|
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
|
||||||
for cam_key in self.meta.camera_keys:
|
for cam_key in self.meta.image_keys:
|
||||||
img_dir = self._get_image_file_dir(episode_index, cam_key)
|
img_dir = self._get_image_file_dir(episode_index, cam_key)
|
||||||
if img_dir.is_dir():
|
if img_dir.is_dir():
|
||||||
shutil.rmtree(img_dir)
|
shutil.rmtree(img_dir)
|
||||||
|
|||||||
@@ -352,6 +352,65 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
|||||||
image_array_to_pil_image(image)
|
image_array_to_pil_image(image)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tmp_image_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
"""Verify temporary image directories are removed for image features after saving episode."""
|
||||||
|
# Image feature: images should be deleted after saving episode
|
||||||
|
image_key = "image"
|
||||||
|
features_image = {
|
||||||
|
image_key: {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]}
|
||||||
|
}
|
||||||
|
ds_img = empty_lerobot_dataset_factory(root=tmp_path / "img", features=features_image)
|
||||||
|
ds_img.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||||
|
ds_img.save_episode()
|
||||||
|
img_dir = ds_img._get_image_file_dir(0, image_key)
|
||||||
|
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
"""Verify temporary image directories are removed for video encoding when `batch_encoding_size == 1`."""
|
||||||
|
# Video feature: when batch_encoding_size == 1 temporary images should be deleted
|
||||||
|
vid_key = "video"
|
||||||
|
features_video = {
|
||||||
|
vid_key: {"dtype": "video", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]}
|
||||||
|
}
|
||||||
|
|
||||||
|
ds_vid = empty_lerobot_dataset_factory(root=tmp_path / "vid", features=features_video)
|
||||||
|
ds_vid.batch_encoding_size = 1
|
||||||
|
ds_vid.add_frame({vid_key: np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||||
|
ds_vid.save_episode()
|
||||||
|
vid_img_dir = ds_vid._get_image_file_dir(0, vid_key)
|
||||||
|
assert not vid_img_dir.exists(), (
|
||||||
|
"Temporary image directory should be removed when batch_encoding_size == 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
"""Verify temporary image directories are removed appropriately when both image and video features are present."""
|
||||||
|
image_key = "image"
|
||||||
|
vid_key = "video"
|
||||||
|
features_mixed = {
|
||||||
|
image_key: {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]},
|
||||||
|
vid_key: {"dtype": "video", "shape": DUMMY_HWC, "names": ["height", "width", "channels"]},
|
||||||
|
}
|
||||||
|
ds_mixed = empty_lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "mixed", features=features_mixed, batch_encoding_size=2
|
||||||
|
)
|
||||||
|
ds_mixed.add_frame(
|
||||||
|
{
|
||||||
|
"image": np.random.rand(*DUMMY_CHW),
|
||||||
|
"video": np.random.rand(*DUMMY_HWC),
|
||||||
|
"task": "Dummy task",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ds_mixed.save_episode()
|
||||||
|
img_dir = ds_mixed._get_image_file_dir(0, image_key)
|
||||||
|
vid_img_dir = ds_mixed._get_image_file_dir(0, vid_key)
|
||||||
|
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
|
||||||
|
assert vid_img_dir.exists(), (
|
||||||
|
"Temporary image directory should not be removed for video features when batch_encoding_size == 2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts):
|
# TODO(aliberts):
|
||||||
# - [ ] test various attributes & state from init and create
|
# - [ ] test various attributes & state from init and create
|
||||||
# - [ ] test init with episodes and check num_frames
|
# - [ ] test init with episodes and check num_frames
|
||||||
|
|||||||
Reference in New Issue
Block a user