diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index ec5ac4372..43e2442d3 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -22,20 +22,21 @@ Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). S - Short, concrete bullets of the modifications (files/behaviour). - Short note if this introduces breaking changes and migration steps. -## How was this tested +## How was this tested (or how to run locally) - Tests added: list new tests or test files. - Manual checks / dataset runs performed. +- Instructions for the reviewer -## How to run locally (reviewer) +Example: -- Run the relevant tests: +- Ran the relevant tests: ```bash pytest -q tests/ -k ``` -- Run a quick example or CLI (if applicable): +- Reproduce with a quick example or CLI (if applicable): ```bash lerobot-train --some.option=true diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml index e3ae71cc9..a75ecc121 100644 --- a/.github/workflows/unbound_deps_tests.yml +++ b/.github/workflows/unbound_deps_tests.yml @@ -20,8 +20,8 @@ on: workflow_dispatch: # Run on the 1st and 15th of every month at 09:00 UTC - schedule: - - cron: '0 2 1,15 * *' + # schedule: + # - cron: '0 2 1,15 * *' permissions: contents: read diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..cf58f6cdb --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,48 @@ +# Security Policy + +## Project Status & Philosophy + +`lerobot` has so far been primarily a research and prototyping tool, which is why deployment security hasn’t been a strong focus until now. As `lerobot` continues to be adopted and deployed in production, we are paying much closer attention to these kinds of issues. + +Fortunately, being an open-source project, the community can also help by reporting and fixing vulnerabilities. We appreciate your efforts to responsibly disclose your findings and will make every effort to acknowledge your contributions. + +## Reporting a Vulnerability + +To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/huggingface/lerobot/security/advisories/new) tab. + +The `lerobot` team will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance. + +#### Hugging Face Security Team + +Since this project is part of the Hugging Face ecosystem, feel free to submit vulnerability reports directly to: **[security@huggingface.co](mailto:security@huggingface.co)**. Someone from the HF security team will review the report and recommend next steps. + +#### Open Source Disclosures + +If reporting a vulnerability specific to the open-source codebase (and not the underlying Hub infrastructure), you may also use [Huntr](https://huntr.com), a vulnerability disclosure program for open source software. + +## Supported Versions + +Currently, we treat `lerobot` as a rolling release. We prioritize security updates for the latest available version (`main` branch). + +| Version | Supported | +| -------- | --------- | +| Latest | ✅ | +| < Latest | ❌ | + +## Secure Usage Guidelines + +`lerobot` is tightly coupled to the Hugging Face Hub for sharing data and pretrained policies. When downloading artifacts uploaded by others, you expose yourself to risks. Please read below for recommendations to keep your runtime and robot environment safe. + +### Remote Artefacts (Weights & Policies) + +Models and policies uploaded to the Hugging Face Hub come in different formats. We heavily recommend uploading and downloading models in the [`safetensors`](https://github.com/huggingface/safetensors) format. + +`safetensors` was developed specifically to prevent arbitrary code execution on your system, which is critical when running software on physical hardware/robots. + +To avoid loading models from unsafe formats (e.g., `pickle`), you should ensure you are prioritizing `safetensors` files. + +### Remote Code + +Some models or environments on the Hub may require `trust_remote_code=True` to run custom architecture code. + +Please **always** verify the content of the modeling files when using this argument. We recommend setting a specific `revision` (commit hash) when loading remote code to ensure you protect yourself from unverified updates to the repository. diff --git a/docs/source/earthrover_mini_plus.mdx b/docs/source/earthrover_mini_plus.mdx index 7e27eb93e..e3ffa6b32 100644 --- a/docs/source/earthrover_mini_plus.mdx +++ b/docs/source/earthrover_mini_plus.mdx @@ -12,23 +12,42 @@ The EarthRover Mini Plus is a fully open source mobile robot that connects throu ### Setting Up the Frodobots SDK -The robot needs the [Frodobots SDK](https://github.com/Frodobots/earth-rovers-sdk) running on your computer. Here's how: +The robot needs the [Frodobots SDK](https://github.com/frodobots-org/earth-rovers-sdk) running on your computer. Here's how: 1. Download and install the SDK: ```bash -git clone https://github.com/Frodobots/earth-rovers-sdk.git +git clone https://github.com/frodobots-org/earth-rovers-sdk.git cd earth-rovers-sdk pip install -r requirements.txt ``` -2. Start the SDK: +2. Save Credentials: + +Write your .env variables with the SDK API key and bot name provided by the Frodobots team. + +```bash +SDK_API_TOKEN=your_sdk_api_token_here +BOT_SLUG=your_bot_slug_here +CHROME_EXECUTABLE_PATH=/path/to/chrome_or_chromium +# Default value is MAP_ZOOM_LEVEL=18 https://wiki.openstreetmap.org/wiki/Zoom_levels +MAP_ZOOM_LEVEL=18 +MISSION_SLUG=your_mission_slug_here +# Image quality between 0.1 and 1.0 (default: 0.8) +# Recommended: 0.8 for better performance +IMAGE_QUALITY=0.8 +# Image format: jpeg, png or webp (default: png) +# Recommended: jpeg for better performance and lower bandwidth usage +IMAGE_FORMAT=jpeg +``` + +3. Start the SDK: ```bash hypercorn main:app --reload ``` -3. Open your web browser and go to `http://localhost:8000`, then click "Join" +4. Open your web browser and go to `http://localhost:8000`, then click "Join" The SDK gives you: diff --git a/docs/source/envhub.mdx b/docs/source/envhub.mdx index ba6464460..df103d0dd 100644 --- a/docs/source/envhub.mdx +++ b/docs/source/envhub.mdx @@ -2,14 +2,32 @@ The **EnvHub** feature allows you to load simulation environments directly from the Hugging Face Hub with a single line of code. This unlocks a powerful new model for collaboration: instead of environments being locked away inside monolithic libraries, anyone can publish custom environments and share them with the community. -## Overview +## What is EnvHub? -With EnvHub, you can: +EnvHub lets you create custom robotics simulation environments with your own robot models and scenarios, and make them easily usable by anyone through the LeRobot framework. -- Load environments from the Hub instantly -- Share your custom simulation tasks with the community -- Version control your environments using Git -- Distribute complex physics simulations without packaging hassles +EnvHub packages are stored on the Hugging Face Hub, and can be seamlessly pulled and used in your AI robotics projects through LeRobot with a single line of code. + +Thanks to EnvHub, you can: + +1. **Create and publish environments** to the Hugging Face Hub as Git repositories, and distribute complex physics simulations without packaging hassles +2. **Load environments** dynamically, without installing them as packages +3. **Version and track** environment changes using Git semantics +4. **Discover** new simulation tasks shared by the community + +This design means you can go from discovering an interesting environment on the Hub to running experiments in seconds, or create your own custom robot and environment without worrying about dependency conflicts or complex installation procedures. + +When you create an EnvHub package, you can build anything you want inside it and use any simulation tool you like: this is your own space to play with. The only requirement is that the package contains an `env.py` file that defines the environment and allows LeRobot to load and use your EnvHub package. + +This `env.py` file needs to expose a small API so LeRobot can load and run it. In particular, you must provide a `make_env(n_envs: int = 1, use_async_envs: bool = False)` or `make_env(n_envs: int = 1, use_async_envs: bool = False, cfg: EnvConfig)` function, which is the main entry point for LeRobot. It should return one of: + +- A `gym.vector.VectorEnv` (most common) +- A single `gym.Env` (will be automatically wrapped) +- A dict mapping `{suite_name: {task_id: VectorEnv}}` (for multi-task benchmarks) + +You can also pass an `EnvConfig` object to `make_env` to configure the environment (e.g. the number of environments, task, camera name, initial states, control mode, episode length, etc.). + +Finally, your environment must implement the standard `gym.vector.VectorEnv` interface so it works with LeRobot, including methods like `reset` and `step`. ## Quick Start @@ -29,17 +47,6 @@ env = make_env("lerobot/cartpole-env", trust_remote_code=True) hash for reproducibility and security. -## What is EnvHub? - -EnvHub is a framework that allows researchers and developers to: - -1. **Publish environments** to the Hugging Face Hub as Git repositories -2. **Load environments** dynamically without installing them as packages -3. **Version and track** environment changes using Git semantics -4. **Discover** new simulation tasks shared by the community - -This design means you can go from discovering an interesting environment on the Hub to running experiments in seconds, without worrying about dependency conflicts or complex installation procedures. - ## Repository Structure To make your environment loadable from the Hub, your repository must contain at minimum: diff --git a/docs/source/using_dataset_tools.mdx b/docs/source/using_dataset_tools.mdx index 29e16ea0a..9e662604e 100644 --- a/docs/source/using_dataset_tools.mdx +++ b/docs/source/using_dataset_tools.mdx @@ -95,26 +95,26 @@ Convert an image-based dataset to video format, creating a new LeRobotDataset wh # Local-only: Save to a custom output directory (no hub push) lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ - --operation.type convert_to_video \ + --operation.type convert_image_to_video \ --operation.output_dir /path/to/output/pusht_video # Save with new repo_id (local storage) lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --new_repo_id lerobot/pusht_video \ - --operation.type convert_to_video + --operation.type convert_image_to_video # Convert and push to Hugging Face Hub lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --new_repo_id lerobot/pusht_video \ - --operation.type convert_to_video \ + --operation.type convert_image_to_video \ --push_to_hub true # Convert with custom video codec and quality settings lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ - --operation.type convert_to_video \ + --operation.type convert_image_to_video \ --operation.output_dir outputs/pusht_video \ --operation.vcodec libsvtav1 \ --operation.pix_fmt yuv420p \ @@ -124,16 +124,23 @@ lerobot-edit-dataset \ # Convert only specific episodes lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ - --operation.type convert_to_video \ + --operation.type convert_image_to_video \ --operation.output_dir outputs/pusht_video \ --operation.episode_indices "[0, 1, 2, 5, 10]" # Convert with multiple workers for parallel processing lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ - --operation.type convert_to_video \ + --operation.type convert_image_to_video \ --operation.output_dir outputs/pusht_video \ --operation.num_workers 8 + +# For memory-constrained systems, users can now specify limits: +lerobot-edit-dataset \ + --repo_id lerobot/pusht_image \ + --operation.type convert_to_video \ + --operation.max_episodes_per_batch 50 \ + --operation.max_frames_per_batch 10000 ``` **Parameters:** diff --git a/pyproject.toml b/pyproject.toml index 067cb1df8..fa4b22bdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ dependencies = [ "packaging>=24.2,<26.0", "pynput>=1.7.7,<1.9.0", "pyserial>=3.5,<4.0", - "wandb>=0.20.0,<0.22.0", # TODO: Bumb dependency (compatible with protobuf) + "wandb>=0.24.0,<0.25.0", "torch>=2.2.1,<2.8.0", # TODO: Bumb dependency "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency @@ -97,7 +97,7 @@ dependencies = [ pygame-dep = ["pygame>=2.5.1,<2.7.0"] placo-dep = ["placo>=0.9.6,<0.10.0"] transformers-dep = ["transformers>=4.57.1,<5.0.0"] -grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb) +grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] # Motors feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"] diff --git a/src/lerobot/async_inference/constants.py b/src/lerobot/async_inference/constants.py index 081db0504..56910e67f 100644 --- a/src/lerobot/async_inference/constants.py +++ b/src/lerobot/async_inference/constants.py @@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS DEFAULT_OBS_QUEUE_TIMEOUT = 2 # All action chunking policies -SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"] +SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05", "groot"] # TODO: Add all other robots SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"] diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 455caf0fe..94ffe602e 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -19,6 +19,7 @@ import logging import shutil from pathlib import Path +import datasets import pandas as pd import tqdm @@ -32,6 +33,7 @@ from lerobot.datasets.utils import ( DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, get_file_size_in_mb, + get_hf_features_from_features, get_parquet_file_size_in_mb, to_parquet_with_hf_images, update_chunk_file_indices, @@ -402,12 +404,21 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si } unique_chunk_file_ids = sorted(unique_chunk_file_ids) + contains_images = len(dst_meta.image_keys) > 0 + + # retrieve features schema for proper image typing in parquet + hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None for src_chunk_idx, src_file_idx in unique_chunk_file_ids: src_path = src_meta.root / DEFAULT_DATA_PATH.format( chunk_index=src_chunk_idx, file_index=src_file_idx ) - df = pd.read_parquet(src_path) + if contains_images: + # Use HuggingFace datasets to read source data to preserve image format + src_ds = datasets.Dataset.from_parquet(str(src_path)) + df = src_ds.to_pandas() + else: + df = pd.read_parquet(src_path) df = update_data_df(df, src_meta, dst_meta) data_idx = append_or_create_parquet_file( @@ -417,8 +428,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si data_files_size_in_mb, chunk_size, DEFAULT_DATA_PATH, - contains_images=len(dst_meta.image_keys) > 0, + contains_images=contains_images, aggr_root=dst_meta.root, + hf_features=hf_features, ) return data_idx @@ -488,6 +500,7 @@ def append_or_create_parquet_file( default_path: str, contains_images: bool = False, aggr_root: Path = None, + hf_features: datasets.Features | None = None, ): """Appends data to an existing parquet file or creates a new one based on size constraints. @@ -503,6 +516,7 @@ def append_or_create_parquet_file( default_path: Format string for generating file paths. contains_images: Whether the data contains images requiring special handling. aggr_root: Root path for the aggregated dataset. + hf_features: Optional HuggingFace Features schema for proper image typing. Returns: dict: Updated index dictionary with current chunk and file indices. @@ -512,7 +526,7 @@ def append_or_create_parquet_file( if not dst_path.exists(): dst_path.parent.mkdir(parents=True, exist_ok=True) if contains_images: - to_parquet_with_hf_images(df, dst_path) + to_parquet_with_hf_images(df, dst_path, features=hf_features) else: df.to_parquet(dst_path) return idx @@ -527,12 +541,17 @@ def append_or_create_parquet_file( final_df = df target_path = new_path else: - existing_df = pd.read_parquet(dst_path) + if contains_images: + # Use HuggingFace datasets to read existing data to preserve image format + existing_ds = datasets.Dataset.from_parquet(str(dst_path)) + existing_df = existing_ds.to_pandas() + else: + existing_df = pd.read_parquet(dst_path) final_df = pd.concat([existing_df, df], ignore_index=True) target_path = dst_path if contains_images: - to_parquet_with_hf_images(final_df, target_path) + to_parquet_with_hf_images(final_df, target_path, features=hf_features) else: final_df.to_parquet(target_path) diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 2fb68dca1..e2928e2a6 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -26,6 +26,7 @@ This module provides utilities for: import logging import shutil from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import datasets @@ -51,7 +52,8 @@ from lerobot.datasets.utils import ( write_stats, write_tasks, ) -from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.datasets.video_utils import encode_video_frames, get_video_info +from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict: @@ -1083,3 +1085,561 @@ def _copy_episodes_metadata_and_stats( else: if src_dataset.meta.stats: write_stats(src_dataset.meta.stats, dst_meta.root) + + +def _save_episode_images_for_video( + dataset: LeRobotDataset, + imgs_dir: Path, + img_key: str, + episode_index: int, + num_workers: int = 4, +) -> None: + """Save images from a specific episode and camera to disk for video encoding. + + Args: + dataset: The LeRobot dataset to extract images from + imgs_dir: Directory to save images to + img_key: The image key (camera) to extract + episode_index: Index of the episode to save + num_workers: Number of threads for parallel image saving + """ + # Create directory + imgs_dir.mkdir(parents=True, exist_ok=True) + + # Get dataset without torch format for PIL image access + hf_dataset = dataset.hf_dataset.with_format(None) + + # Select only this camera's images + imgs_dataset = hf_dataset.select_columns(img_key) + + # Get episode start and end indices + from_idx = dataset.meta.episodes["dataset_from_index"][episode_index] + to_idx = dataset.meta.episodes["dataset_to_index"][episode_index] + + # Get all items for this episode + episode_dataset = imgs_dataset.select(range(from_idx, to_idx)) + + # Define function to save a single image + def save_single_image(i_item_tuple): + i, item = i_item_tuple + img = item[img_key] + # Use frame-XXXXXX.png format to match encode_video_frames expectations + img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100) + return i + + # Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png) + items = list(enumerate(episode_dataset)) + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(save_single_image, item) for item in items] + for future in as_completed(futures): + future.result() # This will raise any exceptions that occurred + + +def _save_batch_episodes_images( + dataset: LeRobotDataset, + imgs_dir: Path, + img_key: str, + episode_indices: list[int], + num_workers: int = 4, +) -> list[float]: + """Save images from multiple episodes to disk for batch video encoding. + + Args: + dataset: The LeRobot dataset to extract images from + imgs_dir: Directory to save images to + img_key: The image key (camera) to extract + episode_indices: List of episode indices to save + num_workers: Number of threads for parallel image saving + + Returns: + List of episode durations in seconds + """ + imgs_dir.mkdir(parents=True, exist_ok=True) + hf_dataset = dataset.hf_dataset.with_format(None) + imgs_dataset = hf_dataset.select_columns(img_key) + + # Define function to save a single image with global frame index + # Defined once outside the loop to avoid repeated closure creation + def save_single_image(i_item_tuple, base_frame_idx, img_key_param): + i, item = i_item_tuple + img = item[img_key_param] + # Use global frame index for naming + img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100) + return i + + episode_durations = [] + frame_idx = 0 + + for ep_idx in episode_indices: + # Get episode range + from_idx = dataset.meta.episodes["dataset_from_index"][ep_idx] + to_idx = dataset.meta.episodes["dataset_to_index"][ep_idx] + episode_length = to_idx - from_idx + episode_durations.append(episode_length / dataset.fps) + + # Get episode images + episode_dataset = imgs_dataset.select(range(from_idx, to_idx)) + + # Save images + items = list(enumerate(episode_dataset)) + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(save_single_image, item, frame_idx, img_key) for item in items] + for future in as_completed(futures): + future.result() + + frame_idx += episode_length + + return episode_durations + + +def _iter_episode_batches( + episode_indices: list[int], + episode_lengths: dict[int, int], + size_per_frame_mb: float, + video_file_size_limit: float, + max_episodes: int | None, + max_frames: int | None, +): + """Generator that yields batches of episode indices for video encoding. + + Groups episodes into batches that respect size and memory constraints: + - Stays under video file size limit + - Respects maximum episodes per batch (if specified) + - Respects maximum frames per batch (if specified) + + Args: + episode_indices: List of episode indices to batch + episode_lengths: Dictionary mapping episode index to episode length + size_per_frame_mb: Estimated size per frame in MB + video_file_size_limit: Maximum video file size in MB + max_episodes: Maximum number of episodes per batch (None = no limit) + max_frames: Maximum number of frames per batch (None = no limit) + + Yields: + List of episode indices for each batch + """ + batch_episodes = [] + estimated_size = 0.0 + total_frames = 0 + + for ep_idx in episode_indices: + ep_length = episode_lengths[ep_idx] + ep_estimated_size = ep_length * size_per_frame_mb + + # we check if adding this episode would exceed any constraint + would_exceed_size = estimated_size > 0 and estimated_size + ep_estimated_size >= video_file_size_limit + would_exceed_episodes = max_episodes is not None and len(batch_episodes) >= max_episodes + would_exceed_frames = max_frames is not None and total_frames + ep_length > max_frames + + if batch_episodes and (would_exceed_size or would_exceed_episodes or would_exceed_frames): + # yield current batch before adding this episode + yield batch_episodes + # start a new batch with current episode + batch_episodes = [ep_idx] + estimated_size = ep_estimated_size + total_frames = ep_length + else: + # add to current batch + batch_episodes.append(ep_idx) + estimated_size += ep_estimated_size + total_frames += ep_length + + # yield final batch if not empty + if batch_episodes: + yield batch_episodes + + +def _estimate_frame_size_via_calibration( + dataset: LeRobotDataset, + img_key: str, + episode_indices: list[int], + temp_dir: Path, + fps: int, + vcodec: str, + pix_fmt: str, + g: int, + crf: int, + fast_decode: int, + num_calibration_frames: int = 30, +) -> float: + """Estimate MB per frame by encoding a small calibration sample. + + Encodes a representative sample of frames using the exact codec parameters + to measure actual compression ratio, which is more accurate than heuristics. + + Args: + dataset: Source dataset with images. + img_key: Image key to calibrate (e.g., "observation.images.top"). + episode_indices: List of episode indices being processed. + temp_dir: Temporary directory for calibration files. + fps: Frames per second for video encoding. + vcodec: Video codec (libsvtav1, h264, hevc). + pix_fmt: Pixel format (yuv420p, etc.). + g: GOP size (group of pictures). + crf: Constant Rate Factor (quality). + fast_decode: Fast decode tuning parameter. + num_calibration_frames: Number of frames to use for calibration (default: 30). + + Returns: + Estimated size in MB per frame based on actual encoding. + """ + calibration_dir = temp_dir / "calibration" / img_key + calibration_dir.mkdir(parents=True, exist_ok=True) + + try: + # Select a representative episode (prefer middle episode if available) + calibration_ep_idx = episode_indices[len(episode_indices) // 2] + + # Get episode range + from_idx = dataset.meta.episodes["dataset_from_index"][calibration_ep_idx] + to_idx = dataset.meta.episodes["dataset_to_index"][calibration_ep_idx] + episode_length = to_idx - from_idx + + # Use up to num_calibration_frames from this episode + num_frames = min(num_calibration_frames, episode_length) + + # Get frames from dataset + hf_dataset = dataset.hf_dataset.with_format(None) + sample_indices = range(from_idx, from_idx + num_frames) + + # Save calibration frames + for i, idx in enumerate(sample_indices): + img = hf_dataset[idx][img_key] + img.save(str(calibration_dir / f"frame-{i:06d}.png"), quality=100) + + # Encode calibration video + calibration_video_path = calibration_dir / "calibration.mp4" + encode_video_frames( + imgs_dir=calibration_dir, + video_path=calibration_video_path, + fps=fps, + vcodec=vcodec, + pix_fmt=pix_fmt, + g=g, + crf=crf, + fast_decode=fast_decode, + overwrite=True, + ) + + # Measure actual compressed size + video_size_bytes = calibration_video_path.stat().st_size + video_size_mb = video_size_bytes / BYTES_PER_MIB + size_per_frame_mb = video_size_mb / num_frames + + logging.info( + f" Calibration: {num_frames} frames -> {video_size_mb:.2f} MB " + f"= {size_per_frame_mb:.4f} MB/frame for {img_key}" + ) + + return size_per_frame_mb + + finally: + # Clean up calibration files + if calibration_dir.exists(): + shutil.rmtree(calibration_dir) + + +def _copy_data_without_images( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_indices: list[int], + img_keys: list[str], +) -> None: + """Copy data files without image columns. + + Args: + src_dataset: Source dataset + dst_meta: Destination metadata + episode_indices: Episodes to include + img_keys: Image keys to remove + """ + from lerobot.datasets.utils import DATA_DIR + + data_dir = src_dataset.root / DATA_DIR + parquet_files = sorted(data_dir.glob("*/*.parquet")) + + if not parquet_files: + raise ValueError(f"No parquet files found in {data_dir}") + + episode_set = set(episode_indices) + + for src_path in tqdm(parquet_files, desc="Processing data files"): + df = pd.read_parquet(src_path).reset_index(drop=True) + + # Filter to only include selected episodes + df = df[df["episode_index"].isin(episode_set)].copy() + + if len(df) == 0: + continue + + # Remove image columns + columns_to_drop = [col for col in img_keys if col in df.columns] + if columns_to_drop: + df = df.drop(columns=columns_to_drop) + + # Get chunk and file indices from path + relative_path = src_path.relative_to(src_dataset.root) + chunk_dir = relative_path.parts[1] + file_name = relative_path.parts[2] + chunk_idx = int(chunk_dir.split("-")[1]) + file_idx = int(file_name.split("-")[1].split(".")[0]) + + # Write to destination without pandas index + dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet" + dst_path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(dst_path, index=False) + + +# Video conversion constants +BYTES_PER_KIB = 1024 +BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB + + +def convert_image_to_video_dataset( + dataset: LeRobotDataset, + output_dir: Path, + repo_id: str | None = None, + vcodec: str = "libsvtav1", + pix_fmt: str = "yuv420p", + g: int = 2, + crf: int = 30, + fast_decode: int = 0, + episode_indices: list[int] | None = None, + num_workers: int = 4, + max_episodes_per_batch: int | None = None, + max_frames_per_batch: int | None = None, +) -> LeRobotDataset: + """Convert image-to-video dataset. + + Creates a new LeRobotDataset with images encoded as videos, following the proper + LeRobot dataset structure with videos stored in chunked MP4 files. + + Args: + dataset: The source LeRobot dataset with images + output_dir: Directory to save the new video dataset + repo_id: Repository ID for the new dataset (default: original_id + "_video") + vcodec: Video codec (default: libsvtav1) + pix_fmt: Pixel format (default: yuv420p) + g: Group of pictures size (default: 2) + crf: Constant rate factor (default: 30) + fast_decode: Fast decode tuning (default: 0) + episode_indices: List of episode indices to convert (None = all episodes) + num_workers: Number of threads for parallel processing (default: 4) + max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit) + max_frames_per_batch: Maximum frames per video batch to avoid memory issues (None = no limit) + + Returns: + New LeRobotDataset with images encoded as videos + """ + # Check that it's an image dataset + if len(dataset.meta.video_keys) > 0: + raise ValueError( + f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}" + ) + + # Get all image keys + hf_dataset = dataset.hf_dataset.with_format(None) + img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)] + + if len(img_keys) == 0: + raise ValueError(f"No image keys found in dataset {dataset.repo_id}") + + # Determine which episodes to process + if episode_indices is None: + episode_indices = list(range(dataset.meta.total_episodes)) + + if repo_id is None: + repo_id = f"{dataset.repo_id}_video" + + logging.info( + f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}" + ) + logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}") + + # Create new features dict, converting image features to video features + new_features = {} + for key, value in dataset.meta.features.items(): + if key not in img_keys: + new_features[key] = value + else: + # Convert image key to video format + new_features[key] = value.copy() + new_features[key]["dtype"] = "video" # Change dtype from "image" to "video" + # Video info will be updated after episodes are encoded + + # Create new metadata for video dataset + new_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=dataset.meta.fps, + features=new_features, + robot_type=dataset.meta.robot_type, + root=output_dir, + use_videos=True, + chunks_size=dataset.meta.chunks_size, + data_files_size_in_mb=dataset.meta.data_files_size_in_mb, + video_files_size_in_mb=dataset.meta.video_files_size_in_mb, + ) + + # Create temporary directory for image extraction + temp_dir = output_dir / "temp_images" + temp_dir.mkdir(parents=True, exist_ok=True) + + # Process all episodes and batch encode videos + # Use dictionary for O(1) episode metadata lookups instead of O(n) linear search + all_episode_metadata = {} + fps = int(dataset.fps) + + try: + # Build episode metadata entries first + logging.info("Building episode metadata...") + cumulative_frame_idx = 0 + for ep_idx in episode_indices: + src_episode = dataset.meta.episodes[ep_idx] + ep_length = src_episode["length"] + ep_meta = { + "episode_index": ep_idx, + "length": ep_length, + "dataset_from_index": cumulative_frame_idx, + "dataset_to_index": cumulative_frame_idx + ep_length, + } + if "data/chunk_index" in src_episode: + ep_meta["data/chunk_index"] = src_episode["data/chunk_index"] + ep_meta["data/file_index"] = src_episode["data/file_index"] + all_episode_metadata[ep_idx] = ep_meta + cumulative_frame_idx += ep_length + + # Process each camera and batch encode multiple episodes together + video_file_size_limit = new_meta.video_files_size_in_mb + + # Pre-compute episode lengths for batching + episode_lengths = {ep_idx: dataset.meta.episodes["length"][ep_idx] for ep_idx in episode_indices} + + for img_key in tqdm(img_keys, desc="Processing cameras"): + # Estimate size per frame by encoding a small calibration sample + # This provides accurate compression ratio for the specific codec parameters + size_per_frame_mb = _estimate_frame_size_via_calibration( + dataset=dataset, + img_key=img_key, + episode_indices=episode_indices, + temp_dir=temp_dir, + fps=fps, + vcodec=vcodec, + pix_fmt=pix_fmt, + g=g, + crf=crf, + fast_decode=fast_decode, + ) + + logging.info(f"Processing camera: {img_key}") + chunk_idx, file_idx = 0, 0 + cumulative_timestamp = 0.0 + + # Process episodes in batches to stay under size limit + for batch_episodes in _iter_episode_batches( + episode_indices=episode_indices, + episode_lengths=episode_lengths, + size_per_frame_mb=size_per_frame_mb, + video_file_size_limit=video_file_size_limit, + max_episodes=max_episodes_per_batch, + max_frames=max_frames_per_batch, + ): + total_frames_in_batch = sum(episode_lengths[idx] for idx in batch_episodes) + logging.info( + f" Encoding batch of {len(batch_episodes)} episodes " + f"({batch_episodes[0]}-{batch_episodes[-1]}) = {total_frames_in_batch} frames" + ) + + # Save images for all episodes in this batch + imgs_dir = temp_dir / f"batch_{chunk_idx}_{file_idx}" / img_key + episode_durations = _save_batch_episodes_images( + dataset=dataset, + imgs_dir=imgs_dir, + img_key=img_key, + episode_indices=batch_episodes, + num_workers=num_workers, + ) + + # Encode all batched episodes into single video + video_path = new_meta.root / new_meta.video_path.format( + video_key=img_key, chunk_index=chunk_idx, file_index=file_idx + ) + video_path.parent.mkdir(parents=True, exist_ok=True) + + encode_video_frames( + imgs_dir=imgs_dir, + video_path=video_path, + fps=fps, + vcodec=vcodec, + pix_fmt=pix_fmt, + g=g, + crf=crf, + fast_decode=fast_decode, + overwrite=True, + ) + + # Clean up temporary images + shutil.rmtree(imgs_dir) + + # Update metadata for each episode in the batch + for ep_idx, duration in zip(batch_episodes, episode_durations, strict=True): + from_timestamp = cumulative_timestamp + to_timestamp = cumulative_timestamp + duration + cumulative_timestamp = to_timestamp + + # Find episode metadata entry and add video metadata (O(1) dictionary lookup) + ep_meta = all_episode_metadata[ep_idx] + ep_meta[f"videos/{img_key}/chunk_index"] = chunk_idx + ep_meta[f"videos/{img_key}/file_index"] = file_idx + ep_meta[f"videos/{img_key}/from_timestamp"] = from_timestamp + ep_meta[f"videos/{img_key}/to_timestamp"] = to_timestamp + + # Move to next video file for next batch + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, new_meta.chunks_size) + cumulative_timestamp = 0.0 + + # Copy and transform data files (removing image columns) + _copy_data_without_images(dataset, new_meta, episode_indices, img_keys) + + # Save episode metadata + episodes_df = pd.DataFrame(list(all_episode_metadata.values())) + episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet" + episodes_path.parent.mkdir(parents=True, exist_ok=True) + episodes_df.to_parquet(episodes_path, index=False) + + # Update metadata info + new_meta.info["total_episodes"] = len(episode_indices) + new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values()) + new_meta.info["total_tasks"] = dataset.meta.total_tasks + new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"} + + # Update video info for all image keys (now videos) + # We need to manually set video info since update_video_info() checks video_keys first + for img_key in img_keys: + if not new_meta.features[img_key].get("info", None): + video_path = new_meta.root / new_meta.video_path.format( + video_key=img_key, chunk_index=0, file_index=0 + ) + new_meta.info["features"][img_key]["info"] = get_video_info(video_path) + + write_info(new_meta.info, new_meta.root) + + # Copy stats and tasks + if dataset.meta.stats is not None: + # Remove image stats + new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys} + write_stats(new_stats, new_meta.root) + + if dataset.meta.tasks is not None: + write_tasks(dataset.meta.tasks, new_meta.root) + + finally: + # Clean up temporary directory + if temp_dir.exists(): + shutil.rmtree(temp_dir) + + logging.info(f"Completed converting {dataset.repo_id} to video format") + logging.info(f"New dataset saved to: {output_dir}") + + # Return new dataset + return LeRobotDataset(repo_id=repo_id, root=output_dir) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 7231bc78d..5c8df37e3 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -941,17 +941,30 @@ class LeRobotDataset(torch.utils.data.Dataset): else: return get_hf_features_from_features(self.features) - def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: + def _get_query_indices( + self, abs_idx: int, ep_idx: int + ) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]: + """Compute query indices for delta timestamps. + + Args: + abs_idx: The absolute index in the full dataset (not the relative index in filtered episodes). + ep_idx: The episode index. + + Returns: + A tuple of (query_indices, padding) where: + - query_indices: Dict mapping keys to lists of absolute indices to query + - padding: Dict mapping "{key}_is_pad" to boolean tensors indicating padded positions + """ ep = self.meta.episodes[ep_idx] ep_start = ep["dataset_from_index"] ep_end = ep["dataset_to_index"] query_indices = { - key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx] + key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx] for key, delta_idx in self.delta_indices.items() } padding = { # Pad values outside of current episode range f"{key}_is_pad": torch.BoolTensor( - [(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx] + [(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx] ) for key, delta_idx in self.delta_indices.items() } @@ -1043,10 +1056,12 @@ class LeRobotDataset(torch.utils.data.Dataset): self._ensure_hf_dataset_loaded() item = self.hf_dataset[idx] ep_idx = item["episode_index"].item() + # Use the absolute index from the dataset for delta timestamp calculations + abs_idx = item["index"].item() query_indices = None if self.delta_indices is not None: - query_indices, padding = self._get_query_indices(idx, ep_idx) + query_indices, padding = self._get_query_indices(abs_idx, ep_idx) query_result = self._query_hf_dataset(query_indices) item = {**item, **padding} for key, val in query_result.items(): @@ -1516,7 +1531,7 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_index = self.episode_buffer["episode_index"] if isinstance(episode_index, np.ndarray): episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0] - for cam_key in self.meta.camera_keys: + for cam_key in self.meta.image_keys: img_dir = self._get_image_file_dir(episode_index, cam_key) if img_dir.is_dir(): shutil.rmtree(img_dir) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 91203bb22..013fb34df 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -1188,12 +1188,21 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: ) -def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None: +def to_parquet_with_hf_images( + df: pandas.DataFrame, path: Path, features: datasets.Features | None = None +) -> None: """This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset. This way, it can be loaded by HF dataset and correctly formatted images are returned. + + Args: + df: DataFrame to write to parquet. + path: Path to write the parquet file. + features: Optional HuggingFace Features schema. If provided, ensures image columns + are properly typed as Image() in the parquet schema. """ # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only - datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path) + ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features) + ds.to_parquet(path) def item_to_torch(item: dict) -> dict: diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 74882ad18..96c5cf102 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -293,9 +293,9 @@ class LiberoEnv(gym.Env): def reset(self, seed=None, **kwargs): super().reset(seed=seed) self._env.seed(seed) - if self.init_states and self._init_states is not None: - self._env.set_init_state(self._init_states[self._init_state_id]) raw_obs = self._env.reset() + if self.init_states and self._init_states is not None: + raw_obs = self._env.set_init_state(self._init_states[self._init_state_id]) # After reset, objects may be unstable (slightly floating, intersecting, etc.). # Step the simulator with a no-op action for a few frames so everything settles. diff --git a/src/lerobot/motors/feetech/tables.py b/src/lerobot/motors/feetech/tables.py index 91e844a72..56500e527 100644 --- a/src/lerobot/motors/feetech/tables.py +++ b/src/lerobot/motors/feetech/tables.py @@ -205,6 +205,7 @@ MODEL_BAUDRATE_TABLE = { # Sign-Magnitude encoding bits STS_SMS_SERIES_ENCODINGS_TABLE = { + "Present_Load": 10, "Homing_Offset": 11, "Goal_Position": 15, "Goal_Velocity": 15, diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index 17eaa8063..91bee994a 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -32,7 +32,7 @@ import serial from deepdiff import DeepDiff from tqdm import tqdm -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.utils import enter_pressed, move_cursor_up NameOrID: TypeAlias = str | int @@ -411,6 +411,7 @@ class MotorsBus(abc.ABC): """bool: `True` if the underlying serial port is open.""" return self.port_handler.is_open + @check_if_already_connected def connect(self, handshake: bool = True) -> None: """Open the serial port and initialise communication. @@ -422,10 +423,6 @@ class MotorsBus(abc.ABC): DeviceAlreadyConnectedError: The port is already open. ConnectionError: The underlying SDK failed to open the port or the handshake did not succeed. """ - if self.is_connected: - raise DeviceAlreadyConnectedError( - f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice." - ) self._connect(handshake) self.set_timeout() @@ -447,6 +444,7 @@ class MotorsBus(abc.ABC): def _handshake(self) -> None: pass + @check_if_not_connected def disconnect(self, disable_torque: bool = True) -> None: """Close the serial port (optionally disabling torque first). @@ -455,10 +453,6 @@ class MotorsBus(abc.ABC): closing the port. This can prevent damaging motors if they are left applying resisting torque after disconnect. """ - if not self.is_connected: - raise DeviceNotConnectedError( - f"{self.__class__.__name__}('{self.port}') is not connected. Try running `{self.__class__.__name__}.connect()` first." - ) if disable_torque: self.port_handler.clearPort() @@ -907,6 +901,7 @@ class MotorsBus(abc.ABC): """ pass + @check_if_not_connected def read( self, data_name: str, @@ -927,10 +922,6 @@ class MotorsBus(abc.ABC): Returns: Value: Raw or normalised value depending on *normalize*. """ - if not self.is_connected: - raise DeviceNotConnectedError( - f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." - ) id_ = self.motors[motor].id model = self.motors[motor].model @@ -981,6 +972,7 @@ class MotorsBus(abc.ABC): return value, comm, error + @check_if_not_connected def write( self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0 ) -> None: @@ -999,10 +991,6 @@ class MotorsBus(abc.ABC): normalize (bool, optional): Enable or disable normalisation. Defaults to `True`. num_retry (int, optional): Retry attempts. Defaults to `0`. """ - if not self.is_connected: - raise DeviceNotConnectedError( - f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." - ) id_ = self.motors[motor].id model = self.motors[motor].model @@ -1044,6 +1032,7 @@ class MotorsBus(abc.ABC): return comm, error + @check_if_not_connected def sync_read( self, data_name: str, @@ -1063,10 +1052,6 @@ class MotorsBus(abc.ABC): Returns: dict[str, Value]: Mapping *motor name → value*. """ - if not self.is_connected: - raise DeviceNotConnectedError( - f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." - ) self._assert_protocol_is_compatible("sync_read") @@ -1139,6 +1124,7 @@ class MotorsBus(abc.ABC): # for id_ in motor_ids: # value = self.sync_reader.getData(id_, address, length) + @check_if_not_connected def sync_write( self, data_name: str, @@ -1160,10 +1146,6 @@ class MotorsBus(abc.ABC): normalize (bool, optional): If `True` (default) convert values from the user range to raw units. num_retry (int, optional): Retry attempts. Defaults to `0`. """ - if not self.is_connected: - raise DeviceNotConnectedError( - f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." - ) ids_values = self._get_ids_values_dict(values) models = [self._id_to_model(id_) for id_ in ids_values] diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index fd9baa9b1..9a479b8f9 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -32,16 +32,22 @@ Notes: from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below. """ +import builtins import os from collections import deque +from pathlib import Path +from typing import TypeVar import torch from torch import Tensor +from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.groot.groot_n1 import GR00TN15 from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.constants import ACTION +from lerobot.utils.constants import ACTION, OBS_IMAGES + +T = TypeVar("T", bound="GrootPolicy") class GrootPolicy(PreTrainedPolicy): @@ -90,6 +96,129 @@ class GrootPolicy(PreTrainedPolicy): """Reset policy state when environment resets.""" self._action_queue = deque([], maxlen=self.config.n_action_steps) + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: GrootConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = True, + **kwargs, + ) -> T: + """Load Groot policy from pretrained model. + + Handles two cases: + 1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model + 2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors + + Args: + pretrained_name_or_path: Path to the GR00T model or fine-tuned checkpoint + config: Optional GrootConfig. If None, loads from checkpoint or creates default + force_download: Force download even if cached + resume_download: Resume interrupted download + proxies: Proxy settings + token: HuggingFace authentication token + cache_dir: Cache directory path + local_files_only: Only use local files + revision: Specific model revision + strict: Strict state dict loading + **kwargs: Additional arguments (passed to config) + + Returns: + Initialized GrootPolicy instance with loaded model + """ + from huggingface_hub import hf_hub_download + from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE + from huggingface_hub.errors import HfHubHTTPError + + print( + "The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n" + f"Loading pretrained model from: {pretrained_name_or_path}" + ) + + model_id = str(pretrained_name_or_path) + is_finetuned_checkpoint = False + + # Check if this is a fine-tuned LeRobot checkpoint (has model.safetensors) + try: + if os.path.isdir(model_id): + is_finetuned_checkpoint = os.path.exists(os.path.join(model_id, SAFETENSORS_SINGLE_FILE)) + else: + # Try to download the safetensors file to check if it exists + try: + hf_hub_download( + repo_id=model_id, + filename=SAFETENSORS_SINGLE_FILE, + revision=revision, + cache_dir=cache_dir, + force_download=False, # Just check, don't force download + proxies=proxies, + token=token, + local_files_only=local_files_only, + ) + is_finetuned_checkpoint = True + except HfHubHTTPError: + is_finetuned_checkpoint = False + except Exception: + is_finetuned_checkpoint = False + + if is_finetuned_checkpoint: + # This is a fine-tuned LeRobot checkpoint - use parent class loading + print("Detected fine-tuned LeRobot checkpoint, loading with state dict...") + return super().from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + config=config, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + strict=strict, + **kwargs, + ) + + # This is a base GR00T model - load it fresh + print("Detected base GR00T model, loading from HuggingFace...") + + if config is None: + # Create default config with the pretrained path + config = GrootConfig(base_model_path=str(pretrained_name_or_path)) + + # Add minimal visual feature required for validation + # validate_features() will automatically add state and action features + # These are placeholders - actual robot features come from the preprocessor + if not config.input_features: + config.input_features = { + f"{OBS_IMAGES}.camera": PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 224, 224), # Default image size from config + ), + } + else: + # Override the base_model_path with the provided path + config.base_model_path = str(pretrained_name_or_path) + + # Pass through any additional config overrides from kwargs + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + + # Create a fresh policy instance - this will automatically load the GR00T model + # in __init__ via _create_groot_model() + policy = cls(config) + + policy.eval() + return policy + def get_optim_params(self) -> dict: return self.parameters() diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 0445d6c00..58b5dc07b 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -1297,3 +1297,14 @@ class PI0Policy(PreTrainedPolicy): loss = losses.mean() loss_dict["loss"] = loss.item() return loss, loss_dict + + def _get_default_peft_targets(self) -> dict[str, any]: + """Return default PEFT target modules for PI0 fine-tuning.""" + common_projections = ( + "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" + ) + target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))" + return { + "target_modules": target_modules, + "modules_to_save": [], + } diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 11d8b4d68..104ec63bf 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -1270,3 +1270,14 @@ class PI05Policy(PreTrainedPolicy): loss = losses.mean() loss_dict["loss"] = loss.item() return loss, loss_dict + + def _get_default_peft_targets(self) -> dict[str, any]: + """Return default PEFT target modules for PI0.5 fine-tuning.""" + common_projections = ( + "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" + ) + target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))" + return { + "target_modules": target_modules, + "modules_to_save": [], + } diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index a1499d077..e730b78a7 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -13,6 +13,7 @@ # limitations under the License. import abc import builtins +import dataclasses import logging import os from importlib.resources import files @@ -265,3 +266,166 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): card = ModelCard.from_template(card_data, template_str=template_card) card.validate() return card + + def wrap_with_peft( + self, + peft_config=None, + peft_cli_overrides: dict | None = None, + ) -> "PreTrainedPolicy": + """ + Wrap this policy with PEFT adapters for parameter-efficient fine-tuning. + + This method is the single entry point for PEFT integration. Subclasses should + override `_get_default_peft_targets()` to provide default target modules, and + `_validate_peft_config()` for policy-specific validation. + + Args: + peft_config: Optional PEFT adapter configuration (e.g., LoraConfig). + If provided, used directly (with CLI overrides applied). + peft_cli_overrides: Optional dict of CLI overrides (method_type, target_modules, r, etc.) + These are merged with policy defaults to build the final config. + """ + from peft import get_peft_model + + # If user provided a complete config, use it directly (with overrides) + if peft_config is not None: + final_config = peft_config + if peft_cli_overrides: + final_config = self._apply_peft_cli_overrides(final_config, peft_cli_overrides) + else: + # Build config from defaults + CLI overrides + final_config = self._build_peft_config(peft_cli_overrides or {}) + + # Validate the configuration + self._validate_peft_config(final_config) + + # Freeze base parameters, only adapter params will be trained + for p in self.parameters(): + p.requires_grad_(False) + + # Store pretrained path for PEFT's base_model_name_or_path + if self.config.pretrained_path: + self.name_or_path = str(self.config.pretrained_path) + + # Wrap with PEFT + peft_model = get_peft_model(self, final_config) + + # Mark config as using PEFT for proper loading later + peft_model.config.use_peft = True + + logging.info(f"Wrapped {self.name} with PEFT ({type(final_config).__name__})") + return peft_model + + def _get_default_peft_targets(self) -> dict[str, any] | None: + """ + Return default PEFT target modules for this policy. + + Override this in subclasses to provide policy-specific defaults. These defaults + are PEFT-method agnostic - they only specify which modules to target. + + """ + return None + + def _validate_peft_config(self, peft_config) -> None: + """ + Validate the PEFT configuration for this policy. + + Override this in subclasses to add policy-specific validation or warnings. + The default implementation checks that a pretrained_path exists. + + Args: + peft_config: The PEFT configuration to validate. + + Raises: + ValueError: If the configuration is invalid. + """ + if not self.config.pretrained_path: + raise ValueError( + "Training from scratch using PEFT is unlikely to yield good results. " + "Supply a `policy.pretrained_path` to fine-tune an existing model." + ) + + def _preprocess_peft_cli_overrides(self, cli_overrides: dict, peft_method_type) -> dict: + """ + Preprocess CLI overrides: rename keys and handle method-specific init_type. + + Args: + cli_overrides: Dict of CLI options (will be copied, not mutated). + peft_method_type: The PeftType enum value for the PEFT method. + + Returns: + Preprocessed dict with renamed keys and init_type mapped to method-specific key. + """ + from peft import PeftType + + cli_overrides = cli_overrides.copy() + + # Handle the full_training_modules -> modules_to_save rename + if "full_training_modules" in cli_overrides: + cli_overrides["modules_to_save"] = cli_overrides.pop("full_training_modules") + + # Remove method_type as it's handled separately + cli_overrides.pop("method_type", None) + + # Handle init_type specially based on PEFT method + init_type = cli_overrides.pop("init_type", None) + if init_type is not None: + if peft_method_type == PeftType.LORA: + cli_overrides["init_lora_weights"] = init_type + elif peft_method_type == PeftType.MISS: + cli_overrides["init_weights"] = init_type + else: + raise ValueError(f"Init type '{init_type}' unknown for PEFT method {peft_method_type}.") + + return cli_overrides + + def _build_peft_config(self, cli_overrides: dict): + """Build a PEFT config from policy defaults and CLI overrides.""" + from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType + + # Determine PEFT method type (default to LORA) + method_type_str = cli_overrides.get("method_type") or "lora" + peft_method_type = PeftType[method_type_str.upper()] + peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type] + + # Preprocess CLI overrides + cli_overrides = self._preprocess_peft_cli_overrides(cli_overrides, peft_method_type) + + # Start with policy defaults, apply CLI overrides + config_dict = dict(self._get_default_peft_targets() or {}) + for key, value in cli_overrides.items(): + if value is not None: + config_dict[key] = value + + # Ensure we have target_modules + if not config_dict.get("target_modules"): + raise ValueError( + f"Policy '{self.name}' does not define default target_modules. " + "Please pass --peft.target_modules explicitly." + ) + + return peft_config_cls(**config_dict) + + def _apply_peft_cli_overrides(self, peft_config, cli_overrides: dict): + """Apply CLI overrides to an existing PEFT config.""" + from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType + + # Get method type from existing config or CLI override + method_type_str = cli_overrides.get("method_type") + if method_type_str: + peft_method_type = PeftType[method_type_str.upper()] + peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type] + else: + peft_method_type = PeftType(peft_config.peft_type) + peft_config_cls = type(peft_config) + + # Preprocess CLI overrides + cli_overrides = self._preprocess_peft_cli_overrides(cli_overrides, peft_method_type) + + # Start with existing config, apply CLI overrides + config_dict = {k: v for k, v in dataclasses.asdict(peft_config).items() if not k.startswith("_")} + for key, value in cli_overrides.items(): + if value is not None: + config_dict[key] = value + + return peft_config_cls(**config_dict) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index f998661f9..c611e9ba2 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -480,6 +480,28 @@ class SmolVLAPolicy(PreTrainedPolicy): actions = pad_vector(batch[ACTION], self.config.max_action_dim) return actions + def _get_default_peft_targets(self) -> dict[str, any]: + """Return default PEFT target modules for SmolVLA fine-tuning.""" + common_projections = ( + "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" + ) + target_modules = rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))" + return { + "target_modules": target_modules, + "modules_to_save": [], + } + + def _validate_peft_config(self, peft_config) -> None: + """Validate PEFT configuration for SmolVLA.""" + super()._validate_peft_config(peft_config) + if not self.config.load_vlm_weights: + import logging + + logging.warning( + "Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. " + "Set `load_vlm_weights=True` to fine-tune the existing policy." + ) + def pad_tensor(tensor, max_len, pad_value=0): """ diff --git a/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py b/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py index 784a95577..cdf6efde1 100644 --- a/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py +++ b/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py @@ -24,7 +24,8 @@ import numpy as np import requests from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..robot import Robot from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig @@ -99,6 +100,7 @@ class EarthRoverMiniPlus(Robot): """Check if robot is connected to SDK.""" return self._is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """Connect to robot via Frodobots SDK. @@ -109,8 +111,6 @@ class EarthRoverMiniPlus(Robot): DeviceAlreadyConnectedError: If robot is already connected DeviceNotConnectedError: If cannot connect to SDK server """ - if self._is_connected: - raise DeviceAlreadyConnectedError(f"{self.name} is already connected") # Verify SDK is running and accessible try: @@ -197,6 +197,7 @@ class EarthRoverMiniPlus(Robot): ACTION_ANGULAR_VEL: float, } + @check_if_not_connected def get_observation(self) -> RobotObservation: """Get current robot observation from SDK. @@ -223,8 +224,6 @@ class EarthRoverMiniPlus(Robot): Robot telemetry is retrieved from /data endpoint. All SDK values are normalized to appropriate ranges for dataset recording. """ - if not self._is_connected: - raise DeviceNotConnectedError(f"{self.name} is not connected") observation = {} @@ -255,6 +254,7 @@ class EarthRoverMiniPlus(Robot): return observation + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: """Send action to robot via SDK. @@ -272,8 +272,6 @@ class EarthRoverMiniPlus(Robot): Actions are sent to SDK via POST /control endpoint. SDK expects commands in range [-1, 1]. """ - if not self._is_connected: - raise DeviceNotConnectedError(f"{self.name} is not connected") # Extract action values and convert to float linear = float(action.get(ACTION_LINEAR_VEL, 0.0)) @@ -291,6 +289,7 @@ class EarthRoverMiniPlus(Robot): ACTION_ANGULAR_VEL: angular, } + @check_if_not_connected def disconnect(self) -> None: """Disconnect from robot. @@ -299,8 +298,6 @@ class EarthRoverMiniPlus(Robot): Raises: DeviceNotConnectedError: If robot is not connected """ - if not self._is_connected: - raise DeviceNotConnectedError(f"{self.name} is not connected") # Stop the robot before disconnecting try: diff --git a/src/lerobot/robots/hope_jr/hope_jr_arm.py b/src/lerobot/robots/hope_jr/hope_jr_arm.py index 4be8a0b17..5fd9c4d1d 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_arm.py +++ b/src/lerobot/robots/hope_jr/hope_jr_arm.py @@ -25,7 +25,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, ) from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from ..utils import ensure_safe_goal_position @@ -82,13 +82,12 @@ class HopeJrArm(Robot): def is_connected(self) -> bool: return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """ We assume that at connection time, arm is in a rest position, and torque can be safely disabled to run calibration. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") self.bus.connect(handshake=False) if not self.is_calibrated and calibrate: @@ -128,10 +127,8 @@ class HopeJrArm(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @check_if_not_connected def get_observation(self) -> RobotObservation: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - # Read arm position start = time.perf_counter() obs_dict = self.bus.sync_read("Present_Position", self.other_motors) @@ -149,10 +146,8 @@ class HopeJrArm(Robot): return obs_dict + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} # Cap goal position when too far away from present position. @@ -165,10 +160,8 @@ class HopeJrArm(Robot): self.bus.sync_write("Goal_Position", goal_pos) return {f"{motor}.pos": val for motor, val in goal_pos.items()} + @check_if_not_connected def disconnect(self): - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self.bus.disconnect(self.config.disable_torque_on_disconnect) for cam in self.cameras.values(): cam.disconnect() diff --git a/src/lerobot/robots/hope_jr/hope_jr_hand.py b/src/lerobot/robots/hope_jr/hope_jr_hand.py index 73fb4464f..1e5c72b72 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_hand.py +++ b/src/lerobot/robots/hope_jr/hope_jr_hand.py @@ -25,7 +25,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, ) from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from .config_hope_jr import HopeJrHandConfig @@ -118,10 +118,8 @@ class HopeJrHand(Robot): def is_connected(self) -> bool: return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - self.bus.connect() if not self.is_calibrated and calibrate: self.calibrate() @@ -159,10 +157,8 @@ class HopeJrHand(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @check_if_not_connected def get_observation(self) -> RobotObservation: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - obs_dict = {} # Read hand position @@ -181,18 +177,14 @@ class HopeJrHand(Robot): return obs_dict + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} self.bus.sync_write("Goal_Position", goal_pos) return action + @check_if_not_connected def disconnect(self): - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self.bus.disconnect(self.config.disable_torque_on_disconnect) for cam in self.cameras.values(): cam.disconnect() diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py index a1d001ba8..fee0adba9 100644 --- a/src/lerobot/robots/koch_follower/koch_follower.py +++ b/src/lerobot/robots/koch_follower/koch_follower.py @@ -25,7 +25,7 @@ from lerobot.motors.dynamixel import ( OperatingMode, ) from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from ..utils import ensure_safe_goal_position @@ -84,13 +84,12 @@ class KochFollower(Robot): def is_connected(self) -> bool: return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """ We assume that at connection time, arm is in a rest position, and torque can be safely disabled to run calibration. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") self.bus.connect() if not self.is_calibrated and calibrate: @@ -182,10 +181,8 @@ class KochFollower(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @check_if_not_connected def get_observation(self) -> RobotObservation: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - # Read arm position start = time.perf_counter() obs_dict = self.bus.sync_read("Present_Position") @@ -202,6 +199,7 @@ class KochFollower(Robot): return obs_dict + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: """Command arm to move to a target joint configuration. @@ -215,8 +213,6 @@ class KochFollower(Robot): Returns: RobotAction: The action sent to the motors, potentially clipped. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} @@ -231,10 +227,8 @@ class KochFollower(Robot): self.bus.sync_write("Goal_Position", goal_pos) return {f"{motor}.pos": val for motor, val in goal_pos.items()} + @check_if_not_connected def disconnect(self): - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self.bus.disconnect(self.config.disable_torque_on_disconnect) for cam in self.cameras.values(): cam.disconnect() diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py index c84e81001..54848f49d 100644 --- a/src/lerobot/robots/lekiwi/lekiwi.py +++ b/src/lerobot/robots/lekiwi/lekiwi.py @@ -29,7 +29,7 @@ from lerobot.motors.feetech import ( OperatingMode, ) from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from ..utils import ensure_safe_goal_position @@ -109,10 +109,8 @@ class LeKiwi(Robot): def is_connected(self) -> bool: return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - self.bus.connect() if not self.is_calibrated and calibrate: logger.info( @@ -339,10 +337,8 @@ class LeKiwi(Robot): "theta.vel": theta, } # m/s and deg/s + @check_if_not_connected def get_observation(self) -> RobotObservation: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - # Read actuators position for arm and vel for base start = time.perf_counter() arm_pos = self.bus.sync_read("Present_Position", self.arm_motors) @@ -370,6 +366,7 @@ class LeKiwi(Robot): return obs_dict + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: """Command lekiwi to move to a target joint configuration. @@ -383,8 +380,6 @@ class LeKiwi(Robot): Returns: RobotAction: the action sent to the motors, potentially clipped. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") arm_goal_pos = {k: v for k, v in action.items() if k.endswith(".pos")} base_goal_vel = {k: v for k, v in action.items() if k.endswith(".vel")} @@ -412,10 +407,8 @@ class LeKiwi(Robot): self.bus.sync_write("Goal_Velocity", dict.fromkeys(self.base_motors, 0), num_retry=5) logger.info("Base motors stopped") + @check_if_not_connected def disconnect(self): - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self.stop_base() self.bus.disconnect(self.config.disable_torque_on_disconnect) for cam in self.cameras.values(): diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py index bb865dc10..1d5ea64a6 100644 --- a/src/lerobot/robots/lekiwi/lekiwi_client.py +++ b/src/lerobot/robots/lekiwi/lekiwi_client.py @@ -24,7 +24,8 @@ import numpy as np from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_STATE -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..robot import Robot from .config_lekiwi import LeKiwiClientConfig @@ -112,14 +113,10 @@ class LeKiwiClient(Robot): def is_calibrated(self) -> bool: pass + @check_if_already_connected def connect(self) -> None: """Establishes ZMQ sockets with the remote mobile robot""" - if self._is_connected: - raise DeviceAlreadyConnectedError( - "LeKiwi Daemon is already connected. Do not run `robot.connect()` twice." - ) - zmq = self._zmq self.zmq_context = zmq.Context() self.zmq_cmd_socket = self.zmq_context.socket(zmq.PUSH) @@ -252,14 +249,13 @@ class LeKiwiClient(Robot): return new_frames, new_state + @check_if_not_connected def get_observation(self) -> RobotObservation: """ Capture observations from the remote robot: current follower arm positions, present wheel speeds (converted to body-frame velocities: x, y, theta), and a camera frame. Receives over ZMQ, translate to body-frame vel """ - if not self._is_connected: - raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.") frames, obs_dict = self._get_data() @@ -307,6 +303,7 @@ class LeKiwiClient(Robot): def configure(self): pass + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: """Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ @@ -318,10 +315,6 @@ class LeKiwiClient(Robot): Returns: np.ndarray: the action sent to the motors, potentially clipped. """ - if not self._is_connected: - raise DeviceNotConnectedError( - "ManipulatorRobot is not connected. You need to run `robot.connect()`." - ) self.zmq_cmd_socket.send_string(json.dumps(action)) # action is in motor space @@ -332,13 +325,10 @@ class LeKiwiClient(Robot): action_sent[ACTION] = actions return action_sent + @check_if_not_connected def disconnect(self): """Cleans ZMQ comms""" - if not self._is_connected: - raise DeviceNotConnectedError( - "LeKiwi is not connected. You need to run `robot.connect()` before disconnecting." - ) self.zmq_observation_socket.close() self.zmq_cmd_socket.close() self.zmq_context.term() diff --git a/src/lerobot/robots/omx_follower/omx_follower.py b/src/lerobot/robots/omx_follower/omx_follower.py index 14668b3a7..a171affbd 100644 --- a/src/lerobot/robots/omx_follower/omx_follower.py +++ b/src/lerobot/robots/omx_follower/omx_follower.py @@ -26,7 +26,7 @@ from lerobot.motors.dynamixel import ( OperatingMode, ) from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from ..utils import ensure_safe_goal_position @@ -84,6 +84,7 @@ class OmxFollower(Robot): def is_connected(self) -> bool: return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """ For OMX robots that come pre-calibrated: @@ -91,8 +92,6 @@ class OmxFollower(Robot): - This allows using pre-calibrated robots without manual calibration - If no calibration file exists, use factory default values (homing_offset=0, range_min=0, range_max=4095) """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") self.bus.connect() if not self.is_calibrated and calibrate: @@ -165,10 +164,8 @@ class OmxFollower(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @check_if_not_connected def get_observation(self) -> RobotObservation: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - # Read arm position start = time.perf_counter() obs_dict = self.bus.sync_read("Present_Position") @@ -185,6 +182,7 @@ class OmxFollower(Robot): return obs_dict + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: """Command arm to move to a target joint configuration. @@ -198,8 +196,6 @@ class OmxFollower(Robot): Returns: RobotAction: The action sent to the motors, potentially clipped. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} @@ -214,10 +210,8 @@ class OmxFollower(Robot): self.bus.sync_write("Goal_Position", goal_pos) return {f"{motor}.pos": val for motor, val in goal_pos.items()} + @check_if_not_connected def disconnect(self): - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self.bus.disconnect(self.config.disable_torque_on_disconnect) for cam in self.cameras.values(): cam.disconnect() diff --git a/src/lerobot/robots/so_follower/so_follower.py b/src/lerobot/robots/so_follower/so_follower.py index 011a0061e..b4d11fe3f 100644 --- a/src/lerobot/robots/so_follower/so_follower.py +++ b/src/lerobot/robots/so_follower/so_follower.py @@ -26,7 +26,7 @@ from lerobot.motors.feetech import ( OperatingMode, ) from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from ..utils import ensure_safe_goal_position @@ -85,13 +85,12 @@ class SOFollower(Robot): def is_connected(self) -> bool: return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """ We assume that at connection time, arm is in a rest position, and torque can be safely disabled to run calibration. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") self.bus.connect() if not self.is_calibrated and calibrate: @@ -176,10 +175,8 @@ class SOFollower(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @check_if_not_connected def get_observation(self) -> RobotObservation: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - # Read arm position start = time.perf_counter() obs_dict = self.bus.sync_read("Present_Position") @@ -196,6 +193,7 @@ class SOFollower(Robot): return obs_dict + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: """Command arm to move to a target joint configuration. @@ -209,8 +207,6 @@ class SOFollower(Robot): Returns: RobotAction: the action sent to the motors, potentially clipped. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} @@ -225,10 +221,8 @@ class SOFollower(Robot): self.bus.sync_write("Goal_Position", goal_pos) return {f"{motor}.pos": val for motor, val in goal_pos.items()} + @check_if_not_connected def disconnect(self): - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self.bus.disconnect(self.config.disable_torque_on_disconnect) for cam in self.cameras.values(): cam.disconnect() diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index e835b1de6..4ba6ce44f 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -66,23 +66,23 @@ Remove camera feature: --operation.type remove_feature \ --operation.feature_names "['observation.images.top']" -Convert image dataset to video format (saves locally): +Convert image dataset to video format and save locally: python -m lerobot.scripts.lerobot_edit_dataset \ --repo_id lerobot/pusht_image \ - --operation.type convert_to_video \ + --operation.type convert_image_to_video \ --operation.output_dir /path/to/output/pusht_video -Convert image dataset and save with new repo_id: +Convert image dataset to video format and save with new repo_id: python -m lerobot.scripts.lerobot_edit_dataset \ --repo_id lerobot/pusht_image \ --new_repo_id lerobot/pusht_video \ - --operation.type convert_to_video + --operation.type convert_image_to_video -Convert and push to hub: +Convert image dataset to video format and push to hub: python -m lerobot.scripts.lerobot_edit_dataset \ --repo_id lerobot/pusht_image \ --new_repo_id lerobot/pusht_video \ - --operation.type convert_to_video \ + --operation.type convert_image_to_video \ --push_to_hub true Using JSON config file: @@ -92,24 +92,19 @@ Using JSON config file: import logging import shutil -from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from pathlib import Path -import pandas as pd -from tqdm import tqdm - from lerobot.configs import parser from lerobot.datasets.dataset_tools import ( + convert_image_to_video_dataset, delete_episodes, merge_datasets, remove_feature, split_dataset, ) -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata -from lerobot.datasets.utils import write_stats, write_tasks -from lerobot.datasets.video_utils import encode_video_frames, get_video_info -from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import HF_LEROBOT_HOME from lerobot.utils.utils import init_logging @@ -138,8 +133,8 @@ class RemoveFeatureConfig: @dataclass -class ConvertToVideoConfig: - type: str = "convert_to_video" +class ConvertImageToVideoConfig: + type: str = "convert_image_to_video" output_dir: str | None = None vcodec: str = "libsvtav1" pix_fmt: str = "yuv420p" @@ -148,12 +143,16 @@ class ConvertToVideoConfig: fast_decode: int = 0 episode_indices: list[int] | None = None num_workers: int = 4 + max_episodes_per_batch: int | None = None + max_frames_per_batch: int | None = None @dataclass class EditDatasetConfig: repo_id: str - operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertToVideoConfig + operation: ( + DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig + ) root: str | None = None new_repo_id: str | None = None push_to_hub: bool = False @@ -297,362 +296,7 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None: LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() -def save_episode_images_for_video( - dataset: LeRobotDataset, - imgs_dir: Path, - img_key: str, - episode_index: int, - num_workers: int = 4, -) -> None: - """Save images from a specific episode and camera to disk for video encoding. - - Args: - dataset: The LeRobot dataset to extract images from - imgs_dir: Directory to save images to - img_key: The image key (camera) to extract - episode_index: Index of the episode to save - num_workers: Number of threads for parallel image saving - """ - # Create directory - imgs_dir.mkdir(parents=True, exist_ok=True) - - # Get dataset without torch format for PIL image access - hf_dataset = dataset.hf_dataset.with_format(None) - - # Select only this camera's images - imgs_dataset = hf_dataset.select_columns(img_key) - - # Get episode start and end indices - from_idx = dataset.meta.episodes["dataset_from_index"][episode_index] - to_idx = dataset.meta.episodes["dataset_to_index"][episode_index] - - # Get all items for this episode - episode_dataset = imgs_dataset.select(range(from_idx, to_idx)) - - # Define function to save a single image - def save_single_image(i_item_tuple): - i, item = i_item_tuple - img = item[img_key] - # Use frame-XXXXXX.png format to match encode_video_frames expectations - img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100) - return i - - # Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png) - items = list(enumerate(episode_dataset)) - - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(save_single_image, item) for item in items] - for future in as_completed(futures): - future.result() # This will raise any exceptions that occurred - - -def encode_episode_videos( - dataset: LeRobotDataset, - new_meta: LeRobotDatasetMetadata, - episode_index: int, - vcodec: str, - pix_fmt: str, - g: int, - crf: int, - fast_decode: int, - temp_dir: Path, - num_image_workers: int = 4, -) -> dict[str, dict]: - """Encode videos for a single episode and return video metadata. - - Args: - dataset: Source dataset with images - new_meta: Metadata object for the new video dataset - episode_index: Episode index to process - vcodec: Video codec - pix_fmt: Pixel format - g: Group of pictures size - crf: Constant rate factor - fast_decode: Fast decode tuning - temp_dir: Temporary directory for images - num_image_workers: Number of workers for saving images - - Returns: - Dictionary mapping video keys to their metadata (chunk_index, file_index, timestamps) - """ - hf_dataset = dataset.hf_dataset.with_format(None) - img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)] - - video_metadata = {} - fps = int(dataset.fps) # Convert to int for PyAV compatibility - episode_length = dataset.meta.episodes["length"][episode_index] - episode_duration = episode_length / dataset.fps # Use original fps for duration calculation - - for img_key in img_keys: - # Save images temporarily - imgs_dir = temp_dir / f"episode_{episode_index:06d}" / img_key - save_episode_images_for_video(dataset, imgs_dir, img_key, episode_index, num_image_workers) - - # Determine chunk and file indices - # For simplicity, we'll put each episode in its own file - chunk_idx = episode_index // new_meta.chunks_size - file_idx = episode_index % new_meta.chunks_size - - # Create video path in the new dataset structure - video_path = new_meta.root / new_meta.video_path.format( - video_key=img_key, chunk_index=chunk_idx, file_index=file_idx - ) - video_path.parent.mkdir(parents=True, exist_ok=True) - - # Encode video - encode_video_frames( - imgs_dir=imgs_dir, - video_path=video_path, - fps=fps, - vcodec=vcodec, - pix_fmt=pix_fmt, - g=g, - crf=crf, - fast_decode=fast_decode, - overwrite=True, - ) - - # Clean up temporary images - shutil.rmtree(imgs_dir) - - # Store video metadata - video_metadata[img_key] = { - f"videos/{img_key}/chunk_index": chunk_idx, - f"videos/{img_key}/file_index": file_idx, - f"videos/{img_key}/from_timestamp": 0.0, - f"videos/{img_key}/to_timestamp": episode_duration, - } - - return video_metadata - - -def convert_dataset_to_videos( - dataset: LeRobotDataset, - output_dir: Path, - repo_id: str | None = None, - vcodec: str = "libsvtav1", - pix_fmt: str = "yuv420p", - g: int = 2, - crf: int = 30, - fast_decode: int = 0, - episode_indices: list[int] | None = None, - num_workers: int = 4, -) -> LeRobotDataset: - """Convert image-based dataset to video-based dataset. - - Creates a new LeRobotDataset with videos instead of images, following the proper - LeRobot dataset structure with videos stored in chunked MP4 files. - - Args: - dataset: The source LeRobot dataset with images - output_dir: Directory to save the new video dataset - repo_id: Repository ID for the new dataset (default: original_id + "_video") - vcodec: Video codec (default: libsvtav1) - pix_fmt: Pixel format (default: yuv420p) - g: Group of pictures size (default: 2) - crf: Constant rate factor (default: 30) - fast_decode: Fast decode tuning (default: 0) - episode_indices: List of episode indices to convert (None = all episodes) - num_workers: Number of threads for parallel processing (default: 4) - - Returns: - New LeRobotDataset with videos - """ - # Check that it's an image dataset - if len(dataset.meta.video_keys) > 0: - raise ValueError( - f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}" - ) - - # Get all image keys - hf_dataset = dataset.hf_dataset.with_format(None) - img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)] - - if len(img_keys) == 0: - raise ValueError(f"No image keys found in dataset {dataset.repo_id}") - - # Determine which episodes to process - if episode_indices is None: - episode_indices = list(range(dataset.meta.total_episodes)) - - if repo_id is None: - repo_id = f"{dataset.repo_id}_video" - - logging.info( - f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}" - ) - logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}") - - # Create new features dict, converting image features to video features - new_features = {} - for key, value in dataset.meta.features.items(): - if key not in img_keys: - new_features[key] = value - else: - # Convert image key to video format - new_features[key] = value.copy() - new_features[key]["dtype"] = "video" # Change dtype from "image" to "video" - # Video info will be updated after episodes are encoded - - # Create new metadata for video dataset - new_meta = LeRobotDatasetMetadata.create( - repo_id=repo_id, - fps=dataset.meta.fps, - features=new_features, - robot_type=dataset.meta.robot_type, - root=output_dir, - use_videos=True, - chunks_size=dataset.meta.chunks_size, - data_files_size_in_mb=dataset.meta.data_files_size_in_mb, - video_files_size_in_mb=dataset.meta.video_files_size_in_mb, - ) - - # Create temporary directory for image extraction - temp_dir = output_dir / "temp_images" - temp_dir.mkdir(parents=True, exist_ok=True) - - # Process each episode - all_episode_metadata = [] - - try: - for ep_idx in tqdm(episode_indices, desc="Converting episodes to videos"): - # Get episode metadata from source - src_episode = dataset.meta.episodes[ep_idx] - - # Encode videos for this episode - video_metadata = encode_episode_videos( - dataset=dataset, - new_meta=new_meta, - episode_index=ep_idx, - vcodec=vcodec, - pix_fmt=pix_fmt, - g=g, - crf=crf, - fast_decode=fast_decode, - temp_dir=temp_dir, - num_image_workers=num_workers, - ) - - # Build episode metadata - episode_meta = { - "episode_index": ep_idx, - "length": src_episode["length"], - "dataset_from_index": ep_idx * src_episode["length"], - "dataset_to_index": (ep_idx + 1) * src_episode["length"], - } - - # Add video metadata - for img_key in img_keys: - episode_meta.update(video_metadata[img_key]) - - # Add data chunk/file info (using same structure as source) - if "data/chunk_index" in src_episode: - episode_meta["data/chunk_index"] = src_episode["data/chunk_index"] - episode_meta["data/file_index"] = src_episode["data/file_index"] - - all_episode_metadata.append(episode_meta) - - # Copy and transform data files (removing image columns) - _copy_data_without_images(dataset, new_meta, episode_indices, img_keys) - - # Save episode metadata - episodes_df = pd.DataFrame(all_episode_metadata) - episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet" - episodes_path.parent.mkdir(parents=True, exist_ok=True) - episodes_df.to_parquet(episodes_path, index=False) - - # Update metadata info - new_meta.info["total_episodes"] = len(episode_indices) - new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata) - new_meta.info["total_tasks"] = dataset.meta.total_tasks - new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"} - - # Update video info for all image keys (now videos) - # We need to manually set video info since update_video_info() checks video_keys first - for img_key in img_keys: - if not new_meta.features[img_key].get("info", None): - video_path = new_meta.root / new_meta.video_path.format( - video_key=img_key, chunk_index=0, file_index=0 - ) - new_meta.info["features"][img_key]["info"] = get_video_info(video_path) - - from lerobot.datasets.utils import write_info - - write_info(new_meta.info, new_meta.root) - - # Copy stats and tasks - if dataset.meta.stats is not None: - # Remove image stats - new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys} - write_stats(new_stats, new_meta.root) - - if dataset.meta.tasks is not None: - write_tasks(dataset.meta.tasks, new_meta.root) - - finally: - # Clean up temporary directory - if temp_dir.exists(): - shutil.rmtree(temp_dir) - - logging.info(f"✓ Completed converting {dataset.repo_id} to video format") - logging.info(f"New dataset saved to: {output_dir}") - - # Return new dataset - return LeRobotDataset(repo_id=repo_id, root=output_dir) - - -def _copy_data_without_images( - src_dataset: LeRobotDataset, - dst_meta: LeRobotDatasetMetadata, - episode_indices: list[int], - img_keys: list[str], -) -> None: - """Copy data files without image columns. - - Args: - src_dataset: Source dataset - dst_meta: Destination metadata - episode_indices: Episodes to include - img_keys: Image keys to remove - """ - from lerobot.datasets.utils import DATA_DIR - - data_dir = src_dataset.root / DATA_DIR - parquet_files = sorted(data_dir.glob("*/*.parquet")) - - if not parquet_files: - raise ValueError(f"No parquet files found in {data_dir}") - - episode_set = set(episode_indices) - - for src_path in tqdm(parquet_files, desc="Processing data files"): - df = pd.read_parquet(src_path).reset_index(drop=True) - - # Filter to only include selected episodes - df = df[df["episode_index"].isin(episode_set)].copy() - - if len(df) == 0: - continue - - # Remove image columns - columns_to_drop = [col for col in img_keys if col in df.columns] - if columns_to_drop: - df = df.drop(columns=columns_to_drop) - - # Get chunk and file indices from path - relative_path = src_path.relative_to(src_dataset.root) - chunk_dir = relative_path.parts[1] - file_name = relative_path.parts[2] - chunk_idx = int(chunk_dir.split("-")[1]) - file_idx = int(file_name.split("-")[1].split(".")[0]) - - # Write to destination without pandas index - dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet" - dst_path.parent.mkdir(parents=True, exist_ok=True) - df.to_parquet(dst_path, index=False) - - -def handle_convert_to_video(cfg: EditDatasetConfig) -> None: +def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None: # Note: Parser may create any config type with the right fields, so we access fields directly # instead of checking isinstance() dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) @@ -664,8 +308,12 @@ def handle_convert_to_video(cfg: EditDatasetConfig) -> None: if cfg.new_repo_id: # Use new_repo_id for both local storage and hub push output_repo_id = cfg.new_repo_id - output_dir = Path(cfg.root) / cfg.new_repo_id if cfg.root else HF_LEROBOT_HOME / cfg.new_repo_id - logging.info(f"Saving to new dataset: {cfg.new_repo_id}") + # Place new dataset as a sibling to the original dataset + # Get the parent of the actual dataset root (not cfg.root which might be the lerobot cache dir) + # Extract just the dataset name (after last slash) for the local directory + local_dir_name = cfg.new_repo_id.split("/")[-1] + output_dir = dataset.root.parent / local_dir_name + logging.info(f"Saving to new dataset: {cfg.new_repo_id} at {output_dir}") elif output_dir_config: # Use custom output directory for local-only storage output_dir = Path(output_dir_config) @@ -675,12 +323,15 @@ def handle_convert_to_video(cfg: EditDatasetConfig) -> None: else: # Auto-generate name: append "_video" to original repo_id output_repo_id = f"{cfg.repo_id}_video" - output_dir = Path(cfg.root) / output_repo_id if cfg.root else HF_LEROBOT_HOME / output_repo_id + # Place new dataset as a sibling to the original dataset + # Extract just the dataset name (after last slash) for the local directory + local_dir_name = output_repo_id.split("/")[-1] + output_dir = dataset.root.parent / local_dir_name logging.info(f"Saving to auto-generated location: {output_dir}") logging.info(f"Converting dataset {cfg.repo_id} to video format") - new_dataset = convert_dataset_to_videos( + new_dataset = convert_image_to_video_dataset( dataset=dataset, output_dir=output_dir, repo_id=output_repo_id, @@ -691,6 +342,8 @@ def handle_convert_to_video(cfg: EditDatasetConfig) -> None: fast_decode=getattr(cfg.operation, "fast_decode", 0), episode_indices=getattr(cfg.operation, "episode_indices", None), num_workers=getattr(cfg.operation, "num_workers", 4), + max_episodes_per_batch=getattr(cfg.operation, "max_episodes_per_batch", None), + max_frames_per_batch=getattr(cfg.operation, "max_frames_per_batch", None), ) logging.info("Video dataset created successfully!") @@ -718,8 +371,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: handle_merge(cfg) elif operation_type == "remove_feature": handle_remove_feature(cfg) - elif operation_type == "convert_to_video": - handle_convert_to_video(cfg) + elif operation_type == "convert_image_to_video": + handle_convert_image_to_video(cfg) else: raise ValueError( f"Unknown operation type: {operation_type}\n" diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index aca3c8672..bba59385a 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -150,92 +150,6 @@ def update_policy( return train_metrics, output_dict -def get_default_peft_configuration(policy_type): - """Build a basic PEFT configuration for the given policy type assuming that we train a policy from a checkpoint.""" - - common_projections = "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" - - if policy_type == "smolvla": - return { - "target_modules": rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))", - "modules_to_save": [], - } - elif policy_type in ("pi0", "pi05"): - return { - "target_modules": rf"(.*\.gemma_expert\..*\.self_attn.(q|v)_proj|model\.({common_projections}))", - "modules_to_save": [], - } - - return {"modules_to_save": None} - - -def wrap_policy_in_peft_model(cfg, policy): - from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType, get_peft_model - - # Disable all gradients because we'll only train the parameters selected by the PEFT method. - # Layers that should receive gradients anyway need to be listed in `modules_to_save`. - for p in policy.parameters(): - p.requires_grad_(False) - - if not cfg.policy.pretrained_path: - raise ValueError( - "Training from scratch using PEFT. This is unlikely to yield good results. " - "Supply a `policy.path` to fine-tune an existing model." - ) - - if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights: - logging.warning( - "Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. Set " - "`load_vlm_weights=True` to fine-tune the existing policy." - ) - - peft_config_policy = get_default_peft_configuration(cfg.policy.type) - peft_config_cli = dataclasses.asdict(cfg.peft) if cfg.peft else {} - peft_config_cli["modules_to_save"] = peft_config_cli["full_training_modules"] # compatibility with PEFT - peft_method_type = PeftType[peft_config_cli["method_type"].upper()] - peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type] - - # Handle specific CLI overrides - for key in ["target_modules", "modules_to_save", "r"]: - if peft_config_cli[key] is not None: - peft_config_policy[key] = peft_config_cli[key] - - if "target_modules" not in peft_config_policy: - raise ValueError( - f"There is no default `target_modules` value for policy {cfg.policy.type}. Please pass it manually." - ) - - # Init method depends on the used PEFT method, your specific PEFT method - # might not be considered here, in that case an error is raised. - if peft_config_cli["init_type"] is not None: - if peft_method_type == "LORA": - peft_config_policy["init_lora_weights"] = peft_config_cli["init_type"] - elif peft_method_type == "MISS": - peft_config_policy["init_weights"] = peft_config_cli["init_type"] - else: - raise ValueError( - f"Init type {peft_config_cli['init_type']} unknown for PEFT method {peft_method_type}." - ) - - # PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the - # correct base model in `make_policy` since in a PEFT loading setting we only get the path to the - # adapter, not the base model. - if policy.config.pretrained_path: - policy.name_or_path = str(policy.config.pretrained_path) - - # Finally wrap the policy in a PEFT model - policy = get_peft_model( - policy, - peft_config_cls(**peft_config_policy), - ) - - # Make sure that the config is tagged as using PEFT so that the loading code can take the - # appropriate steps to use the adapter weights and the PEFT config instead of the full model weights. - policy.config.use_peft = True - - return policy - - @parser.wrap() def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): """ @@ -313,9 +227,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # On real-world data, no need to create an environment as evaluations are done outside train.py, # using the eval.py instead, with gym_dora environment and dora-rs. eval_env = None - if cfg.eval_freq > 0 and cfg.env is not None: - if is_main_process: - logging.info("Creating env") + if cfg.eval_freq > 0 and cfg.env is not None and is_main_process: + logging.info("Creating env") eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) if is_main_process: @@ -328,7 +241,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): if cfg.peft is not None: logging.info("Using PEFT! Wrapping model.") - policy = wrap_policy_in_peft_model(cfg, policy) + # Convert CLI peft config to dict for overrides + peft_cli_overrides = dataclasses.asdict(cfg.peft) + policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides) # Wait for all processes to finish policy creation before continuing accelerator.wait_for_everyone() diff --git a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py index 45c46c100..90bf2a92d 100644 --- a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py +++ b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py @@ -18,6 +18,7 @@ import logging from functools import cached_property from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig +from lerobot.utils.decorators import check_if_not_connected from ..so_leader import SOLeader from ..teleoperator import Teleoperator @@ -91,6 +92,7 @@ class BiSOLeader(Teleoperator): self.left_arm.setup_motors() self.right_arm.setup_motors() + @check_if_not_connected def get_action(self) -> dict[str, float]: action_dict = {} diff --git a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py index 4dbb49c1d..69cb0f971 100644 --- a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py +++ b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py @@ -21,6 +21,7 @@ from typing import Any import numpy as np from lerobot.processor import RobotAction +from lerobot.utils.decorators import check_if_not_connected from ..teleoperator import Teleoperator from ..utils import TeleopEvents @@ -85,6 +86,7 @@ class GamepadTeleop(Teleoperator): self.gamepad = Gamepad() self.gamepad.start() + @check_if_not_connected def get_action(self) -> RobotAction: # Update the controller to get fresh inputs self.gamepad.update() @@ -158,6 +160,7 @@ class GamepadTeleop(Teleoperator): self.gamepad.stop() self.gamepad = None + @property def is_connected(self) -> bool: """Check if gamepad is connected.""" return self.gamepad is not None diff --git a/src/lerobot/teleoperators/homunculus/homunculus_arm.py b/src/lerobot/teleoperators/homunculus/homunculus_arm.py index 43116f5c0..178eed544 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_arm.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_arm.py @@ -22,7 +22,7 @@ from pprint import pformat import serial from lerobot.motors.motors_bus import MotorCalibration, MotorNormMode -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.utils import enter_pressed, move_cursor_up from ..teleoperator import Teleoperator @@ -93,10 +93,8 @@ class HomunculusArm(Teleoperator): with self.serial_lock: return self.serial.is_open and self.thread.is_alive() + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - if not self.serial.is_open: self.serial.open() self.thread.start() @@ -299,6 +297,7 @@ class HomunculusArm(Teleoperator): except Exception as e: logger.debug(f"Error reading frame in background thread for {self}: {e}") + @check_if_not_connected def get_action(self) -> dict[str, float]: joint_positions = self._read() return {f"{joint}.pos": pos for joint, pos in joint_positions.items()} @@ -306,10 +305,8 @@ class HomunculusArm(Teleoperator): def send_feedback(self, feedback: dict[str, float]) -> None: raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - DeviceNotConnectedError(f"{self} is not connected.") - self.stop_event.set() self.thread.join(timeout=1) self.serial.close() diff --git a/src/lerobot/teleoperators/homunculus/homunculus_glove.py b/src/lerobot/teleoperators/homunculus/homunculus_glove.py index fefeec1e8..c4393d660 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_glove.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_glove.py @@ -24,7 +24,7 @@ import serial from lerobot.motors import MotorCalibration from lerobot.motors.motors_bus import MotorNormMode from lerobot.teleoperators.homunculus.joints_translation import homunculus_glove_to_hope_jr_hand -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.utils import enter_pressed, move_cursor_up from ..teleoperator import Teleoperator @@ -119,10 +119,8 @@ class HomunculusGlove(Teleoperator): with self.serial_lock: return self.serial.is_open and self.thread.is_alive() + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - if not self.serial.is_open: self.serial.open() self.thread.start() @@ -325,6 +323,7 @@ class HomunculusGlove(Teleoperator): except Exception as e: logger.debug(f"Error reading frame in background thread for {self}: {e}") + @check_if_not_connected def get_action(self) -> dict[str, float]: joint_positions = self._read() return homunculus_glove_to_hope_jr_hand( @@ -334,10 +333,8 @@ class HomunculusGlove(Teleoperator): def send_feedback(self, feedback: dict[str, float]) -> None: raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - DeviceNotConnectedError(f"{self} is not connected.") - self.stop_event.set() self.thread.join(timeout=1) self.serial.close() diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py index 55c158da8..919f463d3 100644 --- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py +++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py @@ -22,7 +22,7 @@ from queue import Queue from typing import Any from lerobot.processor import RobotAction -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator from ..utils import TeleopEvents @@ -86,12 +86,8 @@ class KeyboardTeleop(Teleoperator): def is_calibrated(self) -> bool: pass + @check_if_already_connected def connect(self) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError( - "Keyboard is already connected. Do not run `robot.connect()` twice." - ) - if PYNPUT_AVAILABLE: logging.info("pynput is available - enabling local keyboard listener.") self.listener = keyboard.Listener( @@ -125,14 +121,10 @@ class KeyboardTeleop(Teleoperator): def configure(self): pass + @check_if_not_connected def get_action(self) -> RobotAction: before_read_t = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError( - "KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`." - ) - self._drain_pressed_keys() # Generate action based on current key states @@ -144,11 +136,8 @@ class KeyboardTeleop(Teleoperator): def send_feedback(self, feedback: dict[str, Any]) -> None: pass + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError( - "KeyboardTeleop is not connected. You need to run `robot.connect()` before `disconnect()`." - ) if self.listener is not None: self.listener.stop() @@ -182,12 +171,8 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop): "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2}, } + @check_if_not_connected def get_action(self) -> RobotAction: - if not self.is_connected: - raise DeviceNotConnectedError( - "KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`." - ) - self._drain_pressed_keys() delta_x = 0.0 delta_y = 0.0 @@ -375,6 +360,7 @@ class KeyboardRoverTeleop(KeyboardTeleop): # Only remove key if it's being released self.current_pressed.pop(key_char, None) + @check_if_not_connected def get_action(self) -> RobotAction: """ Get the current action based on pressed keys. @@ -384,11 +370,6 @@ class KeyboardRoverTeleop(KeyboardTeleop): """ before_read_t = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError( - "KeyboardRoverTeleop is not connected. You need to run `connect()` before `get_action()`." - ) - self._drain_pressed_keys() linear_velocity = 0.0 diff --git a/src/lerobot/teleoperators/koch_leader/koch_leader.py b/src/lerobot/teleoperators/koch_leader/koch_leader.py index 0409f2e57..87084b6b9 100644 --- a/src/lerobot/teleoperators/koch_leader/koch_leader.py +++ b/src/lerobot/teleoperators/koch_leader/koch_leader.py @@ -23,7 +23,7 @@ from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator from .config_koch_leader import KochLeaderConfig @@ -69,10 +69,8 @@ class KochLeader(Teleoperator): def is_connected(self) -> bool: return self.bus.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - self.bus.connect() if not self.is_calibrated and calibrate: logger.info( @@ -161,10 +159,8 @@ class KochLeader(Teleoperator): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @check_if_not_connected def get_action(self) -> dict[str, float]: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - start = time.perf_counter() action = self.bus.sync_read("Present_Position") action = {f"{motor}.pos": val for motor, val in action.items()} @@ -176,9 +172,7 @@ class KochLeader(Teleoperator): # TODO(rcadene, aliberts): Implement force feedback raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self.bus.disconnect() logger.info(f"{self} disconnected.") diff --git a/src/lerobot/teleoperators/omx_leader/omx_leader.py b/src/lerobot/teleoperators/omx_leader/omx_leader.py index c0e49b558..4423be714 100644 --- a/src/lerobot/teleoperators/omx_leader/omx_leader.py +++ b/src/lerobot/teleoperators/omx_leader/omx_leader.py @@ -23,7 +23,7 @@ from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator from .config_omx_leader import OmxLeaderConfig @@ -68,10 +68,8 @@ class OmxLeader(Teleoperator): def is_connected(self) -> bool: return self.bus.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - self.bus.connect() if not self.is_calibrated and calibrate: logger.info( @@ -142,10 +140,8 @@ class OmxLeader(Teleoperator): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @check_if_not_connected def get_action(self) -> dict[str, float]: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - start = time.perf_counter() action = self.bus.sync_read("Present_Position") action = {f"{motor}.pos": val for motor, val in action.items()} @@ -157,9 +153,7 @@ class OmxLeader(Teleoperator): # TODO(rcadene, aliberts): Implement force feedback raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self.bus.disconnect() logger.info(f"{self} disconnected.") diff --git a/src/lerobot/teleoperators/phone/teleop_phone.py b/src/lerobot/teleoperators/phone/teleop_phone.py index 91e613190..221ee8083 100644 --- a/src/lerobot/teleoperators/phone/teleop_phone.py +++ b/src/lerobot/teleoperators/phone/teleop_phone.py @@ -28,7 +28,7 @@ from teleop import Teleop from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS from lerobot.teleoperators.teleoperator import Teleoperator -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.rotation import Rotation logger = logging.getLogger(__name__) @@ -81,10 +81,8 @@ class IOSPhone(BasePhone, Teleoperator): def is_connected(self) -> bool: return self._group is not None + @check_if_already_connected def connect(self) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.") lookup = hebi.Lookup() time.sleep(2.0) @@ -164,6 +162,7 @@ class IOSPhone(BasePhone, Teleoperator): pos = ar_pos - rot.apply(self.config.camera_offset) return True, pos, rot, pose + @check_if_not_connected def get_action(self) -> dict: has_pose, raw_position, raw_rotation, fb_pose = self._read_current_pose() if not has_pose or not self.is_calibrated: @@ -204,10 +203,8 @@ class IOSPhone(BasePhone, Teleoperator): "phone.enabled": self._enabled, } + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self._group = None @@ -227,10 +224,8 @@ class AndroidPhone(BasePhone, Teleoperator): def is_connected(self) -> bool: return self._teleop is not None + @check_if_already_connected def connect(self) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - logger.info("Starting teleop stream for Android...") self._teleop = Teleop() self._teleop.subscribe(self._android_callback) @@ -318,6 +313,7 @@ class AndroidPhone(BasePhone, Teleoperator): self._latest_pose = pose self._latest_message = message + @check_if_not_connected def get_action(self) -> dict: ok, raw_pos, raw_rot, pose = self._read_current_pose() if not ok or not self.is_calibrated: @@ -350,10 +346,8 @@ class AndroidPhone(BasePhone, Teleoperator): "phone.enabled": self._enabled, } + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self._teleop = None if self._teleop_thread and self._teleop_thread.is_alive(): self._teleop_thread.join(timeout=1.0) diff --git a/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py b/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py index 578aaa7b2..db076b20f 100644 --- a/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py +++ b/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py @@ -26,7 +26,8 @@ if TYPE_CHECKING or _reachy2_sdk_available: else: ReachySDK = None -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..teleoperator import Teleoperator from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig @@ -126,10 +127,8 @@ class Reachy2Teleoperator(Teleoperator): def is_connected(self) -> bool: return self.reachy.is_connected() if self.reachy is not None else False + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - self.reachy = ReachySDK(self.config.ip_address) if not self.is_connected: @@ -146,12 +145,10 @@ class Reachy2Teleoperator(Teleoperator): def configure(self) -> None: pass + @check_if_not_connected def get_action(self) -> dict[str, float]: start = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - joint_action: dict[str, float] = {} vel_action: dict[str, float] = {} diff --git a/src/lerobot/teleoperators/so_leader/so_leader.py b/src/lerobot/teleoperators/so_leader/so_leader.py index 760ef2eb1..a10e3a61f 100644 --- a/src/lerobot/teleoperators/so_leader/so_leader.py +++ b/src/lerobot/teleoperators/so_leader/so_leader.py @@ -23,7 +23,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator from .config_so_leader import SOLeaderTeleopConfig @@ -66,10 +66,8 @@ class SOLeader(Teleoperator): def is_connected(self) -> bool: return self.bus.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - self.bus.connect() if not self.is_calibrated and calibrate: logger.info( @@ -139,6 +137,7 @@ class SOLeader(Teleoperator): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @check_if_not_connected def get_action(self) -> dict[str, float]: start = time.perf_counter() action = self.bus.sync_read("Present_Position") @@ -151,10 +150,8 @@ class SOLeader(Teleoperator): # TODO: Implement force feedback raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - DeviceNotConnectedError(f"{self} is not connected.") - self.bus.disconnect() logger.info(f"{self} disconnected.") diff --git a/src/lerobot/utils/decorators.py b/src/lerobot/utils/decorators.py new file mode 100644 index 000000000..8fc2f9a07 --- /dev/null +++ b/src/lerobot/utils/decorators.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import wraps + +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + + +def check_if_not_connected(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if not self.is_connected: + raise DeviceNotConnectedError( + f"{self.__class__.__name__} is not connected. Run `.connect()` first." + ) + return func(self, *args, **kwargs) + + return wrapper + + +def check_if_already_connected(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self.__class__.__name__} is already connected.") + return func(self, *args, **kwargs) + + return wrapper diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index 0206a8ac9..a499b96c7 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -21,12 +21,23 @@ from typing import Any from draccus.choice_types import ChoiceRegistry -def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: - """Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py - Check if the package spec exists and grab its version to avoid importing a local directory. - **Note:** this doesn't work for all packages. +def is_package_available( + pkg_name: str, import_name: str | None = None, return_version: bool = False +) -> tuple[bool, str] | bool: """ - package_exists = importlib.util.find_spec(pkg_name) is not None + Check if the package spec exists and grab its version to avoid importing a local directory. + + Args: + pkg_name: The name of the package as installed via pip (e.g. "python-can"). + import_name: The actual name used to import the package (e.g. "can"). + Defaults to pkg_name if not provided. + return_version: Whether to return the version string. + """ + if import_name is None: + import_name = pkg_name + + # Check if the module spec exists using the import name + package_exists = importlib.util.find_spec(import_name) is not None package_version = "N/A" if package_exists: try: @@ -37,7 +48,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b # Fallback method: Only for "torch" and versions containing "dev" if pkg_name == "torch": try: - package = importlib.import_module(pkg_name) + package = importlib.import_module(import_name) temp_version = getattr(package, "__version__", "N/A") # Check if the version contains "dev" if "dev" in temp_version: @@ -48,9 +59,6 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b except ImportError: # If the package can't be imported, it's not available package_exists = False - elif pkg_name == "grpc": - package = importlib.import_module(pkg_name) - package_version = getattr(package, "__version__", "N/A") else: # For packages other than "torch", don't attempt the fallback and set as not available package_exists = False diff --git a/tests/async_inference/test_policy_server.py b/tests/async_inference/test_policy_server.py index 29583d4fa..c3ee37c8f 100644 --- a/tests/async_inference/test_policy_server.py +++ b/tests/async_inference/test_policy_server.py @@ -62,7 +62,7 @@ class MockPolicy: @pytest.fixture -@require_package("grpc") +@require_package("grpcio", "grpc") def policy_server(): """Fresh `PolicyServer` instance with a stubbed-out policy model.""" # Import only when the test actually runs (after decorator check) diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index b710a3a4b..031c29d60 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -16,6 +16,7 @@ from unittest.mock import patch +import datasets import torch from lerobot.datasets.aggregate import aggregate_datasets @@ -380,3 +381,147 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory): for key in aggr_ds.meta.video_keys: assert key in item, f"Video key {key} missing from item {i}" assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}" + + +def assert_image_schema_preserved(aggr_ds): + """Test that HuggingFace Image feature schema is preserved in aggregated parquet files. + + This verifies the fix for a bug where image columns were written with a generic + struct schema {'bytes': Value('binary'), 'path': Value('string')} instead of + the proper Image() feature type, causing HuggingFace Hub viewer to display + raw dict objects instead of image thumbnails. + """ + image_keys = aggr_ds.meta.image_keys + if not image_keys: + return + + # Check that parquet files have proper Image schema + data_dir = aggr_ds.root / "data" + parquet_files = list(data_dir.rglob("*.parquet")) + assert len(parquet_files) > 0, "No parquet files found in aggregated dataset" + + for parquet_file in parquet_files: + # Load with HuggingFace datasets to check schema + ds = datasets.Dataset.from_parquet(str(parquet_file)) + + for image_key in image_keys: + feature = ds.features.get(image_key) + assert feature is not None, f"Image key '{image_key}' not found in parquet schema" + assert isinstance(feature, datasets.Image), ( + f"Image key '{image_key}' should have Image() feature type, " + f"but got {type(feature).__name__}: {feature}. " + "This indicates image schema was not preserved during aggregation." + ) + + +def assert_image_frames_integrity(aggr_ds, ds_0, ds_1): + """Test that image frames are correctly preserved after aggregation.""" + image_keys = aggr_ds.meta.image_keys + if not image_keys: + return + + def images_equal(img1, img2): + return torch.allclose(img1, img2) + + # Test the section corresponding to the first dataset (ds_0) + for i in range(len(ds_0)): + assert aggr_ds[i]["index"] == i, ( + f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}" + ) + for key in image_keys: + assert images_equal(aggr_ds[i][key], ds_0[i][key]), ( + f"Image frames at position {i} should be equal between aggregated and ds_0" + ) + + # Test the section corresponding to the second dataset (ds_1) + for i in range(len(ds_0), len(ds_0) + len(ds_1)): + assert aggr_ds[i]["index"] == i, ( + f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}" + ) + for key in image_keys: + assert images_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), ( + f"Image frames at position {i} should be equal between aggregated and ds_1" + ) + + +def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory): + """Test aggregation of image-based datasets preserves HuggingFace Image schema. + + This test specifically verifies that: + 1. Image-based datasets can be aggregated correctly + 2. The HuggingFace Image() feature type is preserved in parquet files + 3. Image data integrity is maintained across aggregation + 4. Images can be properly decoded after aggregation + + This catches the bug where to_parquet_with_hf_images() was not passing + the features schema, causing image columns to be written as generic + struct types instead of Image() types. + """ + ds_0_num_frames = 50 + ds_1_num_frames = 75 + ds_0_num_episodes = 2 + ds_1_num_episodes = 3 + + # Create two image-based datasets (use_videos=False) + ds_0 = lerobot_dataset_factory( + root=tmp_path / "image_0", + repo_id=f"{DUMMY_REPO_ID}_image_0", + total_episodes=ds_0_num_episodes, + total_frames=ds_0_num_frames, + use_videos=False, # Image-based dataset + ) + ds_1 = lerobot_dataset_factory( + root=tmp_path / "image_1", + repo_id=f"{DUMMY_REPO_ID}_image_1", + total_episodes=ds_1_num_episodes, + total_frames=ds_1_num_frames, + use_videos=False, # Image-based dataset + ) + + # Verify source datasets have image keys + assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys" + assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys" + + # Aggregate the datasets + aggregate_datasets( + repo_ids=[ds_0.repo_id, ds_1.repo_id], + roots=[ds_0.root, ds_1.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_image_aggr", + aggr_root=tmp_path / "image_aggr", + ) + + # Load the aggregated dataset + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "image_aggr") + aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_image_aggr", root=tmp_path / "image_aggr") + + # Verify aggregated dataset has image keys + assert len(aggr_ds.meta.image_keys) > 0, "Aggregated dataset should have image keys" + assert aggr_ds.meta.image_keys == ds_0.meta.image_keys, "Image keys should match source datasets" + + # Run standard aggregation assertions + expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes + expected_total_frames = ds_0_num_frames + ds_1_num_frames + + assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames) + assert_dataset_content_integrity(aggr_ds, ds_0, ds_1) + assert_metadata_consistency(aggr_ds, ds_0, ds_1) + assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) + + # Image-specific assertions + assert_image_schema_preserved(aggr_ds) + assert_image_frames_integrity(aggr_ds, ds_0, ds_1) + + # Verify images can be accessed and have correct shape + sample_item = aggr_ds[0] + for image_key in aggr_ds.meta.image_keys: + img = sample_item[image_key] + assert isinstance(img, torch.Tensor), f"Image {image_key} should be a tensor" + assert img.dim() == 3, f"Image {image_key} should have 3 dimensions (C, H, W)" + assert img.shape[0] == 3, f"Image {image_key} should have 3 channels" + + assert_dataset_iteration_works(aggr_ds) diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index 3a4516fc8..35a369de9 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -29,7 +29,7 @@ from lerobot.datasets.dataset_tools import ( remove_feature, split_dataset, ) -from lerobot.scripts.lerobot_edit_dataset import convert_dataset_to_videos +from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset @pytest.fixture @@ -1050,7 +1050,7 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): assert "reward" in modified_dataset.meta.features -def test_convert_dataset_to_videos(tmp_path): +def test_convert_image_to_video_dataset(tmp_path): """Test converting lerobot/pusht_image dataset to video format.""" from lerobot.datasets.lerobot_dataset import LeRobotDataset @@ -1071,7 +1071,7 @@ def test_convert_dataset_to_videos(tmp_path): assert "observation.image" in source_dataset.meta.features # Convert to video dataset (only first 2 episodes for speed) - video_dataset = convert_dataset_to_videos( + video_dataset = convert_image_to_video_dataset( dataset=source_dataset, output_dir=output_dir, repo_id="lerobot/pusht_video", @@ -1113,7 +1113,7 @@ def test_convert_dataset_to_videos(tmp_path): shutil.rmtree(output_dir) -def test_convert_dataset_to_videos_subset_episodes(tmp_path): +def test_convert_image_to_video_dataset_subset_episodes(tmp_path): """Test converting only specific episodes from lerobot/pusht_image to video format.""" from lerobot.datasets.lerobot_dataset import LeRobotDataset @@ -1132,7 +1132,7 @@ def test_convert_dataset_to_videos_subset_episodes(tmp_path): # Convert only episode 0 to video (subset of loaded episodes) episode_indices = [0] - video_dataset = convert_dataset_to_videos( + video_dataset = convert_image_to_video_dataset( dataset=source_dataset, output_dir=output_dir, repo_id="lerobot/pusht_video_subset", diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 4c91c55c0..27c51b3c4 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -352,6 +352,65 @@ def test_image_array_to_pil_image_wrong_range_float_0_255(): image_array_to_pil_image(image) +def test_tmp_image_deletion(tmp_path, empty_lerobot_dataset_factory): + """Verify temporary image directories are removed for image features after saving episode.""" + # Image feature: images should be deleted after saving episode + image_key = "image" + features_image = { + image_key: {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]} + } + ds_img = empty_lerobot_dataset_factory(root=tmp_path / "img", features=features_image) + ds_img.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) + ds_img.save_episode() + img_dir = ds_img._get_image_file_dir(0, image_key) + assert not img_dir.exists(), "Temporary image directory should be removed for image features" + + +def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory): + """Verify temporary image directories are removed for video encoding when `batch_encoding_size == 1`.""" + # Video feature: when batch_encoding_size == 1 temporary images should be deleted + vid_key = "video" + features_video = { + vid_key: {"dtype": "video", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]} + } + + ds_vid = empty_lerobot_dataset_factory(root=tmp_path / "vid", features=features_video) + ds_vid.batch_encoding_size = 1 + ds_vid.add_frame({vid_key: np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) + ds_vid.save_episode() + vid_img_dir = ds_vid._get_image_file_dir(0, vid_key) + assert not vid_img_dir.exists(), ( + "Temporary image directory should be removed when batch_encoding_size == 1" + ) + + +def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory): + """Verify temporary image directories are removed appropriately when both image and video features are present.""" + image_key = "image" + vid_key = "video" + features_mixed = { + image_key: {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]}, + vid_key: {"dtype": "video", "shape": DUMMY_HWC, "names": ["height", "width", "channels"]}, + } + ds_mixed = empty_lerobot_dataset_factory( + root=tmp_path / "mixed", features=features_mixed, batch_encoding_size=2 + ) + ds_mixed.add_frame( + { + "image": np.random.rand(*DUMMY_CHW), + "video": np.random.rand(*DUMMY_HWC), + "task": "Dummy task", + } + ) + ds_mixed.save_episode() + img_dir = ds_mixed._get_image_file_dir(0, image_key) + vid_img_dir = ds_mixed._get_image_file_dir(0, vid_key) + assert not img_dir.exists(), "Temporary image directory should be removed for image features" + assert vid_img_dir.exists(), ( + "Temporary image directory should not be removed for video features when batch_encoding_size == 2" + ) + + # TODO(aliberts): # - [ ] test various attributes & state from init and create # - [ ] test init with episodes and check num_frames @@ -1392,3 +1451,202 @@ def test_valid_video_codecs_constant(): assert "hevc" in VALID_VIDEO_CODECS assert "libsvtav1" in VALID_VIDEO_CODECS assert len(VALID_VIDEO_CODECS) == 3 + + +def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory): + """Regression test for bug where delta_timestamps incorrectly marked all frames as padded when using episodes filter. + + The bug occurred because _get_query_indices was using the relative index (idx) in the filtered dataset + instead of the absolute index when comparing against episode boundaries (ep_start, ep_end). + """ + features = { + "observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + "action": {"dtype": "float32", "shape": (2,), "names": ["vx", "vy"]}, + } + + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) + + # Create 3 episodes with 10 frames each + frames_per_episode = 10 + for ep_idx in range(3): + for frame_idx in range(frames_per_episode): + dataset.add_frame( + { + "observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32), + "action": torch.randn(2), + "task": f"task_{ep_idx}", + } + ) + dataset.save_episode() + dataset.finalize() + + # Load only episode 1 (middle episode) with delta_timestamps + delta_ts = {"observation.state": [0.0]} # Just the current frame + filtered_dataset = LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episodes=[1], + delta_timestamps=delta_ts, + ) + + # Verify the filtered dataset has the correct length + assert len(filtered_dataset) == frames_per_episode + + # Check that no frames are marked as padded (since delta=0 should always be valid) + for idx in range(len(filtered_dataset)): + frame = filtered_dataset[idx] + assert frame["observation.state_is_pad"].item() is False, f"Frame {idx} incorrectly marked as padded" + # Verify we're getting data from episode 1 + assert frame["episode_index"].item() == 1 + + +def test_delta_timestamps_padding_at_episode_boundaries(tmp_path, empty_lerobot_dataset_factory): + """Test that delta_timestamps correctly marks padding at episode boundaries when using episodes filter.""" + features = { + "observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + "action": {"dtype": "float32", "shape": (2,), "names": ["vx", "vy"]}, + } + + dataset = empty_lerobot_dataset_factory( + root=tmp_path / "test", features=features, use_videos=False, fps=10 + ) + + # Create 3 episodes with 5 frames each + frames_per_episode = 5 + for ep_idx in range(3): + for frame_idx in range(frames_per_episode): + dataset.add_frame( + { + "observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32), + "action": torch.randn(2), + "task": f"task_{ep_idx}", + } + ) + dataset.save_episode() + dataset.finalize() + + # Load only episode 1 with delta_timestamps that go beyond episode boundaries + # fps=10, so 0.1s = 1 frame offset + delta_ts = {"observation.state": [-0.2, -0.1, 0.0, 0.1, 0.2]} # -2, -1, 0, +1, +2 frames + filtered_dataset = LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episodes=[1], + delta_timestamps=delta_ts, + tolerance_s=0.04, # Slightly less than half a frame at 10fps + ) + + assert len(filtered_dataset) == frames_per_episode + + # Check padding at the start of the episode (first frame) + first_frame = filtered_dataset[0] + is_pad = first_frame["observation.state_is_pad"].tolist() + # At frame 0 of episode 1: delta -2 and -1 should be padded, 0, +1, +2 should not + assert is_pad == [True, True, False, False, False], f"First frame padding incorrect: {is_pad}" + + # Check middle frame (no padding expected) + mid_frame = filtered_dataset[2] + is_pad = mid_frame["observation.state_is_pad"].tolist() + assert is_pad == [False, False, False, False, False], f"Middle frame padding incorrect: {is_pad}" + + # Check padding at the end of the episode (last frame) + last_frame = filtered_dataset[4] + is_pad = last_frame["observation.state_is_pad"].tolist() + # At frame 4 of episode 1: delta -2, -1, 0 should not be padded, +1, +2 should be + assert is_pad == [False, False, False, True, True], f"Last frame padding incorrect: {is_pad}" + + +def test_delta_timestamps_multiple_episodes_filter(tmp_path, empty_lerobot_dataset_factory): + """Test delta_timestamps with multiple non-consecutive episodes selected.""" + features = { + "observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + } + + dataset = empty_lerobot_dataset_factory( + root=tmp_path / "test", features=features, use_videos=False, fps=10 + ) + + # Create 5 episodes with 5 frames each + frames_per_episode = 5 + for ep_idx in range(5): + for frame_idx in range(frames_per_episode): + dataset.add_frame( + { + "observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32), + "task": f"task_{ep_idx}", + } + ) + dataset.save_episode() + dataset.finalize() + + # Load episodes 1 and 3 (non-consecutive) + delta_ts = {"observation.state": [0.0]} + filtered_dataset = LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episodes=[1, 3], + delta_timestamps=delta_ts, + ) + + assert len(filtered_dataset) == 2 * frames_per_episode + + # All frames should have valid (non-padded) data for delta=0 + for idx in range(len(filtered_dataset)): + frame = filtered_dataset[idx] + assert frame["observation.state_is_pad"].item() is False + + # Verify we're getting the correct episodes + episode_indices = [filtered_dataset[i]["episode_index"].item() for i in range(len(filtered_dataset))] + expected_episodes = [1] * frames_per_episode + [3] * frames_per_episode + assert episode_indices == expected_episodes + + +def test_delta_timestamps_query_returns_correct_values(tmp_path, empty_lerobot_dataset_factory): + """Test that delta_timestamps returns the correct observation values, not just correct padding.""" + features = { + "observation.state": {"dtype": "float32", "shape": (1,), "names": ["x"]}, + } + + dataset = empty_lerobot_dataset_factory( + root=tmp_path / "test", features=features, use_videos=False, fps=10 + ) + + # Create 2 episodes with known values + # Episode 0: frames with values 0, 1, 2, 3, 4 + # Episode 1: frames with values 10, 11, 12, 13, 14 + frames_per_episode = 5 + for ep_idx in range(2): + for frame_idx in range(frames_per_episode): + value = ep_idx * 10 + frame_idx + dataset.add_frame( + { + "observation.state": torch.tensor([value], dtype=torch.float32), + "task": f"task_{ep_idx}", + } + ) + dataset.save_episode() + dataset.finalize() + + # Load episode 1 with delta that looks at previous frame + delta_ts = {"observation.state": [-0.1, 0.0]} # Previous frame and current frame + filtered_dataset = LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episodes=[1], + delta_timestamps=delta_ts, + tolerance_s=0.04, + ) + + # Check frame 2 of episode 1 (which has absolute index 7, value 12) + frame = filtered_dataset[2] + state_values = frame["observation.state"].tolist() + # Should get [11, 12] - the previous and current values within episode 1 + assert state_values == [11.0, 12.0], f"Expected [11.0, 12.0], got {state_values}" + + # Check first frame - previous frame should be clamped to episode start (padded) + first_frame = filtered_dataset[0] + state_values = first_frame["observation.state"].tolist() + is_pad = first_frame["observation.state_is_pad"].tolist() + # Previous frame is outside episode, so it's clamped to first frame and marked as padded + assert state_values == [10.0, 10.0], f"Expected [10.0, 10.0], got {state_values}" + assert is_pad == [True, False], f"Expected [True, False], got {is_pad}" diff --git a/tests/mocks/mock_robot.py b/tests/mocks/mock_robot.py index d997cb6d4..f69a2c02a 100644 --- a/tests/mocks/mock_robot.py +++ b/tests/mocks/mock_robot.py @@ -22,7 +22,7 @@ from lerobot.cameras import CameraConfig, make_cameras_from_configs from lerobot.motors.motors_bus import Motor, MotorNormMode from lerobot.processor import RobotAction, RobotObservation from lerobot.robots import Robot, RobotConfig -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from tests.mocks.mock_motors_bus import MockMotorsBus @@ -98,10 +98,8 @@ class MockRobot(Robot): def is_connected(self) -> bool: return self._is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - self._is_connected = True if calibrate: self.calibrate() @@ -110,19 +108,15 @@ class MockRobot(Robot): def is_calibrated(self) -> bool: return self._is_calibrated + @check_if_not_connected def calibrate(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self._is_calibrated = True def configure(self) -> None: pass + @check_if_not_connected def get_observation(self) -> RobotObservation: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.config.random_values: return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors} else: @@ -130,14 +124,10 @@ class MockRobot(Robot): f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True) } + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - return action + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self._is_connected = False diff --git a/tests/mocks/mock_teleop.py b/tests/mocks/mock_teleop.py index 04479bad9..89174dadf 100644 --- a/tests/mocks/mock_teleop.py +++ b/tests/mocks/mock_teleop.py @@ -21,7 +21,7 @@ from typing import Any from lerobot.processor import RobotAction from lerobot.teleoperators import Teleoperator, TeleoperatorConfig -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected @TeleoperatorConfig.register_subclass("mock_teleop") @@ -68,10 +68,8 @@ class MockTeleop(Teleoperator): def is_connected(self) -> bool: return self._is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - self._is_connected = True if calibrate: self.calibrate() @@ -80,19 +78,15 @@ class MockTeleop(Teleoperator): def is_calibrated(self) -> bool: return self._is_calibrated + @check_if_not_connected def calibrate(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self._is_calibrated = True def configure(self) -> None: pass + @check_if_not_connected def get_action(self) -> RobotAction: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.config.random_values: return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors} else: @@ -100,12 +94,9 @@ class MockTeleop(Teleoperator): f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True) } - def send_feedback(self, feedback: dict[str, Any]) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") + @check_if_not_connected + def send_feedback(self, feedback: dict[str, Any]) -> None: ... + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self._is_connected = False diff --git a/tests/rl/test_actor.py b/tests/rl/test_actor.py index ec67f1889..54e4d2870 100644 --- a/tests/rl/test_actor.py +++ b/tests/rl/test_actor.py @@ -64,7 +64,7 @@ def close_service_stub(channel, server): server.stop(None) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_establish_learner_connection_success(): from lerobot.rl.actor import establish_learner_connection @@ -81,7 +81,7 @@ def test_establish_learner_connection_success(): close_service_stub(channel, server) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_establish_learner_connection_failure(): from lerobot.rl.actor import establish_learner_connection @@ -100,7 +100,7 @@ def test_establish_learner_connection_failure(): close_service_stub(channel, server) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_push_transitions_to_transport_queue(): from lerobot.rl.actor import push_transitions_to_transport_queue from lerobot.transport.utils import bytes_to_transitions @@ -135,7 +135,7 @@ def test_push_transitions_to_transport_queue(): assert_transitions_equal(deserialized_transition, transitions[i]) -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_transitions_stream(): from lerobot.rl.actor import transitions_stream @@ -167,7 +167,7 @@ def test_transitions_stream(): assert streamed_data[2].data == b"transition_data_3" -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_interactions_stream(): from lerobot.rl.actor import interactions_stream diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py index 5d95dee04..e13862d82 100644 --- a/tests/rl/test_actor_learner.py +++ b/tests/rl/test_actor_learner.py @@ -88,7 +88,7 @@ def cfg(): return cfg -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(10) # force cross-platform watchdog def test_end_to_end_transitions_flow(cfg): from lerobot.rl.actor import ( @@ -150,7 +150,7 @@ def test_end_to_end_transitions_flow(cfg): assert_transitions_equal(transition, input_transitions[i]) -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(10) def test_end_to_end_interactions_flow(cfg): from lerobot.rl.actor import ( @@ -223,7 +223,7 @@ def test_end_to_end_interactions_flow(cfg): assert received == expected -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.parametrize("data_size", ["small", "large"]) @pytest.mark.timeout(10) def test_end_to_end_parameters_flow(cfg, data_size): diff --git a/tests/rl/test_learner_service.py b/tests/rl/test_learner_service.py index e0f0292be..d967388f0 100644 --- a/tests/rl/test_learner_service.py +++ b/tests/rl/test_learner_service.py @@ -39,7 +39,7 @@ def learner_service_stub(): close_learner_service_stub(channel, server) -@require_package("grpc") +@require_package("grpcio", "grpc") def create_learner_service_stub( shutdown_event: Event, parameters_queue: Queue, @@ -75,7 +75,7 @@ def create_learner_service_stub( return services_pb2_grpc.LearnerServiceStub(channel), channel, server -@require_package("grpc") +@require_package("grpcio", "grpc") def close_learner_service_stub(channel, server): channel.close() server.stop(None) @@ -91,7 +91,7 @@ def test_ready_method(learner_service_stub): assert response == services_pb2.Empty() -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_send_interactions(): from lerobot.transport import services_pb2 @@ -135,7 +135,7 @@ def test_send_interactions(): assert interactions == [b"123", b"4", b"5", b"678"] -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_send_transitions(): from lerobot.transport import services_pb2 @@ -181,7 +181,7 @@ def test_send_transitions(): assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"] -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_send_transitions_empty_stream(): from lerobot.transport import services_pb2 @@ -209,7 +209,7 @@ def test_send_transitions_empty_stream(): assert transitions_queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(10) # force cross-platform watchdog def test_stream_parameters(): import time @@ -267,7 +267,7 @@ def test_stream_parameters(): assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1) -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_stream_parameters_with_shutdown(): from lerobot.transport import services_pb2 @@ -319,7 +319,7 @@ def test_stream_parameters_with_shutdown(): assert received_params == [b"param_batch_1", b"stop"] -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_stream_parameters_waits_and_retries_on_empty_queue(): import threading diff --git a/tests/transport/test_transport_utils.py b/tests/transport/test_transport_utils.py index 52825a24e..63632a8f4 100644 --- a/tests/transport/test_transport_utils.py +++ b/tests/transport/test_transport_utils.py @@ -26,7 +26,7 @@ from lerobot.utils.transition import Transition from tests.utils import require_cuda, require_package -@require_package("grpc") +@require_package("grpcio", "grpc") def test_bytes_buffer_size_empty_buffer(): from lerobot.transport.utils import bytes_buffer_size @@ -37,7 +37,7 @@ def test_bytes_buffer_size_empty_buffer(): assert buffer.tell() == 0 -@require_package("grpc") +@require_package("grpcio", "grpc") def test_bytes_buffer_size_small_buffer(): from lerobot.transport.utils import bytes_buffer_size @@ -47,7 +47,7 @@ def test_bytes_buffer_size_small_buffer(): assert buffer.tell() == 0 -@require_package("grpc") +@require_package("grpcio", "grpc") def test_bytes_buffer_size_large_buffer(): from lerobot.transport.utils import CHUNK_SIZE, bytes_buffer_size @@ -58,7 +58,7 @@ def test_bytes_buffer_size_large_buffer(): assert buffer.tell() == 0 -@require_package("grpc") +@require_package("grpcio", "grpc") def test_send_bytes_in_chunks_empty_data(): from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 @@ -68,7 +68,7 @@ def test_send_bytes_in_chunks_empty_data(): assert len(chunks) == 0 -@require_package("grpc") +@require_package("grpcio", "grpc") def test_single_chunk_small_data(): from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 @@ -82,7 +82,7 @@ def test_single_chunk_small_data(): assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END -@require_package("grpc") +@require_package("grpcio", "grpc") def test_not_silent_mode(): from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 @@ -94,7 +94,7 @@ def test_not_silent_mode(): assert chunks[0].data == b"Some data" -@require_package("grpc") +@require_package("grpcio", "grpc") def test_send_bytes_in_chunks_large_data(): from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 @@ -111,7 +111,7 @@ def test_send_bytes_in_chunks_large_data(): assert chunks[2].transfer_state == services_pb2.TransferState.TRANSFER_END -@require_package("grpc") +@require_package("grpcio", "grpc") def test_send_bytes_in_chunks_large_data_with_exact_chunk_size(): from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 @@ -124,7 +124,7 @@ def test_send_bytes_in_chunks_large_data_with_exact_chunk_size(): assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_empty_data(): from lerobot.transport.utils import receive_bytes_in_chunks @@ -138,7 +138,7 @@ def test_receive_bytes_in_chunks_empty_data(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_single_chunk(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -157,7 +157,7 @@ def test_receive_bytes_in_chunks_single_chunk(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_single_not_end_chunk(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -175,7 +175,7 @@ def test_receive_bytes_in_chunks_single_not_end_chunk(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_multiple_chunks(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -199,7 +199,7 @@ def test_receive_bytes_in_chunks_multiple_chunks(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_multiple_messages(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -235,7 +235,7 @@ def test_receive_bytes_in_chunks_multiple_messages(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_shutdown_during_receive(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -259,7 +259,7 @@ def test_receive_bytes_in_chunks_shutdown_during_receive(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_only_begin_chunk(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -279,7 +279,7 @@ def test_receive_bytes_in_chunks_only_begin_chunk(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_missing_begin(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -303,7 +303,7 @@ def test_receive_bytes_in_chunks_missing_begin(): # Tests for state_to_bytes and bytes_to_state_dict -@require_package("grpc") +@require_package("grpcio", "grpc") def test_state_to_bytes_empty_dict(): from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes @@ -314,7 +314,7 @@ def test_state_to_bytes_empty_dict(): assert reconstructed == state_dict -@require_package("grpc") +@require_package("grpcio", "grpc") def test_bytes_to_state_dict_empty_data(): from lerobot.transport.utils import bytes_to_state_dict @@ -323,7 +323,7 @@ def test_bytes_to_state_dict_empty_data(): bytes_to_state_dict(b"") -@require_package("grpc") +@require_package("grpcio", "grpc") def test_state_to_bytes_simple_dict(): from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes @@ -347,7 +347,7 @@ def test_state_to_bytes_simple_dict(): assert torch.allclose(state_dict[key], reconstructed[key]) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_state_to_bytes_various_dtypes(): from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes @@ -372,7 +372,7 @@ def test_state_to_bytes_various_dtypes(): assert torch.allclose(state_dict[key], reconstructed[key]) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_bytes_to_state_dict_invalid_data(): from lerobot.transport.utils import bytes_to_state_dict @@ -382,7 +382,7 @@ def test_bytes_to_state_dict_invalid_data(): @require_cuda -@require_package("grpc") +@require_package("grpcio", "grpc") def test_state_to_bytes_various_dtypes_cuda(): from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes @@ -407,7 +407,7 @@ def test_state_to_bytes_various_dtypes_cuda(): assert torch.allclose(state_dict[key], reconstructed[key]) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_python_object_to_bytes_none(): from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes @@ -439,7 +439,7 @@ def test_python_object_to_bytes_none(): (1, 2, 3), ], ) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_python_object_to_bytes_simple_types(obj): from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes @@ -450,7 +450,7 @@ def test_python_object_to_bytes_simple_types(obj): assert type(reconstructed) is type(obj) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_python_object_to_bytes_with_tensors(): from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes @@ -475,7 +475,7 @@ def test_python_object_to_bytes_with_tensors(): assert torch.equal(obj["nested"]["tensor2"], reconstructed["nested"]["tensor2"]) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_transitions_to_bytes_empty_list(): from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes @@ -487,7 +487,7 @@ def test_transitions_to_bytes_empty_list(): assert isinstance(reconstructed, list) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_transitions_to_bytes_single_transition(): from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes @@ -509,7 +509,7 @@ def test_transitions_to_bytes_single_transition(): assert_transitions_equal(transitions[0], reconstructed[0]) -@require_package("grpc") +@require_package("grpcio", "grpc") def assert_transitions_equal(t1: Transition, t2: Transition): """Helper to assert two transitions are equal.""" assert_observation_equal(t1["state"], t2["state"]) @@ -519,7 +519,7 @@ def assert_transitions_equal(t1: Transition, t2: Transition): assert_observation_equal(t1["next_state"], t2["next_state"]) -@require_package("grpc") +@require_package("grpcio", "grpc") def assert_observation_equal(o1: dict, o2: dict): """Helper to assert two observations are equal.""" assert set(o1.keys()) == set(o2.keys()) @@ -527,7 +527,7 @@ def assert_observation_equal(o1: dict, o2: dict): assert torch.allclose(o1[key], o2[key]) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_transitions_to_bytes_multiple_transitions(): from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes @@ -551,7 +551,7 @@ def test_transitions_to_bytes_multiple_transitions(): assert_transitions_equal(original, reconstructed_item) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_unknown_state(): from lerobot.transport.utils import receive_bytes_in_chunks diff --git a/tests/utils.py b/tests/utils.py index 800b7d4b3..38841db02 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -167,7 +167,7 @@ def require_package_arg(func): return wrapper -def require_package(package_name): +def require_package(package_name, import_name=None): """ Decorator that skips the test if the specified package is not installed. """ @@ -175,7 +175,7 @@ def require_package(package_name): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): - if not is_package_available(package_name): + if not is_package_available(pkg_name=package_name, import_name=import_name): pytest.skip(f"{package_name} not installed") return func(*args, **kwargs)