Remove local_files_only and use codebase_version instead of branches (#734)

This commit is contained in:
Simon Alibert
2025-02-19 08:36:32 +01:00
committed by GitHub
parent 624eaf1175
commit fbf2f2222a
18 changed files with 253 additions and 198 deletions
+32 -63
View File
@@ -13,10 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import importlib.resources
import json
import logging
import textwrap
from collections.abc import Iterator
from itertools import accumulate
from pathlib import Path
@@ -31,9 +31,11 @@ import pyarrow.compute as pc
import torch
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from packaging import version
from PIL import Image as PILImage
from torchvision import transforms
from lerobot.common.datasets.backward_compatibility import V21_MESSAGE, BackwardCompatibilityError
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
@@ -200,7 +202,7 @@ def write_task(task_index: int, task: dict, local_dir: Path):
append_jsonlines(task_dict, local_dir / TASKS_PATH)
def load_tasks(local_dir: Path) -> dict:
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
tasks = load_jsonlines(local_dir / TASKS_PATH)
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
@@ -231,7 +233,9 @@ def load_episodes_stats(local_dir: Path) -> dict:
}
def backward_compatible_episodes_stats(stats, episodes: list[int]) -> dict[str, dict[str, np.ndarray]]:
def backward_compatible_episodes_stats(
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
) -> dict[str, dict[str, np.ndarray]]:
return {ep_idx: stats for ep_idx in episodes}
@@ -265,73 +269,38 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
return items_dict
def _get_major_minor(version: str) -> tuple[int]:
split = version.strip("v").split(".")
return int(split[0]), int(split[1])
class BackwardCompatibilityError(Exception):
def __init__(self, repo_id, version):
message = textwrap.dedent(f"""
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v2.0 which is not backward compatible with v1.x.
Please, use our conversion script. Modify the following command with your own task description:
```
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
--repo-id {repo_id} \\
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
```
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
""")
super().__init__(message)
def check_version_compatibility(
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
) -> None:
current_major, _ = _get_major_minor(current_version)
major_to_check, _ = _get_major_minor(version_to_check)
if major_to_check < current_major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, version_to_check)
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
logging.warning(
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
codebase. The current codebase version is {current_version}. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
)
v_check = version.parse(version_to_check)
v_current = version.parse(current_version)
if v_check.major < v_current.major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, v_check)
elif v_check.minor < v_current.minor:
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=version_to_check))
def get_hub_safe_version(repo_id: str, version: str) -> str:
def get_repo_versions(repo_id: str) -> list[version.Version]:
"""Returns available valid versions (branches and tags) on given repo."""
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if version not in branches:
num_version = float(version.strip("v"))
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
raise BackwardCompatibilityError(repo_id, version)
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
repo_versions = []
for ref in repo_refs:
with contextlib.suppress(version.InvalidVersion):
repo_versions.append(version.parse(ref))
logging.warning(
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
codebase. The following versions are available: {branches}.
The requested version ('{version}') is not found. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
)
if "main" not in branches:
raise ValueError(f"Version 'main' not found on {repo_id}")
return "main"
else:
return version
return repo_versions
def get_safe_revision(repo_id: str, revision: str) -> str:
"""Returns the version if available on repo, otherwise return the latest available."""
api = HfApi()
if api.revision_exists(repo_id, revision, repo_type="dataset"):
return revision
hub_versions = get_repo_versions(repo_id)
return f"v{max(hub_versions)}"
def get_hf_features_from_features(features: dict) -> datasets.Features: