mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Bugfix: Add tests for image deletion and fix mixed image-video deletion (#2592)
* Add tests for image deletion and fix mixed-image-video deletion * Fix docstring whitespace * Remove debug print Signed-off-by: Alex Tyshka <atyshka15@gmail.com> * Remove inaccurate comment * Remove batched video test --------- Signed-off-by: Alex Tyshka <atyshka15@gmail.com>
This commit is contained in:
@@ -1498,7 +1498,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if isinstance(episode_index, np.ndarray):
|
||||
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
|
||||
for cam_key in self.meta.camera_keys:
|
||||
for cam_key in self.meta.image_keys:
|
||||
img_dir = self._get_image_file_dir(episode_index, cam_key)
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
@@ -352,6 +352,65 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
||||
image_array_to_pil_image(image)
|
||||
|
||||
|
||||
def test_tmp_image_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Verify temporary image directories are removed for image features after saving episode."""
|
||||
# Image feature: images should be deleted after saving episode
|
||||
image_key = "image"
|
||||
features_image = {
|
||||
image_key: {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]}
|
||||
}
|
||||
ds_img = empty_lerobot_dataset_factory(root=tmp_path / "img", features=features_image)
|
||||
ds_img.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
ds_img.save_episode()
|
||||
img_dir = ds_img._get_image_file_dir(0, image_key)
|
||||
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
|
||||
|
||||
|
||||
def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Verify temporary image directories are removed for video encoding when `batch_encoding_size == 1`."""
|
||||
# Video feature: when batch_encoding_size == 1 temporary images should be deleted
|
||||
vid_key = "video"
|
||||
features_video = {
|
||||
vid_key: {"dtype": "video", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]}
|
||||
}
|
||||
|
||||
ds_vid = empty_lerobot_dataset_factory(root=tmp_path / "vid", features=features_video)
|
||||
ds_vid.batch_encoding_size = 1
|
||||
ds_vid.add_frame({vid_key: np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
ds_vid.save_episode()
|
||||
vid_img_dir = ds_vid._get_image_file_dir(0, vid_key)
|
||||
assert not vid_img_dir.exists(), (
|
||||
"Temporary image directory should be removed when batch_encoding_size == 1"
|
||||
)
|
||||
|
||||
|
||||
def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Verify temporary image directories are removed appropriately when both image and video features are present."""
|
||||
image_key = "image"
|
||||
vid_key = "video"
|
||||
features_mixed = {
|
||||
image_key: {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]},
|
||||
vid_key: {"dtype": "video", "shape": DUMMY_HWC, "names": ["height", "width", "channels"]},
|
||||
}
|
||||
ds_mixed = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "mixed", features=features_mixed, batch_encoding_size=2
|
||||
)
|
||||
ds_mixed.add_frame(
|
||||
{
|
||||
"image": np.random.rand(*DUMMY_CHW),
|
||||
"video": np.random.rand(*DUMMY_HWC),
|
||||
"task": "Dummy task",
|
||||
}
|
||||
)
|
||||
ds_mixed.save_episode()
|
||||
img_dir = ds_mixed._get_image_file_dir(0, image_key)
|
||||
vid_img_dir = ds_mixed._get_image_file_dir(0, vid_key)
|
||||
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
|
||||
assert vid_img_dir.exists(), (
|
||||
"Temporary image directory should not be removed for video features when batch_encoding_size == 2"
|
||||
)
|
||||
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [ ] test various attributes & state from init and create
|
||||
# - [ ] test init with episodes and check num_frames
|
||||
|
||||
Reference in New Issue
Block a user