mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
remove the sampler cause the relative index is added (#2521)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
@@ -65,7 +65,6 @@ import argparse
|
||||
import gc
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -78,19 +77,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
return iter(self.frame_ids)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.frame_ids)
|
||||
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
assert chw_float32_torch.dtype == torch.float32
|
||||
assert chw_float32_torch.ndim == 3
|
||||
@@ -119,12 +105,10 @@ def visualize_dataset(
|
||||
repo_id = dataset.repo_id
|
||||
|
||||
logging.info("Loading dataloader")
|
||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
sampler=episode_sampler,
|
||||
)
|
||||
|
||||
logging.info("Starting Rerun")
|
||||
|
||||
Reference in New Issue
Block a user