diff --git a/openx_rlds.py b/openx_rlds.py index 24346f9..cf873a8 100644 --- a/openx_rlds.py +++ b/openx_rlds.py @@ -81,7 +81,7 @@ def transform_raw_dataset(episode, dataset_name): def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bool = True): - dataset_name = builder.name + dataset_name = Path(builder.data_dir).parent.name state_names = [f"motor_{i}" for i in range(8)] if dataset_name in OXE_DATASET_CONFIGS: