mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 06:59:44 +00:00
Validate features during add_frame + Add 2D-to-5D + Add string (#720)
This commit is contained in:
@@ -39,6 +39,7 @@ from lerobot.common.datasets.utils import (
|
||||
TASKS_PATH,
|
||||
append_jsonlines,
|
||||
check_delta_timestamps,
|
||||
check_frame_features,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
create_branch,
|
||||
@@ -724,10 +725,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
||||
then needs to be called.
|
||||
"""
|
||||
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
|
||||
# check the dtype and shape matches, etc.
|
||||
if "task" not in frame:
|
||||
raise ValueError("The mandatory feature 'task' wasn't found in `frame` dictionnary.")
|
||||
# Convert torch to numpy if needed
|
||||
for name in frame:
|
||||
if isinstance(frame[name], torch.Tensor):
|
||||
frame[name] = frame[name].numpy()
|
||||
|
||||
check_frame_features(frame, self.features)
|
||||
|
||||
if self.episode_buffer is None:
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
@@ -757,8 +760,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self._save_image(frame[key], img_path)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
else:
|
||||
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
|
||||
self.episode_buffer[key].append(item)
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
@@ -815,12 +817,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# are processed separately by storing image path and frame info as meta data
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
|
||||
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
|
||||
elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
else:
|
||||
raise ValueError(key)
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
self._wait_image_writer()
|
||||
self._save_episode_table(episode_buffer, episode_index)
|
||||
|
||||
Reference in New Issue
Block a user