refactor(datasets): module cleanup (#3169)

This commit is contained in:
Steven Palma
2026-03-15 20:42:15 -07:00
committed by GitHub
parent a07b1d76f1
commit 7c2ec31793
9 changed files with 38 additions and 745 deletions
+17 -1
View File
@@ -13,15 +13,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from datasets import Dataset
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.datasets.sampler import EpisodeAwareSampler
from lerobot.datasets.utils import (
hf_transform_to_torch,
)
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
"""Calculate episode data index for testing. Returns {"from": Tensor, "to": Tensor}."""
episode_data_index: dict[str, list[int]] = {"from": [], "to": []}
current_episode = None
if len(hf_dataset) == 0:
return {"from": torch.tensor([]), "to": torch.tensor([])}
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
if episode_idx != current_episode:
episode_data_index["from"].append(idx)
if current_episode is not None:
episode_data_index["to"].append(idx)
current_episode = episode_idx
episode_data_index["to"].append(idx + 1)
return {k: torch.tensor(v) for k, v in episode_data_index.items()}
def test_drop_n_first_frames():
dataset = Dataset.from_dict(
{