Compare commits

..

46 Commits

Author SHA1 Message Date
AdilZouitine 4457c978b2 Enhance keyboard teleoperation to capture both character and special keys. Clear current pressed keys in KeyboardEndEffectorTeleop for improved state management. 2025-10-03 16:51:22 +02:00
Pepijn a4bed41132 Improve docs pi (#2110)
* Improve docs and add numpy to pi install requirments

* fix formatting

* update command

* remvoe numpy dep
2025-10-03 12:06:18 +02:00
Michel Aractingi 5c8dd883be fix bug in augment_dataset_quantile_stats.py that was not detecting… (#2106)
* fix bug in `augment_dataset_quantile_stats.py` that was not detecting the image features because we were looping over hf_dataset. Now we loop over the dataset itself

* Update src/lerobot/datasets/v30/augment_dataset_quantile_stats.py

Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>

---------

Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-02 18:28:44 +02:00
Michel Aractingi 38f6fc816b (chore) improve v3 message, allow converting local datasets to V3 (#1948)
Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com>
2025-10-02 15:49:18 +02:00
Pepijn abde7be3b3 Add OpenPi, Pi0 and Pi0.5 (#1910)
* initial commit

* change device in test

* do detailed import

* adhere to python 3.11 syntax

* fix autodocstring

* additionally

* do same in other files

* add model. prefix to all keys in state dict

* use dummy stats

* add pi05

* also shorten action_steps

* fix test

* all test pass! and fix tokenizer max length between 05 and 0

* remove test

* fix transformer dependency

* fix test

* split pi0 and pi05 policy in seperate files

* fix test

* fix push to hub test

* add some comments, license and readme

* remove warning in config

* add pi05 to factory

* remove check

* rename action_horizon to chunk_size

* clean up padding of state and action (more in line with lerobot pi0)

* add openpi image transforms for training and add more flexibility to _preprocess_images similar to lerobot pi0

* fix key match from pytorch state dict (similar keys to openpi implementation now)

* also for pi05

* update to python 3.11

* revert to openpi transformer replace python 3.11

* fix(modeling pi0): nit  warning message

* use safeauto_docstring

* fix: remove unused param

* fix from pretrained

* add preprocess tests

* also compile forward method

* Do not add model prefix to normalization

* use same name for action and state dim as lerobot pi0 and remove fixed image keys

* load from pretrained_path

* temp: hardcode base model

* fix override self.pretrained_path = None overwrite

* rename to loss

* remove additional image augmentations, lerobot dataset already does this

* Add docs

* put tests in test folder

* Add test to instatiate all base models

* go back to python 3.10

* update docs

* adapt docs pi05

* change docs: finetune base model options

* minor docs fixes and dependencies

* remove todo

* cast float64 to float32 for mps

* skip if no transformers

* fix tests

* add new models to modelcard

* add back init

* fix circular input

* feat: only run pi test on GPU

* remove require_nightly_gpu

* replace decorator test_pi0_openpi

* rename action_dim, state_dim to max_action_dim, max_state_dim

* fix doc and constants

* cleanup tests

* fix from pretrained

* fix tests

* add comment pi0 pi05 tests, add image features to pi0 pi05 hub tests

* fix, state is included in language not in flow head

* Move test to specific folder

* and paligemma task with newline

* remove add_special_tokens, not needed

* feedback pr

* Remove previous pi0 and rename pi0_openpi and pi05_openpi

* Add Quantile stats to LeRobotDataset (#1985)

* - Add RunningQuantileStats class for efficient histogram-based quantile computation
- Integrate quantile parameters (compute_quantiles, quantiles) into LeRobotDataset
- Support quantile computation during episode collection and aggregation
- Add comprehensive function-based test suite (24 tests) for quantile functionality
- Maintain full backward compatibility with existing stats computation
- Enable configurable quantiles (default: [0.01, 0.99]) for robust normalization

* style fixes, make quantiles computation by default to new datasets

* fix tests

* - Added DEFAULT_QUANTILES=[0.01, 0.10, 0.50, 0.90, 0.99] to be computed for each features instead of being chosen by the user
- Fortified tests.

* - add helper functions to reshape stats
- add missing test for quantiles

* - Add QUANTILE normalization mode to normalize the data with the 1st and 99th percentiles.
- Add QUANTILE10 normalization mode to normalize the data with the 10th and 90th percentiles.

* style fixes

* Added missing lisence

* Simplify compute_stats

* - added script `augment_dataset_quantile_stats.py` so that we can add quantile stats to existing v3 datasets that dont have quatniles
- modified quantile computation instead of using the edge for the value, interpolate the values in the bin

* rename pi0/pi05 files

* Remove open pi patch and use custom transformer branch for now

* renaming

* fix

* Revert "fix"

This reverts commit 1ea65730ac.

* fix naming

* feet(pi0/pi0.5): add pipeline (#2009)

* feat(processor): convert openpi model with processor

* TODO: Make test works

* fix(modeling_pi0openpi): update attention mask value and time scaling; improve task handling in tests

- Changed the attention mask value from `self.config.attention_mask_value` to a fixed value of `-2.3819763e38`.
- Updated time scaling in the `sample_noise` method to use a constant factor of `0.999` and an offset of `0.001`.
- Enhanced task handling in tests to ensure proper formatting and batch size consistency.
- Cleaned up commented-out test code for clarity.

* refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy

- Updated imports and references throughout the codebase to reflect the new naming convention.
- Introduced a new processor file for PI0 to handle pre-processing and post-processing steps.
- Adjusted tests to utilize the renamed classes, ensuring consistency and functionality.
- Enhanced clarity and maintainability by removing outdated naming conventions.

* refactor(pi05): rename PI0OpenPIPolicy to PI0Policy and update configuration

- Renamed `PI0OpenPIPolicy` to `PI0Policy` for consistency with naming conventions.
- Updated the `PI05OpenPIConfig` to include a new `tokenizer_max_length` attribute and changed the normalization mode for state from `MEAN_STD` to `QUANTILES`.
- Simplified model initialization in `PI05OpenPIPolicy` by removing unused `dataset_stats` parameter.
- Added a new processor class for `Pi05PrepareStateTokenizerProcessorStep` with `@dataclass` for improved readability.
- Introduced a test script to compare the integration of the PI0OpenPI policy with the original implementation, ensuring local testing compatibility.

* feat(processor): convert openpi model with processor

* TODO: Make test works

* fix(modeling_pi0openpi): update attention mask value and time scaling; improve task handling in tests

- Changed the attention mask value from `self.config.attention_mask_value` to a fixed value of `-2.3819763e38`.
- Updated time scaling in the `sample_noise` method to use a constant factor of `0.999` and an offset of `0.001`.
- Enhanced task handling in tests to ensure proper formatting and batch size consistency.
- Cleaned up commented-out test code for clarity.

* refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy

- Updated imports and references throughout the codebase to reflect the new naming convention.
- Introduced a new processor file for PI0 to handle pre-processing and post-processing steps.
- Adjusted tests to utilize the renamed classes, ensuring consistency and functionality.
- Enhanced clarity and maintainability by removing outdated naming conventions.

* refactor(pi05): rename PI0OpenPIPolicy to PI0Policy and update configuration

- Renamed `PI0OpenPIPolicy` to `PI0Policy` for consistency with naming conventions.
- Updated the `PI05OpenPIConfig` to include a new `tokenizer_max_length` attribute and changed the normalization mode for state from `MEAN_STD` to `QUANTILES`.
- Simplified model initialization in `PI05OpenPIPolicy` by removing unused `dataset_stats` parameter.
- Added a new processor class for `Pi05PrepareStateTokenizerProcessorStep` with `@dataclass` for improved readability.
- Introduced a test script to compare the integration of the PI0OpenPI policy with the original implementation, ensuring local testing compatibility.

* refactor(pi05): update imports and rename configuration classes

- Changed imports to reflect the new naming convention for PI05 configuration and policy classes.
- Renamed `PI05OpenPIConfig` to `PI05Config` and `PI05OpenPIPolicy` to `PI05Policy` for consistency.
- Introduced a new processor file for PI05, implementing pre-processing and post-processing steps.
- Updated tests to utilize the renamed classes, ensuring functionality and consistency across the codebase.

* update(pi05): increase tokenizer_max_length for improved processing

- Changed the `tokenizer_max_length` from 48 to 200 to enhance the model's capability in handling longer sequences.
- This adjustment aims to improve the overall performance and flexibility of the PI05 configuration.

* add default for state (max_state_dim)

* correct naming

* fix import

* cleanup code

* remove unused test

* us quantiles for action

* move to device

* remove discrete state assert

* fix pi05 test

* move pi05 to device

* use base models in comparison tests

* small renames for tests

* change number of tokens pi05 test

* fix openpi tokenization in test

* fix hub test

* fix test

* assert lerobot vs openpi tests

---------

Co-authored-by: Pepijn <pepijn@huggingface.co>

* add headers

* add back previously removed imports

* update if statement load processor with dataset stats

* remove to avoid circular import

* inject dataset stats for pretrained models

* check normalization before applying

* add link to  quantile augument script

* fix(policies): transformers import for ci in PI0 & PI05 (#2039)

* fix(policies): transformers import for ci in PI0

* fix(policies): transformers import for ci in PI05

* test(processor): fix expected raise when normalization types are missing (#2040)

* switch normalization order pipeline for pi05

* Fix/quantiles script (#2064)

* refactor augment stats with quantiles script
add parallelization for faster processing
shift the quantile normalization between -1 1

* fix replay buffer tests

* fix comment

* overwrite the pipeline normalization features with the policy features

* remove double normalization overwrite

* cleanup from pretrained

* remove typo

* also set norm_map

* fix(augment_quantiles) images incorrectly divided by 255

* clamp quantiles

* link to lerobot base models

* rename tests

* encorperate PR feedback

* update docstring for RunningQuantileStats

* update doc links

* Revert "clamp quantiles"

This reverts commit 172207471c.

* fix self.paligemma

* fix tests related to quantiles that were scaled to [0,1], the new range is [-1, 1]

* fix libero doc and use different transformer branch

* use fix branch instead of feat

* update results libero

* add new line

* fix formatting

* precommit

* update results libero

* update libero doc

* update title

* final changes

* add quantiles to test

* run pre commit

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
2025-10-02 13:14:45 +02:00
Akhil Ivaturi b6c528a438 Making Envs module pass MyPy checks (#2048)
* Fix configs.py None MyPy error

* Use img_tensor instead of img in utils.py

* Add type assertion in factory.py

* Resolve merge conflict

* Uncomment envs moodule for mypy checks in pyproject.toml

---------

Signed-off-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
2025-10-01 16:11:48 +02:00
Adil Zouitine 6d331310ab feat(mypy): configure mypy settings and add module overrides for gradual typing (#2101) 2025-10-01 15:14:41 +02:00
Adil Zouitine 5dfdec9288 feat(mypy): enable type checking for envs module and configure mypy settings in pyproject.toml (#2099)
* feat(mypy): enable type checking for envs module and configure mypy settings in pyproject.toml

* Add mypy configuration to check only the envs module.
* Exclude examples, benchmarks, and tests from type checking.
* Set ignore_missing_imports to true and follow_imports to skip.

* chore: comment out mypy configuration in pyproject.toml and pre-commit-config.yaml

* Comment out mypy settings to disable type checking for the envs module.
* Update pre-commit configuration to reflect changes in mypy settings.
2025-10-01 13:19:51 +02:00
Caroline Pascal 50977a2c28 fix(video_path): setting video_path to None during conversion for images datasets (#2095) 2025-10-01 11:03:52 +02:00
Adil Zouitine a0d7627d81 feat(train): include input and output features in processor overrides for normalization (#2088) (#2090)
Signed-off-by: AdilZouitine <adilzouitinegm@gmail.com>
2025-09-29 17:37:26 +02:00
Adil Zouitine 1ad2da403d feat(policies): add noise parameter to action prediction methods (#2063)
* feat(policies): add noise parameter to action prediction methods

- Introduced `ActionSelectKwargs` TypedDict for better type hinting.
- Updated `predict_action_chunk` and `select_action` methods in `PreTrainedPolicy` and its subclasses to accept a `noise` parameter.
- Modified `generate_actions` and `conditional_sample` methods in `DiffusionModel` to utilize the new noise parameter for action generation.

* refactor(policies): make ActionSelectKwargs TypedDict fields optional

- Updated `ActionSelectKwargs` to inherit with `total=False`, allowing for optional fields.
2025-09-29 17:02:19 +02:00
Adil Zouitine 2d3a605b3c Revert feat(normalization): add validation for empty features in NormalizerProcessorStep and UnnormalizerProcessorStep (#2087)
Revert "feat(normalization): add validation for empty features in NormalizerProcessorStep and UnnormalizerProcessorStep (#2087)"

This reverts commit f173265354.
2025-09-29 16:55:52 +02:00
Adil Zouitine f173265354 feat(normalization): add validation for empty features in NormalizerProcessorStep and UnnormalizerProcessorStep (#2087)
* feat(normalization): add validation for empty features in NormalizerProcessorStep and UnnormalizerProcessorStep

* refactor(normalization): streamline feature reconstruction logic in _NormalizationMixin

* refactor(tests): remove unused preprocessor initialization in test_act_backbone_lr

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2025-09-29 16:02:15 +02:00
Steven Palma bbcf66bd82 chore: enable simplify in ruff lint (#2085) 2025-09-29 15:06:56 +02:00
Steven Palma c378a325f0 chore: enable pyugrade ruff lint (#2084) 2025-09-29 13:28:53 +02:00
Qizhi Chen 90684a9690 Improve V3 aggregate implementation (#2077)
* fix return type

* improve apply with vertorize op

* Update src/lerobot/datasets/aggregate.py

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-09-29 11:18:54 +02:00
Steven Palma f59eb54f5c chore: remove unused code (#2062) 2025-09-29 10:49:36 +02:00
Qizhi Chen 62e9849ffd use abs path when concatenating (#2076) 2025-09-28 14:18:22 +02:00
Francesco Capuano e3b572992e Save Cropped Dataset to Hub (#2071)
* fix: cast fps argument from dataset to int

* fix: typo

* fix: specify repo-id
2025-09-27 16:07:53 +02:00
Jade Choghari 5b647e3bcb docs(fix): libero example command (#2060)
Signed-off-by: Jade Choghari <chogharijade@gmail.com>
2025-09-26 15:09:42 +02:00
Adil Zouitine ddfff054bc feat(train): enhance processor overrides with normalizer and unnormalizer stats (#2038) 2025-09-26 14:32:29 +02:00
Steven Palma 49918efbc1 chore(utils): remove unused code (#2059) 2025-09-26 14:30:17 +02:00
Steven Palma c5b5955c5a chore: replace hard-coded next values with constants throughout all the source code (#2056) 2025-09-26 14:30:07 +02:00
Michel Aractingi ec40ccde0d Bug in conversion from v2.1 script (#2057)
* False logic in setting the dataset to index in the meta data when converting from v2.1'

* Improved logging
2025-09-26 14:28:58 +02:00
Steven Palma d2782cf66b chore: replace hard-coded action values with constants throughout all the source code (#2055)
* chore: replace hard-coded 'action' values with constants throughout all the source code

* chore(tests): replace hard-coded action values with constants throughout all the test code
2025-09-26 13:33:18 +02:00
Adil Zouitine 9627765ce2 chore(mypy): add mypy configuration and module overrides for gradual type checking (#2052) 2025-09-26 11:53:27 +02:00
Steven Palma 43d878a102 chore: replace hard-coded obs values with constants throughout all the source code (#2037)
* chore: replace hard-coded OBS values with constants throughout all the source code

* chore(tests): replace hard-coded OBS values with constants throughout all the test code
2025-09-25 15:36:47 +02:00
Steven Palma ddba994d73 chore(scripts): rename eval and train scripts (#2033) 2025-09-24 18:29:58 +02:00
Jade Choghari a87d4c9a74 (docs): small change in dataset name (#2032)
* small change

Signed-off-by: Jade Choghari <chogharijade@gmail.com>

* update

Signed-off-by: Jade Choghari <chogharijade@gmail.com>

---------

Signed-off-by: Jade Choghari <chogharijade@gmail.com>
2025-09-24 17:30:32 +02:00
Steven Palma 170c09e7f6 chore(utils): move queue utils and wandb_utils to their respective modules (#2030)
* chore(utils): move queue utils and wandb_utils to their respective modules

* fix(rl): remove double imports

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2025-09-24 17:10:52 +02:00
Steven Palma 853cc70194 chore(utils): remove unused utils legacy functions + rename init_rerun (#2031) 2025-09-24 17:10:27 +02:00
Steven Palma ec63225dc1 chore(utils): move encoding utils and process to their respective modules (#2029)
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2025-09-24 16:47:37 +02:00
Steven Palma af1760f175 chore(utils): move benchmark and buffer to their respective modules (#2028) 2025-09-24 16:46:38 +02:00
Steven Palma 163df97c0c fix(docs): update outdated links (#2026) 2025-09-24 16:17:39 +02:00
Steven Palma cdd2bf1c4e chore(ci): update stale message (#2027) 2025-09-24 15:46:44 +02:00
Steven Palma 1cba47da20 chore(async): move async related code to its directory at top level (#2003)
* chore(async): move async related code to its directory at top level

* chore(style): apply pre-commit to renamed headers

* test(async): fix async imports

* docs(async): update async headers doc
2025-09-24 14:49:37 +02:00
Steven Palma 7359e18eb6 chore(scripts): move replay to scripts (#2021)
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2025-09-24 14:48:23 +02:00
Steven Palma 13010647bc chore(scripts): move setup_motors to scripts (#2020)
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2025-09-24 14:06:58 +02:00
Steven Palma acbc14f60a chore(scripts): move calibrate to scripts (#2024)
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2025-09-24 14:06:48 +02:00
Steven Palma 2b59850f15 chore(scripts): move record to scripts (#2022)
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2025-09-24 13:38:12 +02:00
Steven Palma 42e4b3d09e chore(scripts): move teleop to scripts (#2023) 2025-09-24 12:01:21 +02:00
Steven Palma 98bcda2d8b chore(scripts): move find_port to scripts (#2019) 2025-09-24 11:38:04 +02:00
Steven Palma a4178f385b feat(script): add entry point for find joints limits (#2010)
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2025-09-24 11:28:56 +02:00
Steven Palma bd09b2153f chore(scripts): move find_cameras to scripts (#2018) 2025-09-24 11:14:48 +02:00
Steven Palma 1033680a57 chore: move errors to utils (#2017)
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2025-09-24 11:14:23 +02:00
Steven Palma 7cf04a5ec3 chore: move constants to utils (#2016) 2025-09-24 11:11:53 +02:00
236 changed files with 7490 additions and 4383 deletions
+2 -2
View File
@@ -31,11 +31,11 @@ env:
Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions.
WARN_ISSUE_MESSAGE: >
This issue has been automatically marked as stale because it has not had
recent activity (1 year). It will be closed if no further activity occurs.
recent activity (6 months). It will be closed if no further activity occurs.
Thank you for your contributions.
WARN_PR_MESSAGE: >
This PR has been automatically marked as stale because it has not had
recent activity (1 year). It will be closed if no further activity occurs.
recent activity (6 months). It will be closed if no further activity occurs.
Thank you for your contributions.
jobs:
+6 -5
View File
@@ -86,11 +86,12 @@ repos:
# TODO(Steven): Uncomment when ready to use
##### Static Analysis & Typing #####
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.16.0
# hooks:
# - id: mypy
# args: [--python-version=3.10]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.16.0
hooks:
- id: mypy
args: [--config-file=pyproject.toml]
exclude: ^(examples|benchmarks|tests)/
##### Docstring Checks #####
# - repo: https://github.com/akaihola/darglint2
-378
View File
@@ -1,378 +0,0 @@
"""
Benchmark memory footprint and inference latency of a policy on arbitrary devices.
This script loads a pretrained policy directly (similar to the async inference server)
and generates dummy input data based on the policy's input_features to perform
accurate benchmarking without requiring datasets.
"""
import argparse
import os
import signal
import statistics
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
import psutil
import torch
from tqdm import tqdm
from lerobot.configs.types import FeatureType
from lerobot.policies.factory import get_policy_class
from lerobot.policies.pretrained import PreTrainedPolicy
class TimeoutException:
pass
@contextmanager
def timeout(seconds):
def signal_handler(signum, frame):
raise TimeoutException(f"Timed out after {seconds} seconds")
# On Windows, signal is not available, so we can't use this timeout mechanism
if not hasattr(signal, "SIGALRM"):
yield
return
old_handler = signal.signal(signal.SIGALRM, signal_handler)
try:
# signal.alarm expects integer seconds
# for float seconds, we can use setitimer
signal.setitimer(signal.ITIMER_REAL, seconds)
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
signal.signal(signal.SIGALRM, old_handler)
def bytes_to_human(n: int) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if n < 1024:
return f"{n:.2f} {unit}"
n /= 1024
return f"{n:.2f} PB"
def percentile(values: list[float], p: float) -> float:
if not values:
return float("nan")
k = (len(values) - 1) * (p / 100.0)
f = int(k)
c = min(f + 1, len(values) - 1)
if f == c:
return values[f]
return values[f] + (values[c] - values[f]) * (k - f)
def generate_dummy_observation(input_features: dict, device: str = "cpu") -> dict:
"""Generate dummy observation data based on policy input features."""
dummy_obs = {}
for key, feature in input_features.items():
shape = feature.shape
if feature.type == FeatureType.VISUAL:
# Images: random values in [0, 1] range (already normalized)
dummy_obs[key] = torch.rand(shape, dtype=torch.float32, device=device)
elif feature.type in [FeatureType.STATE, FeatureType.ACTION, FeatureType.ENV]:
# State/action/env: random normal distribution
dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device)
else:
# Default: random normal for unknown types
dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device)
# Add batch dimension
for key in dummy_obs:
dummy_obs[key] = dummy_obs[key].unsqueeze(0)
# Add task string for language-conditioned policies
dummy_obs["task"] = ""
return dummy_obs
def main():
parser = argparse.ArgumentParser(description="Policy inference benchmark")
parser.add_argument(
"--policy-id", type=str, required=True, help="Model ID or local path to pretrained policy"
)
parser.add_argument(
"--policy-type", type=str, required=True, help="Type of policy (smolvla, act, diffusion, etc.)"
)
parser.add_argument(
"--device", type=str, default="mps", choices=["cuda", "cpu", "mps"], help="Device to run on"
)
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument(
"--num-samples", type=int, default=100, help="Number of inference samples to benchmark"
)
parser.add_argument("--warmup", type=int, default=10, help="Number of warmup samples (not timed)")
parser.add_argument(
"--output-dir", type=str, default="outputs/benchmarks", help="Directory to save benchmark results"
)
parser.add_argument(
"--timeout",
type=float,
default=0.3,
help="Timeout for each inference pass in seconds (default: 0.3s = 300ms)",
)
args = parser.parse_args()
# Seed & deterministic-ish setup
torch.manual_seed(args.seed)
if args.device == "cuda":
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False # leave False to avoid perf cliffs
# Resolve device availability
device = args.device.lower()
if device == "cuda" and not torch.cuda.is_available():
print("[!] CUDA requested but unavailable. Falling back to CPU.")
device = "cpu"
elif device == "mps" and not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()):
print("[!] MPS requested but unavailable. Falling back to CPU.")
device = "cpu"
use_cuda = device == "cuda"
# Create output directory and log file
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
policy_name = args.policy_id.replace("/", "_").replace("\\", "_")
log_file = output_dir / f"benchmark_{args.policy_type}_{policy_name}_{device}_{timestamp}.txt"
# Load policy directly from pretrained (similar to async inference server)
print(f"Loading policy {args.policy_type} from {args.policy_id}...")
policy_class = get_policy_class(args.policy_type)
policy: PreTrainedPolicy = policy_class.from_pretrained(args.policy_id)
policy.eval()
policy.to(device)
print(f"Policy loaded on {device}")
print(f"Input features: {list(policy.config.input_features.keys())}")
print(f"Output features: {list(policy.config.output_features.keys())}")
# Generate dummy observation based on policy input features
dummy_observation = generate_dummy_observation(policy.config.input_features, device)
dummy_observation["task"] = ""
# Helper to sync for fair timings
def _sync(dev_=device):
if dev_ == "cuda" and torch.cuda.is_available():
torch.cuda.synchronize()
elif dev_ == "mps" and hasattr(torch, "mps"):
try:
torch.mps.synchronize()
except AttributeError:
pass # MPS sync not available in this PyTorch version
# Warmup (to stabilize kernels/caches)
print("Warming up...")
with torch.no_grad():
policy.reset()
for _ in range(args.warmup):
_ = policy.select_action(dummy_observation)
_sync()
# Memory footprint before timing
process = psutil.Process(os.getpid())
rss_before = process.memory_info().rss
if use_cuda:
torch.cuda.reset_peak_memory_stats()
# PyTorch timing with Event objects for more accurate GPU timing
print(f"Running benchmark: {args.num_samples} samples...")
if use_cuda:
# Use CUDA Events for precise GPU timing
start_events = []
end_events = []
timeout_count = 0
with torch.no_grad():
for forward in tqdm(range(args.num_samples), desc="Trials"):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
try:
with timeout(args.timeout):
start_event.record()
_ = policy.select_action(dummy_observation)
end_event.record()
start_events.append(start_event)
end_events.append(end_event)
except TimeoutException:
timeout_count += 1
# Add placeholder for timeout
start_events.append(None)
end_events.append(None)
print(f"\n[!] Timeout on forward {forward + 1}")
continue
# Synchronize and collect timing results
torch.cuda.synchronize()
per_forward_ms = []
for start_event, end_event in zip(start_events, end_events, strict=True):
if start_event is None:
per_forward_ms.append(args.timeout * 1000)
else:
per_forward_ms.append(start_event.elapsed_time(end_event))
if timeout_count > 0:
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
else:
# Use simple time.perf_counter for CPU/MPS timing with timeout
import time
per_forward_ms = []
timeout_count = 0
with torch.no_grad():
for sample in tqdm(range(args.num_samples), desc="Samples"):
try:
with timeout(args.timeout):
start_time = time.perf_counter()
_ = policy.select_action(dummy_observation)
end_time = time.perf_counter()
per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms
except TimeoutException:
timeout_count += 1
per_forward_ms.append(args.timeout * 1000)
print(f"\n[!] Timeout on sample {sample + 1}")
continue
if timeout_count > 0:
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
# Memory footprint after timing
rss_after = process.memory_info().rss
rss_delta = rss_after - rss_before
cuda_peak = torch.cuda.max_memory_allocated() if use_cuda else 0
# Sort timing results for percentile calculations
per_forward_ms_sorted = sorted(per_forward_ms)
mean_ms = statistics.fmean(per_forward_ms) if per_forward_ms else float("nan")
std_ms = statistics.pstdev(per_forward_ms) if len(per_forward_ms) > 1 else 0.0
min_ms = per_forward_ms_sorted[0] if per_forward_ms_sorted else float("nan")
max_ms = per_forward_ms_sorted[-1] if per_forward_ms_sorted else float("nan")
p50_ms = percentile(per_forward_ms_sorted, 50)
p95_ms = percentile(per_forward_ms_sorted, 95)
# Model size
num_params = sum(p.numel() for p in policy.parameters())
# Prepare results for logging
results = {
"timestamp": datetime.now().isoformat(),
"policy_type": args.policy_type,
"policy_id": args.policy_id,
"device": device,
"num_trials": args.num_samples,
"forwards_per_trial": 1,
"warmup": args.warmup,
"timeout_ms": args.timeout * 1000,
"seed": args.seed,
"num_params": num_params,
"timeout_count": timeout_count,
"latency_mean_ms": mean_ms,
"latency_std_ms": std_ms,
"latency_min_ms": min_ms,
"latency_max_ms": max_ms,
"latency_p50_ms": p50_ms,
"latency_p95_ms": p95_ms,
"cpu_rss_before": rss_before,
"cpu_rss_after": rss_after,
"cpu_rss_delta": rss_delta,
"cuda_peak_alloc": cuda_peak,
"input_features": list(policy.config.input_features.keys()),
"output_features": list(policy.config.output_features.keys()),
}
# Format and write results to log file
log_content = f"""
=== LeRobot Policy Inference Benchmark ===
Timestamp: {results["timestamp"]}
Policy: {results["policy_type"]} ({results["policy_id"]})
Device: {results["device"]}
Seed: {results["seed"]}
=== Model Information ===
Parameters: {results["num_params"]:,}
Input Features: {", ".join(results["input_features"])}
Output Features: {", ".join(results["output_features"])}
=== Benchmark Configuration ===
Samples: {results["num_trials"]}
Warmup: {results["warmup"]}
Total Measurements: {len(per_forward_ms)}
Timeout: {results["timeout_ms"]:.1f}ms
Timeouts: {results["timeout_count"]} / {results["num_trials"]}
=== Latency Results (ms) ===
Mean: {results["latency_mean_ms"]:.3f}
Std Dev: {results["latency_std_ms"]:.3f}
Min: {results["latency_min_ms"]:.3f}
Max: {results["latency_max_ms"]:.3f}
P50: {results["latency_p50_ms"]:.3f}
P95: {results["latency_p95_ms"]:.3f}
=== Memory Footprint ===
CPU RSS Before: {bytes_to_human(results["cpu_rss_before"])}
CPU RSS After: {bytes_to_human(results["cpu_rss_after"])}{bytes_to_human(results["cpu_rss_delta"])})
"""
if use_cuda:
log_content += f"CUDA Peak: {bytes_to_human(results['cuda_peak_alloc'])} (reset before timing)\n"
log_content += f"""
=== Raw Timing Data (first 20 measurements, ms) ===
{", ".join(f"{t:.3f}" for t in per_forward_ms[:20])}
{"..." if len(per_forward_ms) > 20 else ""}
=== Summary Statistics ===
Timing Method: {"CUDA Events" if use_cuda else "torch.utils.benchmark.Timer"}
Device Available: {torch.cuda.is_available() if device == "cuda" else torch.backends.mps.is_available() if device == "mps" else True}
PyTorch Version: {torch.__version__}
Benchmark completed successfully at {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
"""
# Write to log file
with open(log_file, "w") as f:
f.write(log_content)
# Print to console (shorter version)
print("\n=== Inference Benchmark Results ===")
print(f"Policy: {args.policy_type} ({args.policy_id})")
print(f"Device: {device}")
print(f"Samples: {args.num_samples} | Warmup: {args.warmup}")
print(f"Model params: {num_params:,}")
print("\nLatency per forward (ms):")
print(f" mean: {mean_ms:.3f} std: {std_ms:.3f}")
print(f" min: {min_ms:.3f} max: {max_ms:.3f}")
print(f" p50: {p50_ms:.3f} p95: {p95_ms:.3f}")
print("\nMemory footprint:")
print(f" CPU RSS before: {bytes_to_human(rss_before)}")
print(f" CPU RSS after : {bytes_to_human(rss_after)}{bytes_to_human(rss_delta)})")
if use_cuda:
print(
f" CUDA peak allocated: {bytes_to_human(cuda_peak)} "
f"(reset by reset_peak_memory_stats before timing)"
)
print(f"\nResults saved to: {log_file}")
print("Benchmark completed successfully!")
if __name__ == "__main__":
main()
+3 -2
View File
@@ -35,12 +35,13 @@ import torch
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
from tqdm import tqdm
from benchmarks.video.benchmark import TimeBenchmark
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.video_utils import (
decode_video_frames_torchvision,
encode_video_frames,
)
from lerobot.utils.benchmark import TimeBenchmark
from lerobot.utils.constants import OBS_IMAGE
BASE_ENCODING = OrderedDict(
[
@@ -117,7 +118,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
hf_dataset = dataset.hf_dataset.with_format(None)
# We only save images from the first camera
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
imgs_dataset = hf_dataset.select_columns(img_keys[0])
for i, item in enumerate(
+5 -2
View File
@@ -28,11 +28,14 @@
title: "Datasets"
- sections:
- local: smolvla
title: Finetune SmolVLA
title: SmolVLA
- local: pi0
title: π₀ (Pi0)
- local: pi05
title: π₀.₅ (Pi05)
- local: libero
title: Using Libero
title: "Policies"
- sections:
- local: introduction_processors
title: Introduction to Robot Processors
+8 -8
View File
@@ -31,7 +31,7 @@ Then, spin up a policy server (in one terminal, or in a separate machine) specif
You can spin up a policy server running:
```shell
python src/lerobot/scripts/server/policy_server.py \
python src/lerobot/async_inference/policy_server.py \
--host=127.0.0.1 \
--port=8080 \
```
@@ -39,7 +39,7 @@ python src/lerobot/scripts/server/policy_server.py \
This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with:
```shell
python src/lerobot/scripts/server/robot_client.py \
python src/lerobot/async_inference/robot_client.py \
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
--robot.type=so100_follower \ # ROBOT: your robot type
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
@@ -122,8 +122,8 @@ python -m lerobot.scripts.server.policy_server \
<!-- prettier-ignore-start -->
```python
from lerobot.scripts.server.configs import PolicyServerConfig
from lerobot.scripts.server.policy_server import serve
from lerobot.async_inference.configs import PolicyServerConfig
from lerobot.async_inference.policy_server import serve
config = PolicyServerConfig(
host="localhost",
@@ -148,7 +148,7 @@ The `RobotClient` streams observations to the `PolicyServer`, and receives actio
<hfoptions id="start_robot_client">
<hfoption id="Command">
```bash
python src/lerobot/scripts/server/robot_client.py \
python src/lerobot/async_inference/robot_client.py \
--server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server
--robot.type=so100_follower \ # ROBOT: your robot type
--robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port
@@ -171,9 +171,9 @@ python src/lerobot/scripts/server/robot_client.py \
import threading
from lerobot.robots.so100_follower import SO100FollowerConfig
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.scripts.server.configs import RobotClientConfig
from lerobot.scripts.server.robot_client import RobotClient
from lerobot.scripts.server.helpers import visualize_action_queue_size
from lerobot.async_inference.configs import RobotClientConfig
from lerobot.async_inference.robot_client import RobotClient
from lerobot.async_inference.helpers import visualize_action_queue_size
# 1. Create the robot instance
"""Check out the cameras available in your setup by running `python lerobot/find_cameras.py`"""
+8 -11
View File
@@ -95,7 +95,6 @@ class HILSerlProcessorConfig:
class ObservationConfig:
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
add_current_to_observation: bool = False # Add motor currents to state
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
display_cameras: bool = False # Display camera feeds during execution
class ImagePreprocessingConfig:
@@ -105,7 +104,6 @@ class ImagePreprocessingConfig:
class GripperConfig:
use_gripper: bool = True # Enable gripper control
gripper_penalty: float = 0.0 # Penalty for inappropriate gripper usage
gripper_penalty_in_reward: bool = False # Include gripper penalty in reward
class ResetConfig:
fixed_reset_joint_positions: Any | None = None # Joint positions for reset
@@ -288,7 +286,6 @@ You can enable multiple observation processing features simultaneously:
"observation": {
"add_joint_velocity_to_observation": true,
"add_current_to_observation": true,
"add_ee_pose_to_observation": false,
"display_cameras": false
}
}
@@ -304,19 +301,19 @@ Before collecting demonstrations, you need to determine the appropriate operatio
This helps simplify the problem of learning on the real robot in two ways: 1) by limiting the robot's operational space to a specific region that solves the task and avoids unnecessary or unsafe exploration, and 2) by allowing training in end-effector space rather than joint space. Empirically, learning in joint space for reinforcement learning in manipulation is often a harder problem - some tasks are nearly impossible to learn in joint space but become learnable when the action space is transformed to end-effector coordinates.
**Using find_joint_limits.py**
**Using lerobot-find-joint-limits**
This script helps you find the safe operational bounds for your robot's end-effector. Given that you have a follower and leader arm, you can use the script to find the bounds for the follower arm that will be applied during training.
Bounding the action space will reduce the redundant exploration of the agent and guarantees safety.
```bash
python -m lerobot.scripts.find_joint_limits \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=black \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=blue
lerobot-find-joint-limits \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.id=black \
--teleop.type=so100_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \
--teleop.id=blue
```
**Workflow**
+4 -4
View File
@@ -200,7 +200,7 @@ from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderCo
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
from lerobot.utils.visualization_utils import init_rerun
from lerobot.record import record_loop
NUM_EPISODES = 5
@@ -237,7 +237,7 @@ dataset = LeRobotDataset.create(
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
_init_rerun(session_name="recording")
init_rerun(session_name="recording")
# Connect the robot and teleoperator
robot.connect()
@@ -517,7 +517,7 @@ from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerCon
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
from lerobot.utils.visualization_utils import init_rerun
from lerobot.record import record_loop
from lerobot.policies.factory import make_processor
@@ -557,7 +557,7 @@ dataset = LeRobotDataset.create(
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
_init_rerun(session_name="recording")
init_rerun(session_name="recording")
# Connect the robot
robot.connect()
+1 -1
View File
@@ -277,7 +277,7 @@ leader.disconnect()
</hfoption>
</hfoptions>
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./il_robots)
> [!TIP]
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
+1 -1
View File
@@ -323,7 +323,7 @@ To replay an episode run the API example below, make sure to change `remote_ip`,
python examples/lekiwi/replay.py
```
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by the training part of this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by the training part of this tutorial: [Getting started with real-world robots](./il_robots)
## Evaluate your policy
+44 -4
View File
@@ -33,7 +33,7 @@ To Install LIBERO, after following LeRobot official instructions, just do:
Evaluate a policy on one LIBERO suite:
```bash
python src/lerobot/scripts/eval.py \
lerobot-eval \
--policy.path="your-policy-id" \
--env.type=libero \
--env.task=libero_object \
@@ -52,7 +52,7 @@ python src/lerobot/scripts/eval.py \
Benchmark a policy across multiple suites at once:
```bash
python src/lerobot/scripts/eval.py \
lerobot-eval \
--policy.path="your-policy-id" \
--env.type=libero \
--env.task=libero_object,libero_spatial \
@@ -103,10 +103,11 @@ For reference, here is the **original dataset** published by Physical Intelligen
### Example training command
```bash
python src/lerobot/scripts/train.py \
lerobot-train \
--policy.type=smolvla \
--policy.repo_id=${HF_USER}/libero-test \
--dataset.repo_id=jadechoghari/smol-libero3 \
--policy.load_vlm_weights=true \
--dataset.repo_id=HuggingFaceVLA/libero \
--env.type=libero \
--env.task=libero_10 \
--output_dir=./outputs/ \
@@ -124,3 +125,42 @@ python src/lerobot/scripts/train.py \
LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation:
- `export MUJOCO_GL=egl` → for headless servers (e.g. HPC, cloud)
## Reproducing π₀.₅ results
We reproduce the results of π₀.₅ on the LIBERO benchmark using the LeRobot implementation. We take the Physical Intelligence LIBERO base model (`pi05_libero`) and finetune for an additional 6k steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero).
The finetuned model can be found here:
- **π₀.₅ LIBERO**: [lerobot/pi05_libero_finetuned](https://huggingface.co/lerobot/pi05_libero_finetuned)
We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command:
```bash
python src/lerobot/scripts/eval.py \
--output_dir=/logs/ \
--env.type=libero \
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
--eval.batch_size=1 \
--eval.n_episodes=10 \
--policy.path=pi05_libero_finetuned \
--policy.n_action_steps=10 \
--output_dir=./eval_logs/ \
--env.max_parallel_tasks=1
```
**Note:** We set `n_action_steps=10`, similar to the original OpenPI implementation.
### Results
We obtain the following results on the LIBERO benchmark:
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
| -------- | -------------- | ------------- | ----------- | --------- | -------- |
| **π₀.₅** | 97.0 | 99.0 | 98.0 | 96.0 | **97.5** |
These results are consistent with the original [results](https://github.com/Physical-Intelligence/openpi/tree/main/examples/libero#results) reported by Physical Intelligence:
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
| -------- | -------------- | ------------- | ----------- | --------- | --------- |
| **π₀.₅** | 98.8 | 98.2 | 98.0 | 92.4 | **96.85** |
+1 -2
View File
@@ -136,13 +136,12 @@ Additionally you can customize mapping or safety limits by editing the processor
),
```
- The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` and `max_ee_twist_step_rad` are the step limits for the EE pose and can be modified to change the safety limits.
- The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` are the step limits for the EE pose and can be modified to change the safety limits.
```examples/phone_to_so100/teleoperate.py
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
max_ee_twist_step_rad=0.50,
)
```
+79
View File
@@ -0,0 +1,79 @@
# π₀ (Pi0)
π₀ is a **Vision-Language-Action model for general robot control**, from Physical Intelligence. The LeRobot implementation is adapted from their open source [OpenPI](https://github.com/Physical-Intelligence/openpi) repository.
## Model Overview
π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi0). Unlike traditional robot programs that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks.
### The Vision for Physical Intelligence
As described by Physical Intelligence, while AI has achieved remarkable success in digital domains, from chess-playing to drug discovery, human intelligence still dramatically outpaces AI in the physical world. To paraphrase Moravec's paradox, winning a game of chess represents an "easy" problem for AI, but folding a shirt or cleaning up a table requires solving some of the most difficult engineering problems ever conceived. π₀ represents a first step toward developing artificial physical intelligence that enables users to simply ask robots to perform any task they want, just like they can with large language models.
### Architecture and Approach
π₀ combines several key innovations:
- **Flow Matching**: Uses a novel method to augment pre-trained VLMs with continuous action outputs via flow matching (a variant of diffusion models)
- **Cross-Embodiment Training**: Trained on data from 8 distinct robot platforms including UR5e, Bimanual UR5e, Franka, Bimanual Trossen, Bimanual ARX, Mobile Trossen, and Mobile Fibocom
- **Internet-Scale Pre-training**: Inherits semantic knowledge from a pre-trained 3B parameter Vision-Language Model
- **High-Frequency Control**: Outputs motor commands at up to 50 Hz for real-time dexterous manipulation
## Installation Requirements
1. Install LeRobot by following our [Installation Guide](./installation).
2. Install Pi0 dependencies by running:
```bash
pip install -e ".[pi]"
```
## Training Data and Capabilities
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
1. **Internet-Scale Pre-training**: Vision-language data from the web for semantic understanding
2. **Open X-Embodiment Dataset**: Open-source robot manipulation datasets
3. **Physical Intelligence Dataset**: Large and diverse dataset of dexterous tasks across 8 distinct robots
## Usage
To use π₀ in LeRobot, specify the policy type as:
```python
policy.type=pi0
```
## Training
For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
```bash
python src/lerobot/scripts/lerobot_train.py \
--dataset.repo_id=your_dataset \
--policy.type=pi0 \
--output_dir=./outputs/pi0_training \
--job_name=pi0_training \
--policy.pretrained_path=lerobot/pi0_base \
--policy.repo_id=your_repo_id \
--policy.compile_model=true \
--policy.gradient_checkpointing=true \
--policy.dtype=bfloat16 \
--steps=3000 \
--policy.device=cuda \
--batch_size=32
```
### Key Training Parameters
- **`--policy.compile_model=true`**: Enables model compilation for faster training
- **`--policy.gradient_checkpointing=true`**: Reduces memory usage significantly during training
- **`--policy.dtype=bfloat16`**: Use mixed precision training for efficiency
- **`--batch_size=32`**: Batch size for training, adapt this based on your GPU memory
- **`--policy.pretrained_path=lerobot/pi0_base`**: The base π₀ model you want to finetune, options are:
- [lerobot/pi0_base](https://huggingface.co/lerobot/pi0_base)
- [lerobot/pi0_libero](https://huggingface.co/lerobot/pi0_libero) (specifically trained on the Libero dataset)
## License
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
+107
View File
@@ -0,0 +1,107 @@
# π₀.₅ (Pi05) Policy
π₀.₅ is a **Vision-Language-Action model with open-world generalization**, from Physical Intelligence. The LeRobot implementation is adapted from their open source [OpenPI](https://github.com/Physical-Intelligence/openpi) repository.
## Model Overview
π₀.₅ represents a significant evolution from π₀, developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi05) to address a big challenge in robotics: **open-world generalization**. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
### The Generalization Challenge
As Physical Intelligence explains, the fundamental challenge isn't performing tasks of agility or dexterity, but generalization, the ability to correctly perform tasks in new settings with new objects. Consider a robot cleaning different homes: each home has different objects in different places. Generalization must occur at multiple levels:
- **Physical Level**: Understanding how to pick up a spoon (by the handle) or plate (by the edge), even with unseen objects in cluttered environments
- **Semantic Level**: Understanding task semantics, where to put clothes and shoes (laundry hamper, not on the bed), and what tools are appropriate for cleaning spills
- **Environmental Level**: Adapting to "messy" real-world environments like homes, grocery stores, offices, and hospitals
### Co-Training on Heterogeneous Data
The breakthrough innovation in π₀.₅ is **co-training on heterogeneous data sources**. The model learns from:
1. **Multimodal Web Data**: Image captioning, visual question answering, object detection
2. **Verbal Instructions**: Humans coaching robots through complex tasks step-by-step
3. **Subtask Commands**: High-level semantic behavior labels (e.g., "pick up the pillow" for an unmade bed)
4. **Cross-Embodiment Robot Data**: Data from various robot platforms with different capabilities
5. **Multi-Environment Data**: Static robots deployed across many different homes
6. **Mobile Manipulation Data**: ~400 hours of mobile robot demonstrations
This diverse training mixture creates a "curriculum" that enables generalization across physical, visual, and semantic levels simultaneously.
## Installation Requirements
1. Install LeRobot by following our [Installation Guide](./installation).
2. Install Pi0.5 dependencies by running:
```bash
pip install -e ".[pi]"
```
## Usage
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
```python
policy.type=pi05
```
## Training
### Training Command Example
Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
```bash
python src/lerobot/scripts/lerobot_train.py\
--dataset.repo_id=your_dataset \
--policy.type=pi05 \
--output_dir=./outputs/pi05_training \
--job_name=pi05_training \
--policy.repo_id=your_repo_id \
--policy.pretrained_path=lerobot/pi05_base \
--policy.compile_model=true \
--policy.gradient_checkpointing=true \
--wandb.enable=true \
--policy.dtype=bfloat16 \
--steps=3000 \
--policy.device=cuda \
--batch_size=32
```
### Key Training Parameters
- **`--policy.compile_model=true`**: Enables model compilation for faster training
- **`--policy.gradient_checkpointing=true`**: Reduces memory usage significantly during training
- **`--policy.dtype=bfloat16`**: Use mixed precision training for efficiency
- **`--batch_size=32`**: Batch size for training, adapt this based on your GPU memory
- **`--policy.pretrained_path=lerobot/pi05_base`**: The base π₀.₅ model you want to finetune, options are:
- [lerobot/pi05_base](https://huggingface.co/lerobot/pi05_base)
- [lerobot/pi05_libero](https://huggingface.co/lerobot/pi05_libero) (specifically trained on the Libero dataset)
If your dataset is not converted with `quantiles`, you can convert it with the following command:
```bash
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
--repo-id=your_dataset \
```
Or train pi05 with this normalization mapping: `--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'`
## Performance Results
### Libero Benchmark Results
π₀.₅ has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the libero base model for an additional 6k steps on the Libero dataset and compared the results to the OpenPI reference results.
| Benchmark | LeRobot Implementation | OpenPI Reference |
| ------------------ | ---------------------- | ---------------- |
| **Libero Spatial** | 97.0% | 98.8% |
| **Libero Object** | 99.0% | 98.2% |
| **Libero Goal** | 98.0% | 98.0% |
| **Libero 10** | 96.0% | 92.4% |
| **Average** | 97.5% | 96.85% |
These results demonstrate π₀.₅'s strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
## License
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
+1 -1
View File
@@ -38,7 +38,7 @@ phone_to_robot_ee_pose_processor = RobotProcessorPipeline[RobotAction, RobotActi
kinematics=kinematics_solver, end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, motor_names=list(robot.bus.motors.keys()),
),
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.20, max_ee_twist_step_rad=0.50,
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.20,
),
GripperVelocityToJoint(),
],
+3 -3
View File
@@ -1,4 +1,4 @@
# Finetune SmolVLA
# SmolVLA
SmolVLA is Hugging Faces lightweight foundation model for robotics. Designed for easy fine-tuning on LeRobot datasets, it helps accelerate your development!
@@ -29,7 +29,7 @@ SmolVLA is Hugging Faces lightweight foundation model for robotics. Designed
## Collect a dataset
SmolVLA is a base model, so fine-tuning on your own data is required for optimal performance in your setup.
We recommend recording ~50 episodes of your task as a starting point. Follow our guide to get started: [Recording a Dataset](https://huggingface.co/docs/lerobot/getting_started_real_world_robot#record-a-dataset)
We recommend recording ~50 episodes of your task as a starting point. Follow our guide to get started: [Recording a Dataset](./il_robots)
<Tip>
@@ -93,7 +93,7 @@ lerobot-train --help
## Evaluate the finetuned model and run it in real-time
Similarly for when recording an episode, it is recommended that you are logged in to the HuggingFace Hub. You can follow the corresponding steps: [Record a dataset](./getting_started_real_world_robot#record-a-dataset).
Similarly for when recording an episode, it is recommended that you are logged in to the HuggingFace Hub. You can follow the corresponding steps: [Record a dataset](./il_robots).
Once you are logged in, you can run inference in your setup by doing:
```bash
+1 -1
View File
@@ -634,7 +634,7 @@ leader.disconnect()
</hfoption>
</hfoptions>
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./il_robots)
> [!TIP]
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
+1 -1
View File
@@ -430,7 +430,7 @@ leader.disconnect()
</hfoption>
</hfoptions>
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot)
Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./il_robots)
> [!TIP]
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
+4 -3
View File
@@ -44,6 +44,7 @@ from lerobot.robots import ( # noqa: F401
so100_follower,
so101_follower,
)
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import (
init_logging,
@@ -78,16 +79,16 @@ def replay(cfg: ReplayConfig):
robot = make_robot_from_config(cfg.robot)
dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode])
actions = dataset.hf_dataset.select_columns("action")
actions = dataset.hf_dataset.select_columns(ACTION)
robot.connect()
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action_array = actions[idx]["action"]
action_array = actions[idx][ACTION]
action = {}
for i, name in enumerate(dataset.features["action"]["names"]):
for i, name in enumerate(dataset.features[ACTION]["names"]):
key = f"{name.removeprefix('main_')}.pos"
action[key] = action_array[i].item()
+6 -5
View File
@@ -19,11 +19,12 @@ from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.processor import make_default_processors
from lerobot.record import record_loop
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.scripts.lerobot_record import record_loop
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
from lerobot.utils.visualization_utils import init_rerun
NUM_EPISODES = 2
FPS = 30
@@ -41,8 +42,8 @@ robot = LeKiwiClient(robot_config)
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
action_features = hw_to_dataset_features(robot.action_features, ACTION)
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
dataset_features = {**action_features, **obs_features}
# Create the dataset
@@ -73,7 +74,7 @@ teleop_action_processor, robot_action_processor, robot_observation_processor = m
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
_init_rerun(session_name="lekiwi_evaluate")
init_rerun(session_name="lekiwi_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
+6 -5
View File
@@ -17,14 +17,15 @@
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.processor import make_default_processors
from lerobot.record import record_loop
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
from lerobot.utils.visualization_utils import init_rerun
NUM_EPISODES = 2
FPS = 30
@@ -47,8 +48,8 @@ keyboard = KeyboardTeleop(keyboard_config)
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
action_features = hw_to_dataset_features(robot.action_features, ACTION)
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
dataset_features = {**action_features, **obs_features}
# Create the dataset
@@ -69,7 +70,7 @@ keyboard.connect()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
_init_rerun(session_name="lekiwi_record")
init_rerun(session_name="lekiwi_record")
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
+3 -2
View File
@@ -19,6 +19,7 @@ import time
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
@@ -34,7 +35,7 @@ robot = LeKiwiClient(robot_config)
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[EPISODE_IDX])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
actions = episode_frames.select_columns("action")
actions = episode_frames.select_columns(ACTION)
# Connect to the robot
robot.connect()
@@ -49,7 +50,7 @@ for idx in range(len(episode_frames)):
# Get recorded action from dataset
action = {
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Send action to robot
+2 -2
View File
@@ -20,7 +20,7 @@ from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop, KeyboardTeleopConfig
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
FPS = 30
@@ -41,7 +41,7 @@ leader_arm.connect()
keyboard.connect()
# Init rerun viewer
_init_rerun(session_name="lekiwi_teleop")
init_rerun(session_name="lekiwi_teleop")
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
raise ValueError("Robot or teleop is not connected!")
+3 -3
View File
@@ -34,16 +34,16 @@ from lerobot.processor.converters import (
transition_to_observation,
transition_to_robot_action,
)
from lerobot.record import record_loop
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.scripts.lerobot_record import record_loop
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
from lerobot.utils.visualization_utils import init_rerun
NUM_EPISODES = 5
FPS = 30
@@ -137,7 +137,7 @@ robot.connect()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
_init_rerun(session_name="phone_so100_evaluate")
init_rerun(session_name="phone_so100_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
+3 -4
View File
@@ -26,7 +26,6 @@ from lerobot.processor.converters import (
transition_to_observation,
transition_to_robot_action,
)
from lerobot.record import record_loop
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
EEBoundsAndSafety,
@@ -36,12 +35,13 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.teleoperators.phone.teleop_phone import Phone
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
from lerobot.utils.visualization_utils import init_rerun
NUM_EPISODES = 2
FPS = 30
@@ -84,7 +84,6 @@ phone_to_robot_ee_pose_processor = RobotProcessorPipeline[tuple[RobotAction, Rob
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.20,
max_ee_twist_step_rad=0.50,
),
GripperVelocityToJoint(speed_factor=20.0),
],
@@ -143,7 +142,7 @@ phone.connect()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
_init_rerun(session_name="phone_so100_record")
init_rerun(session_name="phone_so100_record")
if not robot.is_connected or not phone.is_connected:
raise ValueError("Robot or teleop is not connected!")
+3 -2
View File
@@ -28,6 +28,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
@@ -66,7 +67,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotOb
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
actions = episode_frames.select_columns("action")
actions = episode_frames.select_columns(ACTION)
# Connect to the robot
robot.connect()
@@ -81,7 +82,7 @@ for idx in range(len(episode_frames)):
# Get recorded action from dataset
ee_action = {
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get robot observation
+2 -3
View File
@@ -33,7 +33,7 @@ from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.teleoperators.phone.teleop_phone import Phone
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
FPS = 30
@@ -67,7 +67,6 @@ phone_to_robot_joints_processor = RobotProcessorPipeline[tuple[RobotAction, Robo
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
max_ee_twist_step_rad=0.50,
),
GripperVelocityToJoint(
speed_factor=20.0,
@@ -87,7 +86,7 @@ robot.connect()
teleop_device.connect()
# Init rerun viewer
_init_rerun(session_name="phone_so100_teleop")
init_rerun(session_name="phone_so100_teleop")
if not robot.is_connected or not teleop_device.is_connected:
raise ValueError("Robot or teleop is not connected!")
+3 -3
View File
@@ -34,16 +34,16 @@ from lerobot.processor.converters import (
transition_to_observation,
transition_to_robot_action,
)
from lerobot.record import record_loop
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
ForwardKinematicsJointsToEE,
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.scripts.lerobot_record import record_loop
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
from lerobot.utils.visualization_utils import init_rerun
NUM_EPISODES = 5
FPS = 30
@@ -138,7 +138,7 @@ robot.connect()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
_init_rerun(session_name="so100_so100_evaluate")
init_rerun(session_name="so100_so100_evaluate")
if not robot.is_connected:
raise ValueError("Robot is not connected!")
+3 -4
View File
@@ -27,7 +27,6 @@ from lerobot.processor.converters import (
transition_to_observation,
transition_to_robot_action,
)
from lerobot.record import record_loop
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.robot_kinematic_processor import (
EEBoundsAndSafety,
@@ -35,11 +34,12 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
from lerobot.utils.visualization_utils import init_rerun
NUM_EPISODES = 2
FPS = 30
@@ -101,7 +101,6 @@ ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservati
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
max_ee_twist_step_rad=0.50,
),
InverseKinematicsEEToJoints(
kinematics=follower_kinematics_solver,
@@ -143,7 +142,7 @@ follower.connect()
# Initialize the keyboard listener and rerun visualization
listener, events = init_keyboard_listener()
_init_rerun(session_name="recording_phone")
init_rerun(session_name="recording_phone")
if not leader.is_connected or not follower.is_connected:
raise ValueError("Robot or teleop is not connected!")
+3 -2
View File
@@ -29,6 +29,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
InverseKinematicsEEToJoints,
)
from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.utils.constants import ACTION
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
@@ -67,7 +68,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotOb
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
actions = episode_frames.select_columns("action")
actions = episode_frames.select_columns(ACTION)
# Connect to the robot
robot.connect()
@@ -82,7 +83,7 @@ for idx in range(len(episode_frames)):
# Get recorded action from dataset
ee_action = {
name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"])
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
}
# Get robot observation
+2 -3
View File
@@ -33,7 +33,7 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower
from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
FPS = 30
@@ -78,7 +78,6 @@ ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservati
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
max_ee_twist_step_rad=0.50,
),
InverseKinematicsEEToJoints(
kinematics=follower_kinematics_solver,
@@ -95,7 +94,7 @@ follower.connect()
leader.connect()
# Init rerun viewer
_init_rerun(session_name="so100_so100_EE_teleop")
init_rerun(session_name="so100_so100_EE_teleop")
print("Starting teleop loop...")
while True:
+1 -1
View File
@@ -20,13 +20,13 @@ from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.constants import ACTION
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.utils.constants import ACTION
def main():
+96 -16
View File
@@ -94,7 +94,7 @@ dependencies = [
# Common
pygame-dep = ["pygame>=2.5.1"]
placo-dep = ["placo>=0.9.6"]
transformers-dep = ["transformers>=4.52.0"]
transformers-dep = ["transformers>=4.53.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
# Motors
@@ -119,7 +119,7 @@ phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"]
# ] # TODO: Currently not supported
# Policies
pi0 = ["lerobot[transformers-dep]"]
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
@@ -147,7 +147,7 @@ all = [
"lerobot[reachy2]",
"lerobot[kinematics]",
"lerobot[intelrealsense]",
"lerobot[pi0]",
"lerobot[pi]",
"lerobot[smolvla]",
"lerobot[hilserl]",
"lerobot[async]",
@@ -162,17 +162,18 @@ all = [
]
[project.scripts]
lerobot-calibrate="lerobot.calibrate:main"
lerobot-find-cameras="lerobot.find_cameras:main"
lerobot-find-port="lerobot.find_port:main"
lerobot-record="lerobot.record:main"
lerobot-replay="lerobot.replay:main"
lerobot-setup-motors="lerobot.setup_motors:main"
lerobot-teleoperate="lerobot.teleoperate:main"
lerobot-eval="lerobot.scripts.eval:main"
lerobot-train="lerobot.scripts.train:main"
lerobot-calibrate="lerobot.scripts.lerobot_calibrate:main"
lerobot-find-cameras="lerobot.scripts.lerobot_find_cameras:main"
lerobot-find-port="lerobot.scripts.lerobot_find_port:main"
lerobot-record="lerobot.scripts.lerobot_record:main"
lerobot-replay="lerobot.scripts.lerobot_replay:main"
lerobot-setup-motors="lerobot.scripts.lerobot_setup_motors:main"
lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main"
lerobot-eval="lerobot.scripts.lerobot_eval:main"
lerobot-train="lerobot.scripts.lerobot_train:main"
lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
lerobot-info="lerobot.scripts.lerobot_info:main"
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
# ---------------- Tool Configurations ----------------
@@ -200,7 +201,7 @@ exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
# N: pep8-naming
# TODO: Uncomment rules when ready to use
select = [
"E", "W", "F", "I", "B", "C4", "T20", "N" # "SIM", "A", "S", "D", "RUF", "UP"
"E", "W", "F", "I", "B", "C4", "T20", "N", "UP", "SIM" #, "A", "S", "D", "RUF"
]
ignore = [
"E501", # Line too long
@@ -266,8 +267,87 @@ default.extend-ignore-identifiers-re = [
# color = true
# paths = ["src/lerobot"]
# [tool.mypy]
# python_version = "3.10"
# TODO: Enable mypy gradually module by module across multiple PRs
# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations
[tool.mypy]
python_version = "3.10"
ignore_missing_imports = true
follow_imports = "skip"
# warn_return_any = true
# warn_unused_configs = true
# ignore_missing_imports = false
# strict = true
# disallow_untyped_defs = true
# disallow_incomplete_defs = true
# check_untyped_defs = true
[[tool.mypy.overrides]]
module = "lerobot.*"
ignore_errors = true
[[tool.mypy.overrides]]
module = "lerobot.envs.*"
# Enable type checking only for the envs module
ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.utils.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.configs.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.optim.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.model.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.processor.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.datasets.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.cameras.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.motors.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.robots.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.teleoperators.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.policies.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.rl.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.async_inference.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.transport.*"
# ignore_errors = false
# [[tool.mypy.overrides]]
# module = "lerobot.scripts.*"
# ignore_errors = false
@@ -18,7 +18,8 @@ from dataclasses import dataclass, field
import torch
from lerobot.robots.config import RobotConfig
from lerobot.scripts.server.constants import (
from .constants import (
DEFAULT_FPS,
DEFAULT_INFERENCE_LATENCY,
DEFAULT_OBS_QUEUE_TIMEOUT,
@@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
DEFAULT_OBS_QUEUE_TIMEOUT = 2
# All action chunking policies
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"]
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
# TODO: Add all other robots
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"]
@@ -22,16 +22,22 @@ from pathlib import Path
import torch
from lerobot.configs.types import PolicyFeature
from lerobot.constants import OBS_IMAGES, OBS_STATE
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
from lerobot.policies import ( # noqa: F401
ACTConfig,
DiffusionConfig,
PI0Config,
PI05Config,
SmolVLAConfig,
VQBeTConfig,
)
from lerobot.robots.robot import Robot
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.utils import init_logging
Action = torch.Tensor
ActionChunk = torch.Tensor
# observation as received from the robot
RawObservation = dict[str, torch.Tensor]
@@ -46,7 +52,7 @@ Observation = dict[str, torch.Tensor]
def visualize_action_queue_size(action_queue_size: list[int]) -> None:
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
_, ax = plt.subplots()
ax.set_title("Action Queue Size Over Time")
ax.set_xlabel("Environment steps")
ax.set_ylabel("Action Queue Size")
@@ -66,7 +72,7 @@ def validate_robot_cameras_for_policy(
def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]:
return hw_to_dataset_features(robot.observation_features, "observation", use_video=False)
return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False)
def is_image_key(k: str) -> bool:
@@ -141,7 +147,7 @@ def make_lerobot_observation(
lerobot_features: dict[str, dict],
) -> LeRobotObservation:
"""Make a lerobot observation from a raw observation."""
return build_dataset_frame(lerobot_features, robot_obs, prefix="observation")
return build_dataset_frame(lerobot_features, robot_obs, prefix=OBS_STR)
def prepare_raw_observation(
@@ -15,7 +15,7 @@
"""
Example:
```shell
python src/lerobot/scripts/server/policy_server.py \
python src/lerobot/async_inference/policy_server.py \
--host=127.0.0.1 \
--port=8080 \
--fps=30 \
@@ -38,9 +38,15 @@ import grpc
import torch
from lerobot.policies.factory import get_policy_class
from lerobot.scripts.server.configs import PolicyServerConfig
from lerobot.scripts.server.constants import SUPPORTED_POLICIES
from lerobot.scripts.server.helpers import (
from lerobot.transport import (
services_pb2, # type: ignore
services_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import receive_bytes_in_chunks
from .configs import PolicyServerConfig
from .constants import SUPPORTED_POLICIES
from .helpers import (
FPSTracker,
Observation,
RemotePolicyConfig,
@@ -50,11 +56,6 @@ from lerobot.scripts.server.helpers import (
observations_similar,
raw_observation_to_observation,
)
from lerobot.transport import (
services_pb2, # type: ignore
services_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import receive_bytes_in_chunks
class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
@@ -15,7 +15,7 @@
"""
Example command:
```shell
python src/lerobot/scripts/server/robot_client.py \
python src/lerobot/async_inference/robot_client.py \
--robot.type=so100_follower \
--robot.port=/dev/tty.usbmodem58760431541 \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
@@ -57,9 +57,15 @@ from lerobot.robots import ( # noqa: F401
so100_follower,
so101_follower,
)
from lerobot.scripts.server.configs import RobotClientConfig
from lerobot.scripts.server.constants import SUPPORTED_ROBOTS
from lerobot.scripts.server.helpers import (
from lerobot.transport import (
services_pb2, # type: ignore
services_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
from .configs import RobotClientConfig
from .constants import SUPPORTED_ROBOTS
from .helpers import (
Action,
FPSTracker,
Observation,
@@ -72,11 +78,6 @@ from lerobot.scripts.server.helpers import (
validate_robot_cameras_for_policy,
visualize_action_queue_size,
)
from lerobot.transport import (
services_pb2, # type: ignore
services_pb2_grpc, # type: ignore
)
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
class RobotClient:
+1 -1
View File
@@ -31,7 +31,7 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"
import cv2
import numpy as np
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..utils import get_cv2_backend, get_cv2_rotation
@@ -31,7 +31,7 @@ import numpy as np
from reachy2_sdk.media.camera import CameraView
from reachy2_sdk.media.camera_manager import CameraManager
from lerobot.errors import DeviceNotConnectedError
from lerobot.utils.errors import DeviceNotConnectedError
from ..camera import Camera
from .configuration_reachy2_camera import ColorMode, Reachy2CameraConfig
@@ -29,7 +29,7 @@ try:
except Exception as e:
logging.info(f"Could not import realsense: {e}")
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from ..camera import Camera
from ..configs import ColorMode
-4
View File
@@ -15,14 +15,10 @@
# limitations under the License.
import platform
from pathlib import Path
from typing import TypeAlias
from .camera import Camera
from .configs import CameraConfig, Cv2Rotation
IndexOrPath: TypeAlias = int | Path
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[str, Camera]:
cameras = {}
-3
View File
@@ -16,9 +16,6 @@
from dataclasses import dataclass, field
from lerobot import (
policies, # noqa: F401
)
from lerobot.datasets.transforms import ImageTransformsConfig
from lerobot.datasets.video_utils import get_safe_default_codec
+4 -2
View File
@@ -27,9 +27,9 @@ from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.optim.optimizers import OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
@@ -71,9 +71,11 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
tags: list[str] | None = None
# Add tags to your policy on the hub.
license: str | None = None
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
pretrained_path: str | None = None
def __post_init__(self):
self.pretrained_path = None
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
+2 -5
View File
@@ -15,7 +15,6 @@
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
from dataclasses import dataclass
from enum import Enum
from typing import Any, Protocol
class FeatureType(str, Enum):
@@ -36,10 +35,8 @@ class NormalizationMode(str, Enum):
MIN_MAX = "MIN_MAX"
MEAN_STD = "MEAN_STD"
IDENTITY = "IDENTITY"
class DictLike(Protocol):
def __getitem__(self, key: Any) -> Any: ...
QUANTILES = "QUANTILES"
QUANTILE10 = "QUANTILE10"
@dataclass
+19 -26
View File
@@ -93,14 +93,13 @@ def update_data_df(df, src_meta, dst_meta):
pd.DataFrame: Updated DataFrame with adjusted indices.
"""
def _update(row):
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
row["index"] = row["index"] + dst_meta.info["total_frames"]
task = src_meta.tasks.iloc[row["task_index"]].name
row["task_index"] = dst_meta.tasks.loc[task].task_index.item()
return row
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
df["index"] = df["index"] + dst_meta.info["total_frames"]
return df.apply(_update, axis=1)
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
return df
def update_meta_data(
@@ -126,27 +125,21 @@ def update_meta_data(
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
"""
def _update(row):
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"]
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"]
row["data/chunk_index"] = row["data/chunk_index"] + data_idx["chunk"]
row["data/file_index"] = row["data/file_index"] + data_idx["file"]
for key, video_idx in videos_idx.items():
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + video_idx["chunk"]
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + video_idx["file"]
row[f"videos/{key}/from_timestamp"] = (
row[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
)
row[f"videos/{key}/to_timestamp"] = (
row[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
)
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
for key, video_idx in videos_idx.items():
df[f"videos/{key}/chunk_index"] = df[f"videos/{key}/chunk_index"] + video_idx["chunk"]
df[f"videos/{key}/file_index"] = df[f"videos/{key}/file_index"] + video_idx["file"]
df[f"videos/{key}/from_timestamp"] = df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
row["dataset_from_index"] = row["dataset_from_index"] + dst_meta.info["total_frames"]
row["dataset_to_index"] = row["dataset_to_index"] + dst_meta.info["total_frames"]
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
return row
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
return df.apply(_update, axis=1)
return df
def aggregate_datasets(
@@ -23,6 +23,9 @@ Please, update your dataset to the new format using this command:
python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id={repo_id}
```
If you already have a converted version uploaded to the hub, then this error might be because of
an older version in your local cache. Consider deleting the cached version and retrying.
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""
+481 -31
View File
@@ -17,6 +17,179 @@ import numpy as np
from lerobot.datasets.utils import load_image_as_numpy
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
class RunningQuantileStats:
"""
Maintains running statistics for batches of vectors, including mean,
standard deviation, min, max, and approximate quantiles.
Statistics are computed per feature dimension and updated incrementally
as new batches are observed. Quantiles are estimated using histograms,
which adapt dynamically if the observed data range expands.
"""
def __init__(self, quantile_list: list[float] | None = None, num_quantile_bins: int = 5000):
self._count = 0
self._mean = None
self._mean_of_squares = None
self._min = None
self._max = None
self._histograms = None
self._bin_edges = None
self._num_quantile_bins = num_quantile_bins
self._quantile_list = quantile_list
if self._quantile_list is None:
self._quantile_list = DEFAULT_QUANTILES
self._quantile_keys = [f"q{int(q * 100):02d}" for q in self._quantile_list]
def update(self, batch: np.ndarray) -> None:
"""Update the running statistics with a batch of vectors.
Args:
batch: An array where all dimensions except the last are batch dimensions.
"""
batch = batch.reshape(-1, batch.shape[-1])
num_elements, vector_length = batch.shape
if self._count == 0:
self._mean = np.mean(batch, axis=0)
self._mean_of_squares = np.mean(batch**2, axis=0)
self._min = np.min(batch, axis=0)
self._max = np.max(batch, axis=0)
self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)]
self._bin_edges = [
np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1)
for i in range(vector_length)
]
else:
if vector_length != self._mean.size:
raise ValueError("The length of new vectors does not match the initialized vector length.")
new_max = np.max(batch, axis=0)
new_min = np.min(batch, axis=0)
max_changed = np.any(new_max > self._max)
min_changed = np.any(new_min < self._min)
self._max = np.maximum(self._max, new_max)
self._min = np.minimum(self._min, new_min)
if max_changed or min_changed:
self._adjust_histograms()
self._count += num_elements
batch_mean = np.mean(batch, axis=0)
batch_mean_of_squares = np.mean(batch**2, axis=0)
# Update running mean and mean of squares
self._mean += (batch_mean - self._mean) * (num_elements / self._count)
self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (
num_elements / self._count
)
self._update_histograms(batch)
def get_statistics(self) -> dict[str, np.ndarray]:
"""Compute and return the statistics of the vectors processed so far.
Args:
quantiles: List of quantiles to compute (e.g., [0.01, 0.10, 0.50, 0.90, 0.99]). If None, no quantiles computed.
Returns:
Dictionary containing the computed statistics.
"""
if self._count < 2:
raise ValueError("Cannot compute statistics for less than 2 vectors.")
variance = self._mean_of_squares - self._mean**2
stddev = np.sqrt(np.maximum(0, variance))
stats = {
"min": self._min.copy(),
"max": self._max.copy(),
"mean": self._mean.copy(),
"std": stddev,
"count": np.array([self._count]),
}
quantile_results = self._compute_quantiles()
for i, q in enumerate(self._quantile_keys):
stats[q] = quantile_results[i]
return stats
def _adjust_histograms(self):
"""Adjust histograms when min or max changes."""
for i in range(len(self._histograms)):
old_edges = self._bin_edges[i]
old_hist = self._histograms[i]
# Create new edges with small padding to ensure range coverage
padding = (self._max[i] - self._min[i]) * 1e-10
new_edges = np.linspace(
self._min[i] - padding, self._max[i] + padding, self._num_quantile_bins + 1
)
# Redistribute existing histogram counts to new bins
# We need to map each old bin center to the new bins
old_centers = (old_edges[:-1] + old_edges[1:]) / 2
new_hist = np.zeros(self._num_quantile_bins)
for old_center, count in zip(old_centers, old_hist, strict=False):
if count > 0:
# Find which new bin this old center belongs to
bin_idx = np.searchsorted(new_edges, old_center) - 1
bin_idx = max(0, min(bin_idx, self._num_quantile_bins - 1))
new_hist[bin_idx] += count
self._histograms[i] = new_hist
self._bin_edges[i] = new_edges
def _update_histograms(self, batch: np.ndarray) -> None:
"""Update histograms with new vectors."""
for i in range(batch.shape[1]):
hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i])
self._histograms[i] += hist
def _compute_quantiles(self) -> list[np.ndarray]:
"""Compute quantiles based on histograms."""
results = []
for q in self._quantile_list:
target_count = q * self._count
q_values = []
for hist, edges in zip(self._histograms, self._bin_edges, strict=True):
q_value = self._compute_single_quantile(hist, edges, target_count)
q_values.append(q_value)
results.append(np.array(q_values))
return results
def _compute_single_quantile(self, hist: np.ndarray, edges: np.ndarray, target_count: float) -> float:
"""Compute a single quantile value from histogram and bin edges."""
cumsum = np.cumsum(hist)
idx = np.searchsorted(cumsum, target_count)
if idx == 0:
return edges[0]
if idx >= len(cumsum):
return edges[-1]
# If not edge case, interpolate within the bin
count_before = cumsum[idx - 1]
count_in_bin = cumsum[idx] - count_before
# If no samples in this bin, use the bin edge
if count_in_bin == 0:
return edges[idx]
# Linear interpolation within the bin
fraction = (target_count - count_before) / count_in_bin
return edges[idx] + fraction * (edges[idx + 1] - edges[idx])
def estimate_num_samples(
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
@@ -72,33 +245,282 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
return images
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
return {
"min": np.min(array, axis=axis, keepdims=keepdims),
"max": np.max(array, axis=axis, keepdims=keepdims),
"mean": np.mean(array, axis=axis, keepdims=keepdims),
"std": np.std(array, axis=axis, keepdims=keepdims),
"count": np.array([len(array)]),
def _reshape_stats_by_axis(
stats: dict[str, np.ndarray],
axis: int | tuple[int, ...] | None,
keepdims: bool,
original_shape: tuple[int, ...],
) -> dict[str, np.ndarray]:
"""Reshape all statistics to match NumPy's output conventions.
Applies consistent reshaping to all statistics (except 'count') based on the
axis and keepdims parameters. This ensures statistics have the correct shape
for broadcasting with the original data.
Args:
stats: Dictionary of computed statistics
axis: Axis or axes along which statistics were computed
keepdims: Whether to keep reduced dimensions as size-1 dimensions
original_shape: Shape of the original array
Returns:
Dictionary with reshaped statistics
Note:
The 'count' statistic is never reshaped as it represents metadata
rather than per-feature statistics.
"""
if axis == (1,) and not keepdims:
return stats
result = {}
for key, value in stats.items():
if key == "count":
result[key] = value
else:
result[key] = _reshape_single_stat(value, axis, keepdims, original_shape)
return result
def _reshape_for_image_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
"""Reshape statistics for image data (axis=(0,2,3))."""
if keepdims and value.ndim == 1:
return value.reshape(1, -1, 1, 1)
return value
def _reshape_for_vector_stats(
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
) -> np.ndarray:
"""Reshape statistics for vector data (axis=0 or axis=(0,))."""
if not keepdims:
return value
if len(original_shape) == 1 and value.ndim > 0:
return value.reshape(1)
elif len(original_shape) >= 2 and value.ndim == 1:
return value.reshape(1, -1)
return value
def _reshape_for_feature_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
"""Reshape statistics for feature-wise computation (axis=(1,))."""
if not keepdims:
return value
if value.ndim == 0:
return value.reshape(1, 1)
elif value.ndim == 1:
return value.reshape(-1, 1)
return value
def _reshape_for_global_stats(
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
) -> np.ndarray | float:
"""Reshape statistics for global reduction (axis=None)."""
if keepdims:
target_shape = tuple(1 for _ in original_shape)
return value.reshape(target_shape)
# Keep at least 1-D arrays to satisfy validator
return np.atleast_1d(value)
def _reshape_single_stat(
value: np.ndarray, axis: int | tuple[int, ...] | None, keepdims: bool, original_shape: tuple[int, ...]
) -> np.ndarray | float:
"""Apply appropriate reshaping to a single statistic array.
This function transforms statistic arrays to match expected output shapes
based on the axis configuration and keepdims parameter.
Args:
value: The statistic array to reshape
axis: Axis or axes that were reduced during computation
keepdims: Whether to maintain reduced dimensions as size-1 dimensions
original_shape: Shape of the original data before reduction
Returns:
Reshaped array following NumPy broadcasting conventions
"""
if axis == (0, 2, 3):
return _reshape_for_image_stats(value, keepdims)
if axis in [0, (0,)]:
return _reshape_for_vector_stats(value, keepdims, original_shape)
if axis == (1,):
return _reshape_for_feature_stats(value, keepdims)
if axis is None:
return _reshape_for_global_stats(value, keepdims, original_shape)
return value
def _prepare_array_for_stats(array: np.ndarray, axis: int | tuple[int, ...] | None) -> tuple[np.ndarray, int]:
"""Prepare array for statistics computation by reshaping according to axis.
Args:
array: Input data array
axis: Axis or axes along which to compute statistics
Returns:
Tuple of (reshaped_array, sample_count)
"""
if axis == (0, 2, 3): # Image data
batch_size, channels, height, width = array.shape
reshaped = array.transpose(0, 2, 3, 1).reshape(-1, channels)
return reshaped, batch_size
if axis == 0 or axis == (0,): # Vector data
reshaped = array
if array.ndim == 1:
reshaped = array.reshape(-1, 1)
return reshaped, array.shape[0]
if axis == (1,): # Feature-wise statistics
return array.T, array.shape[1]
if axis is None: # Global statistics
reshaped = array.reshape(-1, 1)
# For backward compatibility, count represents the first dimension size
return reshaped, array.shape[0] if array.ndim > 0 else 1
raise ValueError(f"Unsupported axis configuration: {axis}")
def _compute_basic_stats(
array: np.ndarray, sample_count: int, quantile_list: list[float] | None = None
) -> dict[str, np.ndarray]:
"""Compute basic statistics for arrays with insufficient samples for quantiles.
Args:
array: Reshaped array ready for statistics computation
sample_count: Number of samples represented in the data
Returns:
Dictionary with basic statistics and quantiles set to mean values
"""
if quantile_list is None:
quantile_list = DEFAULT_QUANTILES
quantile_list_keys = [f"q{int(q * 100):02d}" for q in quantile_list]
stats = {
"min": np.min(array, axis=0),
"max": np.max(array, axis=0),
"mean": np.mean(array, axis=0),
"std": np.std(array, axis=0),
"count": np.array([sample_count]),
}
for q in quantile_list_keys:
stats[q] = stats["mean"].copy()
return stats
def get_feature_stats(
array: np.ndarray,
axis: int | tuple[int, ...] | None,
keepdims: bool,
quantile_list: list[float] | None = None,
) -> dict[str, np.ndarray]:
"""Compute comprehensive statistics for array features along specified axes.
This function calculates min, max, mean, std, and quantiles (1%, 10%, 50%, 90%, 99%)
for the input array along the specified axes. It handles different data layouts:
- Image data: axis=(0,2,3) computes per-channel statistics
- Vector data: axis=0 computes per-feature statistics
- Feature-wise: axis=1 computes statistics across features
- Global: axis=None computes statistics over entire array
Args:
array: Input data array with shape appropriate for the specified axis
axis: Axis or axes along which to compute statistics
- (0, 2, 3): For image data (batch, channels, height, width)
- 0 or (0,): For vector/tabular data (samples, features)
- (1,): For computing across features
- None: For global statistics over entire array
keepdims: If True, reduced axes are kept as dimensions with size 1
Returns:
Dictionary containing:
- 'min': Minimum values
- 'max': Maximum values
- 'mean': Mean values
- 'std': Standard deviation
- 'count': Number of samples (always shape (1,))
- 'q01', 'q10', 'q50', 'q90', 'q99': Quantile values
"""
if quantile_list is None:
quantile_list = DEFAULT_QUANTILES
original_shape = array.shape
reshaped, sample_count = _prepare_array_for_stats(array, axis)
if reshaped.shape[0] < 2:
stats = _compute_basic_stats(reshaped, sample_count, quantile_list)
else:
running_stats = RunningQuantileStats()
running_stats.update(reshaped)
stats = running_stats.get_statistics()
stats["count"] = np.array([sample_count])
stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape)
return stats
def compute_episode_stats(
episode_data: dict[str, list[str] | np.ndarray],
features: dict,
quantile_list: list[float] | None = None,
) -> dict:
"""Compute comprehensive statistics for all features in an episode.
Processes different data types appropriately:
- Images/videos: Samples from paths, computes per-channel stats, normalizes to [0,1]
- Numerical arrays: Computes per-feature statistics
- Strings: Skipped (no statistics computed)
Args:
episode_data: Dictionary mapping feature names to data
- For images/videos: list of file paths
- For numerical data: numpy arrays
features: Dictionary describing each feature's dtype and shape
Returns:
Dictionary mapping feature names to their statistics dictionaries.
Each statistics dictionary contains min, max, mean, std, count, and quantiles.
Note:
Image statistics are normalized to [0,1] range and have shape (3,1,1) for
per-channel values when dtype is 'image' or 'video'.
"""
if quantile_list is None:
quantile_list = DEFAULT_QUANTILES
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
continue # HACK: we should receive np.arrays of strings
elif features[key]["dtype"] in ["image", "video"]:
ep_ft_array = sample_images(data) # data is a list of image paths
axes_to_reduce = (0, 2, 3) # keep channel dim
continue
if features[key]["dtype"] in ["image", "video"]:
ep_ft_array = sample_images(data)
axes_to_reduce = (0, 2, 3)
keepdims = True
else:
ep_ft_array = data # data is already a np.ndarray
axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array
ep_ft_array = data
axes_to_reduce = 0
keepdims = data.ndim == 1
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
ep_stats[key] = get_feature_stats(
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list
)
# finally, we normalize and remove batch dim for images
if features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
@@ -107,20 +529,37 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
return ep_stats
def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
"""Validate a single statistic value."""
if not isinstance(value, np.ndarray):
raise ValueError(
f"Stats must be composed of numpy array, but key '{key}' of feature '{feature_key}' "
f"is of type '{type(value)}' instead."
)
if value.ndim == 0:
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
if key == "count" and value.shape != (1,):
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
for i in range(len(stats_list)):
for fkey in stats_list[i]:
for k, v in stats_list[i][fkey].items():
if not isinstance(v, np.ndarray):
raise ValueError(
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
)
if v.ndim == 0:
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
if k == "count" and v.shape != (1,):
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
"""Validate that all statistics have correct types and shapes.
Args:
stats_list: List of statistics dictionaries to validate
Raises:
ValueError: If any statistic has incorrect type or shape
"""
for stats in stats_list:
for feature_key, feature_stats in stats.items():
for stat_key, stat_value in feature_stats.items():
_validate_stat_value(stat_value, stat_key, feature_key)
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
@@ -143,7 +582,7 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
weighted_variances = (variances + delta_means**2) * counts
total_variance = weighted_variances.sum(axis=0) / total_count
return {
aggregated = {
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
"mean": total_mean,
@@ -151,6 +590,17 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
"count": total_count,
}
if stats_ft_list:
quantile_keys = [k for k in stats_ft_list[0] if k.startswith("q") and k[1:].isdigit()]
for q_key in quantile_keys:
if all(q_key in s for s in stats_ft_list):
quantile_values = np.stack([s[q_key] for s in stats_ft_list])
weighted_quantiles = quantile_values * counts
aggregated[q_key] = weighted_quantiles.sum(axis=0) / total_count
return aggregated
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
+4 -3
View File
@@ -27,6 +27,7 @@ from lerobot.datasets.lerobot_dataset import (
)
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
IMAGENET_STATS = {
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
@@ -54,11 +55,11 @@ def resolve_delta_timestamps(
"""
delta_timestamps = {}
for key in ds_meta.features:
if key == "next.reward" and cfg.reward_delta_indices is not None:
if key == REWARD and cfg.reward_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
if key == "action" and cfg.action_delta_indices is not None:
if key == ACTION and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
if len(delta_timestamps) == 0:
+5 -15
View File
@@ -31,7 +31,6 @@ import torch.utils
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
from lerobot.constants import HF_LEROBOT_HOME
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.datasets.utils import (
@@ -79,6 +78,7 @@ from lerobot.datasets.video_utils import (
get_video_duration_in_s,
get_video_info,
)
from lerobot.utils.constants import HF_LEROBOT_HOME
CODEBASE_VERSION = "v3.0"
@@ -848,11 +848,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return item
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
for key, val in padding.items():
item[key] = torch.BoolTensor(val)
return item
def __len__(self):
return self.num_frames
@@ -1032,7 +1027,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Reset episode buffer and clean up temporary images (if not already deleted during video encoding)
self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0)
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None):
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
"""
Batch save videos for multiple episodes.
@@ -1158,7 +1153,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
}
return metadata
def _save_episode_video(self, video_key: str, episode_index: int):
def _save_episode_video(self, video_key: str, episode_index: int) -> dict:
# Encode episode frames into a temporary video
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
ep_size_in_mb = get_video_size_in_mb(ep_path)
@@ -1263,7 +1258,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.image_writer is not None:
self.image_writer.wait_until_done()
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> dict:
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
"""
Use ffmpeg to convert frames stored as png into mp4 videos.
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
@@ -1396,11 +1391,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
"""
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
@property
def repo_index_to_id(self):
"""Return the inverse mapping if repo_id_to_index."""
return {v: k for k, v in self.repo_id_to_index}
@property
def fps(self) -> int:
"""Frames per second used during data collection.
@@ -1431,7 +1421,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.features.items():
if isinstance(feats, (datasets.Image, VideoFrame)):
if isinstance(feats, (datasets.Image | VideoFrame)):
keys.append(key)
return keys
+9 -11
View File
@@ -17,9 +17,9 @@ from collections.abc import Sequence
from typing import Any
from lerobot.configs.types import PipelineFeatureType
from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.processor import DataProcessorPipeline
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR
def create_initial_features(
@@ -92,8 +92,8 @@ def aggregate_pipeline_dataset_features(
# Intermediate storage for categorized and filtered features.
processed_features: dict[str, dict[str, Any]] = {
"action": {},
"observation": {},
ACTION: {},
OBS_STR: {},
}
images_token = OBS_IMAGES.split(".")[-1]
@@ -125,17 +125,15 @@ def aggregate_pipeline_dataset_features(
# 3. Add the feature to the appropriate group with a clean name.
name = strip_prefix(key, PREFIXES_TO_STRIP)
if is_action:
processed_features["action"][name] = value
processed_features[ACTION][name] = value
else:
processed_features["observation"][name] = value
processed_features[OBS_STR][name] = value
# Convert the processed features into the final dataset format.
dataset_features = {}
if processed_features["action"]:
dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos))
if processed_features["observation"]:
dataset_features.update(
hw_to_dataset_features(processed_features["observation"], "observation", use_videos)
)
if processed_features[ACTION]:
dataset_features.update(hw_to_dataset_features(processed_features[ACTION], ACTION, use_videos))
if processed_features[OBS_STR]:
dataset_features.update(hw_to_dataset_features(processed_features[OBS_STR], OBS_STR, use_videos))
return dataset_features
@@ -13,67 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import datasets
import numpy
import PIL
import torch
from lerobot.datasets.video_utils import encode_video_frames
def concatenate_episodes(ep_dicts):
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
def save_image(img_array, i, out_dir):
img = PIL.Image.fromarray(img_array)
img.save(str(out_dir / f"frame_{i:06d}.png"), quality=100)
num_images = len(imgs_array)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
def get_default_encoding() -> dict:
"""Returns the default ffmpeg encoding parameters used by `encode_video_frames`."""
signature = inspect.signature(encode_video_frames)
return {
k: v.default
for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
}
def check_repo_id(repo_id: str) -> None:
if len(repo_id.split("/")) != 2:
raise ValueError(
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
)
# TODO(aliberts): remove
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
+2 -4
View File
@@ -21,7 +21,6 @@ import numpy as np
import torch
from datasets import load_dataset
from lerobot.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
Backtrackable,
@@ -38,6 +37,7 @@ from lerobot.datasets.video_utils import (
VideoDecoderCache,
decode_video_frames_torchcodec,
)
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
@@ -298,9 +298,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
return padding_mask
def make_frame(
self, dataset_iterator: Backtrackable, previous_dataset_iterator: Backtrackable | None = None
) -> Generator:
def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
"""Makes a frame starting from a dataset iterator"""
item = next(dataset_iterator)
item = item_to_torch(item)
+1 -1
View File
@@ -120,7 +120,7 @@ class SharpnessJitter(Transform):
self.sharpness = self._check_input(sharpness)
def _check_input(self, sharpness):
if isinstance(sharpness, (int, float)):
if isinstance(sharpness, (int | float)):
if sharpness < 0:
raise ValueError("If sharpness is a single number, it must be non negative.")
sharpness = [1.0 - sharpness, 1.0 + sharpness]
+13 -61
View File
@@ -21,7 +21,7 @@ from collections import deque
from collections.abc import Iterable, Iterator
from pathlib import Path
from pprint import pformat
from typing import Any, Deque, Generic, TypeVar
from typing import Any, Generic, TypeVar
import datasets
import numpy as np
@@ -43,6 +43,7 @@ from lerobot.datasets.backward_compatibility import (
BackwardCompatibilityError,
ForwardCompatibilityError,
)
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
from lerobot.utils.utils import is_valid_numpy_dtype_string
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
@@ -66,18 +67,6 @@ DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{fram
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
LEGACY_DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
LEGACY_DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
DATASET_CARD_TEMPLATE = """
---
# Metadata will go there
---
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
## {}
"""
DEFAULT_FEATURES = {
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
@@ -218,13 +207,13 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
"""
serialized_dict = {}
for key, value in flatten_dict(stats).items():
if isinstance(value, (torch.Tensor, np.ndarray)):
if isinstance(value, (torch.Tensor | np.ndarray)):
serialized_dict[key] = value.tolist()
elif isinstance(value, list) and isinstance(value[0], (int, float, list)):
elif isinstance(value, list) and isinstance(value[0], (int | float | list)):
serialized_dict[key] = value
elif isinstance(value, np.generic):
serialized_dict[key] = value.item()
elif isinstance(value, (int, float)):
elif isinstance(value, (int | float)):
serialized_dict[key] = value
else:
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
@@ -382,12 +371,6 @@ def load_episodes(local_dir: Path) -> datasets.Dataset:
return episodes
def backward_compatible_episodes_stats(
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
) -> dict[int, dict[str, dict[str, np.ndarray]]]:
return dict.fromkeys(episodes, stats)
def load_image_as_numpy(
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
) -> np.ndarray:
@@ -645,14 +628,14 @@ def hw_to_dataset_features(
}
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
if joint_fts and prefix == "action":
if joint_fts and prefix == ACTION:
features[prefix] = {
"dtype": "float32",
"shape": (len(joint_fts),),
"names": list(joint_fts),
}
if joint_fts and prefix == "observation":
if joint_fts and prefix == OBS_STR:
features[f"{prefix}.state"] = {
"dtype": "float32",
"shape": (len(joint_fts),),
@@ -728,11 +711,11 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif key == "observation.environment_state":
elif key == OBS_ENV_STATE:
type = FeatureType.ENV
elif key.startswith("observation"):
elif key.startswith(OBS_STR):
type = FeatureType.STATE
elif key.startswith("action"):
elif key.startswith(ACTION):
type = FeatureType.ACTION
else:
continue
@@ -1196,7 +1179,7 @@ def item_to_torch(item: dict) -> dict:
dict: Dictionary with all tensor-like items converted to torch.Tensor.
"""
for key, val in item.items():
if isinstance(val, (np.ndarray, list)) and key not in ["task"]:
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
# Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val)
return item
@@ -1270,8 +1253,8 @@ class Backtrackable(Generic[T]):
raise ValueError("lookahead must be > 0")
self._source: Iterator[T] = iter(iterable)
self._back_buf: Deque[T] = deque(maxlen=history)
self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
self._back_buf: deque[T] = deque(maxlen=history)
self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
self._cursor: int = 0
self._history = history
self._lookahead = lookahead
@@ -1345,12 +1328,6 @@ class Backtrackable(Generic[T]):
# When cursor<0, slice so the order remains chronological
return list(self._back_buf)[: self._cursor or None]
def lookahead_buffer(self) -> list[T]:
"""
Return a copy of the current lookahead buffer.
"""
return list(self._ahead_buf)
def can_peek_back(self, steps: int = 1) -> bool:
"""
Check if we can go back `steps` items without raising an IndexError.
@@ -1376,31 +1353,6 @@ class Backtrackable(Generic[T]):
except StopIteration:
return False
def reset_cursor(self) -> None:
"""
Reset cursor to the most recent position (equivalent to calling next()
until you're back to the latest item).
"""
self._cursor = 0
def clear_ahead_buffer(self) -> None:
"""
Clear the ahead buffer, discarding any pre-fetched items.
"""
self._ahead_buf.clear()
def switch_source_iterable(self, new_source: Iterable[T]) -> None:
"""
Switch the source of the backtrackable to a new iterable, keeping the history.
This is useful when iterating over a sequence of datasets. The history from the
previous source is kept, but the lookahead buffer is cleared. The cursor is reset
to the present.
"""
self._source = iter(new_source)
self.clear_ahead_buffer()
self.reset_cursor()
def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) -> datasets.Dataset:
"""
@@ -0,0 +1,260 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script augments existing LeRobot datasets with quantile statistics.
Most datasets created before the quantile feature was added do not contain
quantile statistics (q01, q10, q50, q90, q99) in their metadata. This script:
1. Loads an existing LeRobot dataset in v3.0 format
2. Checks if it already contains quantile statistics
3. If missing, computes quantile statistics for all features
4. Updates the dataset metadata with the new quantile statistics
Usage:
```bash
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
--repo-id=lerobot/pusht \
```
"""
import argparse
import concurrent.futures
import logging
from pathlib import Path
import numpy as np
import torch
from huggingface_hub import HfApi
from requests import HTTPError
from tqdm import tqdm
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.datasets.utils import write_stats
from lerobot.utils.utils import init_logging
def has_quantile_stats(stats: dict[str, dict] | None, quantile_list_keys: list[str] | None = None) -> bool:
"""Check if dataset statistics already contain quantile information.
Args:
stats: Dataset statistics dictionary
Returns:
True if quantile statistics are present, False otherwise
"""
if quantile_list_keys is None:
quantile_list_keys = [f"q{int(q * 100):02d}" for q in DEFAULT_QUANTILES]
if stats is None:
return False
for feature_stats in stats.values():
if any(q_key in feature_stats for q_key in quantile_list_keys):
return True
return False
def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict:
"""Process a single episode and return its statistics.
Args:
dataset: The LeRobot dataset
episode_idx: Index of the episode to process
Returns:
Dictionary containing episode statistics
"""
logging.info(f"Computing stats for episode {episode_idx}")
start_idx = dataset.meta.episodes[episode_idx]["dataset_from_index"]
end_idx = dataset.meta.episodes[episode_idx]["dataset_to_index"]
collected_data: dict[str, list] = {}
for idx in range(start_idx, end_idx):
item = dataset[idx]
for key, value in item.items():
if key not in dataset.features:
continue
if key not in collected_data:
collected_data[key] = []
collected_data[key].append(value)
ep_stats = {}
for key, data_list in collected_data.items():
if dataset.features[key]["dtype"] == "string":
continue
data = torch.stack(data_list).cpu().numpy()
if dataset.features[key]["dtype"] in ["image", "video"]:
if data.dtype == np.uint8:
data = data.astype(np.float32) / 255.0
axes_to_reduce = (0, 2, 3)
keepdims = True
else:
axes_to_reduce = 0
keepdims = data.ndim == 1
ep_stats[key] = get_feature_stats(
data, axis=axes_to_reduce, keepdims=keepdims, quantile_list=DEFAULT_QUANTILES
)
if dataset.features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
}
return ep_stats
def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dict]:
"""Compute quantile statistics for all episodes in the dataset.
Args:
dataset: The LeRobot dataset to compute statistics for
Returns:
Dictionary containing aggregated statistics with quantiles
Note:
Video decoding operations are not thread-safe, so we process episodes sequentially
when video keys are present. For datasets without videos, we use parallel processing
with ThreadPoolExecutor for better performance.
"""
logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes")
episode_stats_list = []
has_videos = len(dataset.meta.video_keys) > 0
if has_videos:
logging.info("Dataset contains video keys - using sequential processing for thread safety")
for episode_idx in tqdm(range(dataset.num_episodes), desc="Processing episodes"):
ep_stats = process_single_episode(dataset, episode_idx)
episode_stats_list.append(ep_stats)
else:
logging.info("Dataset has no video keys - using parallel processing for better performance")
max_workers = min(dataset.num_episodes, 16)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_episode = {
executor.submit(process_single_episode, dataset, episode_idx): episode_idx
for episode_idx in range(dataset.num_episodes)
}
episode_results = {}
with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar:
for future in concurrent.futures.as_completed(future_to_episode):
episode_idx = future_to_episode[future]
ep_stats = future.result()
episode_results[episode_idx] = ep_stats
pbar.update(1)
for episode_idx in range(dataset.num_episodes):
if episode_idx in episode_results:
episode_stats_list.append(episode_results[episode_idx])
if not episode_stats_list:
raise ValueError("No episode data found for computing statistics")
logging.info(f"Aggregating statistics from {len(episode_stats_list)} episodes")
return aggregate_stats(episode_stats_list)
def augment_dataset_with_quantile_stats(
repo_id: str,
root: str | Path | None = None,
overwrite: bool = False,
) -> None:
"""Augment a dataset with quantile statistics if they are missing.
Args:
repo_id: Repository ID of the dataset
root: Local root directory for the dataset
overwrite: Overwrite existing quantile statistics if they already exist
"""
logging.info(f"Loading dataset: {repo_id}")
dataset = LeRobotDataset(
repo_id=repo_id,
root=root,
)
if not overwrite and has_quantile_stats(dataset.meta.stats):
logging.info("Dataset already contains quantile statistics. No action needed.")
return
logging.info("Dataset does not contain quantile statistics. Computing them now...")
new_stats = compute_quantile_stats_for_dataset(dataset)
logging.info("Updating dataset metadata with new quantile statistics")
dataset.meta.stats = new_stats
write_stats(new_stats, dataset.meta.root)
logging.info("Successfully updated dataset with quantile statistics")
dataset.push_to_hub()
hub_api = HfApi()
try:
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
except HTTPError as e:
logging.info(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
pass
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=None, repo_type="dataset")
def main():
"""Main function to run the augmentation script."""
parser = argparse.ArgumentParser(description="Augment LeRobot dataset with quantile statistics")
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repository ID of the dataset (e.g., 'lerobot/pusht')",
)
parser.add_argument(
"--root",
type=str,
help="Local root directory for the dataset",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite existing quantile statistics if they already exist",
)
args = parser.parse_args()
root = Path(args.root) if args.root else None
init_logging()
augment_dataset_with_quantile_stats(
repo_id=args.repo_id,
root=root,
overwrite=args.overwrite,
)
if __name__ == "__main__":
main()
@@ -26,14 +26,24 @@ This script will help you convert any LeRobot dataset already pushed to the hub
Usage:
Convert a dataset from the hub:
```bash
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
--repo-id=lerobot/pusht
```
Convert a local dataset (works in place):
```bash
python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \
--repo-id=lerobot/pusht \
--root=/path/to/local/dataset/directory
--push-to-hub=false
```
"""
import argparse
import logging
import shutil
from pathlib import Path
from typing import Any
@@ -46,7 +56,6 @@ from datasets import Dataset, Features, Image
from huggingface_hub import HfApi, snapshot_download
from requests import HTTPError
from lerobot.constants import HF_LEROBOT_HOME
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.datasets.utils import (
@@ -71,9 +80,11 @@ from lerobot.datasets.utils import (
write_tasks,
)
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.utils import init_logging
V21 = "v2.1"
V30 = "v3.0"
"""
-------------------------
@@ -143,7 +154,19 @@ def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]:
return tasks, task_to_task_index
def validate_local_dataset_version(local_path: Path) -> None:
"""Validate that the local dataset has the expected v2.1 version."""
info = load_info(local_path)
dataset_version = info.get("codebase_version", "unknown")
if dataset_version != V21:
raise ValueError(
f"Local dataset has codebase version '{dataset_version}', expected '{V21}'. "
f"This script is specifically for converting v2.1 datasets to v3.0."
)
def convert_tasks(root, new_root):
logging.info(f"Converting tasks from {root} to {new_root}")
tasks, _ = legacy_load_tasks(root)
task_indices = tasks.keys()
task_strings = tasks.values()
@@ -185,7 +208,10 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
num_frames = 0
paths_to_cat = []
episodes_metadata = []
for ep_path in ep_paths:
logging.info(f"Converting data files from {len(ep_paths)} episodes")
for ep_path in tqdm.tqdm(ep_paths, desc="convert data files"):
ep_size_in_mb = get_parquet_file_size_in_mb(ep_path)
ep_num_frames = get_parquet_num_frames(ep_path)
ep_metadata = {
@@ -209,7 +235,6 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int):
# Reset for the next file
size_in_mb = ep_size_in_mb
num_frames = ep_num_frames
paths_to_cat = [ep_path]
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
@@ -236,6 +261,8 @@ def get_image_keys(root):
def convert_videos(root: Path, new_root: Path, video_file_size_in_mb: int):
logging.info(f"Converting videos from {root} to {new_root}")
video_keys = get_video_keys(root)
if len(video_keys) == 0:
return None
@@ -254,7 +281,7 @@ def convert_videos(root: Path, new_root: Path, video_file_size_in_mb: int):
episods_metadata = []
num_cameras = len(video_keys)
num_episodes = num_eps_per_cam[0]
for ep_idx in range(num_episodes):
for ep_idx in tqdm.tqdm(range(num_episodes), desc="convert videos"):
# Sanity check
ep_ids = [eps_metadata_per_cam[cam_idx][ep_idx]["episode_index"] for cam_idx in range(num_cameras)]
ep_ids += [ep_idx]
@@ -281,6 +308,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
duration_in_s = 0.0
paths_to_cat = []
episodes_metadata = []
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
ep_size_in_mb = get_video_size_in_mb(ep_path)
ep_duration_in_s = get_video_duration_in_s(ep_path)
@@ -374,6 +402,8 @@ def generate_episode_metadata_dict(
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
logging.info(f"Converting episodes metadata from {root} to {new_root}")
episodes_legacy_metadata = legacy_load_episodes(root)
episodes_stats = legacy_load_episodes_stats(root)
@@ -397,14 +427,15 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb):
info = load_info(root)
info["codebase_version"] = "v3.0"
info["codebase_version"] = V30
del info["total_chunks"]
del info["total_videos"]
info["data_files_size_in_mb"] = data_file_size_in_mb
info["video_files_size_in_mb"] = video_file_size_in_mb
info["data_path"] = DEFAULT_DATA_PATH
info["video_path"] = DEFAULT_VIDEO_PATH
info["video_path"] = DEFAULT_VIDEO_PATH if info["video_path"] is not None else None
info["fps"] = int(info["fps"])
logging.info(f"Converting info from {root} to {new_root}")
for key in info["features"]:
if info["features"][key]["dtype"] == "video":
# already has fps in video_info
@@ -418,16 +449,36 @@ def convert_dataset(
branch: str | None = None,
data_file_size_in_mb: int | None = None,
video_file_size_in_mb: int | None = None,
root: str | Path | None = None,
push_to_hub: bool = True,
force_conversion: bool = False,
):
root = HF_LEROBOT_HOME / repo_id
old_root = HF_LEROBOT_HOME / f"{repo_id}_old"
new_root = HF_LEROBOT_HOME / f"{repo_id}_v30"
if data_file_size_in_mb is None:
data_file_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
if video_file_size_in_mb is None:
video_file_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
# First check if the dataset already has a v3.0 version
if root is None and not force_conversion:
try:
print("Trying to download v3.0 version of the dataset from the hub...")
snapshot_download(repo_id, repo_type="dataset", revision=V30, local_dir=HF_LEROBOT_HOME / repo_id)
return
except Exception:
print("Dataset does not have an uploaded v3.0 version. Continuing with conversion.")
# Set root based on whether local dataset path is provided
use_local_dataset = False
root = HF_LEROBOT_HOME / repo_id if root is None else Path(root) / repo_id
if root.exists():
validate_local_dataset_version(root)
use_local_dataset = True
print(f"Using local dataset at {root}")
old_root = root.parent / f"{root.name}_old"
new_root = root.parent / f"{root.name}_v30"
# Handle old_root cleanup if both old_root and root exist
if old_root.is_dir() and root.is_dir():
shutil.rmtree(str(root))
shutil.move(str(old_root), str(root))
@@ -435,12 +486,13 @@ def convert_dataset(
if new_root.is_dir():
shutil.rmtree(new_root)
snapshot_download(
repo_id,
repo_type="dataset",
revision=V21,
local_dir=root,
)
if not use_local_dataset:
snapshot_download(
repo_id,
repo_type="dataset",
revision=V21,
local_dir=root,
)
convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb)
convert_tasks(root, new_root)
@@ -451,24 +503,26 @@ def convert_dataset(
shutil.move(str(root), str(old_root))
shutil.move(str(new_root), str(root))
hub_api = HfApi()
try:
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
except HTTPError as e:
print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
pass
hub_api.delete_files(
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"],
repo_id=repo_id,
revision=branch,
repo_type="dataset",
)
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
if push_to_hub:
hub_api = HfApi()
try:
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
except HTTPError as e:
print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
pass
hub_api.delete_files(
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"],
repo_id=repo_id,
revision=branch,
repo_type="dataset",
)
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
LeRobotDataset(repo_id).push_to_hub()
LeRobotDataset(repo_id).push_to_hub()
if __name__ == "__main__":
init_logging()
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
@@ -495,6 +549,23 @@ if __name__ == "__main__":
default=None,
help="File size in MB. Defaults to 100 for data and 500 for videos.",
)
parser.add_argument(
"--root",
type=str,
default=None,
help="Local directory to use for downloading/writing the dataset.",
)
parser.add_argument(
"--push-to-hub",
type=lambda input: input.lower() == "true",
default=True,
help="Push the converted dataset to the hub.",
)
parser.add_argument(
"--force-conversion",
action="store_true",
help="Force conversion even if the dataset already has a v3.0 version.",
)
args = parser.parse_args()
convert_dataset(**vars(args))
+4 -15
View File
@@ -428,7 +428,7 @@ def concatenate_video_files(
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
tmp_concatenate_file.write("ffconcat version 1.0\n")
for input_path in input_video_paths:
tmp_concatenate_file.write(f"file '{str(input_path)}'\n")
tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
tmp_concatenate_file.flush()
tmp_concatenate_path = tmp_concatenate_file.name
@@ -437,7 +437,9 @@ def concatenate_video_files(
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
) # safe = 0 allows absolute paths as well as relative paths
tmp_output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
tmp_output_video_path = tmp_named_file.name
output_container = av.open(
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
) # faststart is to move the metadata to the beginning of the file to speed up loading
@@ -585,19 +587,6 @@ def get_video_pixel_channels(pix_fmt: str) -> int:
raise ValueError("Unknown format")
def get_image_pixel_channels(image: Image):
if image.mode == "L":
return 1 # Grayscale
elif image.mode == "LA":
return 2 # Grayscale + Alpha
elif image.mode == "RGB":
return 3 # RGB
elif image.mode == "RGBA":
return 4 # RGBA
else:
raise ValueError("Unknown format")
def get_video_duration_in_s(video_path: Path | str) -> float:
"""
Get the duration of a video file in seconds using PyAV.
+10 -12
View File
@@ -19,9 +19,9 @@ from typing import Any
import draccus
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.robots import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
@dataclass
@@ -53,12 +53,12 @@ class AlohaEnv(EnvConfig):
render_mode: str = "rgb_array"
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
ACTION: ACTION,
"agent_pos": OBS_STATE,
"top": f"{OBS_IMAGE}.top",
"pixels/top": f"{OBS_IMAGES}.top",
@@ -93,13 +93,13 @@ class PushtEnv(EnvConfig):
visualization_height: int = 384
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
ACTION: ACTION,
"agent_pos": OBS_STATE,
"environment_state": OBS_ENV_STATE,
"pixels": OBS_IMAGE,
@@ -135,13 +135,13 @@ class XarmEnv(EnvConfig):
visualization_height: int = 384
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
ACTION: ACTION,
"agent_pos": OBS_STATE,
"pixels": OBS_IMAGE,
}
@@ -193,7 +193,6 @@ class ObservationConfig:
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
display_cameras: bool = False
@@ -203,7 +202,6 @@ class GripperConfig:
use_gripper: bool = True
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
@dataclass
@@ -256,15 +254,15 @@ class LiberoEnv(EnvConfig):
render_mode: str = "rgb_array"
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
init_states: bool = True
camera_name_mapping: dict[str, str] | None = (None,)
camera_name_mapping: dict[str, str] | None = None
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
ACTION: ACTION,
"agent_pos": OBS_STATE,
"pixels/agentview_image": f"{OBS_IMAGES}.image",
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
+3
View File
@@ -63,6 +63,9 @@ def make_env(
if "libero" in cfg.type:
from lerobot.envs.libero import create_libero_envs
if cfg.task is None:
raise ValueError("LiberoEnv requires a task to be specified")
return create_libero_envs(
task=cfg.task,
n_envs=n_envs,
+1 -1
View File
@@ -35,7 +35,7 @@ def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
"""Normalize camera_name into a non-empty list of strings."""
if isinstance(camera_name, str):
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
elif isinstance(camera_name, (list, tuple)):
elif isinstance(camera_name, (list | tuple)):
cams = [str(c).strip() for c in camera_name if str(c).strip()]
else:
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
+17 -16
View File
@@ -26,6 +26,7 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
from lerobot.utils.utils import get_channel_first_image_shape
@@ -41,44 +42,44 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return_observations = {}
if "pixels" in observations:
if isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in observations["pixels"].items()}
else:
imgs = {"observation.image": observations["pixels"]}
imgs = {OBS_IMAGE: observations["pixels"]}
for imgkey, img in imgs.items():
# TODO(aliberts, rcadene): use transforms.ToTensor()?
img = torch.from_numpy(img)
img_tensor = torch.from_numpy(img)
# When preprocessing observations in a non-vectorized environment, we need to add a batch dimension.
# This is the case for human-in-the-loop RL where there is only one environment.
if img.ndim == 3:
img = img.unsqueeze(0)
if img_tensor.ndim == 3:
img_tensor = img_tensor.unsqueeze(0)
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
_, h, w, c = img_tensor.shape
assert c < h and c < w, f"expect channel last images, but instead got {img_tensor.shape=}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
assert img_tensor.dtype == torch.uint8, f"expect torch.uint8, but instead {img_tensor.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
img /= 255
img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous()
img_tensor = img_tensor.type(torch.float32)
img_tensor /= 255
return_observations[imgkey] = img
return_observations[imgkey] = img_tensor
if "environment_state" in observations:
env_state = torch.from_numpy(observations["environment_state"]).float()
if env_state.dim() == 1:
env_state = env_state.unsqueeze(0)
return_observations["observation.environment_state"] = env_state
return_observations[OBS_ENV_STATE] = env_state
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations["observation.state"] = agent_pos
return_observations[OBS_STATE] = agent_pos
return return_observations
@@ -182,10 +183,10 @@ def _(env: Mapping) -> None:
@close_envs.register
def _(envs: Sequence) -> None:
if isinstance(envs, (str, bytes)):
if isinstance(envs, (str | bytes)):
return
for v in envs:
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str, bytes)):
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str | bytes)):
close_envs(v)
elif hasattr(v, "close"):
_close_single_env(v)
+1 -1
View File
@@ -22,7 +22,7 @@ import logging
from copy import deepcopy
from enum import Enum
from lerobot.utils.encoding_utils import decode_twos_complement, encode_twos_complement
from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from .tables import (
+1 -1
View File
@@ -17,7 +17,7 @@ from copy import deepcopy
from enum import Enum
from pprint import pformat
from lerobot.utils.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address
from .tables import (
+5 -11
View File
@@ -32,7 +32,7 @@ import serial
from deepdiff import DeepDiff
from tqdm import tqdm
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from lerobot.utils.utils import enter_pressed, move_cursor_up
NameOrID: TypeAlias = str | int
@@ -99,12 +99,6 @@ class Motor:
norm_mode: MotorNormMode
class JointOutOfRangeError(Exception):
def __init__(self, message="Joint is out of range"):
self.message = message
super().__init__(self.message)
class PortHandler(Protocol):
def __init__(self, port_name):
self.is_open: bool
@@ -348,7 +342,7 @@ class MotorsBus(abc.ABC):
raise TypeError(motors)
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
if isinstance(values, (int, float)):
if isinstance(values, (int | float)):
return dict.fromkeys(self.ids, values)
elif isinstance(values, dict):
return {self.motors[motor].id: val for motor, val in values.items()}
@@ -675,7 +669,7 @@ class MotorsBus(abc.ABC):
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str, int)):
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
@@ -703,7 +697,7 @@ class MotorsBus(abc.ABC):
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str, int)):
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
@@ -739,7 +733,7 @@ class MotorsBus(abc.ABC):
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str, int)):
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
+2 -2
View File
@@ -22,11 +22,11 @@ import draccus
import torch
from safetensors.torch import load_file, save_file
from lerobot.constants import (
from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json
from lerobot.utils.constants import (
OPTIMIZER_PARAM_GROUPS,
OPTIMIZER_STATE,
)
from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json
from lerobot.utils.io_utils import deserialize_json_into_object
+1 -1
View File
@@ -22,8 +22,8 @@ import draccus
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from lerobot.constants import SCHEDULER_STATE
from lerobot.datasets.utils import write_json
from lerobot.utils.constants import SCHEDULER_STATE
from lerobot.utils.io_utils import deserialize_json_into_object
+2 -1
View File
@@ -15,7 +15,7 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0.processor_pi0 import Pi0NewLineProcessor
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
@@ -25,6 +25,7 @@ __all__ = [
"ACTConfig",
"DiffusionConfig",
"PI0Config",
"PI05Config",
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
+11 -16
View File
@@ -33,9 +33,9 @@ from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.constants import ACTION, OBS_IMAGES
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
class ACTPolicy(PreTrainedPolicy):
@@ -394,25 +394,22 @@ class ACT(nn.Module):
latent dimension.
"""
if self.config.use_vae and self.training:
assert "action" in batch, (
assert ACTION in batch, (
"actions must be provided when using the variational objective in training mode."
)
if "observation.images" in batch:
batch_size = batch["observation.images"][0].shape[0]
else:
batch_size = batch["observation.environment_state"].shape[0]
batch_size = batch[OBS_IMAGES][0].shape[0] if OBS_IMAGES in batch else batch[OBS_ENV_STATE].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch and self.training:
if self.config.use_vae and ACTION in batch and self.training:
# Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch[OBS_STATE])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
action_embed = self.vae_encoder_action_input_proj(batch[ACTION]) # (B, S, D)
if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
@@ -430,7 +427,7 @@ class ACT(nn.Module):
cls_joint_is_pad = torch.full(
(batch_size, 2 if self.config.robot_state_feature else 1),
False,
device=batch["observation.state"].device,
device=batch[OBS_STATE].device,
)
key_padding_mask = torch.cat(
[cls_joint_is_pad, batch["action_is_pad"]], axis=1
@@ -454,7 +451,7 @@ class ACT(nn.Module):
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
batch[OBS_STATE].device
)
# Prepare transformer encoder inputs.
@@ -462,18 +459,16 @@ class ACT(nn.Module):
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
# Robot state token.
if self.config.robot_state_feature:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch[OBS_STATE]))
# Environment state token.
if self.config.env_state_feature:
encoder_in_tokens.append(
self.encoder_env_state_input_proj(batch["observation.environment_state"])
)
encoder_in_tokens.append(self.encoder_env_state_input_proj(batch[OBS_ENV_STATE]))
if self.config.image_features:
# For a list of images, the H and W may vary but H*W is constant.
# NOTE: If modifying this section, verify on MPS devices that
# gradients remain stable (no explosions or NaNs).
for img in batch["observation.images"]:
for img in batch[OBS_IMAGES]:
cam_features = self.backbone(img)["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features)
+1 -1
View File
@@ -17,7 +17,6 @@ from typing import Any
import torch
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
@@ -29,6 +28,7 @@ from lerobot.processor import (
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_act_pre_post_processors(
@@ -33,7 +33,6 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor, nn
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import (
@@ -42,6 +41,7 @@ from lerobot.policies.utils import (
get_output_shape,
populate_queues,
)
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
class DiffusionPolicy(PreTrainedPolicy):
@@ -81,25 +81,25 @@ class DiffusionPolicy(PreTrainedPolicy):
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
ACTION: deque(maxlen=self.config.n_action_steps),
}
if self.config.image_features:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Predict a chunk of actions given environment observations."""
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
actions = self.diffusion.generate_actions(batch, noise=noise)
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the
@@ -131,7 +131,7 @@ class DiffusionPolicy(PreTrainedPolicy):
self._queues = populate_queues(self._queues, batch)
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
actions = self.predict_action_chunk(batch, noise=noise)
self._queues[ACTION].extend(actions.transpose(0, 1))
action = self._queues[ACTION].popleft()
@@ -199,17 +199,25 @@ class DiffusionModel(nn.Module):
# ========= inference ============
def conditional_sample(
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
self,
batch_size: int,
global_cond: Tensor | None = None,
generator: torch.Generator | None = None,
noise: Tensor | None = None,
) -> Tensor:
device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self)
# Sample prior.
sample = torch.randn(
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
dtype=dtype,
device=device,
generator=generator,
sample = (
noise
if noise is not None
else torch.randn(
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
dtype=dtype,
device=device,
generator=generator,
)
)
self.noise_scheduler.set_timesteps(self.num_inference_steps)
@@ -234,7 +242,7 @@ class DiffusionModel(nn.Module):
if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
images_per_camera = einops.rearrange(batch[OBS_IMAGES], "b s n ... -> n (b s) ...")
img_features_list = torch.cat(
[
encoder(images)
@@ -249,7 +257,7 @@ class DiffusionModel(nn.Module):
else:
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
img_features = self.rgb_encoder(
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ...")
)
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
# feature dim (effectively concatenating the camera features).
@@ -264,7 +272,7 @@ class DiffusionModel(nn.Module):
# Concatenate features then flatten to (B, global_cond_dim).
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
def generate_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""
This function expects `batch` to have:
{
@@ -275,14 +283,14 @@ class DiffusionModel(nn.Module):
"observation.environment_state": (B, n_obs_steps, environment_dim)
}
"""
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
assert n_obs_steps == self.config.n_obs_steps
# Encode image features and concatenate them all together along with the state vector.
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
# run sampling
actions = self.conditional_sample(batch_size, global_cond=global_cond)
actions = self.conditional_sample(batch_size, global_cond=global_cond, noise=noise)
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
@@ -306,10 +314,10 @@ class DiffusionModel(nn.Module):
}
"""
# Input validation.
assert set(batch).issuperset({"observation.state", "action", "action_is_pad"})
assert "observation.images" in batch or "observation.environment_state" in batch
n_obs_steps = batch["observation.state"].shape[1]
horizon = batch["action"].shape[1]
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
n_obs_steps = batch[OBS_STATE].shape[1]
horizon = batch[ACTION].shape[1]
assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps
@@ -317,7 +325,7 @@ class DiffusionModel(nn.Module):
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
# Forward diffusion.
trajectory = batch["action"]
trajectory = batch[ACTION]
# Sample noise to add to the trajectory.
eps = torch.randn(trajectory.shape, device=trajectory.device)
# Sample a random noising timestep for each item in the batch.
@@ -338,7 +346,7 @@ class DiffusionModel(nn.Module):
if self.config.prediction_type == "epsilon":
target = eps
elif self.config.prediction_type == "sample":
target = batch["action"]
target = batch[ACTION]
else:
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
@@ -18,7 +18,6 @@ from typing import Any
import torch
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
@@ -30,6 +29,7 @@ from lerobot.processor import (
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_diffusion_pre_post_processors(
+25 -10
View File
@@ -24,7 +24,6 @@ from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.envs.configs import EnvConfig
@@ -33,6 +32,7 @@ from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
@@ -46,6 +46,7 @@ from lerobot.processor.converters import (
transition_to_batch,
transition_to_policy_action,
)
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
@@ -81,14 +82,18 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy
elif name == "pi0":
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
return PI0Policy
elif name == "pi0fast":
from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
return PI0FASTPolicy
elif name == "pi0":
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
return PI0Policy
elif name == "pi05":
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
return PI05Policy
elif name == "sac":
from lerobot.policies.sac.modeling_sac import SACPolicy
@@ -132,10 +137,12 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return ACTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
elif policy_type == "pi05":
return PI05Config(**kwargs)
elif policy_type == "sac":
return SACConfig(**kwargs)
elif policy_type == "smolvla":
@@ -253,6 +260,14 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI0FASTConfig):
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
processors = make_pi0fast_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI0Config):
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
@@ -261,10 +276,10 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI0FASTConfig):
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
elif isinstance(policy_cfg, PI05Config):
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors
processors = make_pi0fast_pre_post_processors(
processors = make_pi05_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
+49
View File
@@ -0,0 +1,49 @@
# π₀ (pi0)
This repository contains the Hugging Face port of **π₀**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
It is designed as a **Vision-Language-Action model for general robot control**.
---
## Model Overview
| Feature | π₀ | π₀.₅ |
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
| AdaRMS | Not used | Used in action expert |
| Tokenizer Length | 48 tokens | 200 tokens |
| Discrete State Input | False (Uses `state_proj` layer) | True |
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
---
## Citation
If you use this work, please cite both **OpenPI** and the π₀ paper:
```bibtex
@misc{openpi2024,
author = {Physical Intelligence Lab},
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
year = {2024},
publisher = {GitHub},
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
license = {Apache-2.0}
}
@misc{black2024pi0visionlanguageactionflowmodel,
title = {π₀: A Vision-Language-Action Flow Model for General Robot Control},
author = {Kevin Black and Noah Brown and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Lucy Xiaoyang Shi and James Tanner and Quan Vuong and Anna Walling and Haohuan Wang and Ury Zhilinsky},
year = {2024},
eprint = {2410.24164},
archivePrefix= {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2410.24164},
}
```
---
## License
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
+21
View File
@@ -0,0 +1,21 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_pi0 import PI0Config
from .modeling_pi0 import PI0Policy
from .processor_pi0 import make_pi0_pre_post_processors
__all__ = ["PI0Config", "PI0Policy", "make_pi0_pre_post_processors"]
+70 -66
View File
@@ -1,4 +1,6 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,19 +19,40 @@ from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import OBS_IMAGES
@PreTrainedConfig.register_subclass("pi0")
@dataclass
class PI0Config(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m"
dtype: str = "float32" # Options: "bfloat16", "float32"
n_obs_steps: int = 1
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
n_action_steps: int = 50 # Number of action steps to execute
# Shorter state and action vectors will be padded to these dimensions
max_state_dim: int = 32
max_action_dim: int = 32
# Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10 # Number of denoising steps during inference
time_sampling_beta_alpha: float = 1.5
time_sampling_beta_beta: float = 1.0
time_sampling_scale: float = 0.999
time_sampling_offset: float = 0.001
min_period: float = 4e-3
max_period: float = 4.0
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
empty_cameras: int = 0
# Normalization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
@@ -38,94 +61,75 @@ class PI0Config(PreTrainedConfig):
}
)
# Shorter state and action vectors will be padded
max_state_dim: int = 32
max_action_dim: int = 32
# Training settings
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
compile_model: bool = False # Whether to use torch.compile for model optimization
compile_mode: str = "max-autotune" # Torch compile mode
device: str | None = None # Device to use for the model (None = auto-detect)
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (224, 224)
# Add empty images. Used by pi0_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0
# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False
# Tokenizer
tokenizer_max_length: int = 48
# Projector
proj_width: int = 1024
# Decoding
num_steps: int = 10
# Attention utils
use_cache: bool = True
attention_implementation: str = "eager" # or fa2, flex
# Finetuning settings
freeze_vision_encoder: bool = True
train_expert_only: bool = False
train_state_proj: bool = True
# Training presets
optimizer_lr: float = 2.5e-5
# Optimizer settings: see openpi `AdamW``
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_weight_decay: float = 0.01
optimizer_grad_clip_norm: float = 1.0
# Scheduler settings: see openpi `CosineDecaySchedule`
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
# TODO: Add EMA
tokenizer_max_length: int = 48 # see openpi `__post_init__`
def __post_init__(self):
super().__post_init__()
# TODO(Steven): Validate device and amp? in all policy configs?
"""Input validation (not exhaustive)."""
# Validate configuration
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
)
if self.use_delta_joint_actions_aloha:
raise NotImplementedError(
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
)
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
if self.dtype not in ["bfloat16", "float32"]:
raise ValueError(f"Invalid dtype: {self.dtype}")
def validate_features(self) -> None:
# TODO: implement value error
# if not self.image_features and not self.env_state_feature:
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
"""Validate and set up input/output features."""
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
key = f"{OBS_IMAGES}.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
shape=(3, *self.image_resolution), # Use configured image resolution
)
self.input_features[key] = empty_camera
if "observation.state" not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim
)
self.input_features["observation.state"] = state_feature
if "action" not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim
)
self.output_features["action"] = action_feature
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
@@ -1,82 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.factory import make_policy
torch.backends.cudnn.benchmark = True
def main():
device = "cuda"
dataset_repo_id = "danaaubakirova/koch_test"
# model_name = "pi0_base"
# ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
ckpt_torch_dir = "lerobot/pi0"
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=1,
)
batch = next(iter(dataloader))
# To device
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].to(device=device, dtype=torch.float32)
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir
policy = make_policy(cfg, ds_meta=dataset.meta)
# policy = torch.compile(policy, mode="reduce-overhead")
warmup_iters = 10
benchmark_iters = 30
# Warmup
for _ in range(warmup_iters):
torch.cuda.synchronize()
policy.select_action(batch)
policy.reset()
torch.cuda.synchronize()
# Benchmark
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(benchmark_iters):
policy.select_action(batch)
policy.reset()
end_event.record()
# Synchronize and measure time
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
avg_time_per_iter = elapsed_time_ms / benchmark_iters
print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms")
if __name__ == "__main__":
with torch.inference_mode():
main()
@@ -1,131 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import pickle
from pathlib import Path
import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.factory import make_policy
def display(tensor: torch.Tensor):
if tensor.dtype == torch.bool:
tensor = tensor.float()
print(f"Shape: {tensor.shape}")
print(f"Mean: {tensor.mean().item()}")
print(f"Std: {tensor.std().item()}")
print(f"Min: {tensor.min().item()}")
print(f"Max: {tensor.max().item()}")
def main():
num_motors = 14
device = "cuda"
# model_name = "pi0_aloha_towel"
model_name = "pi0_aloha_sim"
if model_name == "pi0_aloha_towel":
dataset_repo_id = "lerobot/aloha_static_towel"
else:
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
save_dir = Path(f"../openpi/data/{model_name}/save")
with open(save_dir / "example.pkl", "rb") as f:
example = pickle.load(f)
with open(save_dir / "outputs.pkl", "rb") as f:
outputs = pickle.load(f)
with open(save_dir / "noise.pkl", "rb") as f:
noise = pickle.load(f)
with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
norm_stats = json.load(f)
# Override stats
dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
dataset_meta.stats["observation.state"]["mean"] = torch.tensor(
norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
)
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
)
# Create LeRobot batch from Jax
batch = {}
for cam_key, uint_chw_array in example["images"].items():
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
batch["observation.state"] = torch.from_numpy(example["state"])
batch["action"] = torch.from_numpy(outputs["actions"])
batch["task"] = example["prompt"]
if model_name == "pi0_aloha_towel":
del batch["observation.images.cam_low"]
elif model_name == "pi0_aloha_sim":
batch["observation.images.top"] = batch["observation.images.cam_high"]
del batch["observation.images.cam_high"]
# Batchify
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].unsqueeze(0)
elif isinstance(batch[key], str):
batch[key] = [batch[key]]
else:
raise ValueError(f"{key}, {batch[key]}")
# To device
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].to(device=device, dtype=torch.float32)
noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32)
from lerobot import policies # noqa
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir
policy = make_policy(cfg, dataset_meta)
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
# loss_dict["loss"].backward()
# print("losses")
# display(loss_dict["losses_after_forward"])
# print("pi_losses")
# display(pi_losses)
actions = []
for _ in range(50):
action = policy.select_action(batch, noise=noise)
actions.append(action)
actions = torch.stack(actions, dim=1)
pi_actions = batch["action"]
print("actions")
display(actions)
print()
print("pi_actions")
display(pi_actions)
print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2))
print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2))
print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2))
if __name__ == "__main__":
main()
@@ -1,84 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import GemmaConfig, PaliGemmaConfig
def get_paligemma_config(precision: str):
config = {
"image_token_index": None,
"pad_token_id": 0,
"bos_token_id": 2,
"eos_token_id": 1,
}
# image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896}
image_size = 224 # image_sizes[variant]
patch_size = 14
num_image_tokens = (image_size**2) // (patch_size**2)
config["image_token_index"] = 257152
text_config = {
"vocab_size": 257152,
"num_hidden_layers": 18,
"num_key_value_heads": 1,
"head_dim": 256,
"torch_dtype": precision,
"hidden_size": 2048,
"hidden_activation": "gelu_pytorch_tanh",
"num_attention_heads": 8,
"intermediate_size": 16384,
"is_encoder_decoder": False,
}
vision_config = {
"torch_dtype": precision,
"image_size": image_size,
"patch_size": patch_size,
"num_image_tokens": num_image_tokens,
"hidden_size": 1152,
"intermediate_size": 4304,
"num_hidden_layers": 27,
"num_attention_heads": 16,
"projector_hidden_act": "gelu_fast",
"vision_use_head": False,
}
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
return final_config
def get_gemma_config(precision: str):
config = {
"image_token_index": None,
"pad_token_id": 0,
"bos_token_id": 2,
"eos_token_id": 1,
}
config["image_token_index"] = 257152
text_config = {
"vocab_size": 257152,
"num_hidden_layers": 18,
"num_key_value_heads": 1,
"head_dim": 256,
"torch_dtype": precision,
"hidden_size": 1024,
"hidden_activation": "gelu_pytorch_tanh",
"num_attention_heads": 8,
"intermediate_size": 4096,
"is_encoder_decoder": False,
}
final_config = GemmaConfig()
final_config.update(text_config)
return final_config
@@ -1,437 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert pi0 parameters from Jax to Pytorch
Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
and install the required libraries.
```bash
cd ~/code/openpi
source .venv/bin/activate
```
Example downloading parameters:
```bash
python
>>> import openpi.shared.download as download
>>> path='s3://openpi-assets/checkpoints/pi0_base/params'
>>> download.maybe_download(path)
```
Converting pi0_base:
```python
python -m lerobot.policies.pi0.conversion_scripts.convert_pi0_to_hf_lerobot \
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch
```
```python
python -m lerobot.policies.pi0.conversion_scripts.convert_pi0_to_hf_lerobot \
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
```
"""
import argparse
import pathlib
import jax
import numpy as np
import orbax.checkpoint as ocp
import torch
from jax.sharding import SingleDeviceSharding
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.conversion_scripts.conversion_utils import (
get_gemma_config,
get_paligemma_config,
)
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
def slice_paligemma_state_dict(state_dict, config):
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
# fmt: off
# patch embeddings
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose(
3, 2, 0, 1
)
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}")
# positional embeddings
state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape(
-1, config.vision_config.hidden_size
)
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}")
encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}")
encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}")
encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}")
encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}")
encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}")
encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}")
encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}")
for i in range(config.vision_config.num_hidden_layers):
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose()
state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}")
# multimodal projector
state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose()
state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}")
# text decoder (gemma)
embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}")
state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
# TODO verify correctness of layer norm loading
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
for i in range(config.text_config.num_hidden_layers):
# llm_attention_q_einsum[i].shape = (8, 2048, 256)
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
# llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256)
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
# llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256)
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
# output projection.
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048)
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
# mlp layers
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}")
state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied.
# fmt: on
expert_dict = {}
final_state_dict = {}
for key, value in state_dict.items():
if key not in [
f"llm/final_norm_1/scale{suffix}",
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
f"llm/layers/attn/kv_einsum_1/w{suffix}",
f"llm/layers/attn/q_einsum_1/w{suffix}",
f"llm/layers/mlp_1/gating_einsum{suffix}",
f"llm/layers/mlp_1/linear{suffix}",
f"llm/layers/pre_attention_norm_1/scale{suffix}",
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
]:
final_state_dict[key] = torch.from_numpy(value)
else:
expert_dict[key] = value
return final_state_dict, expert_dict
def slice_gemma_state_dict(state_dict, config, num_expert=1):
# fmt: off
# text decoder (gemma)
# no embedding vector, the expert just has the decoder layers
embedding_vector = torch.zeros([config.vocab_size, config.hidden_size])
state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
# TODO verify correctness of layer norm loading
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
for i in range(config.num_hidden_layers):
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
# output projection.
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024)
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0)
state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
# mlp layers
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}")
state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here)
# fmt: on
final_state_dict = {}
for key, value in state_dict.items():
if not isinstance(value, torch.Tensor):
final_state_dict[key] = torch.from_numpy(value)
else:
final_state_dict[key] = value
return final_state_dict
def flatten_for_memory(tree, parent_key=""):
out = {}
for k, v in tree.items():
new_key = f"{parent_key}/{k}" if parent_key else k
if isinstance(v, dict):
out.update(flatten_for_memory(v, new_key))
else:
out[new_key] = np.array(v) # Ensure conversion to np.array for consistency
return out
def flatten_for_npz(tree, parent_key=""):
out = {}
for k, v in tree.items():
new_key = f"{parent_key}/{k}" if parent_key else k
if isinstance(v, dict):
out.update(flatten_for_npz(v, new_key))
else:
# bf16/f32 here?
out[new_key] = np.array(v)
return out
def slice_initial_orbax_checkpoint(checkpoint_dir: str):
params_path = pathlib.Path(checkpoint_dir).resolve()
checkpointer = ocp.PyTreeCheckpointer()
metadata = checkpointer.metadata(params_path)
print("Metadata keys:", list(metadata.keys()))
params_name = "params"
item = {params_name: metadata[params_name]}
device = jax.local_devices()[0] # Use the first local device
sharding = SingleDeviceSharding(device)
restored = checkpointer.restore(
params_path,
ocp.args.PyTreeRestore(
item=item,
restore_args=jax.tree_util.tree_map(
lambda _: ocp.ArrayRestoreArgs(
restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it
sharding=sharding,
),
item,
),
transforms={},
),
)
params = restored[params_name]
# get params for PaliGemma
pali_params = params["PaliGemma"]
del params["PaliGemma"]
pali_params_flat = flatten_for_npz(pali_params)
return {"paligemma_params": pali_params_flat, "projection_params": params}
def update_keys_with_prefix(d: dict, prefix: str) -> dict:
"""Update dictionary keys by adding a prefix."""
return {f"{prefix}{key}": value for key, value in d.items()}
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
# Break down orbax ckpts - they are in OCDBT
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
# process projection params
keys = [
"state_proj",
"action_in_proj",
"action_out_proj",
"action_time_mlp_in",
"action_time_mlp_out",
]
projection_params = {}
for key in keys:
kernel_params = initial_params["projection_params"][key]["kernel"]
bias_params = initial_params["projection_params"][key]["bias"]
if isinstance(kernel_params, dict):
weight = kernel_params["value"]
bias = bias_params["value"]
else:
weight = kernel_params
bias = bias_params
projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T
projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias))
# Process PaliGemma weights
paligemma_config = get_paligemma_config(precision)
paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict(
initial_params["paligemma_params"], paligemma_config
)
# Process Gemma weights (at this stage they are unused)
gemma_config = get_gemma_config(precision)
gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config)
# Instantiate model from configs
if "pi0_aloha_sim" in checkpoint_dir:
pi0_config = PI0Config(
empty_cameras=2,
adapt_to_pi_aloha=True,
use_delta_joint_actions_aloha=False,
)
elif "pi0_aloha_towel" in checkpoint_dir:
pi0_config = PI0Config(
adapt_to_pi_aloha=True,
use_delta_joint_actions_aloha=True,
)
elif "pi0_base" in checkpoint_dir:
pi0_config = PI0Config(
empty_cameras=0,
adapt_to_pi_aloha=False,
use_delta_joint_actions_aloha=False,
)
else:
raise ValueError()
# gemma_config=gemma_config, paligemma_config=paligemma_config)
pi0_model = PI0Policy(pi0_config)
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
projection_params = update_keys_with_prefix(projection_params, "model.")
# load state dict
torch_dtype = PRECISIONS[precision]
pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params})
pi0_model = pi0_model.to(torch_dtype)
# pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
pi0_model.save_pretrained(output_path, safe_serialization=True)
# pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype)
# assert that model loads properly
del pi0_model
PI0Policy.from_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_dir",
default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params",
type=str,
help="Path to the ocdbt checkpoint",
)
parser.add_argument(
"--precision",
choices=["float32", "bfloat16", "float16"],
default="float32",
type=str,
help="Precision identifier for model conversion - should match the base checkpoint precision.",
)
# tokenizer is identical to paligemma, it appears
parser.add_argument(
"--tokenizer_hub_id",
default="google/paligemma-3b-pt-224",
type=str,
help="Hub path to the tokenizer to save",
)
parser.add_argument(
"--output_path",
required=True,
type=str,
help="Path to save converted weights to",
)
args = parser.parse_args()
convert_pi0_checkpoint(
checkpoint_dir=args.checkpoint_dir,
precision=args.precision,
tokenizer_id=args.tokenizer_hub_id,
output_path=args.output_path,
)
-141
View File
@@ -1,141 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F # noqa: N812
from packaging.version import Version
if Version(torch.__version__) > Version("2.5.0"):
# Ffex attention is only available from torch 2.5 onwards
from torch.nn.attention.flex_attention import (
_mask_mod_signature,
_round_up_to_multiple,
create_block_mask,
create_mask,
flex_attention,
)
# @torch.compile(dynamic=False)
def flex_attention_forward(
attention_mask: torch.Tensor,
batch_size: int,
head_dim: int,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
scaling=None,
):
"""
This is defined out of classes to make compile happy.
"""
original_dtype = query_states.dtype
num_att_heads = 8
num_key_value_heads = 1
num_key_value_groups = num_att_heads // num_key_value_heads
key_states = key_states[:, :, :, None, :]
key_states = key_states.expand(
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :]
value_states = value_states.expand(
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
query_states = query_states.to(torch.float32)
key_states = key_states.to(torch.float32)
value_states = value_states.to(torch.float32)
causal_mask = attention_mask
if causal_mask is not None:
causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
def mask_mod(b, h, q_idx, kv_idx):
# Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
return precomputed_mask[b][h][q_idx][kv_idx]
return mask_mod
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
block_size = 128
q_len_rounded = _round_up_to_multiple(q_len, block_size)
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
# *CRITICAL* we do need to expand here, else we get a CUDA index error
pad_q = q_len_rounded - q_len
pad_k = kv_len_rounded - kv_len
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
mask_4d = create_mask(
mod_fn=mask_mod_fn_orig,
B=b_mask,
H=h_mask,
Q_LEN=q_len_rounded,
KV_LEN=kv_len_rounded,
device=causal_mask.device,
_compile=False,
)
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
block_mask = create_block_mask(
mask_mod=mask_mod_fn_padded,
B=b_mask,
H=h_mask,
Q_LEN=q_len_rounded,
KV_LEN=kv_len_rounded,
BLOCK_SIZE=block_size,
device=causal_mask.device,
_compile=False,
)
# mask is applied inside the kernel, ideally more efficiently than score_mod.
attn_output, attention_weights = flex_attention(
query_states,
key_states,
value_states,
block_mask=block_mask,
enable_gqa=True, # because we shaped query/key states for GQA
scale=head_dim**-0.5 if scaling is None else scaling,
return_lse=True,
)
attn_output = attn_output.to(dtype=original_dtype)
attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
attn_output = attn_output.reshape(
batch_size,
-1,
attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
)
return attn_output
File diff suppressed because it is too large Load Diff
@@ -1,420 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.version
from pytest import Cache
from torch import nn
from transformers import (
AutoConfig,
GemmaForCausalLM,
PaliGemmaForConditionalGeneration,
PretrainedConfig,
PreTrainedModel,
)
from transformers.models.auto import CONFIG_MAPPING
from lerobot.policies.pi0.flex_attention import flex_attention_forward
def apply_rope(x, positions, max_wavelength=10_000):
"""
Applies RoPE positions [B, L] to x [B, L, H, D].
"""
d_half = x.shape[-1] // 2
device = x.device
dtype = x.dtype
x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
radians = radians[..., None, :]
sin = torch.sin(radians) # .to(dtype=dtype)
cos = torch.cos(radians) # .to(dtype=dtype)
x1, x2 = x.split(d_half, dim=-1)
res = torch.empty_like(x)
res[..., :d_half] = x1 * cos - x2 * sin
res[..., d_half:] = x2 * cos + x1 * sin
return res.to(dtype)
class PaliGemmaWithExpertConfig(PretrainedConfig):
model_type = "PaliGemmaWithExpertModel"
sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
def __init__(
self,
paligemma_config: dict | None = None,
gemma_expert_config: dict | None = None,
freeze_vision_encoder: bool = True,
train_expert_only: bool = True,
attention_implementation: str = "eager",
**kwargs,
):
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
self.attention_implementation = attention_implementation
if paligemma_config is None:
# Default config from Pi0
self.paligemma_config = CONFIG_MAPPING["paligemma"](
transformers_version="4.48.1",
_vocab_size=257152,
bos_token_id=2,
eos_token_id=1,
hidden_size=2048,
image_token_index=257152,
model_type="paligemma",
pad_token_id=0,
projection_dim=2048,
text_config={
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2048,
"intermediate_size": 16384,
"model_type": "gemma",
"num_attention_heads": 8,
"num_hidden_layers": 18,
"num_image_tokens": 256,
"num_key_value_heads": 1,
"torch_dtype": "float32",
"vocab_size": 257152,
},
vision_config={
"hidden_size": 1152,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"num_image_tokens": 256,
"patch_size": 14,
"projection_dim": 2048,
"projector_hidden_act": "gelu_fast",
"torch_dtype": "float32",
"vision_use_head": False,
},
)
elif isinstance(self.paligemma_config, dict):
# Override Pi0 default config for PaliGemma
if "model_type" not in gemma_expert_config:
paligemma_config["model_type"] = "paligemma"
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
self.paligemma_config = cfg_cls(**paligemma_config)
if gemma_expert_config is None:
# Default config from Pi0
self.gemma_expert_config = CONFIG_MAPPING["gemma"](
attention_bias=False,
attention_dropout=0.0,
bos_token_id=2,
eos_token_id=1,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation="gelu_pytorch_tanh",
hidden_size=1024,
initializer_range=0.02,
intermediate_size=4096,
max_position_embeddings=8192,
model_type="gemma",
num_attention_heads=8,
num_hidden_layers=18,
num_key_value_heads=1,
pad_token_id=0,
rms_norm_eps=1e-06,
rope_theta=10000.0,
torch_dtype="float32",
transformers_version="4.48.1",
use_cache=True,
vocab_size=257152,
)
elif isinstance(self.gemma_expert_config, dict):
# Override Pi0 default config for Gemma Expert
if "model_type" not in gemma_expert_config:
gemma_expert_config["model_type"] = "gemma"
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
self.gemma_expert_config = cfg_cls(**gemma_expert_config)
super().__init__(**kwargs)
def __post_init__(self):
super().__post_init__()
if self.train_expert_only and not self.freeze_vision_encoder:
raise ValueError(
"You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
)
if self.attention_implementation not in ["eager", "fa2", "flex"]:
raise ValueError(
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
)
class PaliGemmaWithExpertModel(PreTrainedModel):
config_class = PaliGemmaWithExpertConfig
def __init__(self, config: PaliGemmaWithExpertConfig):
super().__init__(config=config)
self.config = config
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
# Remove unused embed_tokens
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_like_physical_intelligence()
self.set_requires_grad()
def set_requires_grad(self):
if self.config.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
for params in self.paligemma.vision_tower.parameters():
params.requires_grad = False
if self.config.train_expert_only:
self.paligemma.eval()
for params in self.paligemma.parameters():
params.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.config.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
if self.config.train_expert_only:
self.paligemma.eval()
def to_bfloat16_like_physical_intelligence(self):
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
params_to_change_dtype = [
"language_model.model.layers",
"gemma_expert.model.layers",
"vision_tower",
"multi_modal",
]
for name, param in self.named_parameters():
if any(selector in name for selector in params_to_change_dtype):
param.data = param.data.to(dtype=torch.bfloat16)
def embed_image(self, image: torch.Tensor):
# Handle different transformers versions
if hasattr(self.paligemma, "get_image_features"):
return self.paligemma.get_image_features(image)
else:
return self.paligemma.model.get_image_features(image)
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.language_model.embed_tokens(tokens)
# TODO: break down this huge forward into modules or functions
def forward(
self,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | Cache | None = None,
inputs_embeds: list[torch.FloatTensor] = None,
use_cache: bool | None = None,
fill_kv_cache: bool | None = None,
):
models = [self.paligemma.language_model, self.gemma_expert.model]
for hidden_states in inputs_embeds:
# TODO this is very inefficient
# dtype is always the same, batch size too (if > 1 len)
# device could be trickier in multi gpu edge cases but that's it
if hidden_states is None:
continue
batch_size = hidden_states.shape[0]
# RMSNorm
num_layers = self.paligemma.config.text_config.num_hidden_layers
head_dim = self.paligemma.config.text_config.head_dim
for layer_idx in range(num_layers):
query_states = []
key_states = []
value_states = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is None:
continue
layer = models[i].layers[layer_idx]
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
# hidden_states = hidden_states * normalizer
hidden_states = layer.input_layernorm(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=torch.bfloat16)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# B,L,H,D with L sequence length, H number of heads, D head dim
# concatenate on the number of embeddings/tokens
query_states = torch.cat(query_states, dim=1)
key_states = torch.cat(key_states, dim=1)
value_states = torch.cat(value_states, dim=1)
query_states = apply_rope(query_states, position_ids)
key_states = apply_rope(key_states, position_ids)
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat(
[past_key_values[layer_idx]["value_states"], value_states], dim=1
)
attention_interface = self.get_attention_interface()
att_output = attention_interface(
attention_mask, batch_size, head_dim, query_states, key_states, value_states
)
att_output = att_output.to(dtype=torch.bfloat16)
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
outputs_embeds = []
start = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
if hidden_states is not None:
end = start + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start:end])
# TODO: first dropout (by default 0.0)
# first residual
out_emb += hidden_states
after_first_residual = out_emb.clone()
out_emb = layer.post_attention_layernorm(out_emb)
out_emb = layer.mlp(out_emb)
# TODO: second dropout (by default 0.0)
# second residual
out_emb += after_first_residual
outputs_embeds.append(out_emb)
start = end
else:
outputs_embeds.append(None)
inputs_embeds = outputs_embeds
# final norm
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is not None:
out_emb = models[i].norm(hidden_states)
outputs_embeds.append(out_emb)
else:
outputs_embeds.append(None)
return outputs_embeds, past_key_values
def get_attention_interface(self):
if self.config.attention_implementation == "fa2":
attention_interface = self.flash_attention_forward
elif self.config.attention_implementation == "flex":
attention_interface = flex_attention_forward
else:
attention_interface = self.eager_attention_forward
return attention_interface
def flash_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
raise NotImplementedError("FA2 is not implemented (yet)")
def eager_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
num_key_value_groups = num_att_heads // num_key_value_heads
# query_states: batch_size, sequence_length, num_att_head, head_dim
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
sequence_length = key_states.shape[1]
key_states = key_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
# Attention here is upcasted to float32 to match the original eager implementation.
query_states = query_states.to(dtype=torch.float32)
key_states = key_states.to(dtype=torch.float32)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
att_weights *= head_dim**-0.5
big_neg = -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
# value_states: batch_size, sequence_length, num_att_heads, head_dim
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
return att_output
+1 -1
View File
@@ -19,7 +19,6 @@ from typing import Any
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
AddBatchDimensionProcessorStep,
@@ -35,6 +34,7 @@ from lerobot.processor import (
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
+49
View File
@@ -0,0 +1,49 @@
# π₀.₅ (pi05)
This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
It is designed as a **Vision-Language-Action model with open-world generalization**.
---
## Model Overview
| Feature | π₀ | π₀.₅ |
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
| AdaRMS | Not used | Used in action expert |
| Tokenizer Length | 48 tokens | 200 tokens |
| Discrete State Input | False (Uses `state_proj` layer) | True |
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
---
## Citation
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
```bibtex
@misc{openpi2024,
author = {Physical Intelligence Lab},
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
year = {2024},
publisher = {GitHub},
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
license = {Apache-2.0}
}
@misc{intelligence2025pi05visionlanguageactionmodelopenworld,
title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization},
author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky},
year = {2025},
eprint = {2504.16054},
archivePrefix= {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2504.16054},
}
```
---
## License
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
+21
View File
@@ -0,0 +1,21 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_pi05 import PI05Config
from .modeling_pi05 import PI05Policy
from .processor_pi05 import make_pi05_pre_post_processors
__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"]
@@ -0,0 +1,153 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("pi05")
@dataclass
class PI05Config(PreTrainedConfig):
paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m"
dtype: str = "float32" # Options: "bfloat16", "float32"
n_obs_steps: int = 1
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
n_action_steps: int = 50 # Number of action steps to execute
# Shorter state and action vectors will be padded to these dimensions
max_state_dim: int = 32
max_action_dim: int = 32
# Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10
time_sampling_beta_alpha: float = 1.5
time_sampling_beta_beta: float = 1.0
time_sampling_scale: float = 0.999
time_sampling_offset: float = 0.001
min_period: float = 4e-3
max_period: float = 4.0
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
empty_cameras: int = 0
tokenizer_max_length: int = 200 # see openpi `__post_init__`
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state
"ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action
}
)
# Training settings
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
compile_model: bool = False # Whether to use torch.compile for model optimization
compile_mode: str = "max-autotune" # Torch compile mode
device: str | None = None # Device to use for the model (None = auto-detect)
# Optimizer settings: see openpi `AdamW`
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.01
optimizer_grad_clip_norm: float = 1.0
# Scheduler settings: see openpi `CosineDecaySchedule`
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
tokenizer_max_length: int = 200 # see openpi `__post_init__`
def __post_init__(self):
super().__post_init__()
# Validate configuration
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
)
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
if self.dtype not in ["bfloat16", "float32"]:
raise ValueError(f"Invalid dtype: {self.dtype}")
def validate_features(self) -> None:
"""Validate and set up input/output features."""
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, *self.image_resolution), # Use configured image resolution
)
self.input_features[key] = empty_camera
if "observation.state" not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim
)
self.input_features["observation.state"] = state_feature
if "action" not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim
)
self.output_features["action"] = action_feature
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
File diff suppressed because it is too large Load Diff
+171
View File
@@ -0,0 +1,171 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.modeling_pi05 import pad_vector
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
@dataclass
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
"""
Processor step to prepare the state and tokenize the language input.
"""
max_state_dim: int = 32
task_key: str = "task"
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
if state is None:
raise ValueError("State is required for PI05")
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
if tasks is None:
raise ValueError("No task found in complementary data")
# TODO: check if this necessary
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.max_state_dim)
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
return transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
This step does not alter the feature definitions.
"""
return features
def make_pi05_pre_post_processors(
config: PI05Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for the PI0 policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Appending a newline character to the task description for tokenizer compatibility.
5. Tokenizing the text prompt using the PaliGemma tokenizer.
6. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the PI0 policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DeviceProcessorStep(device=config.device),
]
output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@@ -1,3 +1,19 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
@@ -6,6 +22,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.utils.constants import OBS_IMAGES
@PreTrainedConfig.register_subclass("pi0fast")
@@ -99,7 +116,7 @@ class PI0FASTConfig(PreTrainedConfig):
def validate_features(self) -> None:
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
key = f"{OBS_IMAGES}.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
@@ -57,9 +57,9 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe
from transformers.cache_utils import HybridCache, StaticCache
from transformers.models.auto import CONFIG_MAPPING
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_STATE
PRECISION = {
"float16": torch.float16,
@@ -18,7 +18,6 @@ from typing import Any
import torch
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
@@ -30,6 +29,7 @@ from lerobot.processor import (
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_pi0fast_pre_post_processors(
+8 -3
View File
@@ -18,7 +18,7 @@ import os
from importlib.resources import files
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TypeVar
from typing import TypedDict, TypeVar
import packaging
import safetensors
@@ -27,6 +27,7 @@ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
@@ -36,6 +37,10 @@ from lerobot.utils.hub import HubMixin
T = TypeVar("T", bound="PreTrainedPolicy")
class ActionSelectKwargs(TypedDict, total=False):
noise: Tensor | None
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
"""
Base class for policy models.
@@ -181,7 +186,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
raise NotImplementedError
@abc.abstractmethod
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode.
Child classes using action chunking should use this method within `select_action` to form the action chunk
@@ -190,7 +195,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
raise NotImplementedError
@abc.abstractmethod
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
def select_action(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals
@@ -19,8 +19,8 @@ from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.optim.optimizers import MultiAdamConfig
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
def is_image_feature(key: str) -> bool:
@@ -139,8 +139,6 @@ class SACConfig(PreTrainedConfig):
# Training parameter
# Number of steps for online training
online_steps: int = 1000000
# Seed for the online environment
online_env_seed: int = 10000
# Capacity of the online replay buffer
online_buffer_capacity: int = 100000
# Capacity of the offline replay buffer
@@ -225,7 +223,7 @@ class SACConfig(PreTrainedConfig):
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
)
if "action" not in self.output_features:
if ACTION not in self.output_features:
raise ValueError("You must provide 'action' in the output features")
@property
+9 -20
View File
@@ -31,6 +31,7 @@ from torch.distributions import MultivariateNormal, TanhTransform, Transform, Tr
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature
from lerobot.policies.utils import get_device_from_parameters
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE
DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
@@ -50,7 +51,7 @@ class SACPolicy(
self.config = config
# Determine action dimension and initialize all components
continuous_action_dim = config.output_features["action"].shape[0]
continuous_action_dim = config.output_features[ACTION].shape[0]
self._init_encoders()
self._init_critics(continuous_action_dim)
self._init_actor(continuous_action_dim)
@@ -157,7 +158,7 @@ class SACPolicy(
The computed loss tensor
"""
# Extract common components from batch
actions: Tensor = batch["action"]
actions: Tensor = batch[ACTION]
observations: dict[str, Tensor] = batch["state"]
observation_features: Tensor = batch.get("observation_feature")
@@ -513,17 +514,17 @@ class SACObservationEncoder(nn.Module):
)
def _init_state_layers(self) -> None:
self.has_env = "observation.environment_state" in self.config.input_features
self.has_state = "observation.state" in self.config.input_features
self.has_env = OBS_ENV_STATE in self.config.input_features
self.has_state = OBS_STATE in self.config.input_features
if self.has_env:
dim = self.config.input_features["observation.environment_state"].shape[0]
dim = self.config.input_features[OBS_ENV_STATE].shape[0]
self.env_encoder = nn.Sequential(
nn.Linear(dim, self.config.latent_dim),
nn.LayerNorm(self.config.latent_dim),
nn.Tanh(),
)
if self.has_state:
dim = self.config.input_features["observation.state"].shape[0]
dim = self.config.input_features[OBS_STATE].shape[0]
self.state_encoder = nn.Sequential(
nn.Linear(dim, self.config.latent_dim),
nn.LayerNorm(self.config.latent_dim),
@@ -549,9 +550,9 @@ class SACObservationEncoder(nn.Module):
cache = self.get_cached_image_features(obs)
parts.append(self._encode_images(cache, detach))
if self.has_env:
parts.append(self.env_encoder(obs["observation.environment_state"]))
parts.append(self.env_encoder(obs[OBS_ENV_STATE]))
if self.has_state:
parts.append(self.state_encoder(obs["observation.state"]))
parts.append(self.state_encoder(obs[OBS_STATE]))
if parts:
return torch.cat(parts, dim=-1)
@@ -1060,15 +1061,3 @@ class TanhMultivariateNormalDiag(TransformedDistribution):
x = transform(x)
return x
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
converted_params = {}
for outer_key, inner_dict in normalization_params.items():
converted_params[outer_key] = {}
for key, value in inner_dict.items():
converted_params[outer_key][key] = torch.tensor(value)
if "image" in outer_key:
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
return converted_params
+1 -1
View File
@@ -19,7 +19,6 @@ from typing import Any
import torch
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
@@ -31,6 +30,7 @@ from lerobot.processor import (
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_sac_pre_post_processors(
@@ -19,6 +19,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import AdamWConfig, OptimizerConfig
from lerobot.optim.schedulers import LRSchedulerConfig
from lerobot.utils.constants import OBS_IMAGE
@PreTrainedConfig.register_subclass(name="reward_classifier")
@@ -69,7 +70,7 @@ class RewardClassifierConfig(PreTrainedConfig):
def validate_features(self) -> None:
"""Validate feature configurations."""
has_image = any(key.startswith("observation.image") for key in self.input_features)
has_image = any(key.startswith(OBS_IMAGE) for key in self.input_features)
if not has_image:
raise ValueError(
"You must provide an image observation (key starting with 'observation.image') in the input features"
@@ -19,9 +19,9 @@ import logging
import torch
from torch import Tensor, nn
from lerobot.constants import OBS_IMAGE, REWARD
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.utils.constants import OBS_IMAGE, REWARD
class ClassifierOutput:

Some files were not shown because too many files have changed in this diff Show More