mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
fixed bug in crop_dataset_roi.py
added missing buffer.pt in server dir Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -187,43 +187,39 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
# 2. Process each episode in the original dataset.
|
||||
episodes_info = original_dataset.meta.episodes
|
||||
# (Sort episodes by episode_index for consistency.)
|
||||
|
||||
episodes_info = sorted(episodes_info, key=lambda x: x["episode_index"])
|
||||
# Use the first task from the episode metadata (or "unknown" if not provided)
|
||||
task = episodes_info[0]["tasks"][0] if episodes_info[0].get("tasks") else "unknown"
|
||||
|
||||
for ep in tqdm(episodes_info[:3], desc="Processing episodes"):
|
||||
ep_index = ep.pop("episode_index")
|
||||
# Use the first task from the episode metadata (or "unknown" if not provided)
|
||||
task = ep["tasks"][0] if ep.get("tasks") else "unknown"
|
||||
last_episode_index = 0
|
||||
for sample in tqdm(original_dataset):
|
||||
episode_index = sample.pop("episode_index")
|
||||
if episode_index != last_episode_index:
|
||||
new_dataset.save_episode(task, encode_videos=True)
|
||||
last_episode_index = episode_index
|
||||
sample.pop("frame_index")
|
||||
# Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable)
|
||||
new_sample = sample.copy()
|
||||
# Loop over each observation key that should be cropped/resized.
|
||||
for key, params in crop_params_dict.items():
|
||||
if key in new_sample:
|
||||
top, left, height, width = params
|
||||
# Apply crop then resize.
|
||||
cropped = F.crop(new_sample[key], top, left, height, width)
|
||||
resized = F.resize(cropped, resize_size)
|
||||
new_sample[key] = resized
|
||||
# Add the transformed frame to the new dataset.
|
||||
new_dataset.add_frame(new_sample)
|
||||
|
||||
# Reset the episode buffer in the new dataset (this will store frames for one episode).
|
||||
new_dataset.episode_buffer = new_dataset.create_episode_buffer(episode_index=ep_index)
|
||||
|
||||
# 3. Filter and process all frames belonging to this episode.
|
||||
# Here we loop over the entire dataset and select the frames with the matching episode_index.
|
||||
# (Depending on the dataset size, you might want a more efficient method.)
|
||||
ep_frames = [sample for sample in original_dataset if sample["episode_index"] == ep_index]
|
||||
|
||||
for sample in tqdm(ep_frames):
|
||||
sample.pop("episode_index")
|
||||
sample.pop("frame_index")
|
||||
# Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable)
|
||||
new_sample = sample.copy()
|
||||
# Loop over each observation key that should be cropped/resized.
|
||||
for key, params in crop_params_dict.items():
|
||||
if key in new_sample:
|
||||
top, left, height, width = params
|
||||
# Apply crop then resize.
|
||||
cropped = F.crop(new_sample[key], top, left, height, width)
|
||||
resized = F.resize(cropped, resize_size)
|
||||
new_sample[key] = resized
|
||||
# Add the transformed frame to the new dataset.
|
||||
new_dataset.add_frame(new_sample)
|
||||
|
||||
# 4. Save the episode (this writes the parquet file and image files).
|
||||
new_dataset.save_episode(task, encode_videos=True)
|
||||
# save last episode
|
||||
new_dataset.save_episode(task, encode_videos=True)
|
||||
|
||||
# Optionally, consolidate the new dataset to compute statistics and update video info.
|
||||
new_dataset.consolidate(run_compute_stats=True, keep_image_files=True)
|
||||
|
||||
new_dataset.push_to_hub(tags=None)
|
||||
|
||||
return new_dataset
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user