mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 09:39:47 +00:00
refactor(datasets): module cleanup (#3169)
This commit is contained in:
@@ -13,15 +13,31 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
|
||||
"""Calculate episode data index for testing. Returns {"from": Tensor, "to": Tensor}."""
|
||||
episode_data_index: dict[str, list[int]] = {"from": [], "to": []}
|
||||
current_episode = None
|
||||
if len(hf_dataset) == 0:
|
||||
return {"from": torch.tensor([]), "to": torch.tensor([])}
|
||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||
if episode_idx != current_episode:
|
||||
episode_data_index["from"].append(idx)
|
||||
if current_episode is not None:
|
||||
episode_data_index["to"].append(idx)
|
||||
current_episode = episode_idx
|
||||
episode_data_index["to"].append(idx + 1)
|
||||
return {k: torch.tensor(v) for k, v in episode_data_index.items()}
|
||||
|
||||
|
||||
def test_drop_n_first_frames():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user