From 49cb1ee7db7c37f0475a1a1ebe21b2775e54e224 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 30 Jun 2026 16:49:08 +0200 Subject: [PATCH] chore(policies): add explicit dataset dependecy to gr00t implementation --- pyproject.toml | 1 + src/lerobot/policies/groot/processor_groot.py | 9 +++++++-- src/lerobot/utils/import_utils.py | 1 + tests/policies/groot/test_groot_n1_7.py | 2 +- uv.lock | 7 +++++++ 5 files changed, 17 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3a86563ad..329b364d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -218,6 +218,7 @@ groot = [ "lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[diffusers-dep]", + "lerobot[dataset]", # NOTE: processor_groot builds a LeRobotDataset for relative-action training stats "dm-tree>=0.1.8,<1.0.0", "timm>=1.0.0,<1.1.0", "decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')", diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index 2cc507660..9ae85e667 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -27,7 +27,7 @@ import torchvision.transforms.v2.functional as tv_functional from einops import rearrange from torchvision.transforms import InterpolationMode -from lerobot.utils.import_utils import _transformers_available, require_package +from lerobot.utils.import_utils import _datasets_available, _transformers_available, require_package if TYPE_CHECKING or _transformers_available: from transformers import ( @@ -44,6 +44,11 @@ else: Qwen3VLProcessor = None Qwen3VLVideoProcessor = None +if TYPE_CHECKING or _datasets_available: + from lerobot.datasets.lerobot_dataset import LeRobotDataset +else: + LeRobotDataset = None + from lerobot.processor import ( AbsoluteActionsProcessorStep, AddBatchDimensionProcessorStep, @@ -811,7 +816,7 @@ def _make_relative_action_training_stats_from_dataset_meta( if dataset_meta is None or repo_id is None or root is None or fps is None: return None - from lerobot.datasets.lerobot_dataset import LeRobotDataset + require_package("datasets", extra="groot") delta_timestamps = {ACTION: [index / fps for index in config.action_delta_indices]} dataset = LeRobotDataset( diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index b0d894c04..112f76a96 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -129,6 +129,7 @@ _placo_available = is_package_available("placo") _hidapi_available = is_package_available("hidapi", import_name="hid") # Data / serialization +_datasets_available = is_package_available("datasets") _pandas_available = is_package_available("pandas") _faker_available = is_package_available("faker") diff --git a/tests/policies/groot/test_groot_n1_7.py b/tests/policies/groot/test_groot_n1_7.py index 07a9311ae..1da50349f 100644 --- a/tests/policies/groot/test_groot_n1_7.py +++ b/tests/policies/groot/test_groot_n1_7.py @@ -2256,7 +2256,7 @@ def test_groot_n1_7_relative_action_processors_compute_stats_from_runtime_datase assert kwargs["delta_timestamps"][ACTION] == [0.0, 1 / runtime_meta.fps] return _RelativeStatsDataset() - monkeypatch.setattr("lerobot.datasets.lerobot_dataset.LeRobotDataset", _fake_lerobot_dataset) + monkeypatch.setattr("lerobot.policies.groot.processor_groot.LeRobotDataset", _fake_lerobot_dataset) config._runtime_dataset_meta = runtime_meta preprocessor, postprocessor = make_groot_pre_post_processors(config, dataset_stats=absolute_dataset_stats) diff --git a/uv.lock b/uv.lock index ed17b0945..7dd837ff2 100644 --- a/uv.lock +++ b/uv.lock @@ -2957,11 +2957,17 @@ gamepad = [ { name = "pygame" }, ] groot = [ + { name = "av" }, + { name = "datasets" }, { name = "decord", marker = "platform_machine == 'AMD64' or platform_machine == 'x86_64'" }, { name = "diffusers" }, { name = "dm-tree" }, + { name = "jsonlines" }, + { name = "pandas" }, { name = "peft" }, + { name = "pyarrow" }, { name = "timm" }, + { name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" }, { name = "transformers" }, ] grpcio-dep = [ @@ -3240,6 +3246,7 @@ requires-dist = [ { name = "lerobot", extras = ["dataset"], marker = "extra == 'annotations'" }, { name = "lerobot", extras = ["dataset"], marker = "extra == 'core-scripts'" }, { name = "lerobot", extras = ["dataset"], marker = "extra == 'dataset-viz'" }, + { name = "lerobot", extras = ["dataset"], marker = "extra == 'groot'" }, { name = "lerobot", extras = ["dataset"], marker = "extra == 'hilserl'" }, { name = "lerobot", extras = ["dataset"], marker = "extra == 'libero'" }, { name = "lerobot", extras = ["dataset"], marker = "extra == 'metaworld'" },