mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
remove the sampler cause the relative index is added (#2521)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
@@ -65,7 +65,6 @@ import argparse
|
|||||||
import gc
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Iterator
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -78,19 +77,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|||||||
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
||||||
|
|
||||||
|
|
||||||
class EpisodeSampler(torch.utils.data.Sampler):
|
|
||||||
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
|
||||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
|
||||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
|
||||||
self.frame_ids = range(from_idx, to_idx)
|
|
||||||
|
|
||||||
def __iter__(self) -> Iterator:
|
|
||||||
return iter(self.frame_ids)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.frame_ids)
|
|
||||||
|
|
||||||
|
|
||||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||||
assert chw_float32_torch.dtype == torch.float32
|
assert chw_float32_torch.dtype == torch.float32
|
||||||
assert chw_float32_torch.ndim == 3
|
assert chw_float32_torch.ndim == 3
|
||||||
@@ -119,12 +105,10 @@ def visualize_dataset(
|
|||||||
repo_id = dataset.repo_id
|
repo_id = dataset.repo_id
|
||||||
|
|
||||||
logging.info("Loading dataloader")
|
logging.info("Loading dataloader")
|
||||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
sampler=episode_sampler,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Starting Rerun")
|
logging.info("Starting Rerun")
|
||||||
|
|||||||
Reference in New Issue
Block a user