fix kuka_dataset_transform & filter

This commit is contained in:
Tavish
2025-02-21 19:51:27 +08:00
parent a96845809b
commit 43fe83852c
2 changed files with 6 additions and 14 deletions
+6 -1
View File
@@ -190,7 +190,12 @@ def create_lerobot_dataset(
builder = tfds.builder(dataset_name, data_dir=data_dir, version=version)
features = generate_features_from_raw(builder, use_videos)
raw_dataset = builder.as_dataset(split="train").map(partial(transform_raw_dataset, dataset_name=dataset_name))
filter_fn = lambda e: e["success"] if dataset_name == "kuka" else True
raw_dataset = (
builder.as_dataset(split="train")
.filter(filter_fn)
.map(partial(transform_raw_dataset, dataset_name=dataset_name))
)
if fps is None:
if dataset_name in OXE_DATASET_CONFIGS: