diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 12273cb1d..974762b0b 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -65,7 +65,6 @@ import argparse import gc import logging import time -from collections.abc import Iterator from pathlib import Path import numpy as np @@ -78,19 +77,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD -class EpisodeSampler(torch.utils.data.Sampler): - def __init__(self, dataset: LeRobotDataset, episode_index: int): - from_idx = dataset.meta.episodes["dataset_from_index"][episode_index] - to_idx = dataset.meta.episodes["dataset_to_index"][episode_index] - self.frame_ids = range(from_idx, to_idx) - - def __iter__(self) -> Iterator: - return iter(self.frame_ids) - - def __len__(self) -> int: - return len(self.frame_ids) - - def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: assert chw_float32_torch.dtype == torch.float32 assert chw_float32_torch.ndim == 3 @@ -119,12 +105,10 @@ def visualize_dataset( repo_id = dataset.repo_id logging.info("Loading dataloader") - episode_sampler = EpisodeSampler(dataset, episode_index) dataloader = torch.utils.data.DataLoader( dataset, num_workers=num_workers, batch_size=batch_size, - sampler=episode_sampler, ) logging.info("Starting Rerun")