add tests/fixes

This commit is contained in:
root
2026-03-11 22:49:06 +00:00
parent f0848c6887
commit 819c1b9710
8 changed files with 306 additions and 144 deletions
+168
View File
@@ -23,11 +23,18 @@ These tests verify that:
- Subtask handling gracefully handles missing data
"""
import numpy as np
import pandas as pd
import pytest
import torch
from lerobot.data_processing.data_annotations.subtask_annotations import EpisodeSkills, Skill
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import (
create_subtask_index_array,
create_subtasks_dataframe,
save_subtasks,
)
class TestSubtaskDataset:
@@ -188,3 +195,164 @@ class TestSubtaskEdgeCases:
)
else:
subtask_map[idx] = subtask
class TestCreateSubtasksDataframe:
"""Tests for create_subtasks_dataframe in utils."""
def test_empty_annotations(self):
"""Empty annotations produce empty DataFrame and empty mapping."""
subtasks_df, skill_to_subtask_idx = create_subtasks_dataframe({})
assert len(subtasks_df) == 0
assert list(subtasks_df.columns) == ["subtask_index"]
assert skill_to_subtask_idx == {}
def test_single_episode_single_skill(self):
"""Single episode with one skill produces one row and correct mapping."""
annotations = {
0: EpisodeSkills(
episode_index=0,
description="Pick",
skills=[Skill("pick", 0.0, 1.0)],
),
}
subtasks_df, skill_to_subtask_idx = create_subtasks_dataframe(annotations)
assert len(subtasks_df) == 1
assert subtasks_df.index.tolist() == ["pick"]
assert subtasks_df.loc["pick", "subtask_index"] == 0
assert skill_to_subtask_idx == {"pick": 0}
def test_multiple_episodes_overlapping_skills(self):
"""Multiple episodes with overlapping skill names yield unique sorted skills."""
annotations = {
0: EpisodeSkills(
episode_index=0,
description="Ep0",
skills=[
Skill("place", 0.0, 0.5),
Skill("pick", 0.5, 1.0),
],
),
1: EpisodeSkills(
episode_index=1,
description="Ep1",
skills=[Skill("pick", 0.0, 1.0)],
),
}
subtasks_df, skill_to_subtask_idx = create_subtasks_dataframe(annotations)
# Sorted order: pick, place
assert subtasks_df.index.tolist() == ["pick", "place"]
assert int(subtasks_df.loc["pick", "subtask_index"]) == 0
assert int(subtasks_df.loc["place", "subtask_index"]) == 1
assert skill_to_subtask_idx["pick"] == 0
assert skill_to_subtask_idx["place"] == 1
def test_skills_sorted_alphabetically(self):
"""Subtask rows are in alphabetical order by skill name."""
annotations = {
0: EpisodeSkills(
episode_index=0,
description="Ep",
skills=[
Skill("z_final", 0.0, 0.33),
Skill("a_first", 0.33, 0.66),
Skill("m_mid", 0.66, 1.0),
],
),
}
subtasks_df, _ = create_subtasks_dataframe(annotations)
assert subtasks_df.index.tolist() == ["a_first", "m_mid", "z_final"]
assert list(subtasks_df["subtask_index"]) == [0, 1, 2]
class TestSaveSubtasks:
"""Tests for save_subtasks in utils."""
def test_save_subtasks_creates_file(self, tmp_path):
"""save_subtasks writes meta/subtasks.parquet and creates parent dir."""
subtasks_df = pd.DataFrame(
[{"subtask": "pick", "subtask_index": 0}, {"subtask": "place", "subtask_index": 1}]
).set_index("subtask")
save_subtasks(subtasks_df, tmp_path)
out = tmp_path / "meta" / "subtasks.parquet"
assert out.exists()
read_df = pd.read_parquet(out)
pd.testing.assert_frame_equal(read_df.reset_index(), subtasks_df.reset_index())
def test_save_subtasks_content_matches(self, tmp_path):
"""Saved parquet round-trips with same content."""
subtasks_df = pd.DataFrame(
[{"subtask": "a", "subtask_index": 0}, {"subtask": "b", "subtask_index": 1}]
).set_index("subtask")
save_subtasks(subtasks_df, tmp_path)
read_df = pd.read_parquet(tmp_path / "meta" / "subtasks.parquet")
assert read_df.index.tolist() == subtasks_df.index.tolist()
assert list(read_df["subtask_index"]) == list(subtasks_df["subtask_index"])
class TestCreateSubtaskIndexArray:
"""Tests for create_subtask_index_array in utils."""
@pytest.fixture
def dataset_with_episodes(self, tmp_path, empty_lerobot_dataset_factory):
"""Dataset with two episodes (10 frames each) for index-array tests."""
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "subtask_idx", features=features)
for _ in range(10):
dataset.add_frame({"state": torch.randn(2), "task": "Task A"})
dataset.save_episode()
for _ in range(10):
dataset.add_frame({"state": torch.randn(2), "task": "Task B"})
dataset.save_episode()
dataset.finalize()
return LeRobotDataset(dataset.repo_id, root=dataset.root)
def test_unannotated_all_minus_one(self, dataset_with_episodes):
"""With no annotations, all frame indices are -1."""
skill_to_subtask_idx = {"pick": 0, "place": 1}
arr = create_subtask_index_array(dataset_with_episodes, {}, skill_to_subtask_idx)
assert len(arr) == len(dataset_with_episodes)
assert arr.dtype == np.int64
assert np.all(arr == -1)
def test_annotated_episode_assigns_by_timestamp(self, dataset_with_episodes):
"""Frames in an annotated episode get subtask index from skill time ranges."""
# Dataset uses DEFAULT_FPS=30. Episode 0: 10 frames -> timestamps 0, 1/30, ..., 9/30 (~0.3s).
# Skills: "pick" [0, 0.2), "place" [0.2, 0.5). At 30 fps: 0.2s = 6 frames, so frames 0-5 = pick, 6-9 = place.
annotations = {
0: EpisodeSkills(
episode_index=0,
description="Pick and place",
skills=[
Skill("pick", 0.0, 0.2), # frames 0-5 at 30 fps
Skill("place", 0.2, 0.5), # frames 6-9 at 30 fps
],
),
}
skill_to_subtask_idx = {"pick": 0, "place": 1}
arr = create_subtask_index_array(dataset_with_episodes, annotations, skill_to_subtask_idx)
assert len(arr) == 20
# Episode 0: from_index=0, to_index=10 at 30 fps
for i in range(6):
assert arr[i] == 0, f"frame {i} should be pick"
for i in range(6, 10):
assert arr[i] == 1, f"frame {i} should be place"
# Episode 1 not annotated
for i in range(10, 20):
assert arr[i] == -1
def test_partial_annotations_leave_others_minus_one(self, dataset_with_episodes):
"""Only annotated episodes get non -1 indices; others stay -1."""
annotations = {
1: EpisodeSkills(
episode_index=1,
description="Place only",
skills=[Skill("place", 0.0, 1.0)],
),
}
skill_to_subtask_idx = {"place": 0}
arr = create_subtask_index_array(dataset_with_episodes, annotations, skill_to_subtask_idx)
for i in range(10):
assert arr[i] == -1
for i in range(10, 20):
assert arr[i] == 0