diff --git a/src/lerobot/datasets/dataset_reader.py b/src/lerobot/datasets/dataset_reader.py index 59aaa40e5..ae5934283 100644 --- a/src/lerobot/datasets/dataset_reader.py +++ b/src/lerobot/datasets/dataset_reader.py @@ -126,10 +126,53 @@ class DatasetReader: def _load_hf_dataset(self) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" features = get_hf_features_from_features(self._meta.features) + # Datasets annotated with the PR1 language columns may have been + # written without registering those columns in ``meta/info.json`` + # (e.g. they predate ``CODEBASE_VERSION="v3.1"`` and were + # back-filled by ``lerobot-annotate``). Probe a single parquet + # shard and graft the column features on so the strict + # ``Dataset.from_parquet`` cast doesn't fail with + # ``column names don't match``. + features = self._extend_features_with_language_columns(features) hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes) hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset + def _extend_features_with_language_columns( + self, features: datasets.Features + ) -> datasets.Features: + """Add ``language_persistent`` / ``language_events`` to ``features`` + when the underlying parquet shards declare them but the metadata + doesn't. No-op when neither column is present or both are + already registered. + """ + # Find any one parquet to peek at; bail if there are none yet + # (the dataset will fail later for an unrelated reason and we + # want that error to surface as-is). + try: + sample = next((self.root / "data").glob("*/*.parquet")) + except StopIteration: + return features + + from pyarrow import parquet as _pq # noqa: PLC0415 + + schema_names = set(_pq.read_schema(sample).names) + from .language import ( # noqa: PLC0415 + LANGUAGE_EVENTS, + LANGUAGE_PERSISTENT, + language_events_column_feature, + language_persistent_column_feature, + ) + + extra: dict[str, object] = {} + if LANGUAGE_PERSISTENT in schema_names and LANGUAGE_PERSISTENT not in features: + extra[LANGUAGE_PERSISTENT] = language_persistent_column_feature() + if LANGUAGE_EVENTS in schema_names and LANGUAGE_EVENTS not in features: + extra[LANGUAGE_EVENTS] = language_events_column_feature() + if not extra: + return features + return datasets.Features({**features, **extra}) + def _check_cached_episodes_sufficient(self) -> bool: """Check if the cached dataset contains all requested episodes and their video files.""" if self.hf_dataset is None or len(self.hf_dataset) == 0: