From a5e0aae13a3efd0080ac7ab6b461980d644014ab Mon Sep 17 00:00:00 2001 From: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Date: Tue, 8 Jul 2025 13:08:32 +0200 Subject: [PATCH 001/158] Fixes `@torch.no_grad()` usage (#1455) * fix: decorator calls with parentheses * fix no grad for normalize too Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --------- Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --- src/lerobot/policies/act/modeling_act.py | 4 ++-- src/lerobot/policies/diffusion/modeling_diffusion.py | 4 ++-- src/lerobot/policies/normalize.py | 4 ++-- src/lerobot/policies/pi0/modeling_pi0.py | 4 ++-- src/lerobot/policies/pi0fast/modeling_pi0fast.py | 4 ++-- src/lerobot/policies/sac/modeling_sac.py | 2 +- src/lerobot/policies/smolvla/modeling_smolvla.py | 3 ++- src/lerobot/policies/tdmpc/modeling_tdmpc.py | 2 +- src/lerobot/policies/vqbet/modeling_vqbet.py | 4 ++-- 9 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index ed911e9be..f66c8ae82 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -107,7 +107,7 @@ class ACTPolicy(PreTrainedPolicy): else: self._action_queue = deque([], maxlen=self.config.n_action_steps) - @torch.no_grad + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. @@ -132,7 +132,7 @@ class ACTPolicy(PreTrainedPolicy): self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() - @torch.no_grad + @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: """Predict a chunk of actions given environment observations.""" self.eval() diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index af40f7a86..6dad8fb89 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -99,7 +99,7 @@ class DiffusionPolicy(PreTrainedPolicy): if self.config.env_state_feature: self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) - @torch.no_grad + @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: """Predict a chunk of actions given environment observations.""" # stack n latest observations from the queue @@ -111,7 +111,7 @@ class DiffusionPolicy(PreTrainedPolicy): return actions - @torch.no_grad + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. diff --git a/src/lerobot/policies/normalize.py b/src/lerobot/policies/normalize.py index 9cc94b929..119055873 100644 --- a/src/lerobot/policies/normalize.py +++ b/src/lerobot/policies/normalize.py @@ -149,7 +149,7 @@ class Normalize(nn.Module): setattr(self, "buffer_" + key.replace(".", "_"), buffer) # TODO(rcadene): should we remove torch.no_grad? - @torch.no_grad + @torch.no_grad() def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: # TODO: Remove this shallow copy batch = dict(batch) # shallow copy avoids mutating the input batch @@ -224,7 +224,7 @@ class Unnormalize(nn.Module): setattr(self, "buffer_" + key.replace(".", "_"), buffer) # TODO(rcadene): should we remove torch.no_grad? - @torch.no_grad + @torch.no_grad() def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch = dict(batch) # shallow copy avoids mutating the input batch for key, ft in self.features.items(): diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 241509d0b..badfb4b8c 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -260,12 +260,12 @@ class PI0Policy(PreTrainedPolicy): def get_optim_params(self) -> dict: return self.parameters() - @torch.no_grad + @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: """Predict a chunk of actions given environment observations.""" raise NotImplementedError("Currently not implemented for PI0") - @torch.no_grad + @torch.no_grad() def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """Select a single action given environment observations. diff --git a/src/lerobot/policies/pi0fast/modeling_pi0fast.py b/src/lerobot/policies/pi0fast/modeling_pi0fast.py index d3e576d1c..0e53bd349 100644 --- a/src/lerobot/policies/pi0fast/modeling_pi0fast.py +++ b/src/lerobot/policies/pi0fast/modeling_pi0fast.py @@ -192,12 +192,12 @@ class PI0FASTPolicy(PreTrainedPolicy): actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) return actions - @torch.no_grad + @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: """Predict a chunk of actions given environment observations.""" raise NotImplementedError("Currently not implemented for PI0FAST") - @torch.no_grad + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index 54ea122a8..93cfe6c93 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -76,7 +76,7 @@ class SACPolicy( """Reset the policy""" pass - @torch.no_grad + @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: """Predict a chunk of actions given environment observations.""" raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!") diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 11bb8bf52..a31e1b078 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -413,6 +413,7 @@ class SmolVLAPolicy(PreTrainedPolicy): return batch + @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: self.eval() @@ -422,7 +423,7 @@ class SmolVLAPolicy(PreTrainedPolicy): actions = self._get_action_chunk(batch, noise) return actions - @torch.no_grad + @torch.no_grad() def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """Select a single action given environment observations. diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index 8b70b265d..c27689387 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -110,7 +110,7 @@ class TDMPCPolicy(PreTrainedPolicy): # CEM for the next step. self._prev_mean: torch.Tensor | None = None - @torch.no_grad + @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: """Predict a chunk of actions given environment observations.""" batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues} diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index c045ccbd2..59c820a96 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -124,14 +124,14 @@ class VQBeTPolicy(PreTrainedPolicy): ACTION: deque(maxlen=self.config.action_chunk_size), } - @torch.no_grad + @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size] actions = self.unnormalize_outputs({ACTION: actions})[ACTION] return actions - @torch.no_grad + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. From 039de254ea5f69ef7e7c39256cd0202e26814cf2 Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Tue, 8 Jul 2025 15:47:11 +0200 Subject: [PATCH 002/158] Add Hope Jr (#935) * Fix imports * Add feetech write tests * Nit * Add autoclosing fixture * Assert ping stub called * Add CalibrationMode * Add Motor in dxl robots * Simplify split_int_bytes * Rename read/write -> sync_read/write, refactor, add write * Rename tests * Refactor dxl tests by functionality * Add dxl write test * Refactor _is_comm_success * Refactor feetech tests by functionality * Add feetech write test * Simplify _is_comm_success & _is_error * Move mock_serial patch to dedicated file * Remove test skips & fix docstrings * Nit * Add dxl operating modes * Add is_connected in robots and teleops * Update Koch * Add feetech operating modes * Caps dxl OperatingMode * Update ensure_safe_goal_position * Update so100 * Privatize methods & renames * Fix dict * Add _configure_motors & move ping methods * Return models (str) with pings * Implement feetech broadcast ping * Add raw_values option * Rename idx -> id_ * Improve errors * Fix feetech ping tests * Ensure motors exist at connection time * Update tests * Add test_motors_bus * Move DriveMode & TorqueMode * Update Koch imports * Update so100 imports * Fix visualize_motors_bus * Fix imports * Add calibration * Rename idx -> id_ * Rename idx -> id_ * (WIP) _async_read * Add new calibration method for robot refactor (#896) Co-authored-by: Simon Alibert * Remove deprecated scripts * Rename CalibrationMode -> MotorNormMode * Fix calibration functions * Remove todo * Add scan_port utility * Add calibration utilities * Move encoding functions to encoding_utils * Add test_encoding_utils * Rename test * Add more calibration utilities * Format baudrate tables * Implement SO-100 leader calibration * Implement SO-100 follower calibration * Implement Koch calibration * Add test_scan_port (TODO) * Fix calibration * Hack feetech firmware bug * Update tests * Update Koch & SO-100 * Improve format * Rename SO-100 classes * Rename Koch classes * Add calibration tests * Remove old calibration tests * Revert feetech hack and monkeypatch instead * Simplify motors mocks * Add is_calibrated test * Update viperx & widowx * Rename viperx & widowx * Remove old calibration * feat(teleop): thread-safe keyboard teleop implementation (#869) Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> * Add support for feetech scs series + various fixes * Update dynamixel with motors bus & tables changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * (WIP) Add Hope Jr * Rename arm -> hand * (WIP) Add homonculus arm & glove * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add Feetech protocol version * Implement read * Use constants from sdks * (nit) move write * Fix broadcast ping type hint * Add protocol 1 broadcast ping * Refactor & add _serialize_data * Add feetech sm8512bl * Make feetech broadcast ping faster in protocol 1 * Cleanup * Add support for feetech protocol 1 to _split_into_byte_chunks * Fix unormalize * Remove test_motors_bus fixtures * Add more segmented tests (base motor bus & feetech), add feetech protocol 1 support * Add more segmented tests (dynamixel) * Refactor tests * Add handshake, fix feetech _read_firmware_version * Fix tests * Motors config & disconnect fixes * Add torque_disabled context * Update branch & fix pre-commit errors * Fix hand & glove readings * Update feetech tables * Move read/write_calibration implementations * Add setup_motor * Fix calibration msg display * Fix setup_motor & add it to robots * Fix _find_single_motor * Remove deprecated configure_motor * Remove deprecated dynamixel_calibration * Remove names * Remove deprecated import * refactor/lekiwi robot (#863) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> * fix(teleoperators): use property is_connected (#1075) * Remove deprecated manipulator * Update robot features & naming * Update teleop features & naming * Add make_teleoperator_from_config * Rename find_port * Fix config parsing * Remove app script * Add setup_motors * Add teleoperate * Add record * Add replay * Fix test_datasets * Add mock robot & teleop * Add new test_control_robot * Add test_record_and_resume * Remove deprecated scripts & tests * Add calibrate * Add docstrings * Fix tests (no-extras install) * Add SO101 * Remove pynput from optional deps * Rename example 7 * Remove unecessary id * Add MotorsBus docstrings * Rename arm -> bus * Remove Moss arm * Fix setup_motors & calibrate configs * Fix test_calibrate * Add copyrights * Update hand & arm * Update homonculus hand & arm * Fix dxl _find_single_motor * Update glove * Add setup_motors for lekiwi * Fix glove calibration * Complete docstring * Add check for same min and max during calibration * Move MockMotorsBus * Add so100_follower tests * (WIP) add calibration gui * Fix test * Add setup_motors * Update calibration gui * Remove old .cache folder * Replace deprecated abc.abstractproperty * Fix feetech protocol 1 configure * Cleanup gui & add copyrights * Anatomically precise joint names * (WIP) Add glove to hand joints translation * Move make_robot_config * Add drive_mode & norm_mode in glove calibration * Fix joints translation * Fix normalization drive_mode * nit * Fix glove to hand conversion * Adapt feetech calibration * Remove pygame prompt * Implement arm calibration (hacks) * Better MotorsBus error messages * Update feetech read_calibration * Fix feetech test_is_calibrated * Cleanup glove * (WIP) Update arm * Add changes from #1117 * refactor(cameras): cameras implementations + tests improvements (#1108) Co-authored-by: Simon Alibert Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Fix arm joints order * Add timeout/event logic * Fix arm & glove * Fix predict_action from record * fix(cameras): update docstring + handle sn when starts with 0 + update timeouts to more reasonable value (#1154) * fix(scripts): parser instead of draccus in record + add __get_path_fields__() to RecordConfig (#1155) * Left/Right sides + other fixes * Arm fixes and add config * More hacks * Add control scripts * Fix merge errors * push changes to calibration, teleop and docs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Move readme to docs * update readme Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com> * Add files via upload Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com> * Update image sources * Symlink doc * Compress image * Move image * Update docs link * fix docs * simplify teleop scripts * fix variable names * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address code review * add EMA to glove * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * integrate teleoperation for hand * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update docs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * import hopejr/homunculus in teleoperate * update docs for teleoperate, record, replay, train and inference * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * chore(hopejr): address comments * chore(hopejr): address coments 2 * chore(docs): update teleoperation instructions for the hand/glove * fix(hopejr): calibration int + update docs --------- Signed-off-by: Martino Russi <77496684+nepyope@users.noreply.github.com> Signed-off-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Steven Palma Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: nepyope Co-authored-by: Martino Russi <77496684+nepyope@users.noreply.github.com> Co-authored-by: Steven Palma --- README.md | 23 + docs/source/_toctree.yml | 2 + docs/source/hope_jr.mdx | 1 + media/hope_jr/hopejr.png | Bin 0 -> 73277 bytes pyproject.toml | 1 + src/lerobot/calibrate.py | 2 + src/lerobot/motors/calibration_gui.py | 401 ++++++++++++++++++ src/lerobot/motors/dynamixel/dynamixel.py | 9 +- src/lerobot/motors/feetech/feetech.py | 15 +- src/lerobot/motors/feetech/tables.py | 2 +- src/lerobot/motors/motors_bus.py | 11 +- src/lerobot/record.py | 2 + src/lerobot/replay.py | 1 + src/lerobot/robots/hope_jr/__init__.py | 3 + src/lerobot/robots/hope_jr/config_hope_jr.py | 51 +++ src/lerobot/robots/hope_jr/hope_jr.mdx | 268 ++++++++++++ src/lerobot/robots/hope_jr/hope_jr_arm.py | 176 ++++++++ src/lerobot/robots/hope_jr/hope_jr_hand.py | 200 +++++++++ src/lerobot/robots/utils.py | 8 + src/lerobot/teleoperate.py | 2 + .../teleoperators/homunculus/__init__.py | 4 + .../homunculus/config_homunculus.py | 38 ++ .../homunculus/homunculus_arm.py | 310 ++++++++++++++ .../homunculus/homunculus_glove.py | 338 +++++++++++++++ .../homunculus/joints_translation.py | 63 +++ src/lerobot/teleoperators/utils.py | 8 + 26 files changed, 1922 insertions(+), 17 deletions(-) create mode 120000 docs/source/hope_jr.mdx create mode 100644 media/hope_jr/hopejr.png create mode 100644 src/lerobot/motors/calibration_gui.py create mode 100644 src/lerobot/robots/hope_jr/__init__.py create mode 100644 src/lerobot/robots/hope_jr/config_hope_jr.py create mode 100644 src/lerobot/robots/hope_jr/hope_jr.mdx create mode 100644 src/lerobot/robots/hope_jr/hope_jr_arm.py create mode 100644 src/lerobot/robots/hope_jr/hope_jr_hand.py create mode 100644 src/lerobot/teleoperators/homunculus/__init__.py create mode 100644 src/lerobot/teleoperators/homunculus/config_homunculus.py create mode 100644 src/lerobot/teleoperators/homunculus/homunculus_arm.py create mode 100644 src/lerobot/teleoperators/homunculus/homunculus_glove.py create mode 100644 src/lerobot/teleoperators/homunculus/joints_translation.py diff --git a/README.md b/README.md index 153a3a215..ff7a92384 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,29 @@ +

+

+ Build Your Own HopeJR Robot!

+

+ +
+ HopeJR robot + +

Meet HopeJR – A humanoid robot arm and hand for dexterous manipulation!

+

Control it with exoskeletons and gloves for precise hand movements.

+

Perfect for advanced manipulation tasks! 🤖

+ +

+ See the full HopeJR tutorial here.

+
+ +
+

Build Your Own SO-101 Robot!

diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index ea80e8257..83777a3c8 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -23,6 +23,8 @@ title: Finetune SmolVLA title: "Policies" - sections: + - local: hope_jr + title: Hope Jr - local: so101 title: SO-101 - local: so100 diff --git a/docs/source/hope_jr.mdx b/docs/source/hope_jr.mdx new file mode 120000 index 000000000..402422634 --- /dev/null +++ b/docs/source/hope_jr.mdx @@ -0,0 +1 @@ +../../src/lerobot/robots/hope_jr/hope_jr.mdx \ No newline at end of file diff --git a/media/hope_jr/hopejr.png b/media/hope_jr/hopejr.png new file mode 100644 index 0000000000000000000000000000000000000000..4186547a25db052d6bedf07ca839889f2c5c055f GIT binary patch literal 73277 zcmV)BK*PU@P)Au$g!Px(*!t0{E<(szPztsD}+W(fa-LA#$(dPf9zUS5H|Gd)qk*?ar+W)M> z?8n~!ywm)<)c?ZQ{l?q>y3qHZy5z9O@5JvMY=U9QUZ2?z*qTvbm$IblypabQ|$R!?0_MMpU^3=9sQmYXyxC0R*Ahi_$N zQ%jPt+&C>MJuxg05ECmSAR`+Xf@@(gCL*1amTg&4b!%{VW?)i7JsuYnCm$Rb6A=~` z8dXL>d2x4YT2*#wX?|y3jDUfgj*oL?WQ=oYLNzb9()^;Cof8fWmbBu2cYShgb(MQ@ zhI@UdpQ2(}SzT08k#}tW+si*TG_TD1d2ewW9wc>OSA}$OZ)tCAVq zcyUBKI+c-@Y-MVWgolrPc{VUGmx6n0U0nX-(Y3F!WnEyGiHLu0YL|q^PBefR~Jo;JmEEy}_@kuD!Rs zsKoK(#<|zDr7bEe%f`$hCo%KX#s1a9jHb@)&A-*q*3Ye-|M107 zc|}uaJxE&r_~SxAMAN>w!LX@me49{ScIM;jmYsMDY}MWLw!PA9b!GG8;ggMfZ>MU$$(+B* z;mXqLl*y5ww25Z#gXZwUrr)K8ta{YYtV3Bldck^Ey=u(tyJVGJtG}f;##o!|%va!X zvH7ou=af!&Nc-f)Gw5S~?617ixpdosF_lKJ}My{FJ-b{*gA`ioftD3nSYJPL`Iif7l-kOz`L;lXcL#`mipW!FGAuZ>^@E zL?1BBM<~bKr)?klkpX$~(Kp}qEub{ULBTMuDrmR+x2n(<)7?hg6B|F)SoCYIwV-D} zR`S$iT>J)uARO|D*!*U2H>(O=Rvzd_{qQgYA|&0N42Zg%xujYeBmJN^gec+zO~J*v zYzBYM{H#|V*D{wC7{L$=y_5`?IzrZ4bmo*M%=Cve%HfNKT0K$@r+>!!fPVS4m*p1E zGPK~EVu@Z~68QhhF!LKkouk^~pG?&q&iIU$^ z({zrId7h+(maF107zPP%0c|t5E#=7b>{b1QOSVP^WR8GG8ci#45HtDw!S%aV!#Jqj z6TA9cwt~;`eh_u*mqg1tgEjy>)7|8|P%c9%KZey$LdWggWneQ&aI>rm51H&1PSNky zAF|QyJzFb~_6S+*{{zGunA9T?W)&2BD#gen=38gAtsfPs0@Abt84#|Q&jaQ`;Y23Y zBcaV5PFuGs#fZw=F`xb{0yJ*S;)t$A{g9b0AUDBW(G3G4q9OGNP>hu8xjgGlHvLJ` z3k{kWdgv|#nqkWjxrrqV5ioZ+?H0#!HFBTnxpogK56BiGJZgnz6-)+95DkIQb21o( zL_x5w{PogZmZU#GKYsSOfNULO^7zofv}W>#WXpUQ7)eP(vDlC3;Y02p^_ zU{Y*6etr4#C0<@$$K(E`lmx#*ZSNE#`@ig*Imi=17=~qGBCIUPcBvLS1sfICqZA7p zI~$7#b{2|ar$SrDu^ybyf3`PLh!%~5fQH+6@J_~)_2*N%hk?;hT7zHc^}FjA`ZY%)I%8G`>PNuFhqlq?|`h*1^iuxqwqDA7Fa2tN{^FgCCE!c2JUl8!%&q6d z)~_do@=UH&c~$Otgt(DNLKcD2{F!ld5+GK~0MS<7sR6cvO#Y}CF}0q}?uV$alAcF^ zo=IRZTMi#!jD6c70z-IUd|vF2j)(QFXS4l@xnCrLyo&HVfu5W)C`|fZPaMgUDFhfo zFhp=z;_eTj(FFGl`|)s(*0T@%`X-MJATKEZm_?vPG7@n2;i2o>2>9z!sd)YQ83sF* zj{VW`a8~OX_Qzj+JKYfUIhJdX#L@sD5Y#hBtfM=Y&R&r+K!@lctd;>Gc*A=E9}9Gi z>0y3Wb@v{<{eijuNLrso+ze%G27#Y|k*EOP4e07EOCP(~4Bp*v^2h_%(y6!+!t&Jb z&&k;z9S>);o^3%$ar``!^?|-@5u+^e$1VlJmq9@5vt}Ny#LLc0+L$x%qJpZ?2q)&x z#Be~#puQt>W(bRs;|BXH2R4)#i8#6{4sriUg#ED#LSIIvUmmOe7m40n^*FxgbZN8p zY{8a$ut1#=rXkE%rPxUT+1nUIfAp*>YdzbB;)6}ON%S8Bs3hS0KF9(asaSVBJT%3KvGol63T0l%L(Nf0y&W*V z(jG{z1;M6M8!5eUBY1t$8NbE&Yet%O{`Nu?CZ!pnopldi;nbL*JuJt=JzLM9FOR%P zY*ek!L%XGz%nnG@!?~A~UYOw=rNOSZ_rG-7KQ2e`1s7RH7ZKRQ#G@)|juHrTwfxCJ*#(mAG+1=+>Mr{Bekbg-JIn&FKDt}0 z?btN-;`!$vSA!iEfG zNkIlBhxl0;U$ugZgIkNB^{nr`6iB_Ez)wol9A>U`WP(?NjL*i9ln0V)%{BYp{|Y&P z;Sc9!lztmDVEi#l>_~P~B{VTG2m}iVXCoRDJUyrlnPp|{{=67EjvXQ%@JS?I6(zo; zQU~_f^j84(%)=H~6e8Z0+KIPcI-I}uk7)gJtww9Xt2Zsa(h<8rG2qYqc%fG?6Qdy3 ze|>ZD!NIJC_wHF-M)sB1d3Y1Mo}8-=;i$2f-s2`|AryD=TA-DTVP%z5`%e7)^!isT zY(;+O5bM^yHS5OS@s!-^gyY&0THn@&oT8bSU%fjri}ws-DY9GJe6jQ9U1^2!_=ms< zKaj>fO++FPHpfSK6Q3*!pr!S|ms`6}_wWA}Mv96Mm)xoEJn`|@7X~1IFF(ykKnD{n zJrm<;RW|915qsO zN==Ce1#ew=(Te!v8F%y3`nq#J8&}`77?5p;BniZS@enxbRPhS_uq7B&3H1FlgH?)5 zy=eP}$NuuvyTFi557V_s#<1{q5>%^s*}%vgL7qwLFc?}wd`RUItl zZoTDhl$O}$7ASpF1Oh^2a+4fyV!@}T#UNgOsv~%Qkb2T{PcQa!@BTbxA?J${n4t7~@P} z_e>GP%eo!-D?0j8*f8c|)P++{A>fw{oReu-q(S_GNFYPd*%yBP>h<^BIXl~}|L}Wb z6UNRVMicxs6I+u%zWeTXd#dFYV*FV=;eSj917Z^RA%ymd%%ZND(IklSw>0sAWf%Gp z{D~95uafIq_)I2!f@#z=)?y;X&l$iz-Tdh4dxk_B2*mA5DC6I$ThS9|1V021=DlkuMr17tkm3P04`mk~X<-(|4^esiZ-yDBTUBtI zPRumdJrH`VQdDI66hl+eU+x>?F7TyPJfS@E1sr-I5DYsXO5dZiIYFBnwC&4HXEA(s zzWkMnU%<;@#{%d=;hH=-~{N=)r>+O8Y3BkbMYmb_-blXiMp-REj zsGd~fUk;1wAS=|A&KDbzHPqDdXM$7;rYE5?7+|BFz?rcx=g@r+@U-2AS|cC7V400W z=j;SOTyCr%{Cv5d9N6?2rv;R+0zr4(1~_H*wvyEK6W+<>5GZc{B(tkw94@si6TqE0 zdm=a!_T^S} zzMw<;H&lkQIwWlx$8X5EHNS=sNCynbyVTx&%mD>80LN^}6N3dRF|U%_-d;=oIMF$TE|3zZnZ}n~hlGd@(bR!Aha+w`phzE2 zQ=$_@n`!%U(0$-}^H<{dN=yC_Hvu3S=^)l%VlM$BoxE9rAiSO2_LfZU*r?q$ObJBg z8&}hpHwL-T1s!6goLEGu>{s+)+!M8x1ZZ6%N;eIVRs;5=$n(mCm1+qa( zAj0Q84r4KzLpaU%<*@nQoj)wTpgFoie|53!MX~@1ywHMTAt+{k08<`*;)ou4Adr)` z=UHlh{RKMYU<8mTUrmjlJ1h#>+euc~{$h2BPl-bU>8HdvjQ9}%r}w@bG~dhk!>S9N zv9ENJqAx;(Ue3XC;l#`M#D55Xhmlw;g3JLj{>HR*8$pKz*&?Bg=nz9$a#=R3qdH{4 z%0nf{IS?2sAVgp+r6|a_;|J4vU(TWVzE$e}F!C0p&@YL;=&b033+V9B5IKHizXrd= z@8BGn#ORy{V9)|!ka@tZ$*sG``-g&bnWF?cL^19fVnLVdkp4FV_yc%fo~K9D2>hx` z7bC!+KvA|Yx6jnIWzO=qq9A@~jx|6yZ>)|)qS0t`bEKfg$Oc|&**Q{*p+7`dX@%Uu z7&BG^DS#WtCUzmt{N(>mSskqQR}6ReuG#F*6biVgBJ2Uxo1w=;!wiwkh0?-& z72@CX%RhU7z}31Q108}_P#t2LH*L;v#v;%qqsEku$17k3ZzZKKCkm_C$Ubs4;$4g& zfYN=rDf;l3RkeX)`2)kKD<%OQJOe7w)dkSKv-agxPd&Zt?2|7&T(#QQu1brrP$-O4o|tbG}v zTYUC0bHxvzb;?PHEI9u7+XaCf29clf*W_6cXZ&S)fZ{}2kM1*LJHOvL_}D-H5dUT< z$t(OAUny~h(z#)HB*gd{RnpJLZ8%$SiAD&#UiD(vIlSO`_4J)2*7Xt{KMv75yCVKq1(&oA>YwRw7IFF zp{52tFgV%8@ZJLa#gy?KgV6}3%}g*I0w)|m&>;ymw*30b3CWNXD!SUaU`hOt?u_>y z2$yHJB7?4RI3Os5$(_J=_L)+4cA@Tr3Mc8y18>)wRf}!lF~7&{(o;`ba1J;>O5b7E zfemGXIWh77vnY@U(1mA#`FuY1fnylz!dbhR4%tKgr9s1P-&=<|>2zofz)K#p+{5TXMVv%46{qDPOUTf>W{R5j)p2mdE+ zdhs^`ebP)8+~29EowNYx1N;bmH-ar3j$+_NM>tuSSx!SRs*l_#b22*{k_i>W$PQqMty|HQ9Yq*=Mfq&ZfI(5j%K2x zs;VS!O6V7C&Jd_ROli7u;|GTp--AO&cONLCaj{|q5JLdn!fbJbOE)zt;Dmq9^+Y znF~}whb%ePS4YA%1@`q+!c;nYOpO4T`6Q7tlPO36izCqTG($`EyWpceJQM|yJy!y| zrpmtX%dM@R`189buu%9r?jpk9@h1R%0za^p!D!KuH%i`Oit?qxhU8T-_Sg#!*a-OX z5R7yZEa7Fbi?^7uSJ%I%$5z}?L%*tfT}h~k+XnHv`cWmGO5FG4!x0*FA55%HBvNfJ zfX4D3yaw;?>M3XG$bomlMU6LEJwK_|zMtl0>BTQN1mSf1Fz6iQvas_V+Y0}Y%& zSp8~s!blBV*5nH}z#e1hlh=t3S+KG>=+W{TS6$jw-JKxbr0GuJqVB2Co%M8=jlY1L zzs9a*?&%-!cgW#TfB^}7!gYurv5{!81V)|Q9)I>M>7C9=3XkjGzzGyz;yjQk+{9ZyOyP=m?E@4}YbDx*chhi@40 zV;!mM^Ko6l9UvtTO=Np@EC<9@IIeTZ0_&m=D#{<@WASx;lsLF{)HJ9L5d?16Q+gtm zysVl34IwMkOy$TAu;n7+i9op&DYQA)C~<;C`oNWCXiSh704sn$X#oDft)z3tnaCdn z@VYF27ZTfaXu&ut4mg5{AGHWy;rOwvkGnDvQo#XMA%FZk%mo2K#ws)Zg`L2=6MyWQ z`0cNW^h04E;g7gL!&Lf@g+qcLA@J=~+*Epo06>O3(IMJ3T~SW}ay24-a%-hzbI$LC zpN+o?If3t#X-fnPfZiT_>UaB3=29wukv~!eT-d_-qju9^)KPD62pM3MH{{QX9=bm2 z@{qEss(OY4eIaESiY@H%vpgO0A^e}Akn2BT ze*i!<6*ttDw;H09JkhP~QHyxKb7}Uz&QTbv(XKua97>rC3xIa{TOmmR7t4PTLeFjsPAhOjZGRz;j3jA)!f zE`CLUnvYC~*u$_7_#^h)j2mIgc|?op<5cn!kpE5Vs-K|AddYFp%i19M{L-i ze_hkm9KoBXqzLjpLYF2j{>**36ba;r5O(5DU~QS6nP4vfwo1h>qW30cPiYv;}q5U|3b5dkW_+~$xa|<_+XpG?)4kZqdnxsJn zU7$pY1aLnX6|qZ;YM@M8BFnm3?fO?vU}LwWY9!?kTp%N`r;YH}etsm1A=hqM9B~C* zADYj{T@S*>JRqTvtz#M}I76$)6{`y*hrh5}lC-f2BmK3I8K6UySZv_DKEe9_M2Ez$ z{e%z*=KCt88On{HB7m^&Yp^7QFh~Qk2PGok7Rd)rMqpVA(u=znd3yNJzKyj{7p8A6 zq<{MJvKQ%3oV&5AX-}8PAb8Z|LpN1F}Cuk|q<@&G^g?N4+_{OFPWMGnBg ztbtIRn-I#)O^UdSVc`PZ2`pUJXGv{Ro`198Wc% z2TQwtuu24bp+k7+2OW}7J6<4vS_BX*jJjCn0!5303xUpJIAASAlM(eu0U*oSqkqgA zkNyu^5FGe}@SEii128&3uzBsodOABUY~e1@VhEz$nj2;@ZGq&$i!eqxs}dd%L$Jjd zb6E029gG*`Fet!*u zKm>4h?J~!04-mf)eu?}*|L5QUNk0Msu%olHqvgi-j*bp+fi}N~2Z$d_{p#^wCM&_v z14boG;*}+kVrh4QNMdnDe*>hiVh&K@OCqeHD*dN7AzPqBOrk^j_RG3JO;_fBFo&I0 zk!7@S6W;Pnj81X{^B*fn@GIgOGRPnMxEEuJ>CrDo0g&B#S<#zcD5xPP~>djbksp7W3@O)X+>QM;;l&FGB_kiLTfT)-b zI;4;O=r&U|Qi&(w1brK%eY(2j0dYEK)J&?{&BKg91TZfWWf}n!E;u8^-hu{+xHP|qobMGDg znYlBws94VKPT7a%89x7W|8vj1BcoPCWl34dq4bp_f4TtfUdsMA)et$NwCW%t@HLh| zFZ?nsnKrmcU(gUjrGr-33^;+LjYY=Fv1W}hrT3gqVu?`lTJhCrG#pEWlRK>gx}*<@ zU+|89UBiboM#INGovD5V1llq&0;yQDA}Z3E5kR<=1A)lHeEwe)q?cL%#NI<0Js9pu z{;I31sp+bk?s)iT@Mi%K{0RW11YR0xxb)bT z{7Fv?_Mm_^9kkSAbQN@&8P0XlK)~lkpm#y~UoQX}+r8LY5P)JW@khfy#a}-7!z3^S zkfM|UOf{sAIJJl*Fy$nHa2_zc>J)?m8X{<9VM{3*VbwhYLeR%oZVHGX>k)1`WK#i; z^dZ-6E&!CVF8h$gY3~K-8~%C~`H*nK16%VAg_eZ1OpH`)iZqMnfS*+Xl@-At5N5=~ z3;)YK`n%n?RQ`;}&2rJ^0w?~I09I)NSgLYx!=<|yc}Sq8=nTm8U!yF=XoVR0htfLh z#TJYcO*-cZAl0AQhw#rCKIA=y%ZQv$$cHqY@SNWOl6K3A6}=Lmdl^pF9hum?*QP6j zOMA{zHocV{Hm9yY^#o6jXPzmHC9{`U!@ELg+2x>HgK}P_o6)+Z#4oIWB&I{~A zE*vr7BH5EZB=Wuhj)CRfvRcl4`~_ zz{beL4mCk{AA-;b_1Aj=s?f|m;y_9nB zNmYe3D;F^&5JNx4&rdH9s5~$;XOTdF#^jkq5IAYIKpik$2=zZ@9}-I3(?TN1dPowO zh`f>Gg9vWf8jsQvxlH2=D#aw~L&>t&@?;uzrQI?Z1b|pOoYn~IC25Xar-x|*!~e2J zzb9P*1`w(pU~;joHP-xAOa3Ype`)||aUgmI=AJ-G!E8fLEvzb~9DKvp*0PJAe^-z| zKXNd>GDpxz!p6tuVAFaG&#@+gl5lgxOB{0<3MBB@p`t-== zl2ff7h%ylk2C-R=R*tcU>#B9&XC+YcApcd5K9*p5>K{{QoCSWjK7Pk;wzNw{zY_YN zgob~TKQbUo6+r5TkrKE&{o$plRNHe~oADDNeqTPRfC9sjpAooCr0ZntH|Eh~RIt!4~M(vWKlGt?ykV zgiuE?V|_#WjI%d}YkBtQ|BD4cY&*&P-D%r@s9=x$m33t$m9-}Ryb`byK;?lMK$gH% zq~RditL@lGDn;{g%6SO%3;_K>f6!cVfebR37#6Anig`A^{yKpcZphe%Zf}W)k^`eY zzLP#AAOZ97iviRu$!!SJ&1adZ?K{lLNL`U@ZevNE>oQVPVy& z)R-`^+15h+d{X@UMlfhr%d1Q;tOuw>DNx7jPw@w;;c@xH0EVK;1lOZ&KE$u$g-}yC zNJAP_Ft4j@Q$B=DapNEdq) z0I|hA&e<=TcKxF`v?hO-lXh8AQ%jQ|ia$93)CCYDKpntTipVP@0;w?}LdEEFu$VA z0Ujblj0G?UTv@WRsBq;T)eiij1n4e+3cw>$9EDX?MWvBQ>ZmO?4@&1@L=Xg;lduBL z7!|^XEDE4D29((u@uvZV%Mg8m!{==vm}DOkXCLB6_^1treaHyqTy`O$P&nMw{(D}I zp94hO83dG>Di7-jrrl0et42!(7h44NhJ>{Jsp$RumkNN$yol__7ugl_zcpps@341S zbtU0v@VDrImnMJ=AW0y_Nkv6PCq*bmByTB>=vQTe%%cFS)&EYt7%JVrZvkVc;V^#c7V5h=CZNd?tC z@-UL{(h;-^Abli(8O$Bn^W3w8f5=3~^7kb?l%@0M{jkg24No-B`&QS`hHK>-e~toJ zHTD=F&u6k9DB~UYJO^JD!%uM{PN!xq(Ul>WlzHz zk(sI}sApmNEE&OU2LjQb>d+d4yMHVR-{n+*!F;bqUq<$ZkM|T{2*?ds6b01=f2IU% zbs%N|bO$i99|2exX=sStn8Uf3spB)!Fkhh<=odgq82xqz)#4CTqRnWc@jw7*K~MoG zDFBpCA$`aV()#B3R)wVxky%)9_(?b({iHR~K5)jSFU`h6t&#G& z=RE?{Gcf1i2sP&KuyWf6=Ye3*&11k4JLnuZ@BomI zi-lvuA_=qNu@7lHf0XzVK}d$e@#{ugL(!)3cNcyLhr@{-%gev|X^sFy*d&lO(W>R# zs%1RM$!oplMp~1!8-lr#Kcrj?1;M07;BqJ7Cv5x?bXwCD@VoPgTNHg2Q!JxBEt#Zpor!+Rp^|JKqOs!4FZ86 zm~ak1*H#j=HOB^g!@RY|6XZh-09gS;v9=q@eN0R~H}^IH*qS`<%HKM>G8tC@C<37C zr9ovz+?UHSg6aDMvr&zakO4Z)5NjHUB>d0JfMkI_9l(rRdx{PG6mWZ#P*hY>SopI! z5Jv%20PYU{7(kZ5)Ipavh`83r5=fyFh-!=wwWRHwWv9ZYn8%_(->4 z(Bd};WC;v~l5a=={p>@+@$nI}1cu`De3AK?FM=<>%@P>8?3k~9`{`%S&;Srq!;3lD zp>2N)M*tD)Y2|1S1anm_8Z|(t0iaG`*0Lqx0&4|2OX!}P{@r;Cj7K%uUBZtuE=2|S zODljL8JL?=Sprkk8A!o4gntue0KuPA1l7`c zXi%tphzLgV87}_CCcfAp=e939rij4C{L{bstxNqW2;=~O4g_+SjPrs`EE!G*AdCoR z<*;58@Zqcxxa+cYr#@uiZ&v{xoUcv)PE`0&P`@eRCh@DID68Q5i}L^w0;o>|!-8;# zg(1uyP05M4F`%Y^0iTjV79=HrNW;7>yjkNNE}^G7*&4~3r~u8#E3!XLV@{)_;!1hN5E z01jmuzZ$auash^g(Zzfy1|fn-xq#ARK3G~8d4@lBVb|nC62l)ePMlVG-)Qn7Odlf| zs}If3&rDB$`s+d)ywDu{arIfzCx(3aWs*N#03B<9ZU9~?v#-#> zSILG1>;TFsVgC3q2n2;pClC~{fj@xf8UPLm;?6C6+!shbWE2vZ2n}CyloG*DsQNcQ z`RVkUuYV5EH>(gxRIyET`+6UfOMIsCN%2wXHKdJ;Bvxu(fEHvpl60MKz# zm_Hy81sQd825^~^aF2z*wXBcY*Qb9AiVsEUg-vh?@FU49*Ap;CfG!q<`O^VJ++h4B ztS-bTkoeJYpbmci=$)hxvvdLCY%iZ9fLM0UKIA;VDnIjnZJ}_odsN)5nwp&t&$Lgk z`Sk0ByquiCON|55t2_F-CO-K3({}=y?>|-jI=3n)0);nr?Nczes?kNzeXX3A0`fwU zEW?TL8Le9XSQ5q}%v~)77_S7(lW&qg0LXiuGHwhUBUdN%4<%^u2LbfXz_Ke^#jE@p zMGiJAAb+E@2Na4kj_MM!x3tjE^GPcC+9Nm#$A@@H67v)IBzAkKH8K8PKxp9HRJb)f zIX%7R`#vr|=RFyXzTMZ+_r>^kKh3Rrw44mwyLpE55D43CiQX?wX2ipY&^rx+=+07$ zqgb!&t%2@sJO#SUN<2a+}e+a+))h$KuW z@d1GnHh-EPQvA*~`SS{hnG!yEOCT>2R;Gk{PLFyp)SOKAkCF@i^6fAYI59oB=KHRN zoV-s$^w`%QNCG=MzpQV*y>WXY@O_d%4(VGA1jU4i(>O4g$s*`wMbg)!L7;t&2+Xj< zs-@-Hwrwv<0=Cl!^m0`TYJubr1Tug$49u7K5pfjxl0PJ1y#%9K5C|XuM0}={0UBuK zPY>NX0>Dj{1|TRj*%EIT;**PnRZ~JvqCVPoeW*U(()%2_jG2xfTB-Ol$pp@=BK#;4 z9esTrGxg2nh+nO!sCnwSfZ|W4V7DZU2R`xiu;xYngca5&&z03!~RZ8wh#>=-!y1*u!!qE&nb{5=H{f@)kfSU@icxtf*xEii?X0KS>}D z16L~i7`Yt@za>7!d?^C9Fl`n?omnq6Q|ohd zVsiRN<8QkvIylslApR799I?mx``f$Yu~_4Zl{@6Wm1pA53ZTzJK4up?%VgR2l-mQ4I>e}550GU7vPQoR1+wZ){9}&2>LkZAJ z05$$nL?2X8*$_w|@-I5qwQ;(W`bh#MgnufmGkUlLf0?5G!agJ%P7ZA9d!~76Haa;u zIlX2sM@_J;`s~3F^QVH-Z~xeM``8CVv3M*}vg6JVY>gH!TP;ur9y>sI3#q464}>ZAK0n zC_Wl}F7gKoQI=UUfF3oN(E~1Q+w+FePr`}8kHgK4RD4;WI&@3`__O-UzTMF=iws;J z8$5sfa>8${d$=(kj}Mpaz2il%fk3&Q#90H)w?Q#2vU%EB;q;9$XiUJ98~~uJ6Tand zCtxk@<}UIF1*B#y29Wtv_!aO=Xkf)2oP@z&Rnd+}z-a|gSHP4_0zVD_It@VZ=OTVO zegZrL!zCk#S|A-Y0Zg5>?{x!LKNb!(J)dZ9rq<`Tx$SRea&k>qmy|$1J>KT+#3Fyu zwxciqh~;m*rD-@y1a4cl0J37`$gAiH8%5$!?Oo&Z7uK_Hs5y!I^m8%em> zF##i(JY-*Wk#!eKALB5E4Nz@v}mR zqBLKMTcC)I$}Jz1rsdmb&4Y&w3=Zsc!1-ss98OSnjm`Y>`In;ed2(`M&FT)Hh&M?B zf8e1Xr{AWw3lIE&`0E~OYHCR~5r9o~Wh+*^o2MzD-kD{WMDbik)TBKD^wvPE)x@qb zT93Y#gqJM==b8Xi=pizC_G0)DhjIu4X-1Bd@D7ZdQoLQjpREL=6`^Gj#ta~z9A-l> z3bX?V{L;%Ge6Vp~K7pcRVB2U%6XF|54PA4?wTGO4`FY_m>D$Z~3pu}Z%nAvdAq`v* z|Id8>O9v(47z&nX=``BA8Of~SAyeqP9-zxZc!9dA!Df6;-$; zi*?>?4!2ItOiWCCG3WQqJwv-u&JceTKfXbUGiRM25Z2{p1<=BD&aD6Rgyn90o>gJph8gL2ML4?L0QZ%BaQlb@~4oK z*csPG=o=)cO**j6xxF)Hr?Y4X*#91{QJeYQ+tNSLVM`-29 zGcx^-jjQzA+uPpWKhV-LF!n)v6KUW+HG8Z$Jm4yUFkzYm$|s(nCJg|cjj)a&2!s(v zcg!+5VLbs`$de}w|G-~1F)rx?dd7O-l1jeVgC0mtK7@Hwz*WHky9i*ZWeA8+HU!fl zksHvWg#CPkAh?U42Y#?6iw;ECm`%3(aHIA3caf2ikxTu#a#NmLIqtkQD|?%(gc@q$ z8@ch(!GY=iEoe{Z*?342_%r?e9&8%=mO7t1CZ@k!^Sf`=CnEj!c7Gt;M=wV|$>2~M zCF8La0hlbSDbIh$4}>^n@YU~d59Xmub?GyEA}P{l>Bk%0=S=505t{V4+lHQL?@S;OefqRU-Wk_V-_Y%+oA?GP=8a`@27C zKWkvWBoTN;O-1?BK}`dJpjG|-w&hBhwAKF>&_8{tCREUKT9~jyKn6jemV}@93(G)| zfHC7s0r(JWKT%Ks0)ZX~tgI<3DJdr(QdM};sRBPbxhG&5l0V=_2lIqnL1}Wf= zeI1#+iN5H3{_*?Hu3479_|VwcNB3WObUX7mu0)UkY>UP(Jhm|!ZOW(a{I|BWcO-y0 z%PsSVrUcHx6tE(mQ*th~7);732%&+?V2<7Za1R3T z?6XTa2^STXmYzx!GI51k0(Ke!s$LjALID+k4gBV~)(`|^_hl5GH33W;qbV?Rj~=T3p@Wtm^$GX;=s-WY zk7JH*7v!DQ-QUmrwO^5>GT`(3y-E_8qSn+$M%{*i+Mv+f`ppU`5Zv5h7jqkcX(#1S zi~)My<>)2;5A2+Pp@5D8NRI=lWdJ3A+!V+G07?DGg&bE_Mm}Vh!h;UluW*Mdp?D|( zx(XoihkzO8SX4lI4e;_ZdqpCB-A@nxskJw}m(uMmd;PtRa)8+L_!omit2+t6b*TA1 z#Ko5x(!jY-Wal&2pA&pB_)-6nV`BruA2EJxfyWuZkIrd~a}4ZGKBRF6YE9)SVF)0Q zv+N|JBS?`c5Tu&4cO}NT8VwCpkCg+1?ke~vlWns0+Zqa3WOU;4Cr zJnb!qbh}KkmjU3x>#Q?38;;!`r6fGV(1qwT+o96$_!zYwbPp##>g^}^dVAYQ0=qAu zc{q9zd{?j!d3BG9inB-ZoCMHsAreg(=2~E;K=2Rh&O9crDhlIc>u#F5??#P*vNJ+) zhJmrN3zUjAT7^0mu#BJ%ph6)!Fu2xbibc_kMjccdwGmARjWxzdM8M(_qj4cNP2AV0 zars02!|&XC-#hc>y*Hw9Dc_qnGhj`O{`}58=bn2Wry_^g_69Dj9LVdxKPyKcy7liL z1F`l!Xn(p9fp_BJmuU4aQ!;tG)Y;~a#= zi$G<8tlP9&w5W$z8sMMOchr6ly!;^G`^VL!0Q`2J73ZG2^p$%@H-GijSAb7yWSHJ3 zo_lWfZYsXaTsSYSk2D|HG7^UFo!mS!jXY%YmUPrqIT2>$z@K4zSQX`swHOtImmaAt z1e2};?MKEekwKVG%4r-X3Buc)f8!X4UFOU5;off>dQWBll?2d<(L@X&3NiI5DS5EK zUJ#g)ziR#ZmFZ|kAmCb!{BZl(g`-X`Yfqr2ta2e4}!fAP8Ki!V+^Zw?cGn@3~! z?%ERWT(@QJvbp-aR~2i_pw)34U4=tb&DXvgz!Jg|(y??mNW)SBN{6JdfHW^s5>mUg zNUOATDIy>t3d$}ZC5_U`(ntu>vB1agyMMqvGjrz5Jm);~%<#7B2hcu($D4;yxO7fZ zx_^Q{BV9U7ET~*zUWR?TTSw8@ z^YOu@qjecK*EP2(8q6Ig-Ef*tF-8a2HV?2|x~IcMHRRP0Y4ftpK|*=*j{~sx zoulYf{c7)UCVN3pao2h=cfHhe7Xl@Y`3d+DzUj`P0;Ny$3%cK>4lQQrURkYU-)99I z;dx?7o9^AGAz48q;}9@?)}lNLWGYTsXW@oKCM8_xyd?W638%!y7JVCd92JDU%oYu; zH1)>%(SiYyAE0hN-9K*$cG=%o|DB#AQ9c&u=SQ{jESU@XeKXAXduwSmW1HUfhIqz& zM9;hG`E0R9Z_9}Fza2X>wH~Qa76-J_VP7vEUQijI?ievgQ9UHU?P(9h?MW~4d&eKM z{+1hkz!rPC^M`3L^s(!sgN?PY;1r*)j}MTAQEMNrKBuKh>{U^|?Y;E=S&j`&EKUnd z3eTe18p%e7EZ&hNS+X(aBkVgU5PwK}t8y@$aEzmHQbYBZ81+vI|Af>sJB{XxpO!7Q z|Bap`&;{Gu>zo-sb_i41PhUD}CaTBkbG@G|z;p*YF)>8ng=}xLa)V#_fTNE8s|q@- z#nSEwv0Z=m1W;vH6CbJ62%m;I;ub!i{Yvi(>n&Rz*&(I;NhOQ0 z3vzk8+RVxg#}(WF3%3AXdapN$WizZv0-2kDr%?*%)5?5V_8;RTKR)j-{DTfN;AcOyb$2iK4OdZ?}(xCEb*( z-sXmQ$ubu=z{&*lv5z0FKWRml#5+Kze)xYYy84l;_9s+zquPZ^+3Q6Bp$)j=QZ(X- zPVDG^^TxLP64QbCWdfJ)U?F_Tu^FYwp)wtEW|5r}@ZtRL`;)Jyz zk5X9!WjTM>ZI?$3mZ4In_0&r&(L~2J=pY{6H9Y$dL^vu!Ro}IaJPf7Wc+ll6G>V+3c8iN~WyLf6k)83p%sW`hAX%0|NXK1Lc5AMyL)b}w=af+PQHpxa8 zSeSto?LzqZ4IZrw+r{qA+#|u(RHN4wzL}s(slZ9Yl>=lZ>HVVa4`()q`+QeJe|;@U zRfXr++SE(s(_<=Xp?0-TpnQZwt@2w_pYLEC(DO!kl72WV^;CWyL~NzI*H4^B^!{ZV z@bqi%@!U2{0TuG}Z5a2qQS3qs25atV8OSQ4vAi6j9@gUhb@){n)Ar+unmx6L1wmG{ zuw1V(c$B7B=b}Cy9OJ7T4dv*?^RV~Pw~hku+&H#xw6!r=Faj%-{q4~0VHqUJJhq4{ zeM3|FM)NJ~K0psKefWi~@HZT$CCF!+{71k_FO`Jw&J<-BI~RCO{*FBH$f4>f4yYgk zW9Pgnzk)t$5hx>$)uh#=(Z(dTV20* zx90?bM#i#xP<#x;HXwQY4`v}RUKJ@E+g}4$GuN7u=y#n)eOK{rg8eG8srZiws_B7daOjk0TB|r)x?iY`(~yj%SlpXHWg*7D~@*~ zJv{10mdWGJO-=V3$@f7hvF(ggeG2TlFT2JmClsRQfquQk6G?`X?5q=Qb~NMBNQ;fn zyj}cn`QmWos;K5045hsei(A+TA3+I7@Y*=wqsj{lWe>U?JRj)M9WyYZiGqF$@9Z*8 z9O&%15*7bv)H&A9{3B`n`yz&1rE2&zcl~6k#eqn}8>2^&4yk|n>7TiUJzZpL#`aUm zlOL~17(^ajv8GFltFaLDnq5!*YmoT#su7M&gZ@4{i}V*}77q29=B$s=rUmh6RYnB9 zjC?G6kvrqM*}^gy2j+emCn9TTd%qyGemj?h?YemoIOpT|rJTI9fVi?SA4$Hey+sa| zIJJdb89%3_fHtn1WXw$5;do7Z{E?%yu7Xi*ZDhI}OZ8lCwJSt4;d;>q`wxCMUBx#_ zjF`{+djq`_hN{rBr}}0VD6tx?b2hi|MAy$*>JaDO#-HtIxZc_Xd3uAA?~P=+o-Sh$7*h!m)l$=_SP zOpS5y@h7Nx<|av?t+}aw1v{L6I>Y0q9OV9{LW9Ujc&?&Z-Y~*Q<|kgTdCvS~io0|? zlDk+`y2v|SMta}ISjR&GDh;SNthy#HZ&A9FmiQx^B|8rkr;n`Y>8;%7;=VD}tMjw$ zQ|(RhoOY3aR})uB1Od@Vxu=CHNaH2T2Sr7{SRE|&vVr5EmnY|%5lj`iHF?terg;@V zYh_r9t8wWkix0*u>w9NqT1Y9t?ZfZh$+CRr&X5BYhyG5V1V{Y+Wp#e(FbjXS{ZHv7 zm!&~!YU;J=uTjRQk5!qqf5pMJIX2#KUX%e8d(1&)LqT#g4yT1>Dg6n8G}0L5*1jZc zSF#s!Cf|3(RK0kpz4Vc(dQKIcwc3~X-t{>SAy#vKQxc$EV{dcS97`T(jTog9B zojL$z>?rjQEPsUF-gUM;K0FY2CDGr0b*G5Im0T;~?U17Hvxk<+t*x@~a3L$As*D#PpZ?!3Q za&}oVK3-r^`>h#eWK(=a;@5bsM+%|nZs4yoSYiR@p}+-ojJV?H2XM>?@6C7@Db}s8 zsrY$eH>YJj+g2JD%}J2}xV(DozcasZk|>@||0}8Kg%>*H?%N=`T|Qt)=T5VDJ2Bmy z{|8AB-Z}&^mhCOrmn$LHU4M%Y-AA88K+LNqsZJ~z`CNWQ67*h$9)v$`zZjmY>$sXc z&}4sUos=|OMp0*vCn55R@q<=A&;v`i7s9Dw42Q95^G_82#vH{PpLj)_3P>T(=L6+ zp{u4)&Hm)kpxf1keSbul$uOB%ZYt%mN~KvOw3`!|6>k4wPG*E8yRvoDu}kp7fZ>q^rP*FY3 zZIf(q5t9{?HPji65w%t)LJW+!cYi)Y{`#=-f?8%dM;E0LtXlM1n)Stmt+~+{zEQ2h z=aXn}Ek=rn-k}!sh+IDi{s(A#?^beYUo1`e;a_oQ`>@GdtvAiY9>gcU(9Gs{6Rnq5%zKtDT886)Qc9@~=u;>kb0l!N-l zg1!aU97bJBFO#Sw!i96YC|TLk zER;A_tF@I7OLgV3r{qgqB;6I$Z--2vYn?Y$-BrQQ5Xy9-_uHVz@;eo1nz!OKYA<-u zavzUQNMKm4i?)l%-rvQtzsx3RX(uM6%`%Ynp9kccp3<8Iue-YGSim1(TDv=J_jO3{ zPK(caEX6;q*xh36870gxdnKsAsh*bZ( z(QW``yBOtvF?6OVelih58dH3&Z9cY$j+mod!F>xtz+Gbzcd!tUhbUumpIoMchxpb$Nu1nO=GmiZ)oO& z@Zo`x_InBB4(?md(c5)?{ygI$C%nR93YldRULYtld+iOR8vy)7i7h5XJ5Bh6lN8f zq!II{*#<_8K4cnkKr$DXj{pv_YW=_)CQkUYIfEnct#G%6ge~mZ@Lb4|aGi3f8^yK< zJQ}+9)rZvl!2Vs4li(%9ICwp-p(_y#D}Gg z2Ng#BA(ER9yht%pJzMdnh&;8+*-xLDKAkV%g2QMcF=KnW?_8vm8$9?}3zl$y>XtL#S}Ed2msS1om>Z!hRdw=P+;^ zZ9qi^5vi*?fuj#O+C+HUnzaOSb^GI{Qfrb1mWu~GCBm9?*AM7ogSXE@{V;ja_B63A!zG$QBS<&N$BA#(Xz5^ zJ>Vp9kL6iP@?dMjLxw_7BCV))YEu>Em&k9tI+783UQ)oQ#+u$+yCiB#fl>+B50H${ zv*{ueDa!XR8TY44iL8nP=5Sg0zJu&uq*tm&)b-JDO#gY_z(M=r8y18)%V4jyzL8Ft zPK9np#7@MUXzc9Z+i-Q%-CpDBh72cPy@r7>40{y^$c{OxM&S&O?U(!fGwk7OVthUs z_6y%xy~X0=jviN?yH0IRjVGyvjrp+o0HzA=`VL><3SUt8x>ZhjM#1WC&0%)-!%n{( z(a~|p4k&hT=S>srC4C>eyFGYzU{6M2B8b(8g?wcr@EZ3_7JK>R}8zWNcj9S zi)DTJJJin&-f!K$e$%2F#JsdO9s)4@D8MEuReM-Jvwq4$I@ig=h!AU0yBHg54MkCN zq&F58jy&7L8RcKpPXi}%llF|rSiDSKFHjt7tc7Cjcu8=%DGT4v8w*V1H-C;+SvhE? z#dLxX4*J0nN}4oSgd2P1UK{gRaPwMZ%Uj6P4LxJP>FsDsZR!{Mg-Gg)tLGWiE1Y$M z#_;&3rXDbz?aBP%O>9*5Rlk>!j+AoVrm1j6TU|CU)1*d_c4~b+HikwNu1H4N2W952 zzx}jz2DBQc;aqGWYs~T@y(hxzF2wvNLtiXj>YhgXZ!%KUi$2V1fpS$MR$+)&<7`M^ z({V%n_OR>c_G)X(AT-w81A|Y4ZU36l0VrKx9qOLHJjWE#M_6v1KH;Y=NnqLiNF5bs zx$n4(enGns9+zv`=Gs+uVk-OeYk>cwpySE8x-Z8Wclo0Odxr-;`-%hw7{(>digN_3 z@S&KC;tu(t`=n09qd6o8KCp2V->8iA(TIxgv%KQEFvQDpg=@h8j!9hmp))Ln@|0o{ z{ay-GT&=O3c{yV4SuL)6o8V%?LjMyKFw%8TqU#n|1ySGhzzXq?Bra%sFQx-01~yi7 zKz6K6R5DIcjG}xpx7Hnz~Wk z&dRbUal#lb4n;5io2nZqnbL!eWy_b2QcVV68ZQH4y^XdKr z^4PzV=k?Q|v+0D};b1IDK`NJJXbIT(iJCC*-M{gEEScT=ao1TALrx$-`oPzo5jg*_ z_*^$u4^{t|#K~7v)6z5DG7@d{?9ap^B=pJ0`SFPR?cH@-n^?RkbkQIzVjs7eYM`Z( zAMfTI(2%gk5|h!Nk>Rnh{Lwya=clDZ#NzTmw;UE6eEa3rRt1BhEv_F1H0j@6naszJ zcw=I#N`#Foa7k)};sMN>Y14naaB)G5!+tzemB)LtOEKqPS$3|*y(cY-Li4}Lh>Feb zfvjGBxCc(OMu#%hnm-fkb>B!XAe74RTceWdv0xC&@Kux!2WT1zcn z4K6mgLhhal|Dz8%|6LT`4QZ`v7q}x8wreLcE)BWxs_`xV<$Lkdav^ALHe~=TT-nvo zf5hZrCq2i(7gFxQ%Sgq!7~JjHS~2*F5rWx{GEXRZyfJpoq9YTEGTHZ(f%HXKJtOX0 z$pz+jt#z35t_qBK&*3s);gvfDP>E(%;QT=b(fdx$z=Fu{+A}s*D9IlMkSoHKP!vp( z5z0ke0~=hc;$^oCx=*@C!QdxA9V@=X0LHx$X6V{9@_6FSC9$%>+(YB*!?69h)Drq@ z4!iKPL0FcY0teQqdFbcH|94GdkYY*doWsZe#<8SR-S>)HxTB56*xVZ_n&_4AnJ+{q zBZFRiKT;LBgUo8a1|bjqL+1<0&Wz>%!^X_qS_7GbAWM1r6fruWa)`p{ zug1>fJH~79*^r~3Q$4QyUh(Fmwd6*F$u1ANMcFyp9)0s*jb~DKrl7A;wA#5FCb1=b6BC6AEH?mm&5ZD%aMgm|0UxASm+zh^S$T*T4On<&`7!C1GE-3ZNhP3u zfk@R16R2j=0J~|Q#8d{zjMLv?|18TaUb3)B3>uc21>1EOj$>K*k<2`Ci1w zkpDQ4eDH+x=FD(`h5i6ID&^{JPtVR(PyGA$DYEjtTDhS_r*lnW95kKn0hT3 zj?KqT!_2MF`mdGDW|~3Z_`Nz8|EK6q38io!2FU|o7} zbDv%)Ga$n@LPwB9%U70Xms!s%iBMNd>{t2?C9U}0h%T8y&MS;A0EJL%6r;)LcEqrZo`=XYnK-9eaOE+ z7cl^3GLgs6a`u*v5fh=l9nRWoPIdl}pl;P`#CHDa&+R`FJbl)D1d_%)YvasBK9}L~ z3;LSgq)})en)HeJzg={XMe%0x+K@a8c;}*i@bufCZWpQtYo891P*IexqVd|mgC7jK zB;ATsfRM{mmA@1>BsE0dSPBN#j7S0M8`OcER~8VkEa8r1B=(-Y#*;&-&xkL72%>iZ zHa71$p$#(fZXx5h;&-GaKM#1#-|Nwk;H2Q!Lb*imA#31+)`dH+mbr1{e8Fr4aPT*1*6g*q#X%_yN5f`kvJ~kq_|u&RbjsHhO*_ ziN-l?w>`-qIzB711bnLhY?{);nC7V_-QcHP&(CNdRBPj#8O~Ur-*qdYSZ-A96cq85 z6)_f`);CCf^?)a^VZpE8BU{8^2~Dy{M5bQc43UcNL5YYd3b&6|y|Gp4x@*dB)pV=# zS%j5H1Jl=^OI(NIe%-w=Q+(}1*}3oJ2mShHq`|#O@(y*w*>y#5R?}=z<@K1L z3YDX*3ih^oS1JroOU+0@^WPeycM8fb>}}-TM{0Kcie+cv5?k^__#%rkHly5Gcq=uo z2TbJzvz zPE=%AbqNNro1%aB9)$KB(WI%^m?&o{Hsrj|gg3aMAjf*9d4Q%#K1Ym~avtw6f=Yqk za=*`CwNb+Y`bioys%a*qS`OiZSI>{d7Qc}Mb;c)3Qd1w!=u;#o7(X^}fs2CL=RBE{ zWuus5`PuHCvQctU;L<|p!ZxR?Jw@*;P?yMort(Q+-wor20Q zM2a~b(wVLZR!lbQbC`Y|BrS6N4P|-f@6y&>w85)Y8zqeg&E#J#$*zhADGAL%z_k7N zM42th(Bpm0<=GA!us43|5)@eGTKmSj_Fu58|5c90=RipNMoUVnb{R1+y;X0LU5`!C zoSoQwuLb`*SY6NP_5Fy^1r2N8@3G7&0Sx&i93LTm56OD0VpwCBT{>VgR)}mzZ)zy| z4zW8&w2=)ZM$LUevv4NYa%DJXSCB#9i+R8Gppgf|a+#wpSh*la6(8jD-|P%IWRRX% zK%~B+iViq0Tu_bG##gWmJERP?F8#<8&M?aPXDbXmf^PAIZ?+9{fI{9lk5x-Dp@*41 zV6TXk%@GDRONWD6jdPBs6D?atN;^+qt7XORF+P29+vA`g?j(Z)&p@9c8#tY|qK=`3 z-ptltjr;+lI|g_Od@r%?-Xb4c#8M0+utU!#>CaQqSXNdREMY;DXcHh?d2B&-{_*ib zLn9mEY(EC)^b2GwWj6mlwfAF~1{f<0+2|Dqp`VQf*LJ3fWepdeV+{zu0_SE@wYk-R zrjgumbBcmLWO2-2giuBXzU)p`FncgK+vdatB>8VR?cWk?@V^nX;oR+_u~!)y95Qh;uj@@Qv6 z8{oT&ri)pfIr#C^Qm1!WNhnxG#tsD3LD9lqK6<(;lbQ0+<)J(c7qr~Nh!W>90VtqJ zsU#Z#A}QGN*gp7pe^@LpCr_E5cW9IJWDjXV-Nf{p?;8}X+vOlNLSfFynAyxo9nYlJ z&fAqHDshJQCt*QJ5pp_sV>uup_{-FT`J>fLwM5TTrJzF`xZ(pr%RwMEcLiUkBF!ix zk4Lsmt!~xo&k2IsdC?BUt$P}WV--pW`_txj8ucF(dzCeHfg;RtPPBS}2w{6_=Fg{L zuf1KP_k>wz=?UqY%`e&ypyTfoe+pJ~(ZE(12Q?3^SuvnujlGagadu?ZFDx=M)~Nkb z0Zp=7Btr5`4hcagsI2;a-ui-1>)j9X-aW7s2FLJP$8-|;U6sj5zuN1nrJf$oV&2@@ z5)S;^zQzpRqdFc|uHF<%S~hCx{NQaNToGqt{K$lxnpTr~6bnyq^03n2R!1ag=gJlH zbOW+@{Vs~-A2O0I-m9ZtpgDo)2eX5f|IH3Ic$av1&R%U4gzfGIN6vkrYO3U1IQhq3 zS6NwkcBpa^_V?3?rFlbbdGy9IVUel=PFWJ~{I;J8CxGdw{7~924YxWv^=)hS3S63Q zD?AP|@|i3IkEI!Muz{Ffsl5O|R4T z7i8XdQ;7Sg_#q)FBsqEnflk*zE;Ug25<1$Hyma0fKarJrq~YbI$}&zG5&D0`z?n1;PBnrF$6z_I8RPgT!WwaZpMe$@oXDcR>C?mR2j@^K)XIa^NoX zI)s2L+hM9pez_|+{Vn7B0oHC3by4b`&{04QgZ7_wi7TSVV>Zv-lN(^R(TQ|wHALg6 zx6IhaFaP-A?cS9{SZ)}8|4Di@Q_N`r9*^%F9v-Y!#K71|Nud|t6EZrS=6!_?=f~?n zu#Mcyk%1y4{EBs-q$A;!qt~};5JkK(oTJQr7FRem+sIn;*wN=Az}Z3h6lFcL9K(VI zPkE+vWvGvAH3;Uj!NPN>VFELxzzXZ0*)NsZH^k^g)> zh~2BPl$useLd4)wHPQD{z6fEM_p-Y$|^yGx!@ud%8656 zC9>o0fWBD{4%W|G%?UJFit68CO(Q(@4A0nW^!&P)&@ymtW#2W0%v{DJsnPFNs-TcN zV+m^GaQ|G4O8D8xQ%)*SGA#jwKK1TwZYG!h{_*j&u_!BCHP9k!Tf}hQrf?ER#1n)h z`>i?v*hLd0%Nhmaqoj^w#gjs!Hyn>{6EPaTtPM1vf0Zf^v~np?8#+-|rKVi&bv065 zEqX+`Yd+{`MR}bT;xh9s)dCG#e}Vm)CJYgd)!#j&dN$DW(0-6ffT$Pp8Iflxci!I3 zd@Jhu(C>=Y;_Sf-wB`zEz(7!F*<=FUBcBe{k00`+VI49A=5gZl^EPSPiE;UuXJwb6 zp}znW7X*GLYy(y^Jmm~$RR(c$ zVB`nhoAvvrL##`=5rup4{z;8Q91NI}s?~?xpEfi8sv`M?uL$w(o!-66m6{WKHa14a z>B5|xkcs2{;j1iVEc|L3Pl8y233c|c`WFToSaeKvmdo-9YMW+!#`kO zS)7GD+eF}}=4NLnCT1_bH&U^)1!Y&b^mwgCR;c)eZXeccr%_^gk*oNUl#&8UFJ z6-I+Lj@=pYk?-=;%(;emnm#w*DULp zP|nUBUYYI*>&Hq$_dUPusjK%(w8-?FpzoNcW7q?)Rf z6b;#ZBok&HvP!}Am>3Y%&MD;5pE~v*;=C#k{sw>A#8#U1oCPscGyodGa5BCCaiMn` zbhxy=CxJ(VZ`y9+4^!nFa2Rd`xs;Bdsm`$kGOe z4K}P`64;wrOB+~cJd2_8cVq`Mwlj4rV!W1f@XsE`HCt;m=jN%ZgQ;;{X&6#)^kseU z{FH(O7DTsAjgdS};v|GXTMiO{eJ};_^z2R^&#bhJjCrO5%j>v9%!TBx*S+nu(dE58 zVP}_=S{XZ4GVC5v;VWJ13 zespj2gM9W&te`@;uFaV~%j&B)GB1>(;`#P}anc<>^--8;ES19!6l%hXl>yQxrW&%JE%VzM zdv`t?Ml5E(Jy;i8CbqJivi+MUvgyMLB|Wz>)3M#|$qHVIy#JyeH}Q9dO>E)sr)#qR z^c|?<6r#qJ%;F%PvN^M3cLjISo??Cj?HXUeKj7`~jR5c-^)JvAR}2N#qIfkf46W9B zD%8UK@t3g#DCh8BC^JTyl4=8ReYMaH%-hb7{S#`3si(UevK;H^b~{La9RqCV_(qWU z2lk<1RBWE2%!nehnr{#V9{)G@4e`09WJVkvM22nHN0=-nvrF|B)}@!x{lHKSa^}gH z*s2MSKhg%YLebhU;#8Zrk$?tM)? zS$Mi`VNN2o0dR7@Df}^75dvhP5pku zJ^ANIcK3Xp(H?5~$ehHTA>`^G6b;Iabx_WE<3lp1d>!at0!uMAZ4%wePzvc2m7qM; zEFt=fHxoabgW<$OsS9!Xdj?v*XczeQ{f>jx^sw@r|}a{!*7 zots?D6PZ2QFDZ1LCi6GY-I78i%roD9hZr?_2l*lnS)FQr#)}woDIpT#;B@nWn2H-) zMAsd4a+Dd_f|l~{)gyh?K#9=K|IoS|0-l!YlVBeBr#U{v{6m5ymP8zmVBA~F$iJ(j z^Km}V%!wIK5~YaW|1N$KwOo%%??*+ulFRH*O>el~Y-V8(UuaYRvORxL&^=D=udB2I z|3q0>!b=^hc9-Q|y<#!D`g9eX-?y*{I&l`LSin=x3=7Oni|(U}DJWUiftLRNqFy@9 zWekA_^xZW??Ds;ot>h6ji7M)W?c6%Q6_QUmqZlw^9fW9x-(d-iGTezY08P}BlsydL z=$P)))jHXGG^_uu8@ucf0E4U`9WW%Fph^FlSOPu!!?QH{mCQV4#etIq{>#Xgcv$n) zJH@?FLj=&c<$8)KhPnCBS>!MhV#Og$k56|tYuxr87pRRF$l+i_F}J|#T*i#Wg({o! zQ2;7_>sqW(^T8*&&VSf~`=h{Wsm_HhDbF)u2(N8tZ1L*9+DTzEmoS`L&8Bz~MSA}u z^h=E!Ab)1KCxmmjM@GZ=UD=m~H>14ny({|$&(h;NZ4Q$h$;J{^Jd99Mz<|))q5-;l zHTyJ=Z$$U7-&1KV{p{_h#?!D2tcNGqLuq)?m+z`_(fq&r)AwH{24$ZES_;^*#BW#uT;P`wsnFcZsQt!JuP*C^!`n z!<7o7r6aOhrW&OsG#|C{|G!P5Re0{-4y8n5T661Zg*bSk0COsPuT^o1`nM9Exr5Qy z;;XM~E{aO1t35i_(88Af;AJ>+Mf~P*6EYax7E*=6Qj8SK>`hAySLe7*Ji7fyTJqJX z4IK7ky(qtVW2{HO8v^A2&y-$`r>}dtih3x+`U@0hI@yHQ{ZzcRRfyQ0F7}z)j8Eb( zu#2FTDMs%RS}{CCl!QxHPzw-%b+ae{YETK~a8F5 z;HSfezoKl4(EwsO?dVEABv}nw|F?HFR#2#pDw6TJMBz&_+Qew_Vaso$ zLvQYo^PGEgb_?w&Op*aOLOJDzT*7n3F+LzIkwzcr5p_9YkmX3KN5I*PCoTp-6+~8B zhGPYzuN922YJ6X#5GmA&#O`Z#RaFe#fcI-FLJhR$0HWIZ;27sBK?48gtV?!d&UNU6 z$UM#aJC=)iUG2A=jj#eizIW+}!i{cy4(lk1uvPq=y_itzhFd7P5xXr@nQtoI>h|yJ zTpj4XrVxXO9~%}EBYv{(BR@;NimrDOu|c@$btS89>e_1|`0eIFp(fW~mIbNwuS->$#MFcgqReY+K*rMHEiUK{sGN&JZa?$gL&hK5ZWK34 z2t+;BS1ZlST6gFCXa$pu&0|KS6Gptz-5la%`Tl_U3!?<iFxi@X+p-cJB2HY+lV0ew5ong|i~N8b;2bbS{CU^73YgN4@<}FW ze^}zT8GFhtLD_8-CqhNCAKRHj^w&z?qLujPKR#+k6m&H6do;Mn+&njHEV+fjA)Fe3 z3J>PUkYrwwzS5)=B>m!=>&bT7eVT=lVQ~I{WFC?#O6P{MxzZRL%1(2Vv?Zz~uu$V) zP>c!^KmU?$O+t#@*pN4gDWms0HRjAHz*-^FEm4Me{`Efu3=iQ~F2rc^flysz_0gMk zhi$*$t82N%d8#pTd^gtA#Pv9j51^-(}C zwg@f?j)U9{pz1c%*T|`a*CV+xVC-S492AqR#_0v3{2aI-FXb3zGhtBmeJRa9kw3J1 zSy0cqYBuxW`@he;a{09G9U`NN#gA?aaOlf&hR_?_ReA^7q3lyc?}f&bs!q7F8a(>t z22Bq`u{I(pIfOE)Tdhuz>iFN{-{2IB!5L>wS3ioey>_Y!US)3b3a%j`m!+$?-z_WN zP{YPSV&>yQRxwezSNP}aZ;Mzo5~TGk-2ugTxCdC_K|0XJ- zB_Jgw6?l0b_dEl84}$RsQ|sn5pT=9;u$kfexg{!&LZ_QM9b_|jEA(9|J~PXjdlap=1tg7H`lnT_o&aBq8pf6~xOw5xC_rN40Tm*d-Cy{89%#<_BN zrYy$Y@!$9|h|V*9I9t03o}-aPkGP@hsGg6oCP*TE-^?-)&xH;GFHpfn-I-$~8BwDa zBo9^fUj1eIovZxAp6#!x$F)T}Lm4XFvwqo`gU^;bGIqX>fWTx<`<)Kz&cI5^tAA;!4AQiLawozWtqE zSiHeLnGD%-CTI;AS4M#3Eq}tw>Ac_Kk{qG^ViGQ>iwCTIqonh0P-D0*J>x(#GoS@# zkN|r@G4Qxs^wlhQ8Nv)*$nGBrn6Kt8_hOuQ{CD=gIg8 z?zcDeITVUrq#VXbVzVb$qce~Jm%Jq1FZ5brKLDN^KR#0f6(iu!eHF=~`@B=s0)!DH zuJCYEfzt!Gwy%P3RQ#5@jnpC!hU7;2b!e?^aXBa1&4M&;u7lFI8mqeLxy^QW#XbO1fG1 zemCf^T^bM+ad4KVcK6=y{pnyI%A)5bXXenM>GTX4FlXE}n=kDp=o_^o0j?AF0+2(Yi*UMw2@UGhyu5#N7i<gXe60fOd;t?1K%KuTcJ8!+U*Ip5b`%c)a{6D4E2MT{{BBI+3zj3-HX1gdjnEeeW| zaq^`W0PH@$gb0YF<8%Q3MQ@7Ri=}^LAYHL+ImOv_s&9_GrsVo>-QoZ+YL~}n>qXaA z!e7KWcggCJgPeIM2DGqDrh6^Xf2gl2EsBL;v`Uy~`jJHyv?>>0V82OM zC|k`Nu@&FDreYU^g&_776pD-5EHHjfb#u3@vHSwbs5bfNX*RMx{ z;iPo*@(xybjczJlVC6eDuq2gg2=wUunyB>v2TJRU5JhgW5ltIeu`*5a(xi_%BcwvF<>H5P+^N#qC@6@RJWy(6(6Y8 zTB$EqTUKtdsA@w9f2xXlu0o#r<%lG(La(}w!BteE=~|$f66zf)B~F4;_*h^?(K*xA zvHf8z(C#}t|4E4CP|kCb$jI~-F`$5Xgdx|Pzuwkn6iN5W8v%Nk(cFbNMfuEg|Hq!D zfCaQ zBPE5qVVu?f+Cshf4Yb%G(;C`;&{j?S{qtB$2IC*CriYQP*IgWf=sv_xlr;JQdHbsa z9jJ{+`s=RtxsRmV7Q%URHj8$V2hcLz=Cu?v)}{{^-EDylWKs+v#F=es@9eSEjL zg!{k9kd#t@HnimuFF5%H6{07!XZM&=7cix~3p@c^T8sEf!V_11d49wOA4(FCMW9s8 z2=OW0w|ZyiLfnTRKO8=PreuX&4a|4Fa+D&U4<_Eddq3_lm(BP>D&^_;!)|S5&hFOa0C@x5Ns2XZ%@8#9l z8#DMRVJj_FwhR=#d7$|y;pM*A5UIIMF)4CA|DejoQ({Q_{>K#a@jcA(!8y_nE=mU0OHDT zKU!9}qHzr4ju8TrYL}+5rtg(WZ@NeCYZvb<`k9F&WLkk+I74!IAvL1D^c$se%i5Hw zYxJ%2u9#MU=L+62Cqu5U30pW|!pG9h{~%A(!I95WtF5Ggrp;-D+CEENZL@CQ4dWd3 z5VP3T$*-hYGl69pxt!_wODGnH)A@?^v8vNOWaQb~N5`x&zhiLxfTTq?5V-Hq2k0_S z;<+b~DD+p!cYbd)DALZ&q%eT-Yd zi3a;m6aatJj676hQ{FMNLTuLlQH2zwFR_GWASn^l(@mmUL2*~O0%U800j?-QsoP(5 zngBDv6w#^Eu#Gga>StS^kfW~XeL5+Yze*9Ap7l4#B#h1qRgR}B(OhI8>)b$l=+u*J zNMlLe@hp-fW)TfW(t&h2bB5y?z3RP5-S^c7aCZR7SOKZ}r!iW$e1?0)F+je2M($cL_ zaCWNP9tXn@=BNwJ7w%FPDm0p>dhf{s8^oGkt1%JQEG#aOm4N< zpae*ad&MWcG0%)+7QX+h1^6qo>}9Zbf5-T3(OWsNwU{gBGH-kXSLta;-0GUP!Hlyg zZdE{Z;974yiGX#o!pCugvg5 zV^NsLEC1}9ZX@5}B$&TMHuN~i{={2R$4&@gu&Hek#PaI3z9)C0TZzVn>b-3r1=VlA z{LDZe=-o9qe!NIntazA?_^XQ~US-J27gV00(HIUKRR|EJgl#KOK}?C)o{GqK+QsyP z8;Hg49?OhLr5}MO0)%m3stXF=(oLTluUL!X2HNRZg>prG|0kOf!HOh~T}6Ja>>w>l z_xCt7yF&3)&8&!JfINpTKe@RJN(lE3>`YQ-H=H5 zGujGca&?mL<@<6BTtY04+NO^AK@rQCNe=j&3^~5O*^z9Dr;{hh8rnib=FUyq_0_@x z?pu_E@?dVQbEAV`J5d9KCirQC{rIAe8Roa9gh4k9Fv*@3lF&&%HA`7GBkEzKaeRw2 zK^9Zu!{6b4zOrLR+ck*9qiG_h91ZN|9O&r139)$kdh?Q!+7;U(NF=_DtWL!Kq*CZ6L%du${*Qx~*FtRG>WiWb4>|h} zJEpdfCw=!?IFlt~l@L*jR0%^#c=HSAn0|m-y@hSPN}f@}jFc{BQ5%?T8{Mv$Pmre7 zB|N;eZ)d{N*~ex(-THmIh5hjg(hm@ogo0Ro&t&{|zHuejU|q%Dbusnt?Ci+09NZ7; zVjSO$YYXO%%5Z#s=b!QfcB*} zU~t04SoG=1D4k~M2MqE-EC}1F()^qWy~jsRM?=4HpOe4r^z#EAV_1hPodrenM@#eW zN8K;Ge?Oh8oeRp^wr6E#`rtigs?6HSzWD8zk78KNsijY7CP|euReJ1y$nr(jKa1r7 zY$H9O2n5vl4w-!K_XqU~(I{d*@^CjIYXmswbqch!lx*tJ&U7MB3PY-pG}0}#(eCji z&;$bi1WTXXx~18_)Bk4AZZ0{VF#W5a=}o`h&l+yc&dJy=%Kk3-jC4&(-9uazZ$cuU zOkKwcriqM>TXJ8lxF|wgtI3gdZz3Cg_ogm3Jx@~L_#@#&$Cv^HkSk!kr5{=1o&xp6 z-K1K;6N~k#lJGPa`A;R#d8)CqusIdzfIIls6#s7Wo4S!c-MY8-SP{5kvT|0JXFWhw zLVM^(FL=9);d`0JV7?#1B<7MQlcHh(q}=rz+Wm1J^aebqBnd-t)^vWiR41Ehu6h6@ z-mqYz=lI;=|c(Ub+GpZ1r~R znTc>&BRTB{`o~ylU1s!8Vv?IfElzv=f+?%;kSP$_g~`CBx}j_#YGT8O`Ql&%LUJ{5 zam4M4@iM*BR*0u1jaI%S8vhCd^o}Fs3)PVa?w9+yfbYleRWI5ELC@<$lEb;Zv^lPc z-iw@Hr7ic~V6xM*k41R0?*nlq8IGjSAAC9{iH^RLarvyXMxnHpT=Qf3EaS=rE~&)) zPnYe(aYy&>hKaY4zUpSgS+)oWT-o>Lg}|4uwm#GxKk^x@)w{O@B?J1w6UeWsgMt*i z5|C~s=Xrc26?=@WaFo($oZ_CeUCH^cv2~|hqWC&;`dXwIj`fVAj|Stx_%jtT{77$) z_R4LB+yu0d0FZ%y_GVOrZAq0sI1xSP1m6_Kw$B6w$MWw7@AS{wSes?O)S|CmCZSKH zwEDaMXQc^B`py|>Y0mnTrgpKRAa z;o`sAdy*;>Bd_+!*d$sq26oLYIEyyN=v|RyAnjuspK_dH>ow2Xki7XA-k|kSYow2V`ge(z4b|PdOTU6FWc4G^XC0XO! zZ+`a=c;C-GpYz`Pp65L0c_>8C78~bR2xWz=a4vpmU3F{KSUY|K`<)7GJgCM;Hz6&GM z#lYFE3?L0~H#h(gi1j#FxEvVLfw+{D%bSH-LX$Iw_QinXXebgjDwgFo~eB5y*bwlzkOcJcGU))d+>*Y>k zif4rE_<7ApK12GexiRR~K550c5{&{b zUPpKWj0qZ6i(sc^$WE2yP7O>*K_cQJ50 zIn&tSo>f%`o-ui|#$aol(cA%L!7ur#nCk5M^KaQD7k|D9Z>6cT7Ey@YJ%X;0u!$Jn zk=#QKL}lNUi&SXdY8m6?y+glM6!2$)F}{`NzLW7wLIN>LP!@0EL_pg|O(Y2-yPxM= zfuXti-5(WsbMKpH-1(jeByuZGUHq-rj)=ZS3OqWxee~)2RzmmIk+^RCGmIMnVjgw< z>lx2QImS&FbMs>7+MqXii1iJI9g}eb35hl)KuFFv_ItBdzH{TNU{QE12t7qgOw8}^ z0kg3)=Wfp>DO`VXg2vB7Y%p=Ch5gO?dmnSK&K$Iy`47#uC7&ODe`i6edH0rOlL9Ll z8Oj?QJ=cBb=KScXx`2j(hQ36QDuIR_8vPO6i=K)@4PR;VMb7NCkn zS9T}SikC=;S1+wk=nA&HT1OhXW{ zA7f`OQIZ&Vr78ULRxCSoOdOg{q4#&Jkx3%SEKc--m}R&+$#tyRoZ(AZ^%X7zrJ2vp zH6I@XphXmxW(b(4&YxfX`Oo#{WbxC<4_w$Rr(|AC-Z)O$b$1Da^*!sy5x%g#na}Fw zV=QUlK-*0`G{SuLVc&Kcd}9i1FoY#C`IO8=wgI8r{5^rn>`eD6ApCG4<|HCa3w;+2 zF6U_xm#LmupAg``jht>Nk*tkpm`>x+R`^Ieji^HObE6iY$bibxX`2%xE_da}WnbtC z^ld*Cvt!!yrR!E(m1poBN0b7;oS7FhkV{|h4Fb$iN3;O{lCjx*S{?!#N#RTWu!V*A zl$8dgbl?jVXq9Te{qoWiO%@Aq9|Nbe0{Arbr}F}Qhm)*s51{rdQcRNcb7DW83r&K= zo}$J+Fd5`ysk?dg;j9*H{T-=Lnd=&T8!xE;gA&^67|hbYAtC>dlcgl!gBg5S=PkJQ z?c0o5mdUFd5}%IUwz|MdvJ z<=}YRU&L}~@KhpZdQR+59q;zC0eRGgnVh^UnW2AQ`e(?&DnWhC@2L@D3M)I8Se$)C z;K?;+<3W&6mvRTU3l$GiNQ{+K?2L0Rs%1E>tw1F`8df8BF=kGF5V?)*qpNucr;l^= z7Nq<}xfW~WCy+D!RuQ5|8W#=MM)ZhQ58sDoy1KZ4qe6@+tk>({(M@pV zw6See5;kL>XnPOwADgBeR9Fnt{&xr^nuN7sn^kROTj5ljr2$c9#NrY$$Ld>#`dD8j zKRUuVxgFBp-X6A4Jf{s(YgC0k(w7Ir;UCRQ?)=s&ge}Kzhe50(-UWJLo@F_(XT87W zSvl-KcT2%j?uUc_6vIoOSl#a&+S(ptcFgpon{x1)bi`eOm7;2HcQKJ@851^&XXVUl zF#(^>5p3y!#7oU5T&13%j!L?1l5SNX+4N^r)lHsq-AutHFr+pY*)x*hdfB z?&D-{X~F4^{pp(r1lu+n-Pc3EOqkqYTjq9c9L;jx)`>wywhE0W`XItQPDz5}hi_n?*ma2GWU{OVbWSa<1Df>j{BMiTMH#v^wY zR`xD1U|rs=ZvFnLUY`1jnVw8A+x_vf{U#tSL`?o~dRLzkfrFA!N&qS_9eRUZf>TTO ztrBv8u6!bY#wI(V`{{hs!|(QX+iu=h*tcQ$`vJm6D?2wCWS3_09uwI+HAjz}X5{bQ z=C|o7VwpH+nQ(pW;R$o0WcJKX&rtJ{%#XjNgOP~N8;`}o(wslgb8pwmMQ-*%!<2Ip zbiv(g?tiy$_gc|BrEMGyOsMunWjygp{kJ@)=&=`rs!vqd`WM>hh$GZAF%1~1vPx%j z8ZQxIiReFt0np`#iKqVZwIDg@)dSaH{ub;)j3sC=Kf6RU4T|sJ`F>t6(|#HJ7wh5S z(X!>9)nfE`x96P^q%C>~Q%7_wP+|_tSV5Wc3XT1>9ueRhQzII^qm8TG>1}LU7&R)( zy+++_1T&UJmt2W^&u7dPIRieTBd?Ch(T|~w?Uk06wz=&TIi1p0$G>Q`o6`mU4JC+o z7k0oE-F@f?6x?yBE(_4ubJc0C-hqpzakd6MOgq@z`(r)C7yjM~@-r3c#}vB(I+8w3 zdgNVeV^t^$2S$%l3B+Z>$PKnERy>+EK>>s2pY&)hO0}gm3h<%Z0z!cOZ{eIsLF`#l z@Ob9^GkK+=5?rVhCcEjxS18>9iz|Dj_5x!i$;9OTCq5Uf)cJ^(*`9)61sxF{PcbZJ zTN#FO3C`v_Z}a!dm1tPuwq@iZmGxn_cVp#rPgOnrZCcN;h0E!72y`3?f#J;+GL}RP z>Q0YA#B;f7xFRVS@fclP$N<)?{h_YnA6c1ywx9no-ok5(j$o!4ErlK(sgAkQa%I$Z zynaq9=1PP(3Po_S25Etil~@b-dzhHbFFeK12j+&nSJpdrVkjbs&q_v(ps#3ZX%Ps2 zOT(`gk7aMp3zb+wDcYm5ozOK$)t_ttmWHM+omL*L2;z2tua`jfD#R6N&15 z>3Ui%OV`Z-G8YE#tupzvU+tOW_J~alTd(wOr68-gcf~T#&o6p*wad2E*DJPtn^Zg1 zg{$KHxxY+}3-UEbzh#jO|59Ct4NYNKlQUoF5c6z+t7lfbdBWvtP+y?jdk=Fxv$7@U|DtY5-b=u}52?8kM$V6d?fC#y#!xh}Dyvo!ev_cDpyt zWeyljB}>^^^_*R%%#5G8wDYFvWRRG;Z@8qETx*q$)3}KvJYB5yzW#e?mOvena}9l& z!OTr-10M8bho#Ml<*5#xi&gNv{Fx&AUgf>|gGfcp;hAMv2OIF=Qrh93N)sfVXL+UR zOFHoAiNl|siD%E@f_Gh+g8raiTr2NM;yeY2l%)68z;KRn2eGY}NB zDLku2RZw8q9M@CZE!c#R`Ps&@sNXT?f5oro3(?UJ4HrR+Va*vDjCIPy?UI}o+Z7J?4=Z@`zTPI-9fG_S+Dw_jV~R$!dbNzqAUxB z*Sd3|lKQarbJSf&G# zNBDofOpw+EZ@>Qq>CwJ9f*B_3qy!#*I{y}s4I27E5Q^Gx|84f&JZG75pS*J02uS#Y z7y{j?b=4;)Q>xE`{engDBBiXev!11zm(hSSo`r}dD&S_C{DQj0bZ?m!W)j?eHh9aQAPge&a*WVnfVengp&q4?k zPbj!N4j^fDaam0VI_s8YelOu*_caMK& zhM8~in_qCk^Y+vQVXQ63c&o!VcOvgnW{A4@a!e`aKk1nV?K#pk0%KA!?DUYk31?Bi ze^XO09!;gLcGoze1q$kp>Bb>+JWN?xKXMrp|1A?AalI5{H&L)L8Y^2`@xa0U)-w zya#?n{P?VdT@S!3H~Ts;h@q}{mlXR-Mk!G}cZ`H+kgPm#0Rg~if1z7y2ft!y#~GPv zmG)ZuYc+d6jWiRL#w!zQ21)TL;3sN7xh1EU5vxx##FF<3K15cR_aF#RMpJ||2lf$F1Lsto9()!9QQHsc@S}?AgP1j?9wp9_tJ*Iwbn&v$A zqo4~nv36+r>Rv2k;w8)Vz7hC^(K0=hRkJu~Qr^8FTfugDc`9V1&w{fDti*oX`GFt& zz?G4(1p)d8f>ovwM5db;QbxzQC$FuC(5aX~5yVz!GkE1w4FHGYc^im{zv252Q9tuH z4XL0X9hD@Xz3hI#{m3&Q!~z*~Ek|$qSGM|O6B}2=3e7Wh7aW28w$cbI|9=Jns;!G4 zqt9HP)5)u^L(O0@5`fzuM`>4_5}uFOmwkjO_%uK5I|%4d<~BCjGY*qp*0v6Mq_si` zupAi}C3|~$)l^s4co`ZP1s1SD>0D;yHIBFew|mWHKl7 zS(GcoG)NLO2peRGMj20$(gNe*eYTwnFS*5DY6w_Dp&?GRhM4$Di{<1x^7KVhccoFn z&7H=ieN4san;Rpq7IM|wWx2~J#B82Aa$tgI5ehkom|G0h0aj%g@lPe z&MQ?Ivf1KA@Oi*JcD?Z2|=RPb5B zWk=bnP(B*-TTwIrs$`;%0OgWB&|OHjYG{b}EA!npC&FpaLbf&>uR`FuN-F)zk$26& z8cj#F%Cg*tTL~Q>BK@5GgByBJ?49~6ns*&65Z9OcWf0u4_t0I;s2e4z+x_{#V5k2z zRFd!mbr%AEyRRUIPh=?tg`QSERIU~3(Lj%DgP(8lHzbIuon%W!d<#~k`6wsb0ize< z*1X!_!MSJ&2SQiAQik}#pAK#TPKOg6nN;WdW`4SFnkW+=L=Vvo2KMfRu*ZgcL*Tk8 zzhB9)!0pAO#rnC9wSiJKG#L-bMqaR(><91_X^yNHsZlGJLnQDgOKS^w89~_oa&6Km1rfD&bQ)U3vtV=g-p@i?A*9nqAqR*{87*cR zms`RUrM@Z&#SL8>k~gR9TfOuwa}(JXc&YyAVTB0A;kf_8FL$7#HA=&_gKU;a^FJ4}rv>V(Pg z!Qy`1?~S{nw)yKnvodAlnxAMI^L6C#WX3t9+w-%19W?Fsm^>wy3<&9XZh+eF%!L}i zgDK{pY&Na@hfjhTAJool)1NCnJvgvri)9wO9%XI-e`-YYmj_-s0X-&{IPpl1phsG>&4$A=tPXN0O+wx@a2ftftIKg~c={4^Up=m$Sb1)#sKsIec~a8>X~|b?OF-UF@p}>?aI5J>Gp=>5LW@sO*)AJ zk$mn0(q~ysab5FK86;bSA6>KbV9pc^cEbSKSdH#qTcAiLyZcP)?92vqUCQ<<{QEpF zkP&8PM$LWqx=dOm)`fbEe;z;fuc)mJBVMCbM0C{3lCGzfuBDYZ(>>W_&%(_B-_-b9 z(j%ysy`ZK)EYVKX=n}7PcF^Dnx6^ZfEZ;uQyusQ@3kvx`#IfUhl%&sU$3m16dqnq@ zefPh{_pPnz*`d*rJ^$4Vtx8zrIs9|_oB&_xTT;LIV5w<|NsbrLgcE+7$(20w)u~`9 z#>Exz?GE%Tz|%A-U?J^nMX-I=cMfjC5tNk1bj}QfZ@yc!Fi6K*1Ry4%Udv9`NS{26 zczlNrVL?h0i;nn_7reWbS+h=0Q@h8^+5<@%Lke+&x6fyP&HbTQ-Luup!UfR1{(iLt zk~1Hl)}VaxlA4$sa_g`KM$}G0Ah~+Km`om)Po-Pmq~xW65>!#Q=_jSWK^KmQ^uTkM z;2e==`oJ4EH$HIoy@~wccjr}3yM6XoJ;{_Gg3wzt-Zorx_^+Tj!?(Wiu34kyhZ@^SE;4)Itv8ETxv5POiZ>$__E9kS8 zUfp^X?fyZi^4-JxWob>Q>*FxayHT9Sa>MG`u@00TOmZcrpD;Dw804IBZDDpaT&&@;ZBhnrIKFEJig1UNH*G@k(P+=qs@Y3oFT2id z1@y1Dp*McOL0lvtEN2e+6|JjP%SyFM>d+o<2z_8ALq{xd_GUQ zXZgbF7$ms|6RdWl0gmj59>N3)8G4faKXsOoPr|3?ce9}YgqF4&4Hu$_F#dG*_OQ5b zTI=s-jNsurF?Al{dl>NFa}8vD8{waBKccXasLi(VBjYm7G>grq2T>^2Wmd|X?WN&5 z*!>mEWdaHEE{&_8rz6=@4t*N_^hqG-hS1t!>?3D?b+LJ5g(={_vbGKe$xd}PJRK=` zy=YUh1HlJ`ol=3nH+WoTnfbi+scNJMf685zXEs2tcAekd zNgaw*>Gd@6%%NhyI3+J|*sxSVrH=7N?!`*kLJ0(S66I?+TQ?+7gdVTX^qry?orzaDdWPVST$v(6PoZr5Cq9!_ir|N+AR=$!7jaEa#)l zGJVQs9DNIH=sMtcxEp4H+&rKhjZxnyr*57CV>d6(xHq(xL-s60a!Ad_Ghe@>^!Zt* zwOM6xsn=1q(kE_Y+J*HRs^Q7r;)w)RZ!QQav6P&?G}i?Dq#Z+G-MD}sZ#AGlN;3)B zx=MGV*bTlAKQdx&0C~Yvav$zD{-+-GL+xVg#}9rDR`^Uu_PhS$rmqXv4O%mTB56WWjCNve z@9|z_fK*9rEWj;7FJFknHVIJv-7-yOf--iDM>Pz}t(CYbtCLmlx&n$MG0FE_D8J1= zJ;Fb@YV3_4>z=~5x@q)+V}5tG!oW%nsyIlpn0YDE${c)Am?9PNuC^BFiTRR0EcK!m zV5h)8$6Arz+eOodD_B-Q^9~8QKBb`IH(eIqk%|{5tu%)jIiTAujgS zyI@kl{w^fte*TC8EoBC@^N;WClu>QqJc{k0vWa!-j*X2e_i3fb2a#Xi4?vgMN3B8k zzP$t@(x$2S81O~#InfYKW?~y08vtjJbMiS*a|LU*yt4NO9_vxTo|HNQk_Ch#i+;u=CHWa04db!mIXOm9Q`UM7WOK%nKSHhs zw12Aua>X_QPnTQ-X2EAZfnL-bM{*DcL_2(|F-U}TUKGVEnmD_rrbnIB-quf~-|l=i zT>iZxe`L?P8YVtJs}E?EzWRChqOo#4)WrNQ&Wuy&(U3?F<<~EyR8Hgl8-GariDJGw z936}M{nftwIq(dDX?P!HpY++(D%_V{A40NIq<70yZK2`^_1)kJJ12azg(1rvL^}M1 zQTuM|-@RYoN_IwV_e?*Gr`bjSTYYN) z1nDt&$A>D*f3SLT-Sif}y|=99m~(Cw*);P7XfbmJ^pb4d$I<7`n8uklzN%6$(p~m~ z-B+H1l;k?vdFq?OWnl#xTY~>raTlU|Cy~hiVrl)$ogykI6JYh<#jCBVqs(%0zC1vm zGn$l{{`Ylrb3XJPe-Ha@7QFQ^^pydmoamau1sRNcE+nFf`_EMK)PEKC<^9XnZc9NUNfG>s(0Qr%LA3 zIPZ~4O%E3l+;-PdlD{XYgK$MW5#o|7<7jBL*GEwNFI_SgxV!ZEBMc7gbt(*B`}eXQ z;*K zT^pP$^adR_@CMSr@*e;`i2pu>vv_;bHh)2XjrUNYpNUWksBySBuw#E6d1SsL%MHo0 zh065x%lL7{mN5|Td0inkH^dWhNCEb+twd@zzQ*ZW65w#!? z>GGwZSOSRmc=JZq^3ViaZl$;{Effgavite5OzTgW#gUiFtxa3B80~9GQ#(GODiao8 zG2e8mm`^9-9&Y#v(*U0|$H9f-kFswf_m9x&}^9SzK!hZ+HDwo~CEv$X!>Z z|G2b3wK?a?*XyI<4PW#4#>mM}-wiE1<-d2faf1y@Lcx=V%q@6>)#jYSNkrc0eU0$a zkYHMnA_1EBmez9_p9xfGfsHowC>0;GhGhq4ywUl}78?uu{{H!;{~erwm+0*MD7|~e zUi70cis^4(iQ(~Otf_n#$4JKSJQ(XT*1%9~4gSh)3=U$gB;GmMAs;61r>B=1I6-EF z#WUVWHwKY$r=hc-vy7%^4N&7_$DbbtsRfB4^#uc zE=-?W(2*nywAZf}9!YyjA&Mbm`vkSYa`1gY=<|6hdZ{cI+CfQz5q}u_)YL4i3)2N4 zc5Bm@4DVCp(RE;`WvK1xVrak<;a}`oS+@Kcz|G$A;f4EQoSdAG85SephOWp(8ZR^B zizn7@s~`286_uBdV^GzVe4d->2=qQ-r0C+}(*XmSl)x7Xp>A z4^=!lyUPwl5=Z#7E0-3g2jOZXHANEQ<&)C#IUnlo+k?cWB(*L#>Henu2p5fyJw|+TM1Iy+Kz|CprXLUD z1*`9#kw0n?aSstEAN)PtdOV(Z$zrL@{>K5pDTzqmJ zUjZG$tbHUuJ)1qBKmQJGfAd{(1s-G**D0I%fV&`NY$yDdFh20vSi;AFV5&iR&rQ6Z z_|)m^NG>LR37wj0gcfxxPhS5`R#rF|Qg7u=Bp3Qi@HcFc-;!SY^;XebhvH6fh;8z) z89vNmDqLNxAR#JSwI6f`OIk~rsl210t)+IYYjpXNUptR+O0Mu&C;rG~&?uGl^%aVh3|VEYPV+s!+XvyIzEv z{eWWx>!;hxAD+^MD6F9|Rv5#if6kJq;jk{n5fkop~CCoj+N0a;gv9O5vSM}oD$eAo@J=kK)Z5f?h~(0w{u z{0#|ZquwN7k!(&Fsm!Zdb&sP9`BBYHklMxY-(-ck!y{&DSqiQpAHNJCeCKaH-xzN(-C zgbyqSX;5{tX-5k0YK$52lnbY+@BzM5KbXM)SHv5uFdc`#=OZJZCl?Fru(j&?II`0g z8>@>en{rrbNqZm1+wyp3Rg*dJzx!}lSn#C*P$Md8@o}SV;Zq?OaBf#UM^AP`Nr4qv zXYR3lU9g&E4!yXw%2&AMBKB@TYmj9cz%C+U*9@ezOufI)SNT}Hl-)#};s}eDsyHPk za<3xr$5={@mrxWHBc7TNP={?}Ybiih$uX8uBJWzapt=&3r84nW*C_6Ga1IRpFyT<4 zfhF`vo=cO2Ez{w_z@!ZZIe>eog=TY*W2rX6jJddzMx%0vnxHrL9b{QeL8D!WfbIXpO7Y1GZ; za1zH4;oSN56WE#rptHgmKWR)Djz5%8TE2w~)!wwIF&zscVgKr6R_B`j_EZ=15Y>xb9rJPVy|ht6-W7`{&9V9 zhv3)0%EK0g8BaXW1FuWe>8+&&YikOZ@FDdK0>1~4j|!Kz#93^pA0YGNLrsg;b5MlX zCLO6upSb>?NGfOz@mvn#chxvI5>&V@xJ4ae8+nM>Pk;AzDg|p*=m9mQX3&;=*9fgS zzSkaqiS+pNiF?Ob7;@F_NKBP)Oc$yqES&JN+sdpgKgv6pN(YEkl5|eG(%vaCv#2D# zMh(I)^}}l0Cf|icJ{MCTgq+z%RI7>I+~@qMt~*UFZV!3&O&;wH%Y6@KVGj_`rh-Q4 zQ9q+hFA3tuBoA`Aze3T&X)>TUS*JH2S&U<@$8`DA`LHUrHcn@;qfooaGuqMKYPFp~ zQ#YaSC~E#)ZmZq-TQR!bwMmI*3nPyZv_+LVVq|S^wTm82ty!|F z@aOx`Mwt`a;NkpDJByv%wYpesy2q;heQN$cBv=RD^l5cpp3e}+;wC|J|A}U+1idGe z)S*=itrU%0q(NgR^zIgnDCJZb%C|(9S)v$_OxBv6R=KvBG#5n6|^X1AOr~R zRf`&ptQKXvaXXMjVytUR0I9p+-w(Iyh!G!db{*Y-VEWRKpv{0MUuL;vT^J+RE4PA= zPh&%!kPyK8-40kHEH=bvKGo6)zF2(^s@6!Biz`c$JYTR`8X4iu2iwyRqE?Dr3L*q?1Tb3IIahL3XVgD&H% zLG;1nmVeQfSxf>P8Vm7!ubDj4RFgS0=x#m6UOFO?6(B`nulVHzp)hI!3|jz*c@m5( z=WIr=a(fUWyXqH?dh(Cn^p7krPaBNR9@m8FNKREV z^=r)oCl*gR@Xgt>g@3T;VOq2#I|MG4()}2GTIX=K-d`vZBNjPLqbm&NDr5d_l4+kr zjsKR20TT+p|9X6onO6b7!I)5A|EtpfLCr8x;5|ceR*iq$&}(fa+miUx19ds zMOV0}$g%nAHOU5S*Y|DvNqJJ#%edDZnmbXOd_@=@%K1Atazey&dV3!HHw|4kzgdP% z>Dh*ooQzRld0`MxZB&T(O)w7>C&X$~#O5W^7k7VQ)rLul^Em0}zlOjX%&#mHRC8m) z-^=jPA^+QRLoTt`w+jr##UB=3&Hnlkmz&VysGe9Y5RpfKE$5cYMCmoss3lzC&=QT5 zC=OxP*yHBrj<^3)wD%-@^zr$^tAeRq)6vWYo|TCc#f`i@ziXQQuS|RGuyaI5;^Btt z;=SO2!zr7;c8g3>G@X0Vqihk)zZvh-t|TwDzPUG2r`N4WEq9NL?|hocv$xO(<@<`D z(uqpG8BHsGK8@b<(>pAE-`w2vq?K{%hC`ria#oDGt_C+3t%oqH`tnfU#~&opf$@m) zV)=|V4CXx8z0||TS<%w;zMi%i*GB~>UEdj$XfxqFh9t_tPe7Vtcp-ch{&5`E`H{GX zB9!`KQ{4jomm#_SR`ZJ+A<9y9zckv))zJhO@IpPi{VBbsz{!es_=B6wBZeNll$5EW z6$+y3a0tK8xQitZ`q?Dgd2JtSiZ5eXqz56TB7-_T-rBkuOB52y-g#ovU$)c!yI7O!)5m@ z!WzHs?zy+MrR0eHss`$EuBRp8EV1X)N?C?U2}l>fs7>oNo*XyBs9%*1HF33)limK% z5?%v~_}g@yM=2X@yQNB0PG20s!p|RtDS_BiTqI`<_^wjF>7!-HhES&F5)#GXo-sZn z#FhLSdYO*Pl-vApqxoNl$01#^>HEC{=Fmz>vF~WxX!b9-`=B9*oEZLtV*aHow~SWH zS#(}Ohmn4}?=ox42gxKlHZtJ~N6QPA2JVwYIa6J4^qa*p4+O+JB39<3pAnYI&z&Ip zS35}bhospldJ(6~RqknP&`i>G7P2+s%|2F@1A)~tWvKfjUrg0OnKy;0WSU1C8`Eez^R&L)|n zGbGtw=$1dYvc_^{V9MMQ+uIi7s7dK7HdmC=lw&oyoE;6qHPVB#DI&HE3_lc;9k%5A z!V#vcQZNucFS=6M=0UCJmiBG9Xyl&cU)?&e@~fz6ASlNX>DthX%P!pk|MoY$<;Xv@ z$7)}0m!j}u?<0q0fd9?sV*0qidr~Zk(k@?L{#IIEmDiAl%>*)O*{XmeRYKn_dp>Sj zqNLb4`Q+ST3*FA-mThMXg$be$2PSZHB7dA?lU^EDyHRm97YQIPbdAuXQNy}H$?0{5 zU|Q%RD|zzM`TjzvuHlz%a%AOR)mHj^Ccd~5F)2X_A6-bv*?k|F|%YdX~oDf>5fwlbS=>z+b8{5BH|$bau!$5g3Exjc#bpkC-noq1{@ ziHE5tuRMKzy6$eq(W8)f8va2}{v<#8#^0j!k2X2+g9kNqapLiZR7j|0n8LkN-`5>e zAu_Z01A1D({>|#>4?{i`6LkkFl4k6cXk{S+qpBq~&Qi)-+!e3}ZyosWYHr_ob0yQV zhu~n?6JDdP2mR{zU5iSC?&;zkA2Ra@J#w3UgqVRBe2q8jUuWpr)3~bD-H!gE7;n~M zLgADhwa+o;M74(S@DV&kTac0AKKdl}!g))>Sm?7T@7g-u9?^>DVt0CSRA$#D7}!5q zq6WIF@o?*C66DTQC`LT{ck^P3Bx$xuHNGJ583GF2|qloFL=x zO=o%ssdQvSc}I)RicqP1(QUDK7R|6w?{R0<^PCrzppp8$^rvs)j+#QMY74zp*nby) zs%{*-sb%Rj*!#g^jJ1)eaA)q3j-O$;?K63%Cz?6JJ!20(f84+>gp0 zMkCYmpXY!hv+EOhqw0RcbL!th(Kh2xFWtb#vAfG^NdDV|-@{j&MPt(+RlgCZ36!{h zkl>%0UN=VUcKS1&7CItWB4+rVVr)NNL5jq0%sb+`NHcTOw_Kv|WbFxJ>hH zw_qe{Uu*171=^D(#m-@XDZw2T96kW~aluUbA=|T)|6IgAzj!pl!D=o>Q0&38;d4({Nk3BGRjThe$ z+(6kzv0Y`>JyjMUOhLij$VL61Q3~?83Epd;19X<0%#kgjp@?rQDIc!>N`VhuTfHM3 zekAN=3n5yMW*sQ(o)_b?nGsap&3;I*s?hBrW_P;SOzNtSYFB*LHk_|-D)?iXwbu|^Ll`ED$S0#Of6kG+ zMPiCJwl$o=ca~!0Y^czbqpCrSj`Gl(=XeQPL3x6E2zpZzGnloC(+|^hz=ixb^S6cZ z*0dv1IIOr>sqY?trXPmNRSz)yV8f>iPVIObEUHXZXbO@O!QA=N9AiyO>rD;K9_bnP^Md`slykGGaia0RU;SzvJQPxe_J$Qsq8OYx>?mU#cKQ z2sKb^AwY7bR*94_c*=fD>u{pFLIMh5m`M_!ar|&mGyZyU@s%=IjF~QQ^5eVup&VJ* z@h5o=WVYgan5ANf7j5eD&R5ZJJ7?VfMJi1%Fea&(Tj2I@#ejsJb(M}&w_1)vavUS9 zA*(jsr{m2#+Cb17#VY@`cU%r#@GH#5nccqDqPnClYO0-YPYI=Z(@UuAdpY;=V*q!+ z<#Njd0ezfBL+-oyOFX$q&T>-|tnJmMQHFoR1l@iuRtU1n+YeVG_a~F(<~7~_;|+I| zYs664t>;9ovhcelFp zhs@y&4-E%;%46PM9fb1TP+ths?nG`WzA=ZWcrM@tG%a4(gSteR70;q60azG~)$`E{ z3maG55t{CN0BC^fset25L86Ldm^50di$}{bSE3r&QjB%bj>UP=VH$;_D&HoMoQq%I zvj=DtGh&HsH0p54`mW~U^sw~h-OgfVh9}$4u1JBZ%g}1)f4bqFNxcOyQeohV8MrQ) zWFdtAsobat{>+3WgJ9(mW$uAMD4E{qSej}X!Q%v-19yl#7CpH*s_ZSC>8Wc>~LoC5a64`lF;q# z=Uqqr)%{O-$vQbhF1@X`vbh0KU0$yV{AxYvujhreL;TjSMym8vjos=L2Fd2RFUcP5 z0gJftXcLuDG&Bt)#T}G$WcE^s3^M&?IF?Im8!5yW*6w_+@OzH!OByX^i>mzjG6Z{V zVC3t{3iqvdn$Jn1$%BzZF^1`|FP=vWJ=rW<@R;;^au98)mfXG66nLzJJC__02w;!n z8eMbkc(9v>0OVdtPpUrhBJDQ&fwLsI3>T-4E-rhdS;^2oTS7=v84IlEBZO+&IT$KE zySB6}^S9Z1(b_IJz{khM^&!75m_L=^kjTyn*{i+*F0P|^T;kY_a`yXqC?mh?B)DN4 zCivm^eGmz9fYUL19Q(`11l1^sv+k>xjrIDcNNgygUk0JXyQ< zf`c6}q|jC%qBXfXrP!$^WO=`7zglz1?`M*jIISzQ3k3D3MRExi)Jk37k6J>8+1~ewROX{8%G8gM6io z1#weidV3Qb|69{dy|FZVi(k8Yk{fbw&6c8|a!>8u^`%(i!|Hnv#OPOo*c*&!8=ql~ z2V=xuI({1aq{MW zK~ZPu33WPKw}4{FSibGXe0+GEmRBA?+Ba*wE3kR_!sN+ zaK3h@FKd;tGj_sVi>=}qR9(`6#IQ#M>;?*J+dtaI##3SJn)Y?#Ww5MdOGcS40~ z%5o3x1N63(esoi8J*NlFX5^B?nC0r#?tV?(Lb6l)do)4v?hqYwa5oMu`)V}x>lD#= zWZ^jn17nbl5r5oY-Q~{WR(_BY_C1ClR3*`JpbJv@gGgo@;2K?8r3#v>%YPpE(dT>K zkNK$OTZnQq{=!|C=bn-|t7;rsW^~Gh*JG*HDlw*D0XK_GU*6{BCjkuHaPYI}xIi@F z3f?c~04F9J{o1)RH4_%=yKr<=6Esgj6)i(E%C6P>teA?l?aWo!OI7zdgTTOrIdKOb zH#R@xfSuMesE!s0M-CF-QxE>u_UR3`N4UU5@8}f1yZ*y~R!yUoEk5gE-c!P1CHm1_ zo=4$VyLf!}heE^4WyHl4#lmyv>>_A@Ds7$*4@)Zhg>O0}>zQ(DH_%Hy)=K$Bg{-&{ z(GN(3s!^+C0(^?&$B!j?Gpg_(p;cDKPemFd=&AqSXdbWa5I;fX)g2hr~Oz|~u;Z)N41*ymLOzO6ruM()p+yH2yhJJmAm zG~@H!pzduQR*ZKh#Xf`n(S6z~3ECAX{sbWU0POCE;fc2pHlX1v<-U#w=n+!<)H*3i zh&hdOJQUiVcVU%iS=hsPKUl@gYj{J=ily1hgSq+Gg^5f{uX7E0_>kd`Nt#ITxvlW- zYi^JZTlDtAuK)b{ZYi(9+^ko%Y%tr}s|c+d_=+z&!93{jR_3nr`E+(hV3CQ{CA)!R zX&%+t`a3(j!%7?1 zsOJ>yb8Y8WLiVQ(+Ztsf4DvPuUJcwP`O;PfzN0Sl$nQ;yHCe&^u{pYePl~N8Bk~Ag zsSX;Y#H)gE+;g{lYfAB(X?e83iBf#JA5Iq|Sh(JgRG@OYfJVM)3GRx+W?~OV+(;X| z1XRAakf_M@L;~UL?aWNiw^QJ>H{GSOazbf%Sa|i&ulLiZPiZ>c`@Q~w1CC??xS{m! zR5qn&#kc4)zaV#}d^WYvzXcD{!!|)i_TcOXNd>~Je#W$+41o@)rvwRzpiB16Mpe>{cr{R5;` z#Jd`svvqUZ;&~8^gqA4CV_2(&?gT&g9ZxtT-?2J;SXEb z`G3&N#pt-VwzPa`QdXK7OUt!enz_4hwTG!@K@Q-*zQ4BOVbsWDb$Wb=!|lhf<63aX z4>YQ|!NZjdi$_5pK{#*aXIu|#T8lpTV=lP|5JNt(IvuvL*}`XknrtUZKd=S1eqh;D zMENK29X>mnw2v@5By+$Hn*)!zF>u^S*qkt2 zhqOQSbLA7@(V^~U&Q;%hwS?BVDQ!5L@Df{JaMIhS)K4-N1yJ5G99r`ALe==wem&4L zLOwInVKrX74EJMCLYPAEE+NsS(aV|@4Gn*DH775c<#Kbqf9v4&%2ef9>ygFi4z{Z6 z`0e!v=J`wUr5Uzk69h(aJJgARKiCy+UMNdbdC%?!gUVRZP}cLA?w08Nt+M({valSC0h z2uf%EP5AtyQy{=0T?wLTpzb;d z^`QyCAxgm=?PWLCRi-AO;&qVIbw(`WkBbnsDG;F=$B9uK0jSKuY6;{zM2KKn0&6aC zlBc^~Tq#CJ?z=mi*t%I4tHmI!SGh3Sf8m@5KyJ+zi^ai?QIfy_0E}Qu$3z~ls;3V5 zSmcee{@Jwv>4#j$!344w@Dlv|JaICpodJ{WZC_j{PJq`~79e&ZGccb5SpjVv2^Z~@0MvwwY7!tV8Bs{jVvXTkTtc$d}!H*@*cIe9U2NHP+V5+?xW*7pH z7a}l$_Fh0k{^c{Ekw6Y}4mOv;@)QUHo$FxdiSUKennfE^H3A)zaA$UYeR2Ng)h*z+ zh4%B(`nF)P;7N=JdnR)Nz}f9p!CxdY*z*J(1B1o+1X_VB!NdQa>)YxUAcc-U?m`1S zO#(3p!GFyF031J=2*C_fC9h-d#|y|$1T%j2XF$%t=2oCO1tYHJwM73^0?84>o&-eA zSz`{Tz=h)C`ux`R)Z+RUOo1?|PS3rHfyn&)+|)pCPbdrkLy4`aU=Yne060nj_63TI zpAU9cWCFA{D|wjjh`e}ejl>TvKt%>-056jn__A^ld}-yoiIWk+IM=Gzp)!jB~ZyiSFmS^_!n zk0Mu0ISJRi?j(WB#nq*mrTMv8x+@3%W|r1bgcQa{qmvZfp;V-Zn<9v0vIDK?P-t>8 zSlpZ2tyt*kJLEa|?MG@Pe+m=ArOcmI84LtJItc)W1b~_Lk#ngq5o6%}jp5`@EGAGg zs8&G4*_PLX^#rmd3prR1pc+S(K*ORKITJ2d!e@**Nnkj!z63Y)>TrShn_Jr1nT=-R zh;$|vf&nJF8|eUiNg&t}7>$NQsZcV%vt8_bC`tex;uo!7uW_2eNdg%_E&J*xAq3!% z2H<$SJ$`nzutHn(0EGmqGB9Uhvm1sbP<=lb!T1RR(F-i+U=w|62{hkP7t^>4%S(G< z&wt5*k7t*Yg>5qEuX^CinqS>3IGZMp@69{cE_X=+c$r{R*@Ck zdn1SS5cuNhb-+&{fVv?Z&pvRNth4>gb&JYv||!pO9d8jIXJSk-X%I!`_TydL}+38XB{Wr%|zfitiehFlEt2)}YTOCTf2&1mfjO-{nt zov)0$b_3(oheqKsJ2f|o%!}=8ATS-DSeO8Yu}FuefV(7u$rLGI;pX|o2b;Sc6o=HA za4Sq$6kLHgl_4_wR>>SJD-arlJAhhgc`exa7Oy7bg7ul(BidV4*(24MF0kT?YmXN2*?G9 za~seofHDBKmuby^%sL&6&{0m=kyXM?`f;KXMUu@P-DbIkv ziLtxJVpxzs1SZ1}#f?}`$7m#+LOM<*1A&qN4lY+5w%`K9#sx_HF@uH%42Mv_chR;N z0On6dz3m^IiMvP?R*h&?3BrDS`b7YE!5OMX%&Bxi*!IN`22fH+07?nuIQT@ESH-Y5 z6|RK+_39oF)pjfu8}EFKaT1V&QX_*$0CSgCLlD~dJ;p?@+4lZWjA8(Dx5 zeP{uyaxVjDY(L9MFbzWT`JvI^V*oIp^5U7KtC&9}d~8K!1lf!w06gn-CDFC>tYFy9byi2Kr+4$o3Bw7jt}v4P0qAx%kOGPSZyPj{st zfIX9&kuraWu@S=sh_wsV8b5rhwDQ+52q6GbgCGTmVTP$Zjnl8QAKK(YU^^B7yugGX zXJOU4Ub8R%pGu&aK!Q-361WoPLPQyNr!&EPayGTH(AfzBvt7-NE=gbn#RvTP(alse z*%KxHdNleD;WQ2f2=k|+&p3nweu6;J{*xcH0nvVu0K$v#*)`S)*hT?OQxHj?1n?9{ zpg99uw*pn3EK)!q$b~S+K_w9Q0YLhQO1MTEvz~cle+n2ZZ0r^@3uB!VYuN@r5vXP1 zzGyerBw#KQ%SDsPL39Ks`%3%`?jEWf{Ot9j08w8I`8J#pz&j0BGwm-YRqMk%4n~-Jfe1$lv|bxcaj{6k8iFiM4W@zH<) zuy@28KjUnr0-8%;Gk#nJBd(`rv?;91kn`;?sN)L)Z6r_~mH-NKw8r3r5 zGJr!P-nG-Y1~GcElmf_JKs|t`Fn%1$GQ^A^5vZF8D@JnaV<7ipXdr@++hU|Y`Xwi! zx;32+6+Ocno$2CIi4%a-mlNJFoe}>VUb^VLr zQ5-;|GFW6_3E;a7;Hb~LOz<&+jGvPLaur;*6(IH%h+vP^Pnr<~fFB|Cb_5EDQww{rwHiF1Hp9K0i3DZ)wbKCPS%Yq~rfzhKxxE0vEsxp1J_Jtb7%T5v)d_ z2Ek?`0`MsTUwkYXDwAo{2j}D^Q(-*^yx3OJxSR3Na7lh&TGdhJCdy zhC4Ax?BX1ViHMZI)1P;^s-8s?aAmoz^tgX3`hiWgej$UkzSi0JwE;B%LombCSsVbx$Bf_^rqf>u zJ`%tS=-S@HHf0C^v}pw@k}wmfmq0>}0hB!MXJH!%l*Sldmhem+8q&%f9Lr)~$MVJ) z2yANyfj+2UD3NUig3bO`FCo~{;Ocg}TDotmcQulQzq!eun+$FIuWu(k%lsKuK?~!R zat161!88h~*6*eSfS>o40StS+AJnP!kAns9EY?4c3+|+QQ+H}*i1E=b_5Dm6U*P4| zo&-H{oR)PA+B{-deohz~$nC)A#*}q14ACpgD{F$l6F@MV@`k1t3)$8t5ZGMn1%m!2 zm%Fi{$$eXkt1%r5yIl=Al*Q51{}@5ZpG5FR@-3Mh0Al(?r3Vp zpcLVD;0r2F?hhjXU`o@#6PZFGQrO#DozMD^iMtxRd|sh~%}q^im#fjA34=Wx2Xn~G zlc}};X*=fG3&5ZC1f0?yG$|0Z+HqKP{vHav1{lRL<0km_K|NK^8%|Bf-X*YK5P= z66O#SVfm(dU75fa>~*R$1-Dg4a&s`4-z{zdzU}Spb=n{}y*W}FPNguYGJ ztsP1G{k5*{Tn-snn~s3N@bdp_Y2b^e81;_&X{Io+L{ctxo zxc*nKTEBP(^H(SNt8~b~au0`OPbT3qfK~5~4~ggnfDK6i_!uc*u-^wWjJXJ50+j;D zW>oZG%B_qy9~O=T)<8#RwB6SH@~a_i%TflQG(gCf1VbSrH^#6x3j|(o+Yj8p9k_|W z%{%La-_A@v8f$i;%i@oYCUFe+B>hAW}wzxw7pOfH= z*IVgd%ke~QdQ=WV01P}1V**ttz|s_01PLIQ?{`uT zHfLZC8pvI1ZOu3B>?HtM0_}sRA!^^)i7qr@n4y30}Ynacpac_}eSS zkZApMH>5e&k=l(6Mv{YrAP@+SOss6Qd1*Q#0K^fnrQb{Q5md+7mAfX^HvSJK@$*-5 z{j0l!uH=t*VM+Sv*X7lyRUh>Z^~&#)wf)(SM+JbRFvFZ{GX+))peB6sP9_VWi1&m9 z@&#cNg62Ews#Qgovon=hGY1dn6viy0KbjlT+JAASGr>c3$5aec5JW>42*mj>m}7!k6oJ71_o<{rMcC2)@`v;7j%j)w#Dlmu^Bx0x<>&rv~%8 zyBdI!z+?sz=#>ak72;~ffY{|qqd)Cx>T-AA_Mh&`ym4+FAOFB#r5Qjqe;Rjeh#>_; zcr*aNL42A1_2oaqVh}Rv^X`6S`8umg0C@>Qs4Ko!c??gJI8BtRMU_gZ`G#Zn8T{{ILW^mg7DI5NzsdY=Kd~ z)^(dJM*V3rD*t#FkHfhDkq#J51hrWRf>Q8n;{fof@4iR;{_D#>XZs$* zASC4T!VF_R|DEP+M8!vhSpX4lhX|fKmf%oZB%g zqk=yM5dFSXVz>ng$P$Q1hU4H53=)Fw+O8H@SN+*AI@2w{PeU-Y@gMZc-1R0_@EQ$( zIuTS0plJNbj~s-^SzOhJKfXf%{_^v0h(BKbB3nr3$)`g#BGyHv>D7sOA{DrYdN#gYzriTdIlK?Tc=`dZuB>g zs;i=~RNRShJwXDeMQD<~gaCrSWa)zgo>kX0fj=;q2Lg#;Y5h%ENKWG{cv(tv^4R+N zT%ZUDl~U7k0=56Ihl z?I3{3knq6)f9H#tA`vJ}U?>t1pW4QogtBvT6CC5~8#5Ae6nq)UaK#RE5RUpE*}x$y z&*t$@@^^^7^&Ssi>kJHDbuW#;j*41W!aQ`&9su}zkE05`2w7PhKQZHHe*)xsa$x{j zK`>;aS&xg5k3NfSS=L^dHt-l*{_M|z>`vwO)Z?Sd_6mj3`Xzrc`t}3p5X}ci^I{}%1SEiWCTKo5 zH@76ufo7Of?N5M*hf6`LZd9ECw97^xZA=A5EXe8clU^8)btK%!bscJhbd0=7DYQH= z$Y9y*2TcHG8vz7?Y57=>n(7OgKyRMcshPl{JWow!;z2-g?2K}p22r8Q48=&1KT!TzCgHN z9{)7^r@36a0x+evQ8My~Rub2i#6e19A_;9@uypjMW2ZdB%sXs3uOdy^-$cqOA7UiMa zsIqogsG&w7m;eMXbL6C@VtOPO&HIQe4FatYJT-h6qC=2f z33g7ncLW+1#;68WO-RRtWH5qPWYYi!CChO|{#3ta0ZCwui3khIyNV%ynUjBStCMM5`+xQ^A-=n*`^{N?T1%X-alpE3_79Frv z2y$IS1ITPNh_jpyVNZg=2jgR$53xE15;ejYrYOgf2wpCYQK3kX!1#C?!6QOOa9B3^ zBZ4VN1$xd65J1Seq!W9%oj#-l^~vU|G%M)HOv}o0V65Gpd>IU&`rs_+yXv6*rt(E5 z_uD4`V`B6O>MsX5<*fGh57({Ru%WA~tGPrB0pBVC%yB#5gnN2$5bWNGW*CbQ#A+cg zzRfv6b2eHnjkZij$6AbrkUvfd+iYwLyTg(#`eq9lCoP~hDHo^GF@xfnkf9MwJdXfU z8Tl}DC8?DJ(CftQtu0<>tzTpW4JnqP7YYKClV@Dw#*(+OxiUGZ2uw;kcfXC#CBBU4 z*O>jI?C>*rU}fxr?Ju|s1U^?&NkPEpay5uRq^2&}0l9bX+_`7>*6!>N2q4>`ZHmB( zK73&kgtEv+_YYT$qX|6A+!YvpV_<}1pw8dY9UJH~%*jEzgmj!d1qO}aiTDo^<|KHL zTb%%507J@2DQu}6q1pnK0ele`ZPpY3KvYB)(+Ly7@%801&ckXCb?=Bm5N>#X0r>cn zWrK9Zk23bXNZIpe*2$^|OJDn_$tz6>G>s~CimA9elQR+ofjN$z9Xoc^8vwLS0n+`- zZ}4ZiVwu@!tNA3ze18||qv8y^I!+5e3}V%rMFPWT1m%r#sxo{C>T>a>aCk@-08}D@ z8HGzc%?M;4-~I#e!9{1w&8Ot= zZdQHeS6}#J7KSc3o%|L0e7-Auh&GrunnnPhmB51jWBgV@XK_&&&N$^UBz*J8D z%<9VOd{?L?#yM#q^yd9e4+w`Nr%Ed&JKD8<#j?Jm`zv|RyeV@%y5&KhFl9N?oh2SA&1Ta__8JZomfB4V@ zj+nZ>D1UG$6XPs_1zfd!$%qkNTxad~GZNMl2v@8jh=sqJBOdJC_eHtHrL%#b-u+&? z*WwKcVjtb_@gWeE2R@^@PzkBOrqWSp3ePxq!eHd>bGo~m9PUQUX1aT}I-dN*IFCgB z4zoG{;#;@W%sLxw#UFR38ZNFOe!5G%jLoTxIc)C|ISQf;G9|IQ{=#v&k4hQqM< z7t#Y$Id!c?7!C;~qYIie6F9TQr8(hOBYB7`4G7*@u+MvAo>~yk(g%B=C3Ez@y)FPB z5^}&Na~lF$@M|P|+T+TvYOHFyb3#LEPI4OTUU!DOs?jkj$K5l2^N~6L?Hw>TQeLO4 zQ(0%DZQ26mkT4OX2p@xCM#5(b(1@C8PCcvzB!P6pyj)$TV(5lp1R0FrOt|R5A=$~a zfy~6*(Y(S_5lTh}ni|9*aP=c2aetxg&Wd=lODqS1S;`S-4z8#=Xy2-QnN0l|dJlr? zQ~J;Bf!U6aL3AlaVM8F`s+cfw`1PXvxMSPe|9&!9;Fn)2CAYkKnwzSJd7X!Z#@!t8x=UBk>eW!E1J04cj=Sdu^`_RCp| zn>zYiv}JK+rL_UX!)ze>Vwm!?%1kp1#)n=w%;GE+4dapgDI-WL7=QO7oC3`VCf)CU zb$p7eND5c@m#h}2z|~7;HV6Wp{s^mySffK$IxREl+ZeG)v3te2qW(hTR~|)q!{=hB2C=coKx-`z>~a0WfDo z^9$>JsvhSr@OsnXfm>G(Ty;l5ffESgxNL)<*ma(9(Lwu2Ouazydq*D@^jZ5~x?E7l zAUJ3N=ifVST*v&Y&+nKU6bPJr{`G0y05G%uGwg4v3gkVD5{pr<}sV;$^ z&+93anHX^nTt31D0Dbn2_d~C|a-UY=hOIn|u`g?Hy+ki2g6rin1opznUcl&g=0M|8 zGA{3!BP9U5cmCw_le)Kdcc%eBsmgwlB9Mp~7zv<_1Gdjb6F@p)<$^6nLMb%)O0!u2ro*GZVsMdi}OKK1a z2Lis^4muEG-R(FY{`Ss2_s;U2xNcqh8K;~a6(g?^S1<+}0JRQ4{J=9H^tGUF*BgQ{NuY^8E^m_^f;uB8q-``Gn-CHS991=HaaC$TWpm4Om9T)G>cy)i zfU8<7e6>E8&vyXttr_;(Y=j#UKMLNrw*bbhtE=;RN@|+ZyEboF_hI|#(I;yMKo{fWaL*z0ZUX-wDocR_V^b$8cjipr_SWw1-QO(~1n%DP_Den8KkxpsTLNfa3~k>81Nu}N z&|;gdeKwjhM8aH@Kp7xpRfJ=5NA-3~g1InE{>-s3K5+uUP(}tuP}87;Uv(r9K@#Zl z2kMqAzh_d7(}!l%yrC_c30!g?_;UrE2jl|RA&)QUR~AqFxb~_Kv2RrRT@^Rpbz@L> zNgAPwhLS>{c$O;vN{>u0{PpzdQ+{}9_?Nq20tJF-LWQwV57fif3At9?`7yg{4l1`7sYhl1^%ew8;hX{6U5II zch$8h0Hftvaw$0r_f2Q#_8r5EcfIwz(py0BX9?ga48xH?*^6)OfVmb%bfMZh6Z=ga z=3EREEAf+`UKK=w$3Qy(>2|53EYdNsf$Bmf!5|&+A$JebJlJ}-9zYqAq*eI?-qo;p zt+!w-M3}%C)yqXiZqj`XE}tLS!2LXY`P8z*L7%biAyY4X1aO@WzX>-^0DX8TPb{7| zapL6n_GjM;2?YMUYwOmeoa7<`*vlFmx4|4|(S>T0ji$^M zk$!4xD(MT4An%dmh1R`@K*BXhLlFoJ4J)akbv=S*M`%u%l53-L{ejY%hMyvVq29AV)-K1KFR7Z)^O_m0=@*apHw5 z#$!Sdc4Vc2zznAoJIZ~oPH~CNq?TK}9;a8Bz!}R?k$basM8ja0s0r`uNSMBN5MA#- z_P&ZgWLwfDfW;fyKm20D)RF1;-3snt0O5b}y#HyGq)*vB&IF$r`_A+!KLWso;u!ez zuRs2{ZQJxYb2}yvzbs*WY@eE7csW1EK>B$1_66M3VZXG}fS-Jg4bqxGq4_RDOR+YB zkB&fnLPq8Txk*|y?1vK5Wz$4D+Tmb0>_;1*&*d)>*}u;F3kv|CG=VGNf?*S{sT+)@ zn0*c)eXl9Xfl07s(0J2J#@N z03?AN(XZ_!jTAMF2>xPs$Rne{ZKyocoIklY(pR(dSTNN%eWwd@WKi|A_#3N z+Onx%r9;3cOkpG9+@_`?pG^MUG2)g2*_#jCu6ktEn@c+HNG~}z;740rfi@xu?><_IA-2q(Pu-mchwjDhXv;MfbZ^q;#190Sps8V0}!$3Tm19XtrK zF@Xg3+5uAxhRqjTYw5rfw2?U~F6yMa@804FDyPyHNXhj1n`JBN>iY}4K@<2Yn)5sF znR;JcQ!S=*2Jf3dWXMx%o;++2>7(P7Ezj5X(%mZj@HwM8&+iN1qS{&?Zpv$(pbipM3`;wxOduU`|Jy>F zI{dY0sxPf<#W>tdcwuP*y|tod`P`}{oipzrSy)hjgmhpOgSh|D^v#}jHxGUk_!E(D z@9>lCDFkCSHHiO#@IV-+MBXKog;CPzn?G^>{ClEQ6e6AQiMSRWl{oMx^&*&)^hpBc zA#ilj#w}a66m5*Pt^u2GhO-$1v9Ag97--cOb5=C@)B9o5K~PrZlp`jAwp(JfU&9h; ziiP9UPr-I+Dhdcz%7%mlsI+uc6BcDvx5EC(p-f={OJxe6b$RE^)~S~^fIol0w<7Vs zHhukrrq68uVElStaib%SifQ+VTAPqezuY@_?OM^zGJigjFIo~1QUtb~s5Qta0HNp; zPmNV-e;D}9gWAM9_u{1drYaX|c}ebixt`ZL0;VjwQ2Tfc470Ata1cBt9D#(;fFPbY z)g2gtKhun+q%SOira&YvSQe2GY8H?>;e${d+*k|wJRbO9q1E@n1bRPM^LTmUK^Ni> z*m}oE%*Ph^np{P1{I{mB42>>?4=-&Wi*CIHP&~oFS9FwibgW$~0|%VmG|A4swLtOS z`5^?~?b@(m9S(#sila#+#HmpR;MN7V-)5o^HsvKRyD zLt~aW2AZ!DjBH!=|;uC?|Rc|0!; zlxnT><~a|)JHG$yWiJ<1Pwkx4(p>0;7sk4d^8d#4y|zGS`UHE-I1BYH?V`Kw114=M z$dl+nLH?!)2}cS6Mr9V)PQ0-KjR~SjCcR6*h@CJ;_d674z)y}m@{_K2X3eAKTzv8H zJ4Y8!o}7x<4Y?Sb3kmsU(S_Q_k{q&uHnDK1B!{mEZkVIO9F$rII`J>^*0buTJTP!kw1`_|Z17>UB0E-K3 z4DHcI0lu&nFqD=ynLn0A%oCcF0un~hib3U!bvCewKG^4z06w=|Bk-O&pWo%#tU+?x zz*U{C_m6DwmBIq%_0b7)bmXa}YZg6u#L@9(WnoW!1Td2CjWYiMkTcjoagie!{7`eg zdlM+7;imwcCDazOhvOa-MG|oxIwU}>@zgZG5Bz@Iw*A|0xBv9;thVCe!-tO^Jvu2l zEi-rQbDNH|FTtrXB;1|x@3k^`4D4^!n~GdaPrGJG3_?I6B~esi8Oqstf(~ z0OB>3H8BEHh!5gmcQ6V`b5#OBPo8I{^1%fN zgk5<@gA?fFB`qx@FRf`nL8P+iFnKhD;-?)w;JDe6#g|`xZOxjM%a$!#v|z!LPtKlw z+z|&4=zmzhY4H+0#U5pFx)MNP`%a2P95GX0+B~I>NNN5w4;+JE+SlrifW*&4ANU(1 z`I7`bD)oxg10C?l?~fVt!1iCidmj3J%)D`vheIg|IrkQ4yX$Ai+V)`KSB`=FTfKY) zOLCFcrbapjo*ZF)B^i{r3R4yx`uGq&;s-@;(E@dXKsm4^`2&ET22X*Ga{P!I<~DVjZ1>yRp==e6PwkD8zq z9}e|uaB8+sBM|QwdIF6IYWx}P-=pMD=;^u0<~(rom~Fp4FZAqhW2TJjSlh97v?Dts z)8Tfn7mM-{fLMuRApc&8%d}&lF%&j3D049kg{_x%=q`3W(-uIIM*{grG=gRVL$`aW z2R{G9dU<&q2+XbWV?Q(q{NkKdVl_ec1wF9T$y4Xr89)<$0Fd}S@yZiVJT3mxF*pDA z(eqNj+&pH^T#?2}NzQOXjt>S5IQY0D4v*-A`#J_UU^yXbE;O4lG8m$aVukcXrl!x4{Jq05_g6P&T68-%wLhgwha@88Uwd1zc!6vkZ=&7?6k8O^kgPE_2r#s!3L25R zq1LU&z19tBD{g4SwW3j6aG^m_Tv9cy20}$N0gHnEA@wi8dd|J)%{O=M@}@Gxd~ep* zQN;fE&OPUzd+%Y($JFkSvHYRF=hlG*>vJa?@&d`7#E!zAd}QhN7j1z6FfsAjYF7+a z{D}b~{IY z#u_C(40N>UJGbs-0Bskv27*Y2QyK8*b3pLMH2SL2%$}oBP|{(|VS2o`=fbuwI`pD# zc{ZP80*9vdW4*iLp3eT>?9dJZOnSxOAXTH6=d0CM=E@4I1?IfeY4rHTZGE5(7b$ux zlIBp}7`NgSSoFiQJ_mI9-x|aHA`PpgWO#VX2G&Q<4G%wh@71IeT3QIeu1k(qBbcw9 z?8yMIvJ>d4eQhD#n)* zLW;s&`_Ief@*IPEzt^L{hZiqdN{fWQJBxuKTs9p!M4kHMyJ$;<_K|0t4%A_nKZx~d zg6#1g!mSMOz?@(wC1>q{Vt>ej`5Qgy?lKT4RcO-yS?2Ih@S_U^Qe+<9{Ks>ezztir z{PhR@dgX#HQfvEHci(*T`Dgaj`T#r!6_^SR18ZHwKnTM?iNO3C-o?B)9j=SrCP9SF zdu;-RJ?pf8Wqc|X)}0EQdfWK;u9lX=J97hE3d{{=^pfCZ_uSH%&vx%swfk9jP%3-Y z`ify^397%ljpo)ekKrmG+Q!5e|G46ahTSPcUM(;*lJ`Uir* z&zcrNYm7ex3v%_q&xN1>Xj>o`W1V(^`dZ|x)t98a*F z>_`;e!3i$QhguPI(H+pI(UVe(KqX?79O1VLpS9H$91kGw2mr+b>7m6ujyzjLpiF|W z0U`ko^S2*ePtyH0f4%bJE7}5IByCtj*j?I20CsHMG&MQu&x3JLOPng-pue{jf&*vS z7^{`gpc>4BB_sn#i$KqY(!7b5)TWaON<+9UtEH=bQBG+z5n14Mw~#9iRmnCg3?We% zVW==e+cG}MpUMZ*^iaD)J+I($xG&7{6Bit?N+gjM0Lm~>-{B_>0|h@L0<%p1fS(fZ zqyH06((TQ>HeC_;%8T^zkKr}LEB9~fYW-wxEy>Ft65xst14C&DCWXTeXj>pZR5ixi zqOY(6fL;d-E*x%@HyvDBo(Xfde83aREq3VeqHmQ{vBj zZ^Mhq0>uNLV^+Ux?P~on*oDyicI6i_ph15pTi!$cT_)(68cxZqT>KCaRT&R|b(suS ziC}PaTIulC!?NUmeS@PZZGj)tF1@!b>>Ha*;;$EQB*ohwEJ$m z5uwp2kA-~I7|sb)NG|Ya@~84iGsw4;=SBFXkn!Tbs=UL4FkXNPA~N3_JANp!Z~l z6kjS7(_S45+S+J7m}780*V}Q;L}IHd1>Q-C@IZg|iV6G;)BCooxMBhxk;($l!fk-d zk3``f3=lPElR)zh1>w~$2LyoN53v{M!3KZ(>b+k#Y}r7MpTD~M*A1KL*2)$yfQmn? zf~&NinD5y#%Z$MWfeHgsuin+Kj7!( zZ&=WSc)RRPoA{6JefGC4Teuwf$LiI;ZTW+L`S?*UIrqcY3eUh*s6ju>V37hJi>pI@2oOVU%&*l~_9Ze=I zm5U(+C&OSaA-|Czt6%7DSd{7lKqrCl>cpJSoG};xGJcWkfj-VaK7M!8#7HBxyC?!z zK0Ir3^3Om2{AcpvZzp&7fGafZFZq{Z;xMkt3+puP033go9E1)n=e%6Tb7T@d@D-Km z*g)*2Fx}TB7C4V%@BkfFvoTq>ZXIpQ@$mgU1YqC7@nn?^AHBziUV|kg!R5>PIbZQg zPy;ytwEiJQ5FYc+!BwIKE8L(j zQ*xygmS`3%&|0jiUp)muhA{0~6Nl|mF%bs_;}En>F!NosR}9DCK5DzacDx}`_wKtN zth(>W-Yhl6^yl|#NVtXU#1O)+<0eRR6AUl|_-O*&Zd?Lo)3QP#8m)lDpH6;Fts(vR z!CGREFe};ubz~E1q_7iZ^-kN@j!u!9z5Im};L0`VlS}~B*^sQW zsxbO;J75?=p9AiBqr~9T52n+(s1r5YJwmf3n?7Fk?S0oA*^|q5_I>Y&!KTgF6nZei z$d6rA!mdqBRfWr;?JH=)qAP`~Ll)Y#jtLB72lR|tc-vNQ3 zrOd`C1|NRFK+j+YHJuL4PA2x#|3`c&z2(aM!8KY#68MINKqSlvjXH^|%ynhA_Qh=Ds0dFS zuH0CY8ADY0WbzW0l<+42!UHWE^@5KM*S%+n8XtI|F_L1?az3FC00yRhkpNGFURNht zqdOHUC2%ii6b5~PMpYvfvGFUPzzki2V6r3-XTA^%TyPAZ)PE~&%y9)Z3})7@UG?yN zcipq(fNWkyf$Tya{GfvW=}$@?0?Vb72Ywg5 zp8f>b9tN_k{i8#T33d<;?%qME2=D`eVF#@DL5R{cEO7mWEqbrsckmo2{Tda}U$C~xwX!q(Gq5DJH|NjoCj8AQSW5cH9lKUkp1nu-e^ z3MW?$hhGnYRs>e~33~9qPyh}tLa#i5Tn!8+A~u1r!1(Tjs6{XX3oI`L*SnJt4Io0X zC^Ip3ExYR^vcPP9y;??+#Mr+K6#xW(h8Y$GaI1bS z+@=aLn8O$F4Ha!v)D*qRAv6+x!QU*{pY4Cxh8G1u$U6f(eL9wAM39JpKK=9(s9%b$ zFHn(bPzc4C}`%y>2&?N%i3teiS~of?HlOn$!2pST&2G6z7@-8PlW3a zTa+aLqa(3d<>QC%a)B{h0qGkS*;S7&JiU50V4<#Q^8*VcH)EbfAa9pS;-(VR6-#02mFM`vcs>YPdT77Dc-uct!2T{ht9iX z-qKw6m7V>A>h9zU0&wv?w;s@+9VPHm%%AvR7{Rc-03!Q<`Pl$6r-8X(EZ>9&M&TFP zF9-}fAmCH^F94jL7W{YusT?>T%j8T7Jp|VHmdl|(!8JD`1jY>k8L)a+5GLs$>?q8n zf4IDD|3elWeayk_m*sLDiw3@@yVR?&Y}qyUFTSGZmW8?5Nl8M4!LlQ<+vmInh>@Qk zK>9(918M>xK~Tak;_s}$2v8dz9FX~&F$pkqF%#HqM`0wxR&!Q(Fxcb~QV|%SgFh#J zndoAfNr?`{nXyybxDa?~`@su)bJUWRyK7xNr6E7ufB#*JujuGJ^m`rxa{C}>11b!I zzzPfWIRTVaZzh1Qa$j)*qBCHGU)um}|5MDt0n6g=!|kR|XisGV0ieG_rdm7@vf#sI zi&w2f!FFieG#$`1z+x)XFQ3}Nd&JDU48-CeVu?1P??s zc6DcYT+nKd@z((B0~UyzOr-$0?{}y2;v|Z}2hc=FPj7E`?_ukykMNrzTB|<3F~QtP zFTeVzvJuoqFf6P9d~N{E78u!y!wX;&wyl)-8--uUNI=8_8;#@_1jh1Lix&K8*CQbh zG*=6QK+iA`ds*65L9Pg9{0N4X(JItqNJeHdbyX|Rg&f*`$;ng(9H6t34jUr?CpL~w z(9$JrvBd;x@UZhf(Wrq;5kvnK#8ixA2uhL!+h2xe_~bNxu|H(Lo&Zeb*#-#ys^k3+ zYtW}(Pd`>0BoLYf0@0h`5*idtaW-tu4VOh9u8h#9-w~XR7@ViMkiP!zrX;rZXaIgm z$wkz~cM@cfnn%QCC@@%Tfc3Gl1_ig($l+h=pYn(d(G`A;3>fBTB){l?3Gi1tZ3%FA z=fL^lVIbNAHG^0!=b0Y%#bTIXxdp%iWa?Q;rKo=V6z+`HF*u*=$YlreqsFERMFJ4) zf%KvrN`NGpaWyXP7S;zCvWkGr_lf(dla zh*%y705LNR1j|W^lh%sN_(@gm(Cihm z=ZM%w=x9*2!Y~wEH;0-aj%|fP>yktzv?2)aI~K*@gO}$QpM7O-e=hfp(Z~-L#RQo= z%@^|{l<-#;|B>vBPDrNcsd&`n32X+@rC#S51Zs)}K%p_=7z4^7gjKsavR_la2I%Ut z`8%iirD;uo!vwklF9f-g=R)-)MbhpV^E|@CU0o4-#8b(+k@QdSGsVC<2!_U2TAirp7&k~8jX$_r7 z;Rg&I0+uF#KoGDM;}|3m254KLZGecr+n9lf{n4Xb{-^*ZhG-Q2UA>@L$rYVUI0{=NnnftrkxJR4ym^w<%S7k z`H>p@?GYOMK^B19#9#BSZ)`WybPyIQd1N#Q0>K(U2n5|6VxG80Dl(2^<1}DK!Nu31 z1#Rs)iUxsPjJ`IX>sfL@-vkWLG{B$*GG(9*H|b8Qgf~|q8mO6;Xb@PuO|RjAa~MB* zmIq*?djhUHmcLr!@3AMgn|abpog>9vE}B=3XEW68b_JiDV@#yCMVsL$b?!g&mN8F!AQ8 z(S$HHMlU4;v?Fj?;TK7V0ia=cApl{44g!tD!*Iae_Ga|JpKXFBhIRo))LaICPt8pJ zgutD_UwBr827)L``)5bOLL9cd57j<<1SkT4m=Na1=%I-^#T`Phe1-%1@KgS0I3Vsg zW3Z>@=L*5V59`}?_x+wso&wN^Uu+9PwYuW@tDXDQBiqeXo%VuwAj=hogEi?IlS0T@ zpN1S$BbkUdUs=%Pk2F!#l$9GV&N)P}B|xyouSh}=gi~aMAk8qu&?pe%9N6Axp=r)u zM4rG$6uMfl_$|B$3GJt0HNSB2b~9HOgFt}iqi~0aKzCVYTf{92pbY%U7ds0!3FKQ* zQx;ob_ejQ`2#HTXO3Qv$VqyRX11vuV^PV{v1j^t~@Y5#fE5n30BGQ^$PW!M#4Yc2zEt zPWNBcySnQ*@_hVjueI0SXV_*OjlZ!2!lb-10Fje}0AK)R$KT^>eu@2=1EvM&y_)^C z!e7&@So>8?-$^P0f%$?A+B;DM_4W#bI*vvV0?Cf^4*2BJFK@pp|BU+Uue|=H(T4+K zf#viU+AQK8F$D}4f=!u*z_93-EXBzL5d2aEQZ}m ze<(bdL(0(m^BMdl5qO%(K-?It;a62dusyho?JvGwyI4*HcSj>gtsH=F#JmV&JDo|oQgM(YuP0QkX5)y zS@`6G0c!Z|(DMMitL;}+{d91qtqG=5O;83>kzy|Lc=XE{vs=YJcczQ9tm1)*0djoX z(Xaph{x&<0gWI4lSmGdrF^uED$G7XpCLXjKfMzz#~-|RCGXh53^XLD{Rs%5YIZHf%Py-``HN0yIY9H)?BrMtv|FfEVv_ z{26eG17dc{S2hV@|A!R<3+uiHV9@{{>oQ@jInMwvF+UlGZG!_p!mlytc;Mttw_nwD zez40#utZ_7ETZ^9Ab}Y0ut6OW5b`1@Pt%N_%mM2T2z;a)!K)*C*ZZH0%T9#sIWS6~ z-o@6xKK24&A@wi^B>blFHv~{R?*Y(J(>}QC?JxBc1epj@6HFDGARsD#PUotlWQio@ zFF!9N;>NBR09bH9=76*!^olGn49Zao@|YzaC|`S-X3KwG2ZD5-QhjkWIasaobE!Fj zi6FU1S`l0fTb2$Rg@oaeA~1&xM#{pM)}bvcM%MRdff)Y1l_4hpB3jh| zJdyaTcj|z_j4w36qsRbLBA4_9kyr0h?WG{E{iRj%>gAXA?BU(-{u|z$&|D(0+_1ke z!fB77lBWYrgJVT{Y>~TTkjR>+PsvN#p|jBzDSBp%MBt9TzE6;h0|Ybs(w{TygFhY2 z0)Y<)%*d(#J!QeSsfFLKS8esJ(wxu`gT2k`dhaZgk$vK^-dPhou6bZwvf0hcx{t8K573HT)Wae)#=yg{v>Qnl$Y( z?((ompO%b>0FR6$hLGP*#l|j`Mr30!51`0Q5M}#bLa75nph6c|;*ozGG}afYwDwVO zzBmC%T3<|l_)a9wBS9Lk;)E9mgg_QStoOxsW%43_9<-EO^47<~M<5qI8Bm)1f}j_E zSG)K;yg5NeM_W_yi>iZtb%r91y}pViLLprTkppHnWa5A@x$OPZhtd`%!SHvf-wNxC z=p*$gSD^j%4{Fb0^sV(?ov+5;kJ0*O(VRd%P@iS7JRCqs*PuOYTzmu3v62%O#L>Vl zc0fWfZ7-3IJD$&HbNS2)@i%XP{uB^xgWu$`_cfmV-PxmG zuCGtECe6VK{lg)UicD0VwkMbU88;eDVTQl{QW1pXB21ESSQ28E2z!d&iE=08sb+(c z5AbtNKu97mlRsO2^AK!Tf9-8Za{}@$&@sV%=7iL%ks0Kbq!Bp=5Gr9V`e`>gu=-$h zG}|*zfF9ssfN}lBLU3*O2LOf+$m9?IGY`DSYj1u{`hc4Bu=QaoGDzVv(}PL0;IX*R z={Q@-F=*|(dSN&*G$ca@oNNKZbIry^NpA zS$n%z#lNh+VZ)r0kFZEgMdkz$1K7%MLC~%~O@_7Z378?_nw|)UYF?BEA?+x$Ho{79 zC-px=PuNQ}b2QB&SnxiPJxM?G!Q{{2bNFkU`?B^{Z%$|<2eiB7gokta3k1L+5YkIy zH{=EBGbywUa-~}2%0gp++{^zI?h54DkZNk<(C&K@ z4-`%Q$bXwPS$p%t@DFFZ%_NbRjmYw_A~=S;aNJ*F2b%l|IW(o71i-51d~-D`2<|-l zoBG9P98fENz~`5r+ui%N_U7uX^RLeyZ3h8}_T~aXZY~HSHljYX@`>5xQ4sJQ+vP*3 zmxG|6d`w9|W`600zCL|;kc-1!({zs@257c) zj4j#^*A#u|_-fpNoCBXdJyCMb3-3dFasNyG^tGV!*WGRca-Y}Uh9n~?f|Y))?b*_Y zdh_FwchEUTz@~S^+`!47pZSVB2c+{pUHEV|A0-Uh35dRdGj6NzaHu+=zXIClt{rY^J zihm=L5nRwiP^vPX3@IbgzzR6SHwoQ?1$*|zRo47KklxcE`(+MTI(-)ew+Oz013J<7 z;rbO{2Cj@CKPAp12v_7KJ#IWmG>E$OSj%K@=+tk<A(u~)=r#y*w)`- z^ljE?>Fo+DBhV+KUzN!Vl0cBRK5uo`tEaxm&pZ|gf?vGve8}E&$vooN%a_vRZ{3zY z>d8p5GNR~p<5dw%@dG{vSH1(vJ;?dymCQGL;+i?2!ymCX5tZkB+x<_Q$zLTIArHX> zoxc(EW=0V$OQEhuaPQ1E#u`<%OlN3Vpzz)J&M-cduY(wta#194w{x|I( zpf|<=``An30evQC^7r#v_8p8&MjR7lC#O@DxjuiG9$~j80nHxOMK^ELr3|FMxxGik z-bDCcQz$&I!={hxCnH)EZ1u_l=pnd0RaCBlp5X_2raFIEpI>sG_UT970ryN(=irTO z;=!}j^|HGuJP%ApPBFpOFHP4OB%Br$fkvb?*Ry=L&L0jax1*_}Lf1VzKlUj46%Z^= z{(js{{zfJvdZDy&)nyQ<_#V&b029X}%iu9UGrJu0NPaoe0|mkfO(k#IxBS(Y1!Ii)dJ*t=3vm?_Sa~Uc-X^%Pd*r6dfLvt0)lo5|mJeR}?_AQ)vMPF9kjI8h&*g`nvpD(Mk)Peu}l`3;me`P;(*BmPvf z;K9t#bHH5J!KUzhMZGTxF(&9raI;6mpf+vIF+V^9Kp#B6*39IuHuGcJmx}NkvuD<~ zH@w4KYZq1@Ov^_Lg^43orTeUex;9#^BSjAhSIi zQM2hfxG{bAP*Y+&uqNz~bAV-h>Y5BHdIn(PfIfh#l)ZB48S5*!A3Hwm?u#$rx4S7k zqnoM9SW4_+NE^TSV@xgWaus!S`= z=irdc!ABO9T?#4jEk>U>A3@BozUHRZhngYq1>jlT<>0|))$PVbRmL*nQ7+gamn1Ak zVQhCnn$dmbJjvnk5qY!t#WO$M(>pQl>`@nS7xQbsJ>RT)+@fCl`Nvbmj};R6TuXRi zaCE(lhTnKiIpI00DC%x0d&Jz1y>jxwZu11cRuYe&Hw}NcCaN+LiD}n@#_qa`@Iw9I zb>zB3ZBIvk*^QdZp`xdDjzG|Xudw8+iN_}Iy>(HQvD&mD7!lcCjY1(~K_DS`gpyiI zsqj)j&nho9k@NWA{%=v0*_AF|s)F($AcN#pBl6bvp>EIc>@L~f;ietiTcDT9U+wT4PMh}Ii%rq* ze-!`P!$H$7HQ{UUlD5|Vv7AaW+JiJvLa*EXh^rXw|*RqM0f) t!E0N4GLq%mZ13_pL+P&HgH_&X^dHq2zE{~Z(-i;!002ovPDHLkV1mRBr;-2w literal 0 HcmV?d00001 diff --git a/pyproject.toml b/pyproject.toml index 9fc84d903..408e3b773 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,7 @@ dora = [ dynamixel = ["dynamixel-sdk>=3.7.31"] feetech = ["feetech-servo-sdk>=1.0.0"] gamepad = ["pygame>=2.5.1", "hidapi>=0.14.0"] +hopejr = ["feetech-servo-sdk>=1.0.0", "pygame>=2.5.1"] kinematics = ["placo>=0.9.6"] intelrealsense = [ "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'", diff --git a/src/lerobot/calibrate.py b/src/lerobot/calibrate.py index 37a9d5bdf..1e8bf4751 100644 --- a/src/lerobot/calibrate.py +++ b/src/lerobot/calibrate.py @@ -36,6 +36,7 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + hope_jr, koch_follower, lekiwi, make_robot_from_config, @@ -45,6 +46,7 @@ from lerobot.robots import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, + homunculus, koch_leader, make_teleoperator_from_config, so100_leader, diff --git a/src/lerobot/motors/calibration_gui.py b/src/lerobot/motors/calibration_gui.py new file mode 100644 index 000000000..9832a1636 --- /dev/null +++ b/src/lerobot/motors/calibration_gui.py @@ -0,0 +1,401 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +from dataclasses import dataclass + +os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1" + +from lerobot.motors import MotorCalibration, MotorsBus + +BAR_LEN, BAR_THICKNESS = 450, 8 +HANDLE_R = 10 +BRACKET_W, BRACKET_H = 6, 14 +TRI_W, TRI_H = 12, 14 + +BTN_W, BTN_H = 60, 22 +SAVE_W, SAVE_H = 80, 28 +LOAD_W = 80 +DD_W, DD_H = 160, 28 + +TOP_GAP = 50 +PADDING_Y, TOP_OFFSET = 70, 60 +FONT_SIZE, FPS = 20, 60 + +BG_COLOR = (30, 30, 30) +BAR_RED, BAR_GREEN = (200, 60, 60), (60, 200, 60) +HANDLE_COLOR, TEXT_COLOR = (240, 240, 240), (250, 250, 250) +TICK_COLOR = (250, 220, 40) +BTN_COLOR, BTN_COLOR_HL = (80, 80, 80), (110, 110, 110) +DD_COLOR, DD_COLOR_HL = (70, 70, 70), (100, 100, 100) + + +def dist(a, b): + return math.hypot(a[0] - b[0], a[1] - b[1]) + + +@dataclass +class RangeValues: + min_v: int + pos_v: int + max_v: int + + +class RangeSlider: + """One motor = one slider row""" + + def __init__(self, motor, idx, res, calibration, present, label_pad, base_y): + import pygame + + self.motor = motor + self.res = res + self.x0 = 40 + label_pad + self.x1 = self.x0 + BAR_LEN + self.y = base_y + idx * PADDING_Y + + self.min_v = calibration.range_min + self.max_v = calibration.range_max + self.pos_v = max(self.min_v, min(present, self.max_v)) + + self.min_x = self._pos_from_val(self.min_v) + self.max_x = self._pos_from_val(self.max_v) + self.pos_x = self._pos_from_val(self.pos_v) + + self.min_btn = pygame.Rect(self.x0 - BTN_W - 6, self.y - BTN_H // 2, BTN_W, BTN_H) + self.max_btn = pygame.Rect(self.x1 + 6, self.y - BTN_H // 2, BTN_W, BTN_H) + + self.drag_min = self.drag_max = self.drag_pos = False + self.tick_val = present + self.font = pygame.font.Font(None, FONT_SIZE) + + def _val_from_pos(self, x): + return round((x - self.x0) / BAR_LEN * self.res) + + def _pos_from_val(self, v): + return self.x0 + (v / self.res) * BAR_LEN + + def set_tick(self, v): + self.tick_val = max(0, min(v, self.res)) + + def _triangle_hit(self, pos): + import pygame + + tri_top = self.y - BAR_THICKNESS // 2 - 2 + return pygame.Rect(self.pos_x - TRI_W // 2, tri_top - TRI_H, TRI_W, TRI_H).collidepoint(pos) + + def handle_event(self, e): + import pygame + + if e.type == pygame.MOUSEBUTTONDOWN and e.button == 1: + if self.min_btn.collidepoint(e.pos): + self.min_x, self.min_v = self.pos_x, self.pos_v + return + if self.max_btn.collidepoint(e.pos): + self.max_x, self.max_v = self.pos_x, self.pos_v + return + if dist(e.pos, (self.min_x, self.y)) <= HANDLE_R: + self.drag_min = True + elif dist(e.pos, (self.max_x, self.y)) <= HANDLE_R: + self.drag_max = True + elif self._triangle_hit(e.pos): + self.drag_pos = True + + elif e.type == pygame.MOUSEBUTTONUP and e.button == 1: + self.drag_min = self.drag_max = self.drag_pos = False + + elif e.type == pygame.MOUSEMOTION: + x = e.pos[0] + if self.drag_min: + self.min_x = max(self.x0, min(x, self.pos_x)) + elif self.drag_max: + self.max_x = min(self.x1, max(x, self.pos_x)) + elif self.drag_pos: + self.pos_x = max(self.min_x, min(x, self.max_x)) + + self.min_v = self._val_from_pos(self.min_x) + self.max_v = self._val_from_pos(self.max_x) + self.pos_v = self._val_from_pos(self.pos_x) + + def _draw_button(self, surf, rect, text): + import pygame + + clr = BTN_COLOR_HL if rect.collidepoint(pygame.mouse.get_pos()) else BTN_COLOR + pygame.draw.rect(surf, clr, rect, border_radius=4) + t = self.font.render(text, True, TEXT_COLOR) + surf.blit(t, (rect.centerx - t.get_width() // 2, rect.centery - t.get_height() // 2)) + + def draw(self, surf): + import pygame + + # motor name above set-min button (right-aligned) + name_surf = self.font.render(self.motor, True, TEXT_COLOR) + surf.blit( + name_surf, + (self.min_btn.right - name_surf.get_width(), self.min_btn.y - name_surf.get_height() - 4), + ) + + # bar + active section + pygame.draw.rect(surf, BAR_RED, (self.x0, self.y - BAR_THICKNESS // 2, BAR_LEN, BAR_THICKNESS)) + pygame.draw.rect( + surf, BAR_GREEN, (self.min_x, self.y - BAR_THICKNESS // 2, self.max_x - self.min_x, BAR_THICKNESS) + ) + + # tick + tick_x = self._pos_from_val(self.tick_val) + pygame.draw.line( + surf, + TICK_COLOR, + (tick_x, self.y - BAR_THICKNESS // 2 - 4), + (tick_x, self.y + BAR_THICKNESS // 2 + 4), + 2, + ) + + # brackets + for x, sign in ((self.min_x, +1), (self.max_x, -1)): + pygame.draw.line( + surf, HANDLE_COLOR, (x, self.y - BRACKET_H // 2), (x, self.y + BRACKET_H // 2), 2 + ) + pygame.draw.line( + surf, + HANDLE_COLOR, + (x, self.y - BRACKET_H // 2), + (x + sign * BRACKET_W, self.y - BRACKET_H // 2), + 2, + ) + pygame.draw.line( + surf, + HANDLE_COLOR, + (x, self.y + BRACKET_H // 2), + (x + sign * BRACKET_W, self.y + BRACKET_H // 2), + 2, + ) + + # triangle ▼ + tri_top = self.y - BAR_THICKNESS // 2 - 2 + pygame.draw.polygon( + surf, + HANDLE_COLOR, + [ + (self.pos_x, tri_top), + (self.pos_x - TRI_W // 2, tri_top - TRI_H), + (self.pos_x + TRI_W // 2, tri_top - TRI_H), + ], + ) + + # numeric labels + fh = self.font.get_height() + pos_y = tri_top - TRI_H - 4 - fh + txts = [ + (self.min_v, self.min_x, self.y - BRACKET_H // 2 - 4 - fh), + (self.max_v, self.max_x, self.y - BRACKET_H // 2 - 4 - fh), + (self.pos_v, self.pos_x, pos_y), + ] + for v, x, y in txts: + s = self.font.render(str(v), True, TEXT_COLOR) + surf.blit(s, (x - s.get_width() // 2, y)) + + # buttons + self._draw_button(surf, self.min_btn, "set min") + self._draw_button(surf, self.max_btn, "set max") + + # external + def values(self) -> RangeValues: + return RangeValues(self.min_v, self.pos_v, self.max_v) + + +class RangeFinderGUI: + def __init__(self, bus: MotorsBus, groups: dict[str, list[str]] | None = None): + import pygame + + self.bus = bus + self.groups = groups if groups is not None else {"all": list(bus.motors)} + self.group_names = list(groups) + self.current_group = self.group_names[0] + + if not bus.is_connected: + bus.connect() + + self.calibration = bus.read_calibration() + self.res_table = bus.model_resolution_table + self.present_cache = { + m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors + } + + pygame.init() + self.font = pygame.font.Font(None, FONT_SIZE) + + label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms) + self.label_pad = label_pad + width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10 + self.controls_bottom = 10 + SAVE_H + self.base_y = self.controls_bottom + TOP_GAP + height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40 + + self.screen = pygame.display.set_mode((width, height)) + pygame.display.set_caption("Motors range finder") + + # ui rects + self.save_btn = pygame.Rect(width - SAVE_W - 10, 10, SAVE_W, SAVE_H) + self.load_btn = pygame.Rect(self.save_btn.left - LOAD_W - 10, 10, LOAD_W, SAVE_H) + self.dd_btn = pygame.Rect(width // 2 - DD_W // 2, 10, DD_W, DD_H) + self.dd_open = False # dropdown expanded? + + self.clock = pygame.time.Clock() + self._build_sliders() + self._adjust_height() + + def _adjust_height(self): + import pygame + + motors = self.groups[self.current_group] + new_h = self.base_y + PADDING_Y * len(motors) + 40 + if new_h != self.screen.get_height(): + w = self.screen.get_width() + self.screen = pygame.display.set_mode((w, new_h)) + + def _build_sliders(self): + self.sliders: list[RangeSlider] = [] + motors = self.groups[self.current_group] + for i, m in enumerate(motors): + self.sliders.append( + RangeSlider( + motor=m, + idx=i, + res=self.res_table[self.bus.motors[m].model] - 1, + calibration=self.calibration[m], + present=self.present_cache[m], + label_pad=self.label_pad, + base_y=self.base_y, + ) + ) + + def _draw_dropdown(self): + import pygame + + # collapsed box + hover = self.dd_btn.collidepoint(pygame.mouse.get_pos()) + pygame.draw.rect(self.screen, DD_COLOR_HL if hover else DD_COLOR, self.dd_btn, border_radius=6) + + txt = self.font.render(self.current_group, True, TEXT_COLOR) + self.screen.blit( + txt, (self.dd_btn.centerx - txt.get_width() // 2, self.dd_btn.centery - txt.get_height() // 2) + ) + + tri_w, tri_h = 12, 6 + cx = self.dd_btn.right - 14 + cy = self.dd_btn.centery + 1 + pygame.draw.polygon( + self.screen, + TEXT_COLOR, + [(cx - tri_w // 2, cy - tri_h // 2), (cx + tri_w // 2, cy - tri_h // 2), (cx, cy + tri_h // 2)], + ) + + if not self.dd_open: + return + + # expanded list + for i, name in enumerate(self.group_names): + item_rect = pygame.Rect(self.dd_btn.left, self.dd_btn.bottom + i * DD_H, DD_W, DD_H) + clr = DD_COLOR_HL if item_rect.collidepoint(pygame.mouse.get_pos()) else DD_COLOR + pygame.draw.rect(self.screen, clr, item_rect) + t = self.font.render(name, True, TEXT_COLOR) + self.screen.blit( + t, (item_rect.centerx - t.get_width() // 2, item_rect.centery - t.get_height() // 2) + ) + + def _handle_dropdown_event(self, e): + import pygame + + if e.type == pygame.MOUSEBUTTONDOWN and e.button == 1: + if self.dd_btn.collidepoint(e.pos): + self.dd_open = not self.dd_open + return True + if self.dd_open: + for i, name in enumerate(self.group_names): + item_rect = pygame.Rect(self.dd_btn.left, self.dd_btn.bottom + i * DD_H, DD_W, DD_H) + if item_rect.collidepoint(e.pos): + if name != self.current_group: + self.current_group = name + self._build_sliders() + self._adjust_height() + self.dd_open = False + return True + self.dd_open = False + return False + + def _save_current(self): + for s in self.sliders: + self.calibration[s.motor].range_min = s.min_v + self.calibration[s.motor].range_max = s.max_v + + with self.bus.torque_disabled(): + self.bus.write_calibration(self.calibration) + + def _load_current(self): + self.calibration = self.bus.read_calibration() + for s in self.sliders: + s.min_v = self.calibration[s.motor].range_min + s.max_v = self.calibration[s.motor].range_max + s.min_x = s._pos_from_val(s.min_v) + s.max_x = s._pos_from_val(s.max_v) + + def run(self) -> dict[str, MotorCalibration]: + import pygame + + while True: + for e in pygame.event.get(): + if e.type == pygame.QUIT: + pygame.quit() + return self.calibration + + if self._handle_dropdown_event(e): + continue + + if e.type == pygame.MOUSEBUTTONDOWN and e.button == 1: + if self.save_btn.collidepoint(e.pos): + self._save_current() + elif self.load_btn.collidepoint(e.pos): + self._load_current() + + for s in self.sliders: + s.handle_event(e) + + # live goal write while dragging + for s in self.sliders: + if s.drag_pos: + self.bus.write("Goal_Position", s.motor, s.pos_v, normalize=False) + + # tick update + for s in self.sliders: + pos = self.bus.read("Present_Position", s.motor, normalize=False) + s.set_tick(pos) + self.present_cache[s.motor] = pos + + # ─ drawing + self.screen.fill(BG_COLOR) + for s in self.sliders: + s.draw(self.screen) + + self._draw_dropdown() + + # load / save buttons + for rect, text in ((self.load_btn, "LOAD"), (self.save_btn, "SAVE")): + clr = BTN_COLOR_HL if rect.collidepoint(pygame.mouse.get_pos()) else BTN_COLOR + pygame.draw.rect(self.screen, clr, rect, border_radius=6) + t = self.font.render(text, True, TEXT_COLOR) + self.screen.blit(t, (rect.centerx - t.get_width() // 2, rect.centery - t.get_height() // 2)) + + pygame.display.flip() + self.clock.tick(FPS) diff --git a/src/lerobot/motors/dynamixel/dynamixel.py b/src/lerobot/motors/dynamixel/dynamixel.py index d4f41643c..1113ec0f7 100644 --- a/src/lerobot/motors/dynamixel/dynamixel.py +++ b/src/lerobot/motors/dynamixel/dynamixel.py @@ -162,11 +162,11 @@ class DynamixelMotorsBus(MotorsBus): raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.") - def configure_motors(self) -> None: + def configure_motors(self, return_delay_time=0) -> None: # By default, Dynamixel motors have a 500µs delay response time (corresponding to a value of 250 on # the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0). for motor in self.motors: - self.write("Return_Delay_Time", motor, 0) + self.write("Return_Delay_Time", motor, return_delay_time) @property def is_calibrated(self) -> bool: @@ -190,13 +190,14 @@ class DynamixelMotorsBus(MotorsBus): return calibration - def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None: + def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None: for motor, calibration in calibration_dict.items(): self.write("Homing_Offset", motor, calibration.homing_offset) self.write("Min_Position_Limit", motor, calibration.range_min) self.write("Max_Position_Limit", motor, calibration.range_max) - self.calibration = calibration_dict + if cache: + self.calibration = calibration_dict def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): diff --git a/src/lerobot/motors/feetech/feetech.py b/src/lerobot/motors/feetech/feetech.py index 7edf869a4..88d45ba39 100644 --- a/src/lerobot/motors/feetech/feetech.py +++ b/src/lerobot/motors/feetech/feetech.py @@ -219,15 +219,15 @@ class FeetechMotorsBus(MotorsBus): raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.") - def configure_motors(self) -> None: + def configure_motors(self, return_delay_time=0, maximum_acceleration=254, acceleration=254) -> None: for motor in self.motors: # By default, Feetech motors have a 500µs delay response time (corresponding to a value of 250 on # the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0). - self.write("Return_Delay_Time", motor, 0) + self.write("Return_Delay_Time", motor, return_delay_time) # Set 'Maximum_Acceleration' to 254 to speedup acceleration and deceleration of the motors. - # Note: this address is not in the official STS3215 Memory Table - self.write("Maximum_Acceleration", motor, 254) - self.write("Acceleration", motor, 254) + if self.protocol_version == 0: + self.write("Maximum_Acceleration", motor, maximum_acceleration) + self.write("Acceleration", motor, acceleration) @property def is_calibrated(self) -> bool: @@ -270,14 +270,15 @@ class FeetechMotorsBus(MotorsBus): return calibration - def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None: + def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None: for motor, calibration in calibration_dict.items(): if self.protocol_version == 0: self.write("Homing_Offset", motor, calibration.homing_offset) self.write("Min_Position_Limit", motor, calibration.range_min) self.write("Max_Position_Limit", motor, calibration.range_max) - self.calibration = calibration_dict + if cache: + self.calibration = calibration_dict def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]: """ diff --git a/src/lerobot/motors/feetech/tables.py b/src/lerobot/motors/feetech/tables.py index 0a2f2659f..48814957f 100644 --- a/src/lerobot/motors/feetech/tables.py +++ b/src/lerobot/motors/feetech/tables.py @@ -189,7 +189,7 @@ MODEL_RESOLUTION = { "scs_series": 1024, "sts3215": 4096, "sts3250": 4096, - "sm8512bl": 65536, + "sm8512bl": 4096, "scs0009": 1024, } diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index 7386bfb1c..26522c7c9 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -586,7 +586,7 @@ class MotorsBus(abc.ABC): pass @contextmanager - def torque_disabled(self): + def torque_disabled(self, motors: int | str | list[str] | None = None): """Context-manager that guarantees torque is re-enabled. This helper is useful to temporarily disable torque when configuring motors. @@ -596,11 +596,11 @@ class MotorsBus(abc.ABC): ... # Safe operations here ... pass """ - self.disable_torque() + self.disable_torque(motors) try: yield finally: - self.enable_torque() + self.enable_torque(motors) def set_timeout(self, timeout_ms: int | None = None): """Change the packet timeout used by the SDK. @@ -653,12 +653,13 @@ class MotorsBus(abc.ABC): pass @abc.abstractmethod - def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None: - """Write calibration parameters to the motors and cache them. + def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None: + """Write calibration parameters to the motors and optionally cache them. Args: calibration_dict (dict[str, MotorCalibration]): Calibration obtained from :pymeth:`read_calibration` or crafted by the user. + cache (bool, optional): Save the calibration to :pyattr:`calibration`. Defaults to True. """ pass diff --git a/src/lerobot/record.py b/src/lerobot/record.py index 635bdf1e4..9fc0dc7ed 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -57,6 +57,7 @@ from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + hope_jr, koch_follower, make_robot_from_config, so100_follower, @@ -65,6 +66,7 @@ from lerobot.robots import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, + homunculus, koch_leader, make_teleoperator_from_config, so100_leader, diff --git a/src/lerobot/replay.py b/src/lerobot/replay.py index ef20c28ef..c51c55cee 100644 --- a/src/lerobot/replay.py +++ b/src/lerobot/replay.py @@ -39,6 +39,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + hope_jr, koch_follower, make_robot_from_config, so100_follower, diff --git a/src/lerobot/robots/hope_jr/__init__.py b/src/lerobot/robots/hope_jr/__init__.py new file mode 100644 index 000000000..324e3c8e8 --- /dev/null +++ b/src/lerobot/robots/hope_jr/__init__.py @@ -0,0 +1,3 @@ +from .config_hope_jr import HopeJrArmConfig, HopeJrHandConfig +from .hope_jr_arm import HopeJrArm +from .hope_jr_hand import HopeJrHand diff --git a/src/lerobot/robots/hope_jr/config_hope_jr.py b/src/lerobot/robots/hope_jr/config_hope_jr.py new file mode 100644 index 000000000..747e98e01 --- /dev/null +++ b/src/lerobot/robots/hope_jr/config_hope_jr.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.cameras import CameraConfig + +from ..config import RobotConfig + + +@RobotConfig.register_subclass("hope_jr_hand") +@dataclass +class HopeJrHandConfig(RobotConfig): + port: str # Port to connect to the hand + side: str # "left" / "right" + + disable_torque_on_disconnect: bool = True + + cameras: dict[str, CameraConfig] = field(default_factory=dict) + + def __post_init__(self): + super().__post_init__() + if self.side not in ["right", "left"]: + raise ValueError(self.side) + + +@RobotConfig.register_subclass("hope_jr_arm") +@dataclass +class HopeJrArmConfig(RobotConfig): + port: str # Port to connect to the hand + disable_torque_on_disconnect: bool = True + + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + max_relative_target: int | None = None + + cameras: dict[str, CameraConfig] = field(default_factory=dict) diff --git a/src/lerobot/robots/hope_jr/hope_jr.mdx b/src/lerobot/robots/hope_jr/hope_jr.mdx new file mode 100644 index 000000000..2f9ec9d89 --- /dev/null +++ b/src/lerobot/robots/hope_jr/hope_jr.mdx @@ -0,0 +1,268 @@ +# HopeJR + +## Prerequisites + +- [Hardware Setup](https://github.com/TheRobotStudio/HOPEJr) + +## Install LeRobot + +Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot. + +Install LeRobot with HopeJR dependencies: +```bash +pip install -e ".[hopejr]" +``` + +## Device Configuration + +Before starting calibration and operation, you need to identify the USB ports for each HopeJR component. Run this script to find the USB ports for the arm, hand, glove, and exoskeleton: + +```bash +python -m lerobot.find_port +``` + +This will display the available USB ports and their associated devices. Make note of the port paths (e.g., `/dev/tty.usbmodem58760433331`, `/dev/tty.usbmodem11301`) as you'll need to specify them in the `--robot.port` and `--teleop.port` parameters when recording data, replaying episodes, or running teleoperation scripts. + +## Step 1: Calibration + +Before performing teleoperation, HopeJR's limbs need to be calibrated. Calibration files will be saved in `~/.cache/huggingface/lerobot/calibration` + +### 1.1 Calibrate Robot Hand + +```bash +python -m lerobot.calibrate \ + --robot.type=hope_jr_hand \ + --robot.port=/dev/tty.usbmodem58760432281 \ + --robot.id=blue \ + --robot.side=right +``` + +When running the calibration script, a calibration GUI will pop up. Finger joints are named as follows: + +**Thumb**: +- **CMC**: base joint connecting thumb to hand +- **MCP**: knuckle joint +- **PIP**: first finger joint +- **DIP** : fingertip joint + +**Index, Middle, Ring, and Pinky fingers**: +- **Radial flexor**: Moves base of finger towards the thumb +- **Ulnar flexor**: Moves base of finger towards the pinky +- **PIP/DIP**: Flexes the distal and proximal phalanx of the finger + +Each one of these will need to be calibrated individually via the GUI. + Note that ulnar and radial flexors should have ranges of the same size (but with different offsets) in order to get symmetric movement. + +

+ Setting boundaries in the hand calibration GUI + +

+ +Use the calibration interface to set the range boundaries for each joint as shown above. + +

+ Saving calibration values + +

+ +Once you have set the appropriate boundaries for all joints, click "Save" to save the calibration values to the motors. + +### 1.2 Calibrate Teleoperator Glove + +```bash +python -m lerobot.calibrate \ + --teleop.type=homunculus_glove \ + --teleop.port=/dev/tty.usbmodem11201 \ + --teleop.id=red \ + --teleop.side=right +``` + +Move each finger through its full range of motion, starting from the thumb. + +``` +Move thumb through its entire range of motion. +Recording positions. Press ENTER to stop... + +------------------------------------------- +NAME | MIN | POS | MAX +thumb_cmc | 1790 | 1831 | 1853 +thumb_mcp | 1497 | 1514 | 1528 +thumb_pip | 1466 | 1496 | 1515 +thumb_dip | 1463 | 1484 | 1514 +``` + +Continue with each finger: + +``` +Move middle through its entire range of motion. +Recording positions. Press ENTER to stop... + +------------------------------------------- +NAME | MIN | POS | MAX +middle_mcp_abduction | 1598 | 1718 | 1820 +middle_mcp_flexion | 1512 | 1658 | 2136 +middle_dip | 1484 | 1500 | 1547 +``` + +Once calibration is complete, the system will save the calibration to `/Users/your_username/.cache/huggingface/lerobot/calibration/teleoperators/homunculus_glove/red.json` + +### 1.3 Calibrate Robot Arm + +```bash +python -m lerobot.calibrate \ + --robot.type=hope_jr_arm \ + --robot.port=/dev/tty.usbserial-1110 \ + --robot.id=white +``` + +This will open a calibration GUI where you can set the range limits for each motor. The arm motions are organized as follows: +- **Shoulder**: pitch, yaw, and roll +- **Elbow**: flex +- **Wrist**: pitch, yaw, and roll + +

+ Setting boundaries in the arm calibration GUI + +

+ +Use the calibration interface to set the range boundaries for each joint. Move each joint through its full range of motion and adjust the minimum and maximum values accordingly. Once you have set the appropriate boundaries for all joints, save the calibration. + +### 1.4 Calibrate Teleoperator Exoskeleton + +```bash +python -m lerobot.calibrate \ + --teleop.type=homunculus_arm \ + --teleop.port=/dev/tty.usbmodem11201 \ + --teleop.id=black +``` + +The exoskeleton allows one to control the robot arm. During calibration, you'll be prompted to move all joints through their full range of motion: + +``` +Move all joints through their entire range of motion. +Recording positions. Press ENTER to stop... + +------------------------------------------- +------------------------------------------- +NAME | MIN | POS | MAX +shoulder_pitch | 586 | 736 | 895 +shoulder_yaw | 1257 | 1374 | 1390 +shoulder_roll | 449 | 1034 | 2564 +elbow_flex | 3023 | 3117 | 3134 +wrist_roll | 3073 | 3096 | 3147 +wrist_yaw | 2143 | 2171 | 2185 +wrist_pitch | 1975 | 1993 | 2074 +Calibration saved to /Users/your_username/.cache/huggingface/lerobot/calibration/teleoperators/homunculus_arm/black.json +``` + +## Step 2: Teleoperation + +Due to global variable conflicts in the Feetech middleware, teleoperation for arm and hand must run in separate shell sessions: + +### Hand +```bash +python -m lerobot.teleoperate \ + --robot.type=hope_jr_hand \ + --robot.port=/dev/tty.usbmodem58760432281 \ + --robot.id=blue \ + --robot.side=right \ + --teleop.type=homunculus_glove \ + --teleop.port=/dev/tty.usbmodem11201 \ + --teleop.id=red \ + --teleop.side=right \ + --display_data=true \ + --fps=30 +``` + +### Arm +```bash +python -m lerobot.teleoperate \ + --robot.type=hope_jr_arm \ + --robot.port=/dev/tty.usbserial-1110 \ + --robot.id=white \ + --teleop.type=homunculus_arm \ + --teleop.port=/dev/tty.usbmodem11201 \ + --teleop.id=black \ + --display_data=true \ + --fps=30 +``` + +## Step 3: Record, Replay, Train + +Record, Replay and Train with Hope-JR is still experimental. + +### Record + +This step records the dataset, which can be seen as an example [here](https://huggingface.co/datasets/nepyope/hand_record_test_with_video_data/settings). + +```bash +python -m lerobot.record \ + --robot.type=hope_jr_hand \ + --robot.port=/dev/tty.usbmodem58760432281 \ + --robot.id=right \ + --robot.side=right \ + --robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \ + --teleop.type=homunculus_glove \ + --teleop.port=/dev/tty.usbmodem1201 \ + --teleop.id=right \ + --teleop.side=right \ + --dataset.repo_id=nepyope/hand_record_test_with_video_data \ + --dataset.single_task="Hand recording test with video data" \ + --dataset.num_episodes=1 \ + --dataset.episode_time_s=5 \ + --dataset.push_to_hub=true \ + --dataset.private=true \ + --display_data=true +``` + +### Replay + +```bash +python -m lerobot.replay \ + --robot.type=hope_jr_hand \ + --robot.port=/dev/tty.usbmodem58760432281 \ + --robot.id=right \ + --robot.side=right \ + --dataset.repo_id=nepyope/hand_record_test_with_camera \ + --dataset.episode=0 +``` + +### Train + +```bash +python -m lerobot.scripts.train \ + --dataset.repo_id=nepyope/hand_record_test_with_video_data \ + --policy.type=act \ + --output_dir=outputs/train/hopejr_hand \ + --job_name=hopejr \ + --policy.device=mps \ + --wandb.enable=true \ + --policy.repo_id=nepyope/hand_test_policy +``` + +### Evaluate + +This training run can be viewed as an example [here](https://wandb.ai/tino/lerobot/runs/rp0k8zvw?nw=nwusertino). + +```bash +python -m lerobot.record \ + --robot.type=hope_jr_hand \ + --robot.port=/dev/tty.usbmodem58760432281 \ + --robot.id=right \ + --robot.side=right \ + --robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \ + --display_data=false \ + --dataset.repo_id=nepyope/eval_hopejr \ + --dataset.single_task="Evaluate hopejr hand policy" \ + --dataset.num_episodes=10 \ + --policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model +``` diff --git a/src/lerobot/robots/hope_jr/hope_jr_arm.py b/src/lerobot/robots/hope_jr/hope_jr_arm.py new file mode 100644 index 000000000..0e3a615a9 --- /dev/null +++ b/src/lerobot/robots/hope_jr/hope_jr_arm.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from functools import cached_property +from typing import Any + +from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.motors import Motor, MotorNormMode +from lerobot.motors.calibration_gui import RangeFinderGUI +from lerobot.motors.feetech import ( + FeetechMotorsBus, +) + +from ..robot import Robot +from ..utils import ensure_safe_goal_position +from .config_hope_jr import HopeJrArmConfig + +logger = logging.getLogger(__name__) + + +class HopeJrArm(Robot): + config_class = HopeJrArmConfig + name = "hope_jr_arm" + + def __init__(self, config: HopeJrArmConfig): + super().__init__(config) + self.config = config + self.bus = FeetechMotorsBus( + port=self.config.port, + motors={ + "shoulder_pitch": Motor(1, "sm8512bl", MotorNormMode.RANGE_M100_100), + "shoulder_yaw": Motor(2, "sts3250", MotorNormMode.RANGE_M100_100), + "shoulder_roll": Motor(3, "sts3250", MotorNormMode.RANGE_M100_100), + "elbow_flex": Motor(4, "sts3250", MotorNormMode.RANGE_M100_100), + "wrist_roll": Motor(5, "sts3250", MotorNormMode.RANGE_M100_100), + "wrist_yaw": Motor(6, "sts3250", MotorNormMode.RANGE_M100_100), + "wrist_pitch": Motor(7, "sts3250", MotorNormMode.RANGE_M100_100), + }, + calibration=self.calibration, + ) + self.cameras = make_cameras_from_configs(config.cameras) + + # HACK + self.shoulder_pitch = "shoulder_pitch" + self.other_motors = [m for m in self.bus.motors if m != "shoulder_pitch"] + + @property + def _motors_ft(self) -> dict[str, type]: + return {f"{motor}.pos": float for motor in self.bus.motors} + + @property + def _cameras_ft(self) -> dict[str, tuple]: + return { + cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._motors_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + return self._motors_ft + + @property + def is_connected(self) -> bool: + return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + + def connect(self, calibrate: bool = True) -> None: + """ + We assume that at connection time, arm is in a rest position, + and torque can be safely disabled to run calibration. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + self.bus.connect(handshake=False) + if not self.is_calibrated and calibrate: + self.calibrate() + + # Connect the cameras + for cam in self.cameras.values(): + cam.connect() + + self.configure() + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.bus.is_calibrated + + def calibrate(self, limb_name: str = None) -> None: + groups = { + "all": list(self.bus.motors.keys()), + "shoulder": ["shoulder_pitch", "shoulder_yaw", "shoulder_roll"], + "elbow": ["elbow_flex"], + "wrist": ["wrist_roll", "wrist_yaw", "wrist_pitch"], + } + + self.calibration = RangeFinderGUI(self.bus, groups).run() + self._save_calibration() + print("Calibration saved to", self.calibration_fpath) + + def configure(self) -> None: + with self.bus.torque_disabled(): + self.bus.configure_motors(maximum_acceleration=30, acceleration=30) + + def setup_motors(self) -> None: + # TODO: add docstring + for motor in reversed(self.bus.motors): + input(f"Connect the controller board to the '{motor}' motor only and press enter.") + self.bus.setup_motor(motor) + print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + + def get_observation(self) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Read arm position + start = time.perf_counter() + obs_dict = self.bus.sync_read("Present_Position", self.other_motors) + obs_dict[self.shoulder_pitch] = self.bus.read("Present_Position", self.shoulder_pitch) + obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()} + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read state: {dt_ms:.1f}ms") + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + start = time.perf_counter() + obs_dict[cam_key] = cam.async_read() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + + return obs_dict + + def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} + + # Cap goal position when too far away from present position. + # /!\ Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = self.bus.sync_read("Present_Position") + goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()} + goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target) + + self.bus.sync_write("Goal_Position", goal_pos) + return {f"{motor}.pos": val for motor, val in goal_pos.items()} + + def disconnect(self): + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self.bus.disconnect(self.config.disable_torque_on_disconnect) + for cam in self.cameras.values(): + cam.disconnect() + + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/robots/hope_jr/hope_jr_hand.py b/src/lerobot/robots/hope_jr/hope_jr_hand.py new file mode 100644 index 000000000..8dc100e06 --- /dev/null +++ b/src/lerobot/robots/hope_jr/hope_jr_hand.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from functools import cached_property +from typing import Any + +from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.motors import Motor, MotorNormMode +from lerobot.motors.calibration_gui import RangeFinderGUI +from lerobot.motors.feetech import ( + FeetechMotorsBus, +) + +from ..robot import Robot +from .config_hope_jr import HopeJrHandConfig + +logger = logging.getLogger(__name__) + +RIGHT_HAND_INVERSIONS = [ + "thumb_mcp", + "thumb_dip", + "index_ulnar_flexor", + "middle_ulnar_flexor", + "ring_ulnar_flexor", + "ring_pip_dip", + "pinky_ulnar_flexor", + "pinky_pip_dip", +] + +LEFT_HAND_INVERSIONS = [ + "thumb_cmc", + "thumb_mcp", + "thumb_dip", + "index_radial_flexor", + "index_pip_dip", + "middle_radial_flexor", + "middle_pip_dip", + "ring_radial_flexor", + "ring_pip_dip", + "pinky_radial_flexor", + # "pinky_pip_dip", +] + + +class HopeJrHand(Robot): + config_class = HopeJrHandConfig + name = "hope_jr_hand" + + def __init__(self, config: HopeJrHandConfig): + super().__init__(config) + self.config = config + self.bus = FeetechMotorsBus( + port=self.config.port, + motors={ + # Thumb + "thumb_cmc": Motor(1, "scs0009", MotorNormMode.RANGE_0_100), + "thumb_mcp": Motor(2, "scs0009", MotorNormMode.RANGE_0_100), + "thumb_pip": Motor(3, "scs0009", MotorNormMode.RANGE_0_100), + "thumb_dip": Motor(4, "scs0009", MotorNormMode.RANGE_0_100), + # Index + "index_radial_flexor": Motor(5, "scs0009", MotorNormMode.RANGE_0_100), + "index_ulnar_flexor": Motor(6, "scs0009", MotorNormMode.RANGE_0_100), + "index_pip_dip": Motor(7, "scs0009", MotorNormMode.RANGE_0_100), + # Middle + "middle_radial_flexor": Motor(8, "scs0009", MotorNormMode.RANGE_0_100), + "middle_ulnar_flexor": Motor(9, "scs0009", MotorNormMode.RANGE_0_100), + "middle_pip_dip": Motor(10, "scs0009", MotorNormMode.RANGE_0_100), + # Ring + "ring_radial_flexor": Motor(11, "scs0009", MotorNormMode.RANGE_0_100), + "ring_ulnar_flexor": Motor(12, "scs0009", MotorNormMode.RANGE_0_100), + "ring_pip_dip": Motor(13, "scs0009", MotorNormMode.RANGE_0_100), + # Pinky + "pinky_radial_flexor": Motor(14, "scs0009", MotorNormMode.RANGE_0_100), + "pinky_ulnar_flexor": Motor(15, "scs0009", MotorNormMode.RANGE_0_100), + "pinky_pip_dip": Motor(16, "scs0009", MotorNormMode.RANGE_0_100), + }, + calibration=self.calibration, + protocol_version=1, + ) + self.cameras = make_cameras_from_configs(config.cameras) + self.inverted_motors = RIGHT_HAND_INVERSIONS if config.side == "right" else LEFT_HAND_INVERSIONS + + @property + def _motors_ft(self) -> dict[str, type]: + return {f"{motor}.pos": float for motor in self.bus.motors} + + @property + def _cameras_ft(self) -> dict[str, tuple]: + return { + cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._motors_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + return self._motors_ft + + @property + def is_connected(self) -> bool: + return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + + def connect(self, calibrate: bool = True) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + self.bus.connect() + if not self.is_calibrated and calibrate: + self.calibrate() + + # Connect the cameras + for cam in self.cameras.values(): + cam.connect() + + self.configure() + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.bus.is_calibrated + + def calibrate(self) -> None: + fingers = {} + for finger in ["thumb", "index", "middle", "ring", "pinky"]: + fingers[finger] = [motor for motor in self.bus.motors if motor.startswith(finger)] + + self.calibration = RangeFinderGUI(self.bus, fingers).run() + for motor in self.inverted_motors: + self.calibration[motor].drive_mode = 1 + self._save_calibration() + print("Calibration saved to", self.calibration_fpath) + + def configure(self) -> None: + with self.bus.torque_disabled(): + self.bus.configure_motors() + + def setup_motors(self) -> None: + # TODO: add docstring + for motor in self.bus.motors: + input(f"Connect the controller board to the '{motor}' motor only and press enter.") + self.bus.setup_motor(motor) + print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + + def get_observation(self) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + obs_dict = {} + + # Read hand position + start = time.perf_counter() + for motor in self.bus.motors: + obs_dict[f"{motor}.pos"] = self.bus.read("Present_Position", motor) + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read state: {dt_ms:.1f}ms") + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + start = time.perf_counter() + obs_dict[cam_key] = cam.async_read() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + + return obs_dict + + def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} + self.bus.sync_write("Goal_Position", goal_pos) + return action + + def disconnect(self): + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self.bus.disconnect(self.config.disable_torque_on_disconnect) + for cam in self.cameras.values(): + cam.disconnect() + + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index 435303c6e..911d40465 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -49,6 +49,14 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .viperx import ViperX return ViperX(config) + elif config.type == "hope_jr_hand": + from .hope_jr import HopeJrHand + + return HopeJrHand(config) + elif config.type == "hope_jr_arm": + from .hope_jr import HopeJrArm + + return HopeJrArm(config) elif config.type == "mock_robot": from tests.mocks.mock_robot import MockRobot diff --git a/src/lerobot/teleoperate.py b/src/lerobot/teleoperate.py index e2819345b..168f898c4 100644 --- a/src/lerobot/teleoperate.py +++ b/src/lerobot/teleoperate.py @@ -43,6 +43,7 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + hope_jr, koch_follower, make_robot_from_config, so100_follower, @@ -52,6 +53,7 @@ from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, gamepad, + homunculus, koch_leader, make_teleoperator_from_config, so100_leader, diff --git a/src/lerobot/teleoperators/homunculus/__init__.py b/src/lerobot/teleoperators/homunculus/__init__.py new file mode 100644 index 000000000..04b5c0f2b --- /dev/null +++ b/src/lerobot/teleoperators/homunculus/__init__.py @@ -0,0 +1,4 @@ +from .config_homunculus import HomunculusArmConfig, HomunculusGloveConfig +from .homunculus_arm import HomunculusArm +from .homunculus_glove import HomunculusGlove +from .joints_translation import homunculus_glove_to_hope_jr_hand diff --git a/src/lerobot/teleoperators/homunculus/config_homunculus.py b/src/lerobot/teleoperators/homunculus/config_homunculus.py new file mode 100644 index 000000000..da465215a --- /dev/null +++ b/src/lerobot/teleoperators/homunculus/config_homunculus.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("homunculus_glove") +@dataclass +class HomunculusGloveConfig(TeleoperatorConfig): + port: str # Port to connect to the glove + side: str # "left" / "right" + baud_rate: int = 115_200 + + def __post_init__(self): + if self.side not in ["right", "left"]: + raise ValueError(self.side) + + +@TeleoperatorConfig.register_subclass("homunculus_arm") +@dataclass +class HomunculusArmConfig(TeleoperatorConfig): + port: str # Port to connect to the arm + baud_rate: int = 115_200 diff --git a/src/lerobot/teleoperators/homunculus/homunculus_arm.py b/src/lerobot/teleoperators/homunculus/homunculus_arm.py new file mode 100644 index 000000000..dfce0c88e --- /dev/null +++ b/src/lerobot/teleoperators/homunculus/homunculus_arm.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import threading +from collections import deque +from pprint import pformat +from typing import Deque, Dict, Optional + +import serial + +from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.motors.motors_bus import MotorCalibration, MotorNormMode +from lerobot.utils.utils import enter_pressed, move_cursor_up + +from ..teleoperator import Teleoperator +from .config_homunculus import HomunculusArmConfig + +logger = logging.getLogger(__name__) + + +class HomunculusArm(Teleoperator): + """ + Homunculus Arm designed by Hugging Face. + """ + + config_class = HomunculusArmConfig + name = "homunculus_arm" + + def __init__(self, config: HomunculusArmConfig): + super().__init__(config) + self.config = config + self.serial = serial.Serial(config.port, config.baud_rate, timeout=1) + self.serial_lock = threading.Lock() + + self.joints = { + "shoulder_pitch": MotorNormMode.RANGE_M100_100, + "shoulder_yaw": MotorNormMode.RANGE_M100_100, + "shoulder_roll": MotorNormMode.RANGE_M100_100, + "elbow_flex": MotorNormMode.RANGE_M100_100, + "wrist_roll": MotorNormMode.RANGE_M100_100, + "wrist_yaw": MotorNormMode.RANGE_M100_100, + "wrist_pitch": MotorNormMode.RANGE_M100_100, + } + n = 50 + # EMA parameters --------------------------------------------------- + self.n: int = n + self.alpha: float = 2 / (n + 1) + # one deque *per joint* so we can inspect raw history if needed + self._buffers: Dict[str, Deque[int]] = { + joint: deque(maxlen=n) + for joint in ( + "shoulder_pitch", + "shoulder_yaw", + "shoulder_roll", + "elbow_flex", + "wrist_roll", + "wrist_yaw", + "wrist_pitch", + ) + } + # running EMA value per joint – lazily initialised on first read + self._ema: Dict[str, Optional[float]] = dict.fromkeys(self._buffers) + + self._state: dict[str, float] | None = None + self.new_state_event = threading.Event() + self.stop_event = threading.Event() + self.thread = threading.Thread(target=self._read_loop, daemon=True, name=f"{self} _read_loop") + self.state_lock = threading.Lock() + + @property + def action_features(self) -> dict: + return {f"{joint}.pos": float for joint in self.joints} + + @property + def feedback_features(self) -> dict: + return {} + + @property + def is_connected(self) -> bool: + with self.serial_lock: + return self.serial.is_open and self.thread.is_alive() + + def connect(self, calibrate: bool = True) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + if not self.serial.is_open: + self.serial.open() + self.thread.start() + + # wait for the thread to ramp up & 1st state to be ready + if not self.new_state_event.wait(timeout=2): + raise TimeoutError(f"{self}: Timed out waiting for state after 2s.") + + if not self.is_calibrated and calibrate: + self.calibrate() + + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.calibration_fpath.is_file() + + def calibrate(self) -> None: + print( + "\nMove all joints through their entire range of motion." + "\nRecording positions. Press ENTER to stop..." + ) + range_mins, range_maxes = self._record_ranges_of_motion() + + self.calibration = {} + for id_, joint in enumerate(self.joints): + self.calibration[joint] = MotorCalibration( + id=id_, + drive_mode=0, + homing_offset=0, + range_min=range_mins[joint], + range_max=range_maxes[joint], + ) + + self._save_calibration() + print("Calibration saved to", self.calibration_fpath) + + # TODO(Steven): This function is copy/paste from the `HomunculusGlove` class. Consider moving it to an utility to reduce duplicated code. + def _record_ranges_of_motion( + self, joints: list[str] | None = None, display_values: bool = True + ) -> tuple[dict[str, int], dict[str, int]]: + """Interactively record the min/max encoder values of each joint. + + Move the joints while the method streams live positions. Press :kbd:`Enter` to finish. + + Args: + joints (list[str] | None, optional): Joints to record. Defaults to every joint (`None`). + display_values (bool, optional): When `True` (default) a live table is printed to the console. + + Raises: + TypeError: `joints` is not `None` or a list. + ValueError: any joint's recorded min and max are the same. + + Returns: + tuple[dict[str, int], dict[str, int]]: Two dictionaries *mins* and *maxes* with the extreme values + observed for each joint. + """ + if joints is None: + joints = list(self.joints) + elif not isinstance(joints, list): + raise TypeError(joints) + + display_len = max(len(key) for key in joints) + + start_positions = self._read(joints, normalize=False) + mins = start_positions.copy() + maxes = start_positions.copy() + + user_pressed_enter = False + while not user_pressed_enter: + positions = self._read(joints, normalize=False) + mins = {joint: int(min(positions[joint], min_)) for joint, min_ in mins.items()} + maxes = {joint: int(max(positions[joint], max_)) for joint, max_ in maxes.items()} + + if display_values: + print("\n-------------------------------------------") + print(f"{'NAME':<{display_len}} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}") + for joint in joints: + print( + f"{joint:<{display_len}} | {mins[joint]:>6} | {positions[joint]:>6} | {maxes[joint]:>6}" + ) + + if enter_pressed(): + user_pressed_enter = True + + if display_values and not user_pressed_enter: + # Move cursor up to overwrite the previous output + move_cursor_up(len(joints) + 3) + + same_min_max = [joint for joint in joints if mins[joint] == maxes[joint]] + if same_min_max: + raise ValueError(f"Some joints have the same min and max values:\n{pformat(same_min_max)}") + + return mins, maxes + + def configure(self) -> None: + pass + + # TODO(Steven): This function is copy/paste from the `HomunculusGlove` class. Consider moving it to an utility to reduce duplicated code. + def _normalize(self, values: dict[str, int]) -> dict[str, float]: + if not self.calibration: + raise RuntimeError(f"{self} has no calibration registered.") + + normalized_values = {} + for joint, val in values.items(): + min_ = self.calibration[joint].range_min + max_ = self.calibration[joint].range_max + drive_mode = self.calibration[joint].drive_mode + bounded_val = min(max_, max(min_, val)) + + if self.joints[joint] is MotorNormMode.RANGE_M100_100: + norm = (((bounded_val - min_) / (max_ - min_)) * 200) - 100 + normalized_values[joint] = -norm if drive_mode else norm + elif self.joints[joint] is MotorNormMode.RANGE_0_100: + norm = ((bounded_val - min_) / (max_ - min_)) * 100 + normalized_values[joint] = 100 - norm if drive_mode else norm + + return normalized_values + + def _apply_ema(self, raw: Dict[str, int]) -> Dict[str, float]: + """Update buffers & running EMA values; return smoothed dict.""" + smoothed: Dict[str, float] = {} + for joint, value in raw.items(): + # maintain raw history + self._buffers[joint].append(value) + + # initialise on first run + if self._ema[joint] is None: + self._ema[joint] = float(value) + else: + self._ema[joint] = self.alpha * value + (1 - self.alpha) * self._ema[joint] + + smoothed[joint] = self._ema[joint] + return smoothed + + def _read( + self, joints: list[str] | None = None, normalize: bool = True, timeout: float = 1 + ) -> dict[str, int | float]: + """ + Return the most recent (single) values from self.last_d, + optionally applying calibration. + """ + if not self.new_state_event.wait(timeout=timeout): + raise TimeoutError(f"{self}: Timed out waiting for state after {timeout}s.") + + with self.state_lock: + state = self._state + + self.new_state_event.clear() + + if state is None: + raise RuntimeError(f"{self} Internal error: Event set but no state available.") + + if joints is not None: + state = {k: v for k, v in state.items() if k in joints} + + if normalize: + state = self._normalize(state) + + state = self._apply_ema(state) + + return state + + def _read_loop(self): + """ + Continuously read from the serial buffer in its own thread and sends values to the main thread through + a queue. + """ + while not self.stop_event.is_set(): + try: + raw_values = None + with self.serial_lock: + if self.serial.in_waiting > 0: + self.serial.flush() + raw_values = self.serial.readline().decode("utf-8").strip().split(" ") + if raw_values is None or len(raw_values) != 21: # 16 raw + 5 angle values + continue + + joint_angles = { + "shoulder_pitch": int(raw_values[19]), + "shoulder_yaw": int(raw_values[18]), + "shoulder_roll": int(raw_values[20]), + "elbow_flex": int(raw_values[17]), + "wrist_roll": int(raw_values[16]), + "wrist_yaw": int(raw_values[1]), + "wrist_pitch": int(raw_values[0]), + } + + with self.state_lock: + self._state = joint_angles + self.new_state_event.set() + + except Exception as e: + logger.debug(f"Error reading frame in background thread for {self}: {e}") + + def get_action(self) -> dict[str, float]: + joint_positions = self._read() + return {f"{joint}.pos": pos for joint, pos in joint_positions.items()} + + def send_feedback(self, feedback: dict[str, float]) -> None: + raise NotImplementedError + + def disconnect(self) -> None: + if not self.is_connected: + DeviceNotConnectedError(f"{self} is not connected.") + + self.stop_event.set() + self.thread.join(timeout=1) + self.serial.close() + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/teleoperators/homunculus/homunculus_glove.py b/src/lerobot/teleoperators/homunculus/homunculus_glove.py new file mode 100644 index 000000000..d367a2a7c --- /dev/null +++ b/src/lerobot/teleoperators/homunculus/homunculus_glove.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import threading +from collections import deque +from pprint import pformat +from typing import Deque, Dict, Optional + +import serial + +from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.motors import MotorCalibration +from lerobot.motors.motors_bus import MotorNormMode +from lerobot.teleoperators.homunculus.joints_translation import homunculus_glove_to_hope_jr_hand +from lerobot.utils.utils import enter_pressed, move_cursor_up + +from ..teleoperator import Teleoperator +from .config_homunculus import HomunculusGloveConfig + +logger = logging.getLogger(__name__) + +LEFT_HAND_INVERSIONS = [ + "thumb_cmc", + "index_dip", + "middle_mcp_abduction", + "middle_dip", + "pinky_mcp_abduction", + "pinky_dip", +] + +RIGHT_HAND_INVERSIONS = [ + "thumb_mcp", + "thumb_cmc", + "thumb_pip", + "thumb_dip", + "index_mcp_abduction", + # "index_dip", + "middle_mcp_abduction", + # "middle_dip", + "ring_mcp_abduction", + "ring_mcp_flexion", + # "ring_dip", + "pinky_mcp_abduction", +] + + +class HomunculusGlove(Teleoperator): + """ + Homunculus Glove designed by NepYope & Hugging Face. + """ + + config_class = HomunculusGloveConfig + name = "homunculus_glove" + + def __init__(self, config: HomunculusGloveConfig): + super().__init__(config) + self.config = config + self.serial = serial.Serial(config.port, config.baud_rate, timeout=1) + self.serial_lock = threading.Lock() + + self.joints = { + "thumb_cmc": MotorNormMode.RANGE_0_100, + "thumb_mcp": MotorNormMode.RANGE_0_100, + "thumb_pip": MotorNormMode.RANGE_0_100, + "thumb_dip": MotorNormMode.RANGE_0_100, + "index_mcp_abduction": MotorNormMode.RANGE_M100_100, + "index_mcp_flexion": MotorNormMode.RANGE_0_100, + "index_dip": MotorNormMode.RANGE_0_100, + "middle_mcp_abduction": MotorNormMode.RANGE_M100_100, + "middle_mcp_flexion": MotorNormMode.RANGE_0_100, + "middle_dip": MotorNormMode.RANGE_0_100, + "ring_mcp_abduction": MotorNormMode.RANGE_M100_100, + "ring_mcp_flexion": MotorNormMode.RANGE_0_100, + "ring_dip": MotorNormMode.RANGE_0_100, + "pinky_mcp_abduction": MotorNormMode.RANGE_M100_100, + "pinky_mcp_flexion": MotorNormMode.RANGE_0_100, + "pinky_dip": MotorNormMode.RANGE_0_100, + } + self.inverted_joints = RIGHT_HAND_INVERSIONS if config.side == "right" else LEFT_HAND_INVERSIONS + + n = 10 + # EMA parameters --------------------------------------------------- + self.n: int = n + self.alpha: float = 2 / (n + 1) + # one deque *per joint* so we can inspect raw history if needed + self._buffers: Dict[str, Deque[int]] = {joint: deque(maxlen=n) for joint in self.joints} + # running EMA value per joint – lazily initialised on first read + self._ema: Dict[str, Optional[float]] = dict.fromkeys(self._buffers) + + self._state: dict[str, float] | None = None + self.new_state_event = threading.Event() + self.stop_event = threading.Event() + self.thread = threading.Thread(target=self._read_loop, daemon=True, name=f"{self} _read_loop") + self.state_lock = threading.Lock() + + @property + def action_features(self) -> dict: + return {f"{joint}.pos": float for joint in self.joints} + + @property + def feedback_features(self) -> dict: + return {} + + @property + def is_connected(self) -> bool: + with self.serial_lock: + return self.serial.is_open and self.thread.is_alive() + + def connect(self, calibrate: bool = True) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + if not self.serial.is_open: + self.serial.open() + self.thread.start() + + # wait for the thread to ramp up & 1st state to be ready + if not self.new_state_event.wait(timeout=2): + raise TimeoutError(f"{self}: Timed out waiting for state after 2s.") + + if not self.is_calibrated and calibrate: + self.calibrate() + + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.calibration_fpath.is_file() + + def calibrate(self) -> None: + range_mins, range_maxes = {}, {} + for finger in ["thumb", "index", "middle", "ring", "pinky"]: + print( + f"\nMove {finger} through its entire range of motion." + "\nRecording positions. Press ENTER to stop..." + ) + finger_joints = [joint for joint in self.joints if joint.startswith(finger)] + finger_mins, finger_maxes = self._record_ranges_of_motion(finger_joints) + range_mins.update(finger_mins) + range_maxes.update(finger_maxes) + + self.calibration = {} + for id_, joint in enumerate(self.joints): + self.calibration[joint] = MotorCalibration( + id=id_, + drive_mode=1 if joint in self.inverted_joints else 0, + homing_offset=0, + range_min=range_mins[joint], + range_max=range_maxes[joint], + ) + + self._save_calibration() + print("Calibration saved to", self.calibration_fpath) + + # TODO(Steven): This function is copy/paste from the `HomunculusArm` class. Consider moving it to an utility to reduce duplicated code. + def _record_ranges_of_motion( + self, joints: list[str] | None = None, display_values: bool = True + ) -> tuple[dict[str, int], dict[str, int]]: + """Interactively record the min/max encoder values of each joint. + + Move the joints while the method streams live positions. Press :kbd:`Enter` to finish. + + Args: + joints (list[str] | None, optional): Joints to record. Defaults to every joint (`None`). + display_values (bool, optional): When `True` (default) a live table is printed to the console. + + Raises: + TypeError: `joints` is not `None` or a list. + ValueError: any joint's recorded min and max are the same. + + Returns: + tuple[dict[str, int], dict[str, int]]: Two dictionaries *mins* and *maxes* with the extreme values + observed for each joint. + """ + if joints is None: + joints = list(self.joints) + elif not isinstance(joints, list): + raise TypeError(joints) + + display_len = max(len(key) for key in joints) + + start_positions = self._read(joints, normalize=False) + mins = start_positions.copy() + maxes = start_positions.copy() + + user_pressed_enter = False + while not user_pressed_enter: + positions = self._read(joints, normalize=False) + mins = {joint: int(min(positions[joint], min_)) for joint, min_ in mins.items()} + maxes = {joint: int(max(positions[joint], max_)) for joint, max_ in maxes.items()} + + if display_values: + print("\n-------------------------------------------") + print(f"{'NAME':<{display_len}} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}") + for joint in joints: + print( + f"{joint:<{display_len}} | {mins[joint]:>6} | {positions[joint]:>6} | {maxes[joint]:>6}" + ) + + if enter_pressed(): + user_pressed_enter = True + + if display_values and not user_pressed_enter: + # Move cursor up to overwrite the previous output + move_cursor_up(len(joints) + 3) + + same_min_max = [joint for joint in joints if mins[joint] == maxes[joint]] + if same_min_max: + raise ValueError(f"Some joints have the same min and max values:\n{pformat(same_min_max)}") + + return mins, maxes + + def configure(self) -> None: + pass + + # TODO(Steven): This function is copy/paste from the `HomunculusArm` class. Consider moving it to an utility to reduce duplicated code. + def _normalize(self, values: dict[str, int]) -> dict[str, float]: + if not self.calibration: + raise RuntimeError(f"{self} has no calibration registered.") + + normalized_values = {} + for joint, val in values.items(): + min_ = self.calibration[joint].range_min + max_ = self.calibration[joint].range_max + drive_mode = self.calibration[joint].drive_mode + bounded_val = min(max_, max(min_, val)) + + if self.joints[joint] is MotorNormMode.RANGE_M100_100: + norm = (((bounded_val - min_) / (max_ - min_)) * 200) - 100 + normalized_values[joint] = -norm if drive_mode else norm + elif self.joints[joint] is MotorNormMode.RANGE_0_100: + norm = ((bounded_val - min_) / (max_ - min_)) * 100 + normalized_values[joint] = 100 - norm if drive_mode else norm + + return normalized_values + + def _apply_ema(self, raw: Dict[str, int]) -> Dict[str, int]: + """Update buffers & running EMA values; return smoothed dict as integers.""" + smoothed: Dict[str, int] = {} + for joint, value in raw.items(): + # maintain raw history + self._buffers[joint].append(value) + + # initialise on first run + if self._ema[joint] is None: + self._ema[joint] = float(value) + else: + self._ema[joint] = self.alpha * value + (1 - self.alpha) * self._ema[joint] + + # Convert back to int for compatibility with normalization + smoothed[joint] = int(round(self._ema[joint])) + return smoothed + + def _read( + self, joints: list[str] | None = None, normalize: bool = True, timeout: float = 1 + ) -> dict[str, int | float]: + """ + Return the most recent (single) values from self.last_d, + optionally applying calibration. + """ + if not self.new_state_event.wait(timeout=timeout): + raise TimeoutError(f"{self}: Timed out waiting for state after {timeout}s.") + + with self.state_lock: + state = self._state + + self.new_state_event.clear() + + if state is None: + raise RuntimeError(f"{self} Internal error: Event set but no state available.") + + if joints is not None: + state = {k: v for k, v in state.items() if k in joints} + + # Apply EMA smoothing to raw values first + state = self._apply_ema(state) + + # Then normalize if requested + if normalize: + state = self._normalize(state) + + return state + + def _read_loop(self): + """ + Continuously read from the serial buffer in its own thread and sends values to the main thread through + a queue. + """ + while not self.stop_event.is_set(): + try: + positions = None + with self.serial_lock: + if self.serial.in_waiting > 0: + self.serial.flush() + positions = self.serial.readline().decode("utf-8").strip().split(" ") + if positions is None or len(positions) != len(self.joints): + continue + + joint_positions = {joint: int(pos) for joint, pos in zip(self.joints, positions, strict=True)} + + with self.state_lock: + self._state = joint_positions + self.new_state_event.set() + + except Exception as e: + logger.debug(f"Error reading frame in background thread for {self}: {e}") + + def get_action(self) -> dict[str, float]: + joint_positions = self._read() + return homunculus_glove_to_hope_jr_hand( + {f"{joint}.pos": pos for joint, pos in joint_positions.items()} + ) + + def send_feedback(self, feedback: dict[str, float]) -> None: + raise NotImplementedError + + def disconnect(self) -> None: + if not self.is_connected: + DeviceNotConnectedError(f"{self} is not connected.") + + self.stop_event.set() + self.thread.join(timeout=1) + self.serial.close() + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/teleoperators/homunculus/joints_translation.py b/src/lerobot/teleoperators/homunculus/joints_translation.py new file mode 100644 index 000000000..f14f7b3ef --- /dev/null +++ b/src/lerobot/teleoperators/homunculus/joints_translation.py @@ -0,0 +1,63 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +INDEX_SPLAY = 0.3 +MIDDLE_SPLAY = 0.3 +RING_SPLAY = 0.3 +PINKY_SPLAY = 0.5 + + +def get_ulnar_flexion(flexion: float, abduction: float, splay: float): + return -abduction * splay + flexion * (1 - splay) + + +def get_radial_flexion(flexion: float, abduction: float, splay: float): + return abduction * splay + flexion * (1 - splay) + + +def homunculus_glove_to_hope_jr_hand(glove_action: dict[str, float]) -> dict[str, float]: + return { + "thumb_cmc.pos": glove_action["thumb_cmc.pos"], + "thumb_mcp.pos": glove_action["thumb_mcp.pos"], + "thumb_pip.pos": glove_action["thumb_pip.pos"], + "thumb_dip.pos": glove_action["thumb_dip.pos"], + "index_radial_flexor.pos": get_radial_flexion( + glove_action["index_mcp_flexion.pos"], glove_action["index_mcp_abduction.pos"], INDEX_SPLAY + ), + "index_ulnar_flexor.pos": get_ulnar_flexion( + glove_action["index_mcp_flexion.pos"], glove_action["index_mcp_abduction.pos"], INDEX_SPLAY + ), + "index_pip_dip.pos": glove_action["index_dip.pos"], + "middle_radial_flexor.pos": get_radial_flexion( + glove_action["middle_mcp_flexion.pos"], glove_action["middle_mcp_abduction.pos"], MIDDLE_SPLAY + ), + "middle_ulnar_flexor.pos": get_ulnar_flexion( + glove_action["middle_mcp_flexion.pos"], glove_action["middle_mcp_abduction.pos"], MIDDLE_SPLAY + ), + "middle_pip_dip.pos": glove_action["middle_dip.pos"], + "ring_radial_flexor.pos": get_radial_flexion( + glove_action["ring_mcp_flexion.pos"], glove_action["ring_mcp_abduction.pos"], RING_SPLAY + ), + "ring_ulnar_flexor.pos": get_ulnar_flexion( + glove_action["ring_mcp_flexion.pos"], glove_action["ring_mcp_abduction.pos"], RING_SPLAY + ), + "ring_pip_dip.pos": glove_action["ring_dip.pos"], + "pinky_radial_flexor.pos": get_radial_flexion( + glove_action["pinky_mcp_flexion.pos"], glove_action["pinky_mcp_abduction.pos"], PINKY_SPLAY + ), + "pinky_ulnar_flexor.pos": get_ulnar_flexion( + glove_action["pinky_mcp_flexion.pos"], glove_action["pinky_mcp_abduction.pos"], PINKY_SPLAY + ), + "pinky_pip_dip.pos": glove_action["pinky_dip.pos"], + } diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index b49addc15..8a667fd41 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -53,5 +53,13 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: from .keyboard.teleop_keyboard import KeyboardEndEffectorTeleop return KeyboardEndEffectorTeleop(config) + elif config.type == "homunculus_glove": + from .homunculus import HomunculusGlove + + return HomunculusGlove(config) + elif config.type == "homunculus_arm": + from .homunculus import HomunculusArm + + return HomunculusArm(config) else: raise ValueError(config.type) From cf86b9300dc83fdad408cfe4787b7b09b55f12cf Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Tue, 8 Jul 2025 18:59:13 +0200 Subject: [PATCH 003/158] fix(logging): Fixing logging levels (#1466) * fix(logging): Fixing logging levels, adding custom logging levels for console and file logging * clean(typing): Adding typing in logging formatter, use proper getter for logging message --- src/lerobot/utils/utils.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 6d9c0338b..2e94a9c93 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -111,35 +111,46 @@ def is_amp_available(device: str): raise ValueError(f"Unknown device '{device}.") -def init_logging(log_file: Path | None = None, display_pid: bool = False): - def custom_format(record): +def init_logging( + log_file: Path | None = None, + display_pid: bool = False, + console_level: str = "INFO", + file_level: str = "DEBUG", +): + def custom_format(record: logging.LogRecord) -> str: dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fnameline = f"{record.pathname}:{record.lineno}" # NOTE: Display PID is useful for multi-process logging. if display_pid: pid_str = f"[PID: {os.getpid()}]" - message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}" + message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.getMessage()}" else: - message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" + message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.getMessage()}" return message - logging.basicConfig(level=logging.INFO) - - for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - formatter = logging.Formatter() formatter.format = custom_format + + logger = logging.getLogger() + logger.setLevel(logging.NOTSET) # Set the logger to the lowest level to capture all messages + + # Remove unused default handlers + for handler in logger.handlers[:]: + logger.removeHandler(handler) + + # Write logs to console console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) - logging.getLogger().addHandler(console_handler) + console_handler.setLevel(console_level.upper()) + logger.addHandler(console_handler) + # Additionally write logs to file if log_file is not None: - # Additionally write logs to file file_handler = logging.FileHandler(log_file) file_handler.setFormatter(formatter) - logging.getLogger().addHandler(file_handler) + file_handler.setLevel(file_level.upper()) + logger.addHandler(file_handler) def format_big_number(num, precision=0): From ce2b9724bfe1b5a4c45e61b1890eef3f5ab0909c Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Wed, 9 Jul 2025 16:22:40 +0200 Subject: [PATCH 004/158] fix(hil-serl): discrete critic send through network (#1468) Co-authored-by: Khalil Meftah Co-authored-by: jpizarrom --- pyproject.toml | 2 +- src/lerobot/scripts/rl/actor.py | 28 ++++++++++-- src/lerobot/scripts/rl/learner.py | 14 +++++- src/lerobot/transport/services.proto | 4 +- src/lerobot/transport/services_pb2.py | 32 ++++++------- src/lerobot/transport/services_pb2_grpc.py | 52 +++++++++++----------- 6 files changed, 81 insertions(+), 51 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 408e3b773..e13a9af01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,7 @@ dependencies = [ [project.optional-dependencies] aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"] docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"] -dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"] +dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "grpcio-tools==1.71.0"] dora = [ "gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'", ] diff --git a/src/lerobot/scripts/rl/actor.py b/src/lerobot/scripts/rl/actor.py index 0e96d3354..cd5e286c0 100644 --- a/src/lerobot/scripts/rl/actor.py +++ b/src/lerobot/scripts/rl/actor.py @@ -317,7 +317,7 @@ def act_with_policy( if done or truncated: logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") - update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device) + update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device) if len(list_transition_to_send_to_learner) > 0: push_transitions_to_transport_queue( @@ -642,9 +642,29 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device) bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False) if bytes_state_dict is not None: logging.info("[ACTOR] Load new parameters from Learner.") - state_dict = bytes_to_state_dict(bytes_state_dict) - state_dict = move_state_dict_to_device(state_dict, device=device) - policy.load_state_dict(state_dict) + state_dicts = bytes_to_state_dict(bytes_state_dict) + + # TODO: check encoder parameter synchronization possible issues: + # 1. When shared_encoder=True, we're loading stale encoder params from actor's state_dict + # instead of the updated encoder params from critic (which is optimized separately) + # 2. When freeze_vision_encoder=True, we waste bandwidth sending/loading frozen params + # 3. Need to handle encoder params correctly for both actor and discrete_critic + # Potential fixes: + # - Send critic's encoder state when shared_encoder=True + # - Skip encoder params entirely when freeze_vision_encoder=True + # - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic) + + # Load actor state dict + actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device) + policy.actor.load_state_dict(actor_state_dict) + + # Load discrete critic if present + if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts: + discrete_critic_state_dict = move_state_dict_to_device( + state_dicts["discrete_critic"], device=device + ) + policy.discrete_critic.load_state_dict(discrete_critic_state_dict) + logging.info("[ACTOR] Loaded discrete critic parameters from Learner.") ################################################# diff --git a/src/lerobot/scripts/rl/learner.py b/src/lerobot/scripts/rl/learner.py index d8830d83e..edd2363b1 100644 --- a/src/lerobot/scripts/rl/learner.py +++ b/src/lerobot/scripts/rl/learner.py @@ -1109,8 +1109,18 @@ def check_nan_in_transition( def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module): logging.debug("[LEARNER] Pushing actor policy to the queue") - state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu") - state_bytes = state_to_bytes(state_dict) + + # Create a dictionary to hold all the state dicts + state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")} + + # Add discrete critic if it exists + if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None: + state_dicts["discrete_critic"] = move_state_dict_to_device( + policy.discrete_critic.state_dict(), device="cpu" + ) + logging.debug("[LEARNER] Including discrete critic in state dict push") + + state_bytes = state_to_bytes(state_dicts) parameters_queue.put(state_bytes) diff --git a/src/lerobot/transport/services.proto b/src/lerobot/transport/services.proto index 89bfc107a..70f39741f 100644 --- a/src/lerobot/transport/services.proto +++ b/src/lerobot/transport/services.proto @@ -11,11 +11,11 @@ // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and -// limitations under the License. +// limitations under the License.python -m grpc_tools.protoc -I src --python_out=src --grpc_python_out=src src/lerobot/transport/services.proto // To generate a classes for transport part (services_pb2.py and services_pb2_grpc.py) use the following command: // -// python -m grpc_tools.protoc -I . --python_out=. --grpc_python_out=. src/lerobot/transport/services.proto +// python -m grpc_tools.protoc -I src --python_out=src --grpc_python_out=src src/lerobot/transport/services.proto // // The command should be launched from the root of the project. diff --git a/src/lerobot/transport/services_pb2.py b/src/lerobot/transport/services_pb2.py index 8a2137687..9e66ae1e3 100644 --- a/src/lerobot/transport/services_pb2.py +++ b/src/lerobot/transport/services_pb2.py @@ -1,6 +1,6 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE -# source: src/lerobot/transport/services.proto +# source: lerobot/transport/services.proto # Protobuf Python Version: 5.29.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor @@ -14,7 +14,7 @@ _runtime_version.ValidateProtobufRuntimeVersion( 29, 0, '', - 'src/lerobot/transport/services.proto' + 'lerobot/transport/services.proto' ) # @@protoc_insertion_point(imports) @@ -23,23 +23,23 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$src/lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'src.lerobot.transport.services_pb2', _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_TRANSFERSTATE']._serialized_start=302 - _globals['_TRANSFERSTATE']._serialized_end=398 - _globals['_TRANSITION']._serialized_start=51 - _globals['_TRANSITION']._serialized_end=127 - _globals['_PARAMETERS']._serialized_start=129 - _globals['_PARAMETERS']._serialized_end=205 - _globals['_INTERACTIONMESSAGE']._serialized_start=207 - _globals['_INTERACTIONMESSAGE']._serialized_end=291 - _globals['_EMPTY']._serialized_start=293 - _globals['_EMPTY']._serialized_end=300 - _globals['_LEARNERSERVICE']._serialized_start=401 - _globals['_LEARNERSERVICE']._serialized_end=658 + _globals['_TRANSFERSTATE']._serialized_start=298 + _globals['_TRANSFERSTATE']._serialized_end=394 + _globals['_TRANSITION']._serialized_start=47 + _globals['_TRANSITION']._serialized_end=123 + _globals['_PARAMETERS']._serialized_start=125 + _globals['_PARAMETERS']._serialized_end=201 + _globals['_INTERACTIONMESSAGE']._serialized_start=203 + _globals['_INTERACTIONMESSAGE']._serialized_end=287 + _globals['_EMPTY']._serialized_start=289 + _globals['_EMPTY']._serialized_end=296 + _globals['_LEARNERSERVICE']._serialized_start=397 + _globals['_LEARNERSERVICE']._serialized_end=654 # @@protoc_insertion_point(module_scope) diff --git a/src/lerobot/transport/services_pb2_grpc.py b/src/lerobot/transport/services_pb2_grpc.py index a4fe8c576..77801a340 100644 --- a/src/lerobot/transport/services_pb2_grpc.py +++ b/src/lerobot/transport/services_pb2_grpc.py @@ -3,7 +3,7 @@ import grpc import warnings -from src.lerobot.transport import services_pb2 as src_dot_lerobot_dot_transport_dot_services__pb2 +from lerobot.transport import services_pb2 as lerobot_dot_transport_dot_services__pb2 GRPC_GENERATED_VERSION = '1.71.0' GRPC_VERSION = grpc.__version__ @@ -18,7 +18,7 @@ except ImportError: if _version_not_supported: raise RuntimeError( f'The grpc package installed is at version {GRPC_VERSION},' - + f' but the generated code in src/lerobot/transport/services_pb2_grpc.py depends on' + + f' but the generated code in lerobot/transport/services_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' @@ -38,23 +38,23 @@ class LearnerServiceStub: """ self.StreamParameters = channel.unary_stream( '/transport.LearnerService/StreamParameters', - request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, - response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Parameters.FromString, + request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Parameters.FromString, _registered_method=True) self.SendTransitions = channel.stream_unary( '/transport.LearnerService/SendTransitions', - request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString, - response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString, + request_serializer=lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, _registered_method=True) self.SendInteractions = channel.stream_unary( '/transport.LearnerService/SendInteractions', - request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, - response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString, + request_serializer=lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, _registered_method=True) self.Ready = channel.unary_unary( '/transport.LearnerService/Ready', - request_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, - response_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString, + request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, _registered_method=True) @@ -93,23 +93,23 @@ def add_LearnerServiceServicer_to_server(servicer, server): rpc_method_handlers = { 'StreamParameters': grpc.unary_stream_rpc_method_handler( servicer.StreamParameters, - request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString, - response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Parameters.SerializeToString, + request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Parameters.SerializeToString, ), 'SendTransitions': grpc.stream_unary_rpc_method_handler( servicer.SendTransitions, - request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Transition.FromString, - response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + request_deserializer=lerobot_dot_transport_dot_services__pb2.Transition.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, ), 'SendInteractions': grpc.stream_unary_rpc_method_handler( servicer.SendInteractions, - request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.InteractionMessage.FromString, - response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + request_deserializer=lerobot_dot_transport_dot_services__pb2.InteractionMessage.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, ), 'Ready': grpc.unary_unary_rpc_method_handler( servicer.Ready, - request_deserializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString, - response_serializer=src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -139,8 +139,8 @@ class LearnerService: request, target, '/transport.LearnerService/StreamParameters', - src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, - src_dot_lerobot_dot_transport_dot_services__pb2.Parameters.FromString, + lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Parameters.FromString, options, channel_credentials, insecure, @@ -166,8 +166,8 @@ class LearnerService: request_iterator, target, '/transport.LearnerService/SendTransitions', - src_dot_lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString, - src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString, + lerobot_dot_transport_dot_services__pb2.Transition.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Empty.FromString, options, channel_credentials, insecure, @@ -193,8 +193,8 @@ class LearnerService: request_iterator, target, '/transport.LearnerService/SendInteractions', - src_dot_lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, - src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString, + lerobot_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Empty.FromString, options, channel_credentials, insecure, @@ -220,8 +220,8 @@ class LearnerService: request, target, '/transport.LearnerService/Ready', - src_dot_lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, - src_dot_lerobot_dot_transport_dot_services__pb2.Empty.FromString, + lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Empty.FromString, options, channel_credentials, insecure, From 30c161006d606700769fe8401b4f1bac3809ce9c Mon Sep 17 00:00:00 2001 From: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Date: Thu, 10 Jul 2025 10:39:11 +0200 Subject: [PATCH 005/158] Add Async Inference (#1196) Co-authored-by: Steven Palma Co-authored-by: Michel Aractingi --- docs/source/_toctree.yml | 2 + docs/source/async.mdx | 272 ++++++++++ pyproject.toml | 3 +- src/lerobot/scripts/server/configs.py | 197 +++++++ src/lerobot/scripts/server/constants.py | 29 + src/lerobot/scripts/server/helpers.py | 386 +++++++++++++ src/lerobot/scripts/server/policy_server.py | 403 ++++++++++++++ src/lerobot/scripts/server/robot_client.py | 509 ++++++++++++++++++ src/lerobot/transport/async_inference.proto | 59 ++ src/lerobot/transport/async_inference_pb2.py | 45 ++ .../transport/async_inference_pb2_grpc.py | 277 ++++++++++ tests/async_inference/test_e2e.py | 177 ++++++ tests/async_inference/test_helpers.py | 459 ++++++++++++++++ tests/async_inference/test_policy_server.py | 215 ++++++++ tests/async_inference/test_robot_client.py | 234 ++++++++ 15 files changed, 3266 insertions(+), 1 deletion(-) create mode 100644 docs/source/async.mdx create mode 100644 src/lerobot/scripts/server/configs.py create mode 100644 src/lerobot/scripts/server/constants.py create mode 100644 src/lerobot/scripts/server/helpers.py create mode 100644 src/lerobot/scripts/server/policy_server.py create mode 100644 src/lerobot/scripts/server/robot_client.py create mode 100644 src/lerobot/transport/async_inference.proto create mode 100644 src/lerobot/transport/async_inference_pb2.py create mode 100644 src/lerobot/transport/async_inference_pb2_grpc.py create mode 100644 tests/async_inference/test_e2e.py create mode 100644 tests/async_inference/test_helpers.py create mode 100644 tests/async_inference/test_policy_server.py create mode 100644 tests/async_inference/test_robot_client.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 83777a3c8..1af96d79d 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -17,6 +17,8 @@ title: Train a Robot with RL - local: hilserl_sim title: Train RL in Simulation + - local: async + title: Use Async Inference title: "Tutorials" - sections: - local: smolvla diff --git a/docs/source/async.mdx b/docs/source/async.mdx new file mode 100644 index 000000000..0a0823cf0 --- /dev/null +++ b/docs/source/async.mdx @@ -0,0 +1,272 @@ +# Asynchronous Inference + +With our [SmolVLA](https://huggingface.co/papers/2506.01844) we introduced a new way to run inference on real-world robots, **decoupling action prediction from action execution**. +In this tutorial, we'll show how to use asynchronous inference (_async inference_) using a finetuned version of SmolVLA, and all the policies supported by LeRobot. +**Try async inference with all the policies** supported by LeRobot! + +**What you'll learn:** +1. Why asynchronous inference matters and how it compares to, more traditional, sequential inference. +2. How to spin-up a `PolicyServer` and connect a `RobotClient` from the same machine, and even over the network. +3. How to tune key parameters (`actions_per_chunk`, `chunk_size_threshold`) for your robot and policy. + +If you get stuck, hop into our [Discord community](https://discord.gg/s3KuuzsPFb)! + + +In a nutshell: with *async inference*, your robot keeps acting while the policy server is already busy computing the next chunk of actions---eliminating "wait-for-inference" lags and unlocking smoother, more reactive behaviours. +This is fundamentally different from synchronous inference (sync), where the robot stays idle while the policy computes the next chunk of actions. + +--- +## Getting started with async inference + +You can read more information on asynchronous inference in our [blogpost](NOTE:blogpost). Here, we report a getting started guide meant to help you setup and run asynchronous inference in your setup. + +First, install `lerobot` with the `async` tag, to install the extra dependencies required to run async inference. + +```shell +pip install -e ".[async]" +``` + +Then, spin up a policy server (in one terminal, or in a separate machine) specifying the host address and port for the client to connect to. +You can spin up a policy server running: + +```shell +python src/lerobot/scripts/server/policy_server.py \ + --host=127.0.0.1 \ + --port=8080 \ +``` + +This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with: + +```shell +python src/lerobot/scripts/server/robot_client.py \ + --server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server + --robot.type=so100_follower \ # ROBOT: your robot type + --robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port + --robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file + --robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy + --task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act` + --policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc) + --pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base) + --policy_device=mps \ # POLICY: the device to run the policy on, on the server + --actions_per_chunk=50 \ # POLICY: the number of actions to output at once + --chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server + --aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions + --debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime +``` +In summary, you need to specify instructions for: +- `SERVER`: the address and port of the policy server +- `ROBOT`: the type of robot to connect to, the port to connect to, and the local `id` of the robot +- `POLICY`: the type of policy to run, and the model name/path on server to the checkpoint to run. You also need to specify which device should the sever be using, and how many actions to output at once (capped at the policy max actions value). +- `CLIENT`: the threshold for the chunk size before sending a new observation to the server, and the function to aggregate actions on overlapping portions. Optionally, you can also visualize the queue size at runtime, to help you tune the `CLIENT` parameters. + +Importantly, +- `actions_per_chunk` and `chunk_size_threshold` are key parameters to tune for your setup. +- `aggregate_fn_name` is the function to aggregate actions on overlapping portions. You can either add a new one to a registry of functions, or add your own in `robot_client.py` (see [here](NOTE:addlinktoLOC)) +- `debug_visualize_queue_size` is a useful tool to tune the `CLIENT` parameters. + +Done! You should see your robot moving around by now 😉 +--- + +## Async vs. synchronous inference + +Synchronous inference relies on interleaving action chunk prediction and action execution. This inherently results in *idle frames*, frames where the robot awaits idle the policy's output: a new action chunk. +In turn, inference is plagued by evident real-time lags, where the robot simply stops acting due to the lack of available actions. +With robotics models increasing in size, this problem risks becoming only more severe. + +

+ +

+

Synchronous inference makes the robot idle while the policy is computing the next chunk of actions.

+ +To overcome this, we design async inference, a paradigm where action planning and execution are decoupled, resulting in (1) higher adaptability and, most importantly, (2) no idle frames. +Crucially, with async inference, the next action chunk is computed *before* the current one is exhausted, resulting in no idleness. +Higher adaptability is ensured by aggregating the different action chunks on overlapping portions, obtaining an up-to-date plan and a tighter control loop. + +

+ +

+

Asynchronous inference results in no idleness because the next chunk is computed before the current chunk is exhausted.

+ + +--- + +## Start the Policy Server + +Policy servers are wrappers around a `PreTrainedPolicy` interfacing them with observations coming from a robot client. +Policy servers are initialized as empty containers which are populated with the requested policy specified in the initial handshake between the robot client and the policy server. +As such, spinning up a policy server is as easy as specifying the host address and port. If you're running the policy server on the same machine as the robot client, you can use `localhost` as the host address. + + + +```bash +python -m lerobot.scripts.server.policy_server \ + --host="localhost" \ + --port=8080 +``` + + +```python +from lerobot.scripts.server.configs import PolicyServerConfig +from lerobot.scripts.server.policy_server import serve + +config = PolicyServerConfig( + host="localhost", + port=8080, +) +serve(config) +``` + + + +This listens on `localhost:8080` for an incoming connection from the associated`RobotClient`, which will communicate which policy to run during the first client-server handshake. + +--- + +## Launch the Robot Client + +`RobotClient` is a wrapper around a `Robot` instance, which `RobotClient` connects to the (possibly remote) `PolicyServer`. +The `RobotClient` streams observations to the `PolicyServer`, and receives action chunks obtained running inference on the server (which we assume to have better computational resources than the robot controller). + + + +```bash +python src/lerobot/scripts/server/robot_client.py \ + --server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server + --robot.type=so100_follower \ # ROBOT: your robot type + --robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port + --robot.id=follower_so100 \ # ROBOT: your robot id, to load calibration file + --robot.cameras="{ laptop: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}, phone: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ # POLICY: the cameras used to acquire frames, with keys matching the keys expected by the policy + --task="dummy" \ # POLICY: The task to run the policy on (`Fold my t-shirt`). Not necessarily defined for all policies, such as `act` + --policy_type=your_policy_type \ # POLICY: the type of policy to run (smolvla, act, etc) + --pretrained_name_or_path=user/model \ # POLICY: the model name/path on server to the checkpoint to run (e.g., lerobot/smolvla_base) + --policy_device=mps \ # POLICY: the device to run the policy on, on the server + --actions_per_chunk=50 \ # POLICY: the number of actions to output at once + --chunk_size_threshold=0.5 \ # CLIENT: the threshold for the chunk size before sending a new observation to the server + --aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions + --debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime +``` + + +```python +import threading +from lerobot.robots.so100_follower import SO100FollowerConfig +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.scripts.server.configs import RobotClientConfig +from lerobot.scripts.server.robot_client import RobotClient +from lerobot.scripts.server.helpers import visualize_action_queue_size + +# 1. Create the robot instance +"""Check out the cameras available in your setup by running `python lerobot/find_cameras.py`""" +# these cameras must match the ones expected by the policy +# check the config.json on the Hub for the policy you are using +camera_cfg = { + "top": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30), + "side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30) +} + +robot_cfg = SO100FollowerConfig( + port="/dev/tty.usbmodem585A0076841", + id="follower_so100", + cameras=camera_cfg +) + +# 3. Create client configuration +client_cfg = RobotClientConfig( + robot=robot_cfg, + server_address="localhost:8080", + policy_device="mps", + policy_type="smolvla", + pretrained_name_or_path="fracapuano/smolvla_async", + chunk_size_threshold=0.5, + actions_per_chunk=50, # make sure this is less than the max actions of the policy +) + +# 4. Create and start client +client = RobotClient(client_cfg) + +# 5. Specify the task +task = "Don't do anything, stay still" + +if client.start(): + # Start action receiver thread + action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True) + action_receiver_thread.start() + + try: + # Run the control loop + client.control_loop(task) + except KeyboardInterrupt: + client.stop() + action_receiver_thread.join() + # (Optionally) plot the action queue size + visualize_action_queue_size(client.action_queue_size) +``` + + + +The following two parameters are key in every setup: + + + + + + + + + + + + + + + + + + + + + +
HyperparameterDefaultWhat it does
actions_per_chunk50How many actions the policy outputs at once. Typical values: 10-50.
chunk_size_threshold0.7When the queue is ≤ 50% full, the client sends a fresh observation. Value in [0, 1].
+ + +Different values of `actions_per_chunk` and `chunk_size_threshold` do result in different behaviours. + + +On the one hand, increasing the value of `actions_per_chunk` will result in reducing the likelihood of ending up with no actions to execute, as more actions will be available when the new chunk is computed. +However, larger values of `actions_per_chunk` might also result in less precise actions, due to the compounding errors consequent to predicting actions over longer timespans. + +On the other hand, increasing the value of `chunk_size_threshold` will result in sending out to the `PolicyServer` observations for inference more often, resulting in a larger number of updates action chunks, overlapping on significant portions. This results in high adaptability, in the limit predicting one action chunk for each observation, which is in turn only marginally consumed while a new one is produced. +This option does also put more pressure on the inference pipeline, as a consequence of the many requests. Conversely, values of `chunk_size_threshold` close to 0.0 collapse to the synchronous edge case, whereby new observations are only sent out whenever the current chunk is exhausted. + +We found the default values of `actions_per_chunk` and `chunk_size_threshold` to work well in the experiments we developed for the [SmolVLA paper](https://huggingface.co/papers/2506.01844), but recommend experimenting with different values to find the best fit for your setup. + +### Tuning async inference for your setup + +1. **Choose your computational resources carefully.** [PI0](https://huggingface.co/lerobot/pi0) occupies 14GB of memory at inference time, while [SmolVLA](https://huggingface.co/lerobot/smolvla_base) requires only ~2GB. You should identify the best computational resource for your use case keeping in mind smaller policies require less computational resources. The combination of policy and device used (CPU-intensive, using MPS, or the number of CUDA cores on a given NVIDIA GPU) directly impacts the average inference latency you should expect. +2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue. +3. **Adjust `chunk_size_threshold`**. + - Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model). + - We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug-visualize-queue-size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup. + +

+ +

+

The action queue size is plotted at runtime when the `--debug-visualize-queue-size` flag is passed, for various levels of `chunk_size_threshold` (`g` in the SmolVLA paper).

+ + +--- + +## Conclusion + +Asynchronous inference represents a significant advancement in real-time robotics control, addressing the fundamental challenge of inference latency that has long plagued robotics applications. Through this tutorial, you've learned how to implement a complete async inference pipeline that eliminates idle frames and enables smoother, more reactive robot behaviors. + +**Key Takeaways:** + +- **Paradigm Shift**: Async inference decouples action prediction from execution, allowing robots to continue acting while new action chunks are computed in parallel +- **Performance Benefits**: Eliminates "wait-for-inference" lags that are inherent in synchronous approaches, becoming increasingly important as policy models grow larger +- **Flexible Architecture**: The server-client design enables distributed computing, where inference can run on powerful remote hardware while maintaining real-time robot control +- **Tunable Parameters**: Success depends on properly configuring `actions_per_chunk` and `chunk_size_threshold` for your specific hardware, policy, and task requirements +- **Universal Compatibility**: Works with all LeRobot-supported policies, from lightweight ACT models to vision-language models like SmolVLA + +Start experimenting with the default parameters, monitor your action queue sizes, and iteratively refine your setup to achieve optimal performance for your specific use case. +If you want to discuss this further, hop into our [Discord community](https://discord.gg/s3KuuzsPFb), or open an issue on our [GitHub repository](https://github.com/lerobot/lerobot/issues). diff --git a/pyproject.toml b/pyproject.toml index e13a9af01..81cb22a21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ classifiers = [ ] dependencies = [ "cmake>=3.29.0.1", - "datasets>=2.19.0", + "datasets>=2.19.0,<=3.6.0", "deepdiff>=7.0.1", "diffusers>=0.27.2", "draccus==0.10.0", @@ -105,6 +105,7 @@ hilserl = ["transformers>=4.50.3", "gym-hil>=0.1.9", "protobuf>=5.29.3", "grpcio umi = ["imagecodecs>=2024.1.1"] video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"] +async = ["grpcio==1.71.0", "matplotlib>=3.10.3"] [tool.poetry] requires-poetry = ">=2.1" diff --git a/src/lerobot/scripts/server/configs.py b/src/lerobot/scripts/server/configs.py new file mode 100644 index 000000000..7058842ae --- /dev/null +++ b/src/lerobot/scripts/server/configs.py @@ -0,0 +1,197 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Callable + +import torch + +from lerobot.robots.config import RobotConfig +from lerobot.scripts.server.constants import ( + DEFAULT_FPS, + DEFAULT_INFERENCE_LATENCY, + DEFAULT_OBS_QUEUE_TIMEOUT, +) + +# Aggregate function registry for CLI usage +AGGREGATE_FUNCTIONS = { + "weighted_average": lambda old, new: 0.3 * old + 0.7 * new, + "latest_only": lambda old, new: new, + "average": lambda old, new: 0.5 * old + 0.5 * new, + "conservative": lambda old, new: 0.7 * old + 0.3 * new, +} + + +def get_aggregate_function(name: str) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: + """Get aggregate function by name from registry.""" + if name not in AGGREGATE_FUNCTIONS: + available = list(AGGREGATE_FUNCTIONS.keys()) + raise ValueError(f"Unknown aggregate function '{name}'. Available: {available}") + return AGGREGATE_FUNCTIONS[name] + + +@dataclass +class PolicyServerConfig: + """Configuration for PolicyServer. + + This class defines all configurable parameters for the PolicyServer, + including networking settings and action chunking specifications. + """ + + # Networking configuration + host: str = field(default="localhost", metadata={"help": "Host address to bind the server to"}) + port: int = field(default=8080, metadata={"help": "Port number to bind the server to"}) + + # Timing configuration + fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"}) + inference_latency: float = field( + default=DEFAULT_INFERENCE_LATENCY, metadata={"help": "Target inference latency in seconds"} + ) + + obs_queue_timeout: float = field( + default=DEFAULT_OBS_QUEUE_TIMEOUT, metadata={"help": "Timeout for observation queue in seconds"} + ) + + def __post_init__(self): + """Validate configuration after initialization.""" + if self.port < 1 or self.port > 65535: + raise ValueError(f"Port must be between 1 and 65535, got {self.port}") + + if self.environment_dt <= 0: + raise ValueError(f"environment_dt must be positive, got {self.environment_dt}") + + if self.inference_latency < 0: + raise ValueError(f"inference_latency must be non-negative, got {self.inference_latency}") + + if self.obs_queue_timeout < 0: + raise ValueError(f"obs_queue_timeout must be non-negative, got {self.obs_queue_timeout}") + + @classmethod + def from_dict(cls, config_dict: dict) -> "PolicyServerConfig": + """Create a PolicyServerConfig from a dictionary.""" + return cls(**config_dict) + + @property + def environment_dt(self) -> float: + """Environment time step, in seconds""" + return 1 / self.fps + + def to_dict(self) -> dict: + """Convert the configuration to a dictionary.""" + return { + "host": self.host, + "port": self.port, + "fps": self.fps, + "environment_dt": self.environment_dt, + "inference_latency": self.inference_latency, + } + + +@dataclass +class RobotClientConfig: + """Configuration for RobotClient. + + This class defines all configurable parameters for the RobotClient, + including network connection, policy settings, and control behavior. + """ + + # Policy configuration + policy_type: str = field(metadata={"help": "Type of policy to use"}) + pretrained_name_or_path: str = field(metadata={"help": "Pretrained model name or path"}) + + # Robot configuration (for CLI usage - robot instance will be created from this) + robot: RobotConfig = field(metadata={"help": "Robot configuration"}) + + # Policies typically output K actions at max, but we can use less to avoid wasting bandwidth (as actions + # would be aggregated on the client side anyway, depending on the value of `chunk_size_threshold`) + actions_per_chunk: int = field(metadata={"help": "Number of actions per chunk"}) + + # Task instruction for the robot to execute (e.g., 'fold my tshirt') + task: str = field(default="", metadata={"help": "Task instruction for the robot to execute"}) + + # Network configuration + server_address: str = field(default="localhost:8080", metadata={"help": "Server address to connect to"}) + + # Device configuration + policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"}) + + # Control behavior configuration + chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"}) + fps: int = field(default=DEFAULT_FPS, metadata={"help": "Frames per second"}) + + # Aggregate function configuration (CLI-compatible) + aggregate_fn_name: str = field( + default="weighted_average", + metadata={"help": f"Name of aggregate function to use. Options: {list(AGGREGATE_FUNCTIONS.keys())}"}, + ) + + # Debug configuration + debug_visualize_queue_size: bool = field( + default=False, metadata={"help": "Visualize the action queue size"} + ) + + # Verification configuration + verify_robot_cameras: bool = field( + default=True, metadata={"help": "Verify that the robot cameras match the policy cameras"} + ) + + @property + def environment_dt(self) -> float: + """Environment time step, in seconds""" + return 1 / self.fps + + def __post_init__(self): + """Validate configuration after initialization.""" + if not self.server_address: + raise ValueError("server_address cannot be empty") + + if not self.policy_type: + raise ValueError("policy_type cannot be empty") + + if not self.pretrained_name_or_path: + raise ValueError("pretrained_name_or_path cannot be empty") + + if not self.policy_device: + raise ValueError("policy_device cannot be empty") + + if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1: + raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}") + + if self.fps <= 0: + raise ValueError(f"fps must be positive, got {self.fps}") + + if self.actions_per_chunk <= 0: + raise ValueError(f"actions_per_chunk must be positive, got {self.actions_per_chunk}") + + self.aggregate_fn = get_aggregate_function(self.aggregate_fn_name) + + @classmethod + def from_dict(cls, config_dict: dict) -> "RobotClientConfig": + """Create a RobotClientConfig from a dictionary.""" + return cls(**config_dict) + + def to_dict(self) -> dict: + """Convert the configuration to a dictionary.""" + return { + "server_address": self.server_address, + "policy_type": self.policy_type, + "pretrained_name_or_path": self.pretrained_name_or_path, + "policy_device": self.policy_device, + "chunk_size_threshold": self.chunk_size_threshold, + "fps": self.fps, + "actions_per_chunk": self.actions_per_chunk, + "task": self.task, + "debug_visualize_queue_size": self.debug_visualize_queue_size, + "aggregate_fn_name": self.aggregate_fn_name, + } diff --git a/src/lerobot/scripts/server/constants.py b/src/lerobot/scripts/server/constants.py new file mode 100644 index 000000000..af983a800 --- /dev/null +++ b/src/lerobot/scripts/server/constants.py @@ -0,0 +1,29 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Client side: The environment evolves with a time resolution equal to 1/fps""" + +DEFAULT_FPS = 30 + +"""Server side: Running inference on (at most) 1/fps""" +DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS + +"""Server side: Timeout for observation queue in seconds""" +DEFAULT_OBS_QUEUE_TIMEOUT = 2 + +# All action chunking policies +SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"] + +# TODO: Add all other robots +SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"] diff --git a/src/lerobot/scripts/server/helpers.py b/src/lerobot/scripts/server/helpers.py new file mode 100644 index 000000000..7fd56e693 --- /dev/null +++ b/src/lerobot/scripts/server/helpers.py @@ -0,0 +1,386 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import logging +import logging.handlers +import os +import time +from dataclasses import dataclass +from pathlib import Path +from threading import Event +from typing import Any + +import torch + +from lerobot.configs.types import PolicyFeature +from lerobot.constants import OBS_IMAGES, OBS_STATE +from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features + +# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config +from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401 +from lerobot.robots.robot import Robot +from lerobot.transport import async_inference_pb2 +from lerobot.transport.utils import bytes_buffer_size +from lerobot.utils.utils import init_logging + +Action = torch.Tensor +ActionChunk = torch.Tensor + +# observation as received from the robot +RawObservation = dict[str, torch.Tensor] + +# observation as those recorded in LeRobot dataset (keys are different) +LeRobotObservation = dict[str, torch.Tensor] + +# observation, ready for policy inference (image keys resized) +Observation = dict[str, torch.Tensor] + + +def visualize_action_queue_size(action_queue_size: list[int]) -> None: + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + ax.set_title("Action Queue Size Over Time") + ax.set_xlabel("Environment steps") + ax.set_ylabel("Action Queue Size") + ax.set_ylim(0, max(action_queue_size) * 1.1) + ax.grid(True, alpha=0.3) + ax.plot(range(len(action_queue_size)), action_queue_size) + plt.show() + + +def validate_robot_cameras_for_policy( + lerobot_observation_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature] +) -> None: + image_keys = list(filter(is_image_key, lerobot_observation_features)) + assert set(image_keys) == set(policy_image_features.keys()), ( + f"Policy image features must match robot cameras! Received {list(policy_image_features.keys())} != {image_keys}" + ) + + +def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]: + return hw_to_dataset_features(robot.observation_features, "observation", use_video=False) + + +def is_image_key(k: str) -> bool: + return k.startswith(OBS_IMAGES) + + +def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int, int, int]) -> torch.tensor: + assert image.ndim == 3, f"Image must be (C, H, W)! Received {image.shape}" + # (H, W, C) -> (C, H, W) for resizing from robot obsevation resolution to policy image resolution + image = image.permute(2, 0, 1) + dims = (resize_dims[1], resize_dims[2]) + # Add batch dimension for interpolate: (C, H, W) -> (1, C, H, W) + image_batched = image.unsqueeze(0) + # Interpolate and remove batch dimension: (1, C, H, W) -> (C, H, W) + resized = torch.nn.functional.interpolate(image_batched, size=dims, mode="bilinear", align_corners=False) + + return resized.squeeze(0) + + +def raw_observation_to_observation( + raw_observation: RawObservation, + lerobot_features: dict[str, dict], + policy_image_features: dict[str, PolicyFeature], + device: str, +) -> Observation: + observation = {} + + observation = prepare_raw_observation(raw_observation, lerobot_features, policy_image_features) + for k, v in observation.items(): + if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations + if "image" in k: + # Policy expects images in shape (B, C, H, W) + observation[k] = prepare_image(v).unsqueeze(0).to(device) + else: + observation[k] = v.to(device) + else: + observation[k] = v + + return observation + + +def prepare_image(image: torch.Tensor) -> torch.Tensor: + """Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor""" + image = image.type(torch.float32) / 255 + image = image.contiguous() + + return image + + +def extract_state_from_raw_observation( + lerobot_obs: RawObservation, +) -> torch.Tensor: + """Extract the state from a raw observation.""" + state = torch.tensor(lerobot_obs[OBS_STATE]) + + if state.ndim == 1: + state = state.unsqueeze(0) + + return state + + +def extract_images_from_raw_observation( + lerobot_obs: RawObservation, + camera_key: str, +) -> dict[str, torch.Tensor]: + """Extract the images from a raw observation.""" + return torch.tensor(lerobot_obs[camera_key]) + + +def make_lerobot_observation( + robot_obs: RawObservation, + lerobot_features: dict[str, dict], +) -> LeRobotObservation: + """Make a lerobot observation from a raw observation.""" + return build_dataset_frame(lerobot_features, robot_obs, prefix="observation") + + +def prepare_raw_observation( + robot_obs: RawObservation, + lerobot_features: dict[str, dict], + policy_image_features: dict[str, PolicyFeature], +) -> Observation: + """Matches keys from the raw robot_obs dict to the keys expected by a given policy (passed as + policy_image_features).""" + # 1. {motor.pos1:value1, motor.pos2:value2, ..., laptop:np.ndarray} -> + # -> {observation.state:[value1,value2,...], observation.images.laptop:np.ndarray} + lerobot_obs = make_lerobot_observation(robot_obs, lerobot_features) + + # 2. Greps all observation.images.<> keys + image_keys = list(filter(is_image_key, lerobot_obs)) + # state's shape is expected as (B, state_dim) + state_dict = {OBS_STATE: extract_state_from_raw_observation(lerobot_obs)} + image_dict = { + image_k: extract_images_from_raw_observation(lerobot_obs, image_k) for image_k in image_keys + } + + # Turns the image features to (C, H, W) with H, W matching the policy image features. + # This reduces the resolution of the images + image_dict = { + key: resize_robot_observation_image(torch.tensor(lerobot_obs[key]), policy_image_features[key].shape) + for key in image_keys + } + + if "task" in robot_obs: + state_dict["task"] = robot_obs["task"] + + return {**state_dict, **image_dict} + + +def get_logger(name: str, log_to_file: bool = True) -> logging.Logger: + """ + Get a logger using the standardized logging setup from utils.py. + + Args: + name: Logger name (e.g., 'policy_server', 'robot_client') + log_to_file: Whether to also log to a file + + Returns: + Configured logger instance + """ + # Create logs directory if logging to file + if log_to_file: + os.makedirs("logs", exist_ok=True) + log_file = Path(f"logs/{name}_{int(time.time())}.log") + else: + log_file = None + + # Initialize the standardized logging + init_logging(log_file=log_file, display_pid=False) + + # Return a named logger + return logging.getLogger(name) + + +@dataclass +class TimedData: + """A data object with timestamp and timestep information. + + Args: + timestamp: Unix timestamp relative to data's creation. + data: The actual data to wrap a timestamp around. + timestep: The timestep of the data. + """ + + timestamp: float + timestep: int + + def get_timestamp(self): + return self.timestamp + + def get_timestep(self): + return self.timestep + + +@dataclass +class TimedAction(TimedData): + action: Action + + def get_action(self): + return self.action + + +@dataclass +class TimedObservation(TimedData): + observation: RawObservation + must_go: bool = False + + def get_observation(self): + return self.observation + + +@dataclass +class FPSTracker: + """Utility class to track FPS metrics over time.""" + + target_fps: float + first_timestamp: float = None + total_obs_count: int = 0 + + def calculate_fps_metrics(self, current_timestamp: float) -> dict[str, float]: + """Calculate average FPS vs target""" + self.total_obs_count += 1 + + # Initialize first observation time + if self.first_timestamp is None: + self.first_timestamp = current_timestamp + + # Calculate overall average FPS (since start) + total_duration = current_timestamp - self.first_timestamp + avg_fps = (self.total_obs_count - 1) / total_duration if total_duration > 1e-6 else 0.0 + + return {"avg_fps": avg_fps, "target_fps": self.target_fps} + + def reset(self): + """Reset the FPS tracker state""" + self.first_timestamp = None + self.total_obs_count = 0 + + +@dataclass +class RemotePolicyConfig: + policy_type: str + pretrained_name_or_path: str + lerobot_features: dict[str, PolicyFeature] + actions_per_chunk: int + device: str = "cpu" + + +def _compare_observation_states(obs1_state: torch.Tensor, obs2_state: torch.Tensor, atol: float) -> bool: + """Check if two observation states are similar, under a tolerance threshold""" + return bool(torch.linalg.norm(obs1_state - obs2_state) < atol) + + +def observations_similar( + obs1: TimedObservation, obs2: TimedObservation, lerobot_features: dict[str, dict], atol: float = 1 +) -> bool: + """Check if two observations are similar, under a tolerance threshold. Measures distance between + observations as the difference in joint-space between the two observations. + + NOTE(fracapuano): This is a very simple check, and it is enough for the current use case. + An immediate next step is to use (fast) perceptual difference metrics comparing some camera views, + to surpass this joint-space similarity check. + """ + obs1_state = extract_state_from_raw_observation( + make_lerobot_observation(obs1.get_observation(), lerobot_features) + ) + obs2_state = extract_state_from_raw_observation( + make_lerobot_observation(obs2.get_observation(), lerobot_features) + ) + + return _compare_observation_states(obs1_state, obs2_state, atol=atol) + + +def send_bytes_in_chunks( + buffer: bytes, + message_class: Any, + log_prefix: str = "", + silent: bool = True, + chunk_size: int = 3 * 1024 * 1024, +): + # NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.send_bytes_in_chunks. Duplication can't be avoided if we + # don't use a unique class for messages sent (due to the different transfer states sent). Also, I'd want more control over the + # chunk size as I am using it to send image observations. + buffer = io.BytesIO(buffer) + size_in_bytes = bytes_buffer_size(buffer) + + sent_bytes = 0 + + logging_method = logging.info if not silent else logging.debug + + logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with") + + while sent_bytes < size_in_bytes: + transfer_state = async_inference_pb2.TransferState.TRANSFER_MIDDLE + + if sent_bytes + chunk_size >= size_in_bytes: + transfer_state = async_inference_pb2.TransferState.TRANSFER_END + elif sent_bytes == 0: + transfer_state = async_inference_pb2.TransferState.TRANSFER_BEGIN + + size_to_read = min(chunk_size, size_in_bytes - sent_bytes) + chunk = buffer.read(size_to_read) + + yield message_class(transfer_state=transfer_state, data=chunk) + sent_bytes += size_to_read + logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}") + + logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") + + +def receive_bytes_in_chunks( + iterator, continue_receiving: Event, logger: logging.Logger, log_prefix: str = "" +): # type: ignore + # NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.receive_bytes_in_chunks. Duplication can't be avoided if we + # don't use a unique class for messages sent (due to the different transfer states sent). Also, on the server side the logic for receiving + # is opposite then the HIL-SERL design (my event showcases keeping on running instead of shutdown) + bytes_buffer = io.BytesIO() + step = 0 + + logger.info(f"{log_prefix} Starting receiver") + for item in iterator: + logger.debug(f"{log_prefix} Received item") + if not continue_receiving.is_set(): + logger.info(f"{log_prefix} Shutting down receiver") + return + + if item.transfer_state == async_inference_pb2.TransferState.TRANSFER_BEGIN: + bytes_buffer.seek(0) + bytes_buffer.truncate(0) + bytes_buffer.write(item.data) + logger.debug(f"{log_prefix} Received data at step 0") + + elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_MIDDLE: + bytes_buffer.write(item.data) + step += 1 + logger.debug(f"{log_prefix} Received data at step {step}") + + elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_END: + bytes_buffer.write(item.data) + logger.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}") + + complete_bytes = bytes_buffer.getvalue() + + bytes_buffer.seek(0) + bytes_buffer.truncate(0) + + logger.debug(f"{log_prefix} Queue updated") + return complete_bytes + + else: + logger.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}") + raise ValueError(f"Received unknown transfer state {item.transfer_state}") diff --git a/src/lerobot/scripts/server/policy_server.py b/src/lerobot/scripts/server/policy_server.py new file mode 100644 index 000000000..669ccc58e --- /dev/null +++ b/src/lerobot/scripts/server/policy_server.py @@ -0,0 +1,403 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example: +```shell +python src/lerobot/scripts/server/policy_server.py \ + --host=127.0.0.1 \ + --port=8080 \ + --fps=30 \ + --inference_latency=0.033 \ + --obs_queue_timeout=1 +``` +""" + +import logging +import pickle # nosec +import threading +import time +from concurrent import futures +from dataclasses import asdict +from pprint import pformat +from queue import Empty, Queue + +import draccus +import grpc +import torch + +from lerobot.policies.factory import get_policy_class +from lerobot.scripts.server.configs import PolicyServerConfig +from lerobot.scripts.server.constants import SUPPORTED_POLICIES +from lerobot.scripts.server.helpers import ( + FPSTracker, + Observation, + RemotePolicyConfig, + TimedAction, + TimedObservation, + get_logger, + observations_similar, + raw_observation_to_observation, + receive_bytes_in_chunks, +) +from lerobot.transport import ( + async_inference_pb2, # type: ignore + async_inference_pb2_grpc, # type: ignore +) + + +class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): + prefix = "policy_server" + logger = get_logger(prefix) + + def __init__(self, config: PolicyServerConfig): + self.config = config + self._running_event = threading.Event() + + # FPS measurement + self.fps_tracker = FPSTracker(target_fps=config.fps) + + self.observation_queue = Queue(maxsize=1) + + self._predicted_timesteps_lock = threading.Lock() + self._predicted_timesteps = set() + + self.last_processed_obs = None + + # Attributes will be set by SendPolicyInstructions + self.device = None + self.policy_type = None + self.lerobot_features = None + self.actions_per_chunk = None + self.policy = None + + @property + def running(self): + return self._running_event.is_set() + + @property + def policy_image_features(self): + return self.policy.config.image_features + + def _reset_server(self) -> None: + """Flushes server state when new client connects.""" + # only running inference on the latest observation received by the server + self._running_event.clear() + self.observation_queue = Queue(maxsize=1) + + with self._predicted_timesteps_lock: + self._predicted_timesteps = set() + + def Ready(self, request, context): # noqa: N802 + client_id = context.peer() + self.logger.info(f"Client {client_id} connected and ready") + self._reset_server() + self._running_event.set() + + return async_inference_pb2.Empty() + + def SendPolicyInstructions(self, request, context): # noqa: N802 + """Receive policy instructions from the robot client""" + + if not self.running: + self.logger.warning("Server is not running. Ignoring policy instructions.") + return async_inference_pb2.Empty() + + client_id = context.peer() + + policy_specs = pickle.loads(request.data) # nosec + + if not isinstance(policy_specs, RemotePolicyConfig): + raise TypeError(f"Policy specs must be a RemotePolicyConfig. Got {type(policy_specs)}") + + if policy_specs.policy_type not in SUPPORTED_POLICIES: + raise ValueError( + f"Policy type {policy_specs.policy_type} not supported. " + f"Supported policies: {SUPPORTED_POLICIES}" + ) + + self.logger.info( + f"Receiving policy instructions from {client_id} | " + f"Policy type: {policy_specs.policy_type} | " + f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | " + f"Actions per chunk: {policy_specs.actions_per_chunk} | " + f"Device: {policy_specs.device}" + ) + + self.device = policy_specs.device + self.policy_type = policy_specs.policy_type # act, pi0, etc. + self.lerobot_features = policy_specs.lerobot_features + self.actions_per_chunk = policy_specs.actions_per_chunk + + policy_class = get_policy_class(self.policy_type) + + start = time.perf_counter() + self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path) + self.policy.to(self.device) + end = time.perf_counter() + + self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds") + + return async_inference_pb2.Empty() + + def SendObservations(self, request_iterator, context): # noqa: N802 + """Receive observations from the robot client""" + client_id = context.peer() + self.logger.debug(f"Receiving observations from {client_id}") + + receive_time = time.time() # comparing timestamps so need time.time() + start_deserialize = time.perf_counter() + received_bytes = receive_bytes_in_chunks( + request_iterator, self._running_event, self.logger + ) # blocking call while looping over request_iterator + timed_observation = pickle.loads(received_bytes) # nosec + deserialize_time = time.perf_counter() - start_deserialize + + self.logger.debug(f"Received observation #{timed_observation.get_timestep()}") + + obs_timestep = timed_observation.get_timestep() + obs_timestamp = timed_observation.get_timestamp() + + # Calculate FPS metrics + fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp) + + self.logger.info( + f"Received observation #{obs_timestep} | " + f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client + f"Target: {fps_metrics['target_fps']:.2f} | " + f"One-way latency: {(receive_time - obs_timestamp) * 1000:.2f}ms" + ) + + self.logger.debug( + f"Server timestamp: {receive_time:.6f} | " + f"Client timestamp: {obs_timestamp:.6f} | " + f"Deserialization time: {deserialize_time:.6f}s" + ) + + if not self._enqueue_observation( + timed_observation # wrapping a RawObservation + ): + self.logger.info(f"Observation #{obs_timestep} has been filtered out") + + return async_inference_pb2.Empty() + + def GetActions(self, request, context): # noqa: N802 + """Returns actions to the robot client. Actions are sent as a single + chunk, containing multiple actions.""" + client_id = context.peer() + self.logger.debug(f"Client {client_id} connected for action streaming") + + # Generate action based on the most recent observation and its timestep + try: + getactions_starts = time.perf_counter() + obs = self.observation_queue.get(timeout=self.config.obs_queue_timeout) + self.logger.info( + f"Running inference for observation #{obs.get_timestep()} (must_go: {obs.must_go})" + ) + + with self._predicted_timesteps_lock: + self._predicted_timesteps.add(obs.get_timestep()) + + start_time = time.perf_counter() + action_chunk = self._predict_action_chunk(obs) + inference_time = time.perf_counter() - start_time + + start_time = time.perf_counter() + actions_bytes = pickle.dumps(action_chunk) # nosec + serialize_time = time.perf_counter() - start_time + + # Create and return the action chunk + actions = async_inference_pb2.Actions(data=actions_bytes) + + self.logger.info( + f"Action chunk #{obs.get_timestep()} generated | " + f"Total time: {(inference_time + serialize_time) * 1000:.2f}ms" + ) + + self.logger.debug( + f"Action chunk #{obs.get_timestep()} generated | " + f"Inference time: {inference_time:.2f}s |" + f"Serialize time: {serialize_time:.2f}s |" + f"Total time: {inference_time + serialize_time:.2f}s" + ) + + time.sleep( + max(0, self.config.inference_latency - max(0, time.perf_counter() - getactions_starts)) + ) # sleep controls inference latency + + return actions + + except Empty: # no observation added to queue in obs_queue_timeout + return async_inference_pb2.Empty() + + except Exception as e: + self.logger.error(f"Error in StreamActions: {e}") + + return async_inference_pb2.Empty() + + def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool: + """Check if the observation is valid to be processed by the policy""" + with self._predicted_timesteps_lock: + predicted_timesteps = self._predicted_timesteps + + if obs.get_timestep() in predicted_timesteps: + self.logger.debug(f"Skipping observation #{obs.get_timestep()} - Timestep predicted already!") + return False + + elif observations_similar(obs, previous_obs, lerobot_features=self.lerobot_features): + self.logger.debug( + f"Skipping observation #{obs.get_timestep()} - Observation too similar to last obs predicted!" + ) + return False + + else: + return True + + def _enqueue_observation(self, obs: TimedObservation) -> bool: + """Enqueue an observation if it must go through processing, otherwise skip it. + Observations not in queue are never run through the policy network""" + + if ( + obs.must_go + or self.last_processed_obs is None + or self._obs_sanity_checks(obs, self.last_processed_obs) + ): + last_obs = self.last_processed_obs.get_timestep() if self.last_processed_obs else "None" + self.logger.debug( + f"Enqueuing observation. Must go: {obs.must_go} | Last processed obs: {last_obs}" + ) + + # If queue is full, get the old observation to make room + if self.observation_queue.full(): + # pops from queue + _ = self.observation_queue.get_nowait() + self.logger.debug("Observation queue was full, removed oldest observation") + + # Now put the new observation (never blocks as queue is non-full here) + self.observation_queue.put(obs) + return True + + return False + + def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]: + """Turn a chunk of actions into a list of TimedAction instances, + with the first action corresponding to t_0 and the rest corresponding to + t_0 + i*environment_dt for i in range(len(action_chunk)) + """ + return [ + TimedAction(timestamp=t_0 + i * self.config.environment_dt, timestep=i_0 + i, action=action) + for i, action in enumerate(action_chunk) + ] + + def _prepare_observation(self, observation_t: TimedObservation) -> Observation: + """ + Prepare observation, ready for policy inference. + E.g.: To keep observation sampling rate high (and network packet tiny) we send int8 [0,255] images from the + client and then convert them to float32 [0,1] images here, before running inference. + """ + # RawObservation from robot.get_observation() - wrong keys, wrong dtype, wrong image shape + observation: Observation = raw_observation_to_observation( + observation_t.get_observation(), + self.lerobot_features, + self.policy_image_features, + self.device, + ) + # processed Observation - right keys, right dtype, right image shape + + return observation + + def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor: + """Get an action chunk from the policy. The chunk contains only""" + chunk = self.policy.predict_action_chunk(observation) + if chunk.ndim != 3: + chunk = chunk.unsqueeze(0) # adding batch dimension, now shape is (B, chunk_size, action_dim) + + return chunk[:, : self.actions_per_chunk, :] + torch.randn_like(chunk[:, : self.actions_per_chunk, :]) + + def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]: + """Predict an action chunk based on an observation""" + inference_starts = time.perf_counter() + + """1. Prepare observation""" + start_time = time.perf_counter() + observation = self._prepare_observation(observation_t) + preprocessing_time = time.perf_counter() - start_time + + self.last_processed_obs: TimedObservation = observation_t + + """2. Get action chunk""" + start_time = time.perf_counter() + action_tensor = self._get_action_chunk(observation) + inference_time = time.perf_counter() - start_time + + """3. Post-inference processing""" + start_time = time.perf_counter() + # Move to CPU before serializing + action_tensor = action_tensor.cpu().squeeze(0) + + action_chunk = self._time_action_chunk( + observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep() + ) + postprocessing_time = time.perf_counter() - start_time + inference_stops = time.perf_counter() + + self.logger.info( + f"Observation {observation_t.get_timestep()} |" + f"Inference time: {1000 * (inference_stops - inference_starts):.2f}ms" + ) + + # full-process latency breakdown for debugging purposes + self.logger.debug( + f"Observation {observation_t.get_timestep()} | " + f"Preprocessing time: {1000 * (preprocessing_time - inference_starts):.2f}ms | " + f"Inference time: {1000 * (inference_time - preprocessing_time):.2f}ms | " + f"Postprocessing time: {1000 * (postprocessing_time - inference_time):.2f}ms | " + f"Total time: {1000 * (postprocessing_time - inference_starts):.2f}ms" + ) + + return action_chunk + + def stop(self): + """Stop the server""" + self._reset_server() + self.logger.info("Server stopping...") + + +@draccus.wrap() +def serve(cfg: PolicyServerConfig): + """Start the PolicyServer with the given configuration. + + Args: + config: PolicyServerConfig instance. If None, uses default configuration. + """ + logging.info(pformat(asdict(cfg))) + + # Create the server instance first + policy_server = PolicyServer(cfg) + + # Setup and start gRPC server + server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) + async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) + server.add_insecure_port(f"{cfg.host}:{cfg.port}") + + policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}") + server.start() + + server.wait_for_termination() + + policy_server.logger.info("Server terminated") + + +if __name__ == "__main__": + serve() diff --git a/src/lerobot/scripts/server/robot_client.py b/src/lerobot/scripts/server/robot_client.py new file mode 100644 index 000000000..a6d7b7242 --- /dev/null +++ b/src/lerobot/scripts/server/robot_client.py @@ -0,0 +1,509 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example command: +```shell +python src/lerobot/scripts/server/robot_client.py \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ + --robot.id=black \ + --task="dummy" \ + --server_address=127.0.0.1:8080 \ + --policy_type=act \ + --pretrained_name_or_path=user/model \ + --policy_device=mps \ + --actions_per_chunk=50 \ + --chunk_size_threshold=0.5 \ + --aggregate_fn_name=weighted_average \ + --debug_visualize_queue_size=True +``` +""" + +import logging +import pickle # nosec +import threading +import time +from dataclasses import asdict +from pprint import pformat +from queue import Queue +from typing import Any, Callable, Optional + +import draccus +import grpc +import torch + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.configs.policies import PreTrainedConfig +from lerobot.robots import ( # noqa: F401 + Robot, + RobotConfig, + koch_follower, + make_robot_from_config, + so100_follower, + so101_follower, +) +from lerobot.scripts.server.configs import RobotClientConfig +from lerobot.scripts.server.constants import SUPPORTED_ROBOTS +from lerobot.scripts.server.helpers import ( + Action, + FPSTracker, + Observation, + RawObservation, + RemotePolicyConfig, + TimedAction, + TimedObservation, + get_logger, + map_robot_keys_to_lerobot_features, + send_bytes_in_chunks, + validate_robot_cameras_for_policy, + visualize_action_queue_size, +) +from lerobot.transport import ( + async_inference_pb2, # type: ignore + async_inference_pb2_grpc, # type: ignore +) + + +class RobotClient: + prefix = "robot_client" + logger = get_logger(prefix) + + def __init__(self, config: RobotClientConfig): + """Initialize RobotClient with unified configuration. + + Args: + config: RobotClientConfig containing all configuration parameters + """ + # Store configuration + self.config = config + self.robot = make_robot_from_config(config.robot) + self.robot.connect() + + lerobot_features = map_robot_keys_to_lerobot_features(self.robot) + + if config.verify_robot_cameras: + # Load policy config for validation + policy_config = PreTrainedConfig.from_pretrained(config.pretrained_name_or_path) + policy_image_features = policy_config.image_features + + # The cameras specified for inference must match the one supported by the policy chosen + validate_robot_cameras_for_policy(lerobot_features, policy_image_features) + + # Use environment variable if server_address is not provided in config + self.server_address = config.server_address + + self.policy_config = RemotePolicyConfig( + config.policy_type, + config.pretrained_name_or_path, + lerobot_features, + config.actions_per_chunk, + config.policy_device, + ) + self.channel = grpc.insecure_channel(self.server_address) + self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel) + self.logger.info(f"Initializing client to connect to server at {self.server_address}") + + self._running_event = threading.Event() + + # Initialize client side variables + self.latest_action_lock = threading.Lock() + self.latest_action = -1 + self.action_chunk_size = -1 + + self._chunk_size_threshold = config.chunk_size_threshold + + self.action_queue = Queue() + self.action_queue_lock = threading.Lock() # Protect queue operations + self.action_queue_size = [] + self.start_barrier = threading.Barrier(2) # 2 threads: action receiver, control loop + + # FPS measurement + self.fps_tracker = FPSTracker(target_fps=self.config.fps) + + self.logger.info("Robot connected and ready") + + # Use an event for thread-safe coordination + self.must_go = threading.Event() + self.must_go.set() # Initially set - observations qualify for direct processing + + @property + def running(self): + return self._running_event.is_set() + + def start(self): + """Start the robot client and connect to the policy server""" + try: + # client-server handshake + start_time = time.perf_counter() + self.stub.Ready(async_inference_pb2.Empty()) + end_time = time.perf_counter() + self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s") + + # send policy instructions + policy_config_bytes = pickle.dumps(self.policy_config) + policy_setup = async_inference_pb2.PolicySetup(data=policy_config_bytes) + + self.logger.info("Sending policy instructions to policy server") + self.logger.debug( + f"Policy type: {self.policy_config.policy_type} | " + f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | " + f"Device: {self.policy_config.device}" + ) + + self.stub.SendPolicyInstructions(policy_setup) + + self._running_event.set() + + return True + + except grpc.RpcError as e: + self.logger.error(f"Failed to connect to policy server: {e}") + return False + + def stop(self): + """Stop the robot client""" + self._running_event.clear() + + self.robot.disconnect() + self.logger.debug("Robot disconnected") + + self.channel.close() + self.logger.debug("Client stopped, channel closed") + + def send_observation( + self, + obs: TimedObservation, + ) -> bool: + """Send observation to the policy server. + Returns True if the observation was sent successfully, False otherwise.""" + if not self.running: + raise RuntimeError("Client not running. Run RobotClient.start() before sending observations.") + + if not isinstance(obs, TimedObservation): + raise ValueError("Input observation needs to be a TimedObservation!") + + start_time = time.perf_counter() + observation_bytes = pickle.dumps(obs) + serialize_time = time.perf_counter() - start_time + self.logger.debug(f"Observation serialization time: {serialize_time:.6f}s") + + try: + observation_iterator = send_bytes_in_chunks( + observation_bytes, + async_inference_pb2.Observation, + log_prefix="[CLIENT] Observation", + silent=True, + ) + _ = self.stub.SendObservations(observation_iterator) + obs_timestep = obs.get_timestep() + self.logger.info(f"Sent observation #{obs_timestep} | ") + + return True + + except grpc.RpcError as e: + self.logger.error(f"Error sending observation #{obs.get_timestep()}: {e}") + return False + + def _inspect_action_queue(self): + with self.action_queue_lock: + queue_size = self.action_queue.qsize() + timestamps = sorted([action.get_timestep() for action in self.action_queue.queue]) + self.logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}") + return queue_size, timestamps + + def _aggregate_action_queues( + self, + incoming_actions: list[TimedAction], + aggregate_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ): + """Finds the same timestep actions in the queue and aggregates them using the aggregate_fn""" + if aggregate_fn is None: + # default aggregate function: take the latest action + def aggregate_fn(x1, x2): + return x2 + + future_action_queue = Queue() + with self.action_queue_lock: + internal_queue = self.action_queue.queue + + current_action_queue = {action.get_timestep(): action.get_action() for action in internal_queue} + + for new_action in incoming_actions: + with self.latest_action_lock: + latest_action = self.latest_action + + # New action is older than the latest action in the queue, skip it + if new_action.get_timestep() <= latest_action: + continue + + # If the new action's timestep is not in the current action queue, add it directly + elif new_action.get_timestep() not in current_action_queue: + future_action_queue.put(new_action) + continue + + # If the new action's timestep is in the current action queue, aggregate it + # TODO: There is probably a way to do this with broadcasting of the two action tensors + future_action_queue.put( + TimedAction( + timestamp=new_action.get_timestamp(), + timestep=new_action.get_timestep(), + action=aggregate_fn( + current_action_queue[new_action.get_timestep()], new_action.get_action() + ), + ) + ) + + with self.action_queue_lock: + self.action_queue = future_action_queue + + def receive_actions(self, verbose: bool = False): + """Receive actions from the policy server""" + # Wait at barrier for synchronized start + self.start_barrier.wait() + self.logger.info("Action receiving thread starting") + + while self.running: + try: + # Use StreamActions to get a stream of actions from the server + actions_chunk = self.stub.GetActions(async_inference_pb2.Empty()) + if len(actions_chunk.data) == 0: + continue # received `Empty` from server, wait for next call + + receive_time = time.time() + + # Deserialize bytes back into list[TimedAction] + deserialize_start = time.perf_counter() + timed_actions = pickle.loads(actions_chunk.data) # nosec + deserialize_time = time.perf_counter() - deserialize_start + + self.action_chunk_size = max(self.action_chunk_size, len(timed_actions)) + + # Calculate network latency if we have matching observations + if len(timed_actions) > 0 and verbose: + with self.latest_action_lock: + latest_action = self.latest_action + + self.logger.debug(f"Current latest action: {latest_action}") + + # Get queue state before changes + old_size, old_timesteps = self._inspect_action_queue() + if not old_timesteps: + old_timesteps = [latest_action] # queue was empty + + # Get queue state before changes + old_size, old_timesteps = self._inspect_action_queue() + if not old_timesteps: + old_timesteps = [latest_action] # queue was empty + + # Log incoming actions + incoming_timesteps = [a.get_timestep() for a in timed_actions] + + first_action_timestep = timed_actions[0].get_timestep() + server_to_client_latency = (receive_time - timed_actions[0].get_timestamp()) * 1000 + + self.logger.info( + f"Received action chunk for step #{first_action_timestep} | " + f"Latest action: #{latest_action} | " + f"Incoming actions: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | " + f"Network latency (server->client): {server_to_client_latency:.2f}ms | " + f"Deserialization time: {deserialize_time * 1000:.2f}ms" + ) + + # Update action queue + start_time = time.perf_counter() + self._aggregate_action_queues(timed_actions, self.config.aggregate_fn) + queue_update_time = time.perf_counter() - start_time + + self.must_go.set() # after receiving actions, next empty queue triggers must-go processing! + + if verbose: + # Get queue state after changes + new_size, new_timesteps = self._inspect_action_queue() + + with self.latest_action_lock: + latest_action = self.latest_action + + self.logger.info( + f"Latest action: {latest_action} | " + f"Old action steps: {old_timesteps[0]}:{old_timesteps[-1]} | " + f"Incoming action steps: {incoming_timesteps[0]}:{incoming_timesteps[-1]} | " + f"Updated action steps: {new_timesteps[0]}:{new_timesteps[-1]}" + ) + self.logger.debug( + f"Queue update complete ({queue_update_time:.6f}s) | " + f"Before: {old_size} items | " + f"After: {new_size} items | " + ) + + except grpc.RpcError as e: + self.logger.error(f"Error receiving actions: {e}") + + def actions_available(self): + """Check if there are actions available in the queue""" + with self.action_queue_lock: + return not self.action_queue.empty() + + def _action_tensor_to_action_dict(self, action_tensor: torch.Tensor) -> dict[str, float]: + action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)} + return action + + def control_loop_action(self, verbose: bool = False) -> dict[str, Any]: + """Reading and performing actions in local queue""" + + # Lock only for queue operations + get_start = time.perf_counter() + with self.action_queue_lock: + self.action_queue_size.append(self.action_queue.qsize()) + # Get action from queue + timed_action = self.action_queue.get_nowait() + get_end = time.perf_counter() - get_start + + _performed_action = self.robot.send_action( + self._action_tensor_to_action_dict(timed_action.get_action()) + ) + with self.latest_action_lock: + self.latest_action = timed_action.get_timestep() + + if verbose: + with self.action_queue_lock: + current_queue_size = self.action_queue.qsize() + + self.logger.debug( + f"Ts={timed_action.get_timestamp()} | " + f"Action #{timed_action.get_timestep()} performed | " + f"Queue size: {current_queue_size}" + ) + + self.logger.debug( + f"Popping action from queue to perform took {get_end:.6f}s | Queue size: {current_queue_size}" + ) + + return _performed_action + + def _ready_to_send_observation(self): + """Flags when the client is ready to send an observation""" + with self.action_queue_lock: + return self.action_queue.qsize() / self.action_chunk_size <= self._chunk_size_threshold + + def control_loop_observation(self, task: str, verbose: bool = False) -> RawObservation: + try: + # Get serialized observation bytes from the function + start_time = time.perf_counter() + + raw_observation: RawObservation = self.robot.get_observation() + raw_observation["task"] = task + + with self.latest_action_lock: + latest_action = self.latest_action + + observation = TimedObservation( + timestamp=time.time(), # need time.time() to compare timestamps across client and server + observation=raw_observation, + timestep=max(latest_action, 0), + ) + + obs_capture_time = time.perf_counter() - start_time + + # If there are no actions left in the queue, the observation must go through processing! + with self.action_queue_lock: + observation.must_go = self.must_go.is_set() and self.action_queue.empty() + current_queue_size = self.action_queue.qsize() + + _ = self.send_observation(observation) + + self.logger.debug(f"QUEUE SIZE: {current_queue_size} (Must go: {observation.must_go})") + if observation.must_go: + # must-go event will be set again after receiving actions + self.must_go.clear() + + if verbose: + # Calculate comprehensive FPS metrics + fps_metrics = self.fps_tracker.calculate_fps_metrics(observation.get_timestamp()) + + self.logger.info( + f"Obs #{observation.get_timestep()} | " + f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " + f"Target: {fps_metrics['target_fps']:.2f}" + ) + + self.logger.debug( + f"Ts={observation.get_timestamp():.6f} | Capturing observation took {obs_capture_time:.6f}s" + ) + + return raw_observation + + except Exception as e: + self.logger.error(f"Error in observation sender: {e}") + + def control_loop(self, task: str, verbose: bool = False) -> tuple[Observation, Action]: + """Combined function for executing actions and streaming observations""" + # Wait at barrier for synchronized start + self.start_barrier.wait() + self.logger.info("Control loop thread starting") + + _performed_action = None + _captured_observation = None + + while self.running: + control_loop_start = time.perf_counter() + """Control loop: (1) Performing actions, when available""" + if self.actions_available(): + _performed_action = self.control_loop_action(verbose) + + """Control loop: (2) Streaming observations to the remote policy server""" + if self._ready_to_send_observation(): + _captured_observation = self.control_loop_observation(task, verbose) + + self.logger.info(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}") + # Dynamically adjust sleep time to maintain the desired control frequency + time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start))) + + return _captured_observation, _performed_action + + +@draccus.wrap() +def async_client(cfg: RobotClientConfig): + logging.info(pformat(asdict(cfg))) + + if cfg.robot.type not in SUPPORTED_ROBOTS: + raise ValueError(f"Robot {cfg.robot.type} not yet supported!") + + client = RobotClient(cfg) + + if client.start(): + client.logger.info("Starting action receiver thread...") + + # Create and start action receiver thread + action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True) + + # Start action receiver thread + action_receiver_thread.start() + + try: + # The main thread runs the control loop + client.control_loop(task=cfg.task) + + finally: + client.stop() + action_receiver_thread.join() + if cfg.debug_visualize_queue_size: + visualize_action_queue_size(client.action_queue_size) + client.logger.info("Client stopped") + + +if __name__ == "__main__": + async_client() # run the client diff --git a/src/lerobot/transport/async_inference.proto b/src/lerobot/transport/async_inference.proto new file mode 100644 index 000000000..434f3142b --- /dev/null +++ b/src/lerobot/transport/async_inference.proto @@ -0,0 +1,59 @@ +// fmt: off +// flake8: noqa +// !/usr/bin/env python + +// Copyright 2024 The HuggingFace Inc. team. +// All rights reserved. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +syntax = "proto3"; + +package async_inference; + +// AsyncInference: from Robot perspective +// Robot send observations to & executes action received from a remote Policy server +service AsyncInference { + // Robot -> Policy to share observations with a remote inference server + // Policy -> Robot to share actions predicted for given observations + rpc SendObservations(stream Observation) returns (Empty); + rpc GetActions(Empty) returns (Actions); + rpc SendPolicyInstructions(PolicySetup) returns (Empty); + rpc Ready(Empty) returns (Empty); + rpc Stop(Empty) returns (Empty); +} + +enum TransferState { + TRANSFER_UNKNOWN = 0; + TRANSFER_BEGIN = 1; + TRANSFER_MIDDLE = 2; + TRANSFER_END = 3; +} + +// Messages +message Observation { + // sent by Robot, to remote Policy + TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size + bytes data = 2; +} + +message Actions { + // sent by remote Policy, to Robot + bytes data = 1; +} + +message PolicySetup { + // sent by Robot to remote server, to init Policy + bytes data = 1; +} + +message Empty {} diff --git a/src/lerobot/transport/async_inference_pb2.py b/src/lerobot/transport/async_inference_pb2.py new file mode 100644 index 000000000..59c8eb488 --- /dev/null +++ b/src/lerobot/transport/async_inference_pb2.py @@ -0,0 +1,45 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: async_inference.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'async_inference.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdd\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12>\n\nGetActions\x12\x16.async_inference.Empty\x1a\x18.async_inference.Actions\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Empty\x12\x36\n\x04Stop\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_TRANSFERSTATE']._serialized_start=190 + _globals['_TRANSFERSTATE']._serialized_end=286 + _globals['_OBSERVATION']._serialized_start=42 + _globals['_OBSERVATION']._serialized_end=125 + _globals['_ACTIONS']._serialized_start=127 + _globals['_ACTIONS']._serialized_end=150 + _globals['_POLICYSETUP']._serialized_start=152 + _globals['_POLICYSETUP']._serialized_end=179 + _globals['_EMPTY']._serialized_start=181 + _globals['_EMPTY']._serialized_end=188 + _globals['_ASYNCINFERENCE']._serialized_start=289 + _globals['_ASYNCINFERENCE']._serialized_end=638 +# @@protoc_insertion_point(module_scope) diff --git a/src/lerobot/transport/async_inference_pb2_grpc.py b/src/lerobot/transport/async_inference_pb2_grpc.py new file mode 100644 index 000000000..3042db0db --- /dev/null +++ b/src/lerobot/transport/async_inference_pb2_grpc.py @@ -0,0 +1,277 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +from lerobot.transport import async_inference_pb2 as async__inference__pb2 + +GRPC_GENERATED_VERSION = '1.71.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in async_inference_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class AsyncInferenceStub: + """AsyncInference: from Robot perspective + Robot send observations to & executes action received from a remote Policy server + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SendObservations = channel.stream_unary( + '/async_inference.AsyncInference/SendObservations', + request_serializer=async__inference__pb2.Observation.SerializeToString, + response_deserializer=async__inference__pb2.Empty.FromString, + _registered_method=True) + self.GetActions = channel.unary_unary( + '/async_inference.AsyncInference/GetActions', + request_serializer=async__inference__pb2.Empty.SerializeToString, + response_deserializer=async__inference__pb2.Actions.FromString, + _registered_method=True) + self.SendPolicyInstructions = channel.unary_unary( + '/async_inference.AsyncInference/SendPolicyInstructions', + request_serializer=async__inference__pb2.PolicySetup.SerializeToString, + response_deserializer=async__inference__pb2.Empty.FromString, + _registered_method=True) + self.Ready = channel.unary_unary( + '/async_inference.AsyncInference/Ready', + request_serializer=async__inference__pb2.Empty.SerializeToString, + response_deserializer=async__inference__pb2.Empty.FromString, + _registered_method=True) + self.Stop = channel.unary_unary( + '/async_inference.AsyncInference/Stop', + request_serializer=async__inference__pb2.Empty.SerializeToString, + response_deserializer=async__inference__pb2.Empty.FromString, + _registered_method=True) + + +class AsyncInferenceServicer: + """AsyncInference: from Robot perspective + Robot send observations to & executes action received from a remote Policy server + """ + + def SendObservations(self, request_iterator, context): + """Robot -> Policy to share observations with a remote inference server + Policy -> Robot to share actions predicted for given observations + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetActions(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendPolicyInstructions(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Ready(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Stop(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_AsyncInferenceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'SendObservations': grpc.stream_unary_rpc_method_handler( + servicer.SendObservations, + request_deserializer=async__inference__pb2.Observation.FromString, + response_serializer=async__inference__pb2.Empty.SerializeToString, + ), + 'GetActions': grpc.unary_unary_rpc_method_handler( + servicer.GetActions, + request_deserializer=async__inference__pb2.Empty.FromString, + response_serializer=async__inference__pb2.Actions.SerializeToString, + ), + 'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler( + servicer.SendPolicyInstructions, + request_deserializer=async__inference__pb2.PolicySetup.FromString, + response_serializer=async__inference__pb2.Empty.SerializeToString, + ), + 'Ready': grpc.unary_unary_rpc_method_handler( + servicer.Ready, + request_deserializer=async__inference__pb2.Empty.FromString, + response_serializer=async__inference__pb2.Empty.SerializeToString, + ), + 'Stop': grpc.unary_unary_rpc_method_handler( + servicer.Stop, + request_deserializer=async__inference__pb2.Empty.FromString, + response_serializer=async__inference__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'async_inference.AsyncInference', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class AsyncInference: + """AsyncInference: from Robot perspective + Robot send observations to & executes action received from a remote Policy server + """ + + @staticmethod + def SendObservations(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/async_inference.AsyncInference/SendObservations', + async__inference__pb2.Observation.SerializeToString, + async__inference__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetActions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/async_inference.AsyncInference/GetActions', + async__inference__pb2.Empty.SerializeToString, + async__inference__pb2.Actions.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendPolicyInstructions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/async_inference.AsyncInference/SendPolicyInstructions', + async__inference__pb2.PolicySetup.SerializeToString, + async__inference__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Ready(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/async_inference.AsyncInference/Ready', + async__inference__pb2.Empty.SerializeToString, + async__inference__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Stop(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/async_inference.AsyncInference/Stop', + async__inference__pb2.Empty.SerializeToString, + async__inference__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/tests/async_inference/test_e2e.py b/tests/async_inference/test_e2e.py new file mode 100644 index 000000000..d7b68e66b --- /dev/null +++ b/tests/async_inference/test_e2e.py @@ -0,0 +1,177 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""End-to-end test of the asynchronous inference stack (client ↔ server). + +This test spins up a lightweight gRPC `PolicyServer` instance with a stubbed +policy network and launches a `RobotClient` that uses a `MockRobot`. The goal +is to exercise the full communication loop: + +1. Client sends policy specification → Server +2. Client streams observations → Server +3. Server streams action chunks → Client +4. Client executes received actions + +The test succeeds if at least one action is executed and the server records at +least one predicted timestep - demonstrating that the gRPC round-trip works +end-to-end using real (but lightweight) protocol messages. +""" + +from __future__ import annotations + +import threading +from concurrent import futures + +import pytest +import torch + +# Skip entire module if grpc is not available +pytest.importorskip("grpc") + +# ----------------------------------------------------------------------------- +# End-to-end test +# ----------------------------------------------------------------------------- + + +def test_async_inference_e2e(monkeypatch): + """Tests the full asynchronous inference pipeline.""" + # Import grpc-dependent modules inside the test function + import grpc + + from lerobot.robots.utils import make_robot_from_config + from lerobot.scripts.server.configs import PolicyServerConfig, RobotClientConfig + from lerobot.scripts.server.helpers import map_robot_keys_to_lerobot_features + from lerobot.scripts.server.policy_server import PolicyServer + from lerobot.scripts.server.robot_client import RobotClient + from lerobot.transport import ( + async_inference_pb2, # type: ignore + async_inference_pb2_grpc, # type: ignore + ) + from tests.mocks.mock_robot import MockRobotConfig + + # Create a stub policy similar to test_policy_server.py + class MockPolicy: + """A minimal mock for an actual policy, returning zeros.""" + + class _Config: + robot_type = "dummy_robot" + + @property + def image_features(self): + """Empty image features since this test doesn't use images.""" + return {} + + def __init__(self): + self.config = self._Config() + + def to(self, *args, **kwargs): + return self + + def model(self, batch): + # Return a chunk of 20 dummy actions. + batch_size = len(batch["robot_type"]) + return torch.zeros(batch_size, 20, 6) + + # ------------------------------------------------------------------ + # 1. Create PolicyServer instance with mock policy + # ------------------------------------------------------------------ + policy_server_config = PolicyServerConfig(host="localhost", port=9999) + policy_server = PolicyServer(policy_server_config) + # Replace the real policy with our fast, deterministic stub. + policy_server.policy = MockPolicy() + policy_server.actions_per_chunk = 20 + policy_server.device = "cpu" + + # Set up robot config and features + robot_config = MockRobotConfig() + mock_robot = make_robot_from_config(robot_config) + + lerobot_features = map_robot_keys_to_lerobot_features(mock_robot) + policy_server.lerobot_features = lerobot_features + + # Force server to produce deterministic action chunks in test mode + policy_server.policy_type = "act" + + def _fake_get_action_chunk(_self, _obs, _type="test"): + action_dim = 6 + batch_size = 1 + actions_per_chunk = policy_server.actions_per_chunk + + return torch.zeros(batch_size, actions_per_chunk, action_dim) + + monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True) + + # Bypass potentially heavy model loading inside SendPolicyInstructions + def _fake_send_policy_instructions(self, request, context): # noqa: N802 + return async_inference_pb2.Empty() + + monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True) + + # Build gRPC server running a PolicyServer + server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server")) + async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) + + # Use the host/port specified in the fixture's config + server_address = f"{policy_server.config.host}:{policy_server.config.port}" + server.add_insecure_port(server_address) + server.start() + + # ------------------------------------------------------------------ + # 2. Create a RobotClient around the MockRobot + # ------------------------------------------------------------------ + client_config = RobotClientConfig( + server_address=server_address, + robot=robot_config, + chunk_size_threshold=0.0, + policy_type="test", + pretrained_name_or_path="test", + actions_per_chunk=20, + verify_robot_cameras=False, + ) + + client = RobotClient(client_config) + assert client.start(), "Client failed initial handshake with the server" + + # Track action chunks received without modifying RobotClient + action_chunks_received = {"count": 0} + original_aggregate = client._aggregate_action_queues + + def counting_aggregate(*args, **kwargs): + action_chunks_received["count"] += 1 + return original_aggregate(*args, **kwargs) + + monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate) + + # Start client threads + action_thread = threading.Thread(target=client.receive_actions, daemon=True) + control_thread = threading.Thread(target=client.control_loop, args=({"task": ""}), daemon=True) + action_thread.start() + control_thread.start() + + # ------------------------------------------------------------------ + # 3. System exchanges a few messages + # ------------------------------------------------------------------ + # Wait for 5 seconds + server.wait_for_termination(timeout=5) + + assert action_chunks_received["count"] > 0, "Client did not receive any action chunks" + assert len(policy_server._predicted_timesteps) > 0, "Server did not record any predicted timesteps" + + # ------------------------------------------------------------------ + # 4. Stop the system + # ------------------------------------------------------------------ + client.stop() + action_thread.join() + control_thread.join() + policy_server.stop() + server.stop(grace=None) diff --git a/tests/async_inference/test_helpers.py b/tests/async_inference/test_helpers.py new file mode 100644 index 000000000..e0b797371 --- /dev/null +++ b/tests/async_inference/test_helpers.py @@ -0,0 +1,459 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import pickle +import time + +import numpy as np +import torch + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.scripts.server.helpers import ( + FPSTracker, + TimedAction, + TimedObservation, + observations_similar, + prepare_image, + prepare_raw_observation, + raw_observation_to_observation, + resize_robot_observation_image, +) + +# --------------------------------------------------------------------- +# FPSTracker +# --------------------------------------------------------------------- + + +def test_fps_tracker_first_observation(): + """First observation should initialize timestamp and return 0 FPS.""" + tracker = FPSTracker(target_fps=30.0) + timestamp = 1000.0 + + metrics = tracker.calculate_fps_metrics(timestamp) + + assert tracker.first_timestamp == timestamp + assert tracker.total_obs_count == 1 + assert metrics["avg_fps"] == 0.0 + assert metrics["target_fps"] == 30.0 + + +def test_fps_tracker_single_interval(): + """Two observations 1 second apart should give 1 FPS.""" + tracker = FPSTracker(target_fps=30.0) + + # First observation at t=0 + metrics1 = tracker.calculate_fps_metrics(0.0) + assert metrics1["avg_fps"] == 0.0 + + # Second observation at t=1 (1 second later) + metrics2 = tracker.calculate_fps_metrics(1.0) + expected_fps = 1.0 # (2-1) observations / 1.0 seconds = 1 FPS + assert math.isclose(metrics2["avg_fps"], expected_fps, rel_tol=1e-6) + + +def test_fps_tracker_multiple_intervals(): + """Multiple observations should calculate correct average FPS.""" + tracker = FPSTracker(target_fps=30.0) + + # Simulate 5 observations over 2 seconds (should be 2 FPS average) + timestamps = [0.0, 0.5, 1.0, 1.5, 2.0] + + for i, ts in enumerate(timestamps): + metrics = tracker.calculate_fps_metrics(ts) + + if i == 0: + assert metrics["avg_fps"] == 0.0 + elif i == len(timestamps) - 1: + # After 5 observations over 2 seconds: (5-1)/2 = 2 FPS + expected_fps = 2.0 + assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6) + + +def test_fps_tracker_irregular_intervals(): + """FPS calculation should work with irregular time intervals.""" + tracker = FPSTracker(target_fps=30.0) + + # Irregular timestamps: 0, 0.1, 0.5, 2.0, 3.0 seconds + timestamps = [0.0, 0.1, 0.5, 2.0, 3.0] + + for ts in timestamps: + metrics = tracker.calculate_fps_metrics(ts) + + # 5 observations over 3 seconds: (5-1)/3 = 1.333... FPS + expected_fps = 4.0 / 3.0 + assert math.isclose(metrics["avg_fps"], expected_fps, rel_tol=1e-6) + + +# --------------------------------------------------------------------- +# TimedData helpers +# --------------------------------------------------------------------- + + +def test_timed_action_getters(): + """TimedAction stores & returns timestamp, action tensor and timestep.""" + ts = time.time() + action = torch.arange(10) + ta = TimedAction(timestamp=ts, action=action, timestep=0) + + assert math.isclose(ta.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6) + torch.testing.assert_close(ta.get_action(), action) + assert ta.get_timestep() == 0 + + +def test_timed_observation_getters(): + """TimedObservation stores & returns timestamp, dict and timestep.""" + ts = time.time() + obs_dict = {"observation.state": torch.ones(6)} + to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0) + + assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6) + assert to.get_observation() is obs_dict + assert to.get_timestep() == 0 + + +def test_timed_data_deserialization_data_getters(): + """TimedAction / TimedObservation survive a round-trip through ``pickle``. + + The async-inference stack uses ``pickle.dumps`` to move these objects across + the gRPC boundary (see RobotClient.send_observation and PolicyServer.StreamActions). + This test ensures that the payload keeps its content intact after + the (de)serialization round-trip. + """ + ts = time.time() + + # ------------------------------------------------------------------ + # TimedAction + # ------------------------------------------------------------------ + original_action = torch.randn(6) + ta_in = TimedAction(timestamp=ts, action=original_action, timestep=13) + + # Serialize → bytes → deserialize + ta_bytes = pickle.dumps(ta_in) # nosec + ta_out: TimedAction = pickle.loads(ta_bytes) # nosec B301 + + # Identity & content checks + assert math.isclose(ta_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6) + assert ta_out.get_timestep() == 13 + torch.testing.assert_close(ta_out.get_action(), original_action) + + # ------------------------------------------------------------------ + # TimedObservation + # ------------------------------------------------------------------ + obs_dict = {"observation.state": torch.arange(4).float()} + to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True) + + to_bytes = pickle.dumps(to_in) # nosec + to_out: TimedObservation = pickle.loads(to_bytes) # nosec B301 + + assert math.isclose(to_out.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6) + assert to_out.get_timestep() == 7 + assert to_out.must_go is True + assert to_out.get_observation().keys() == obs_dict.keys() + torch.testing.assert_close(to_out.get_observation()["observation.state"], obs_dict["observation.state"]) + + +# --------------------------------------------------------------------- +# observations_similar() +# --------------------------------------------------------------------- + + +def _make_obs(state: torch.Tensor) -> TimedObservation: + """Create a TimedObservation with raw robot observation format.""" + return TimedObservation( + timestamp=time.time(), + observation={ + "shoulder": state[0].item() if len(state) > 0 else 0.0, + "elbow": state[1].item() if len(state) > 1 else 0.0, + "wrist": state[2].item() if len(state) > 2 else 0.0, + "gripper": state[3].item() if len(state) > 3 else 0.0, + }, + timestep=0, + ) + + +def test_observations_similar_true(): + """Distance below atol → observations considered similar.""" + # Create mock lerobot features for the similarity check + lerobot_features = { + "observation.state": { + "dtype": "float32", + "shape": [4], + "names": ["shoulder", "elbow", "wrist", "gripper"], + } + } + + obs1 = _make_obs(torch.zeros(4)) + obs2 = _make_obs(0.5 * torch.ones(4)) + assert observations_similar(obs1, obs2, lerobot_features, atol=2.0) + + obs3 = _make_obs(2.0 * torch.ones(4)) + assert not observations_similar(obs1, obs3, lerobot_features, atol=2.0) + + +# --------------------------------------------------------------------- +# raw_observation_to_observation and helpers +# --------------------------------------------------------------------- + + +def _create_mock_robot_observation(): + """Create a mock robot observation with motor positions and camera images.""" + return { + "shoulder": 1.0, + "elbow": 2.0, + "wrist": 3.0, + "gripper": 0.5, + "laptop": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8), + "phone": np.random.randint(0, 256, size=(480, 640, 3), dtype=np.uint8), + } + + +def _create_mock_lerobot_features(): + """Create mock lerobot features mapping similar to what hw_to_dataset_features returns.""" + return { + "observation.state": { + "dtype": "float32", + "shape": [4], + "names": ["shoulder", "elbow", "wrist", "gripper"], + }, + "observation.images.laptop": { + "dtype": "image", + "shape": [480, 640, 3], + "names": ["height", "width", "channels"], + }, + "observation.images.phone": { + "dtype": "image", + "shape": [480, 640, 3], + "names": ["height", "width", "channels"], + }, + } + + +def _create_mock_policy_image_features(): + """Create mock policy image features with different resolutions.""" + return { + "observation.images.laptop": PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 224, 224), # Policy expects smaller resolution + ), + "observation.images.phone": PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 160, 160), # Different resolution for second camera + ), + } + + +def test_prepare_image(): + """Test image preprocessing: int8 → float32, normalization to [0,1].""" + # Create mock int8 image data + image_int8 = torch.randint(0, 256, size=(3, 224, 224), dtype=torch.uint8) + + processed = prepare_image(image_int8) + + # Check dtype conversion + assert processed.dtype == torch.float32 + + # Check normalization range + assert processed.min() >= 0.0 + assert processed.max() <= 1.0 + + # Check that values are scaled correctly (255 → 1.0, 0 → 0.0) + if image_int8.max() == 255: + assert torch.isclose(processed.max(), torch.tensor(1.0), atol=1e-6) + if image_int8.min() == 0: + assert torch.isclose(processed.min(), torch.tensor(0.0), atol=1e-6) + + # Check memory contiguity + assert processed.is_contiguous() + + +def test_resize_robot_observation_image(): + """Test image resizing from robot resolution to policy resolution.""" + # Create mock image: (H=480, W=640, C=3) + original_image = torch.randint(0, 256, size=(480, 640, 3), dtype=torch.uint8) + target_shape = (3, 224, 224) # (C, H, W) + + resized = resize_robot_observation_image(original_image, target_shape) + + # Check output shape matches target + assert resized.shape == target_shape + + # Check that original image had different dimensions + assert original_image.shape != resized.shape + + # Check that resizing preserves value range + assert resized.min() >= 0 + assert resized.max() <= 255 + + +def test_prepare_raw_observation(): + """Test the preparation of raw robot observation to lerobot format.""" + robot_obs = _create_mock_robot_observation() + lerobot_features = _create_mock_lerobot_features() + policy_image_features = _create_mock_policy_image_features() + + prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features) + + # Check that state is properly extracted and batched + assert "observation.state" in prepared + state = prepared["observation.state"] + assert isinstance(state, torch.Tensor) + assert state.shape == (1, 4) # Batched state + + # Check that images are processed and resized + assert "observation.images.laptop" in prepared + assert "observation.images.phone" in prepared + + laptop_img = prepared["observation.images.laptop"] + phone_img = prepared["observation.images.phone"] + + # Check image shapes match policy requirements + assert laptop_img.shape == policy_image_features["observation.images.laptop"].shape + assert phone_img.shape == policy_image_features["observation.images.phone"].shape + + # Check that images are tensors + assert isinstance(laptop_img, torch.Tensor) + assert isinstance(phone_img, torch.Tensor) + + +def test_raw_observation_to_observation_basic(): + """Test the main raw_observation_to_observation function.""" + robot_obs = _create_mock_robot_observation() + lerobot_features = _create_mock_lerobot_features() + policy_image_features = _create_mock_policy_image_features() + device = "cpu" + + observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) + + # Check that all expected keys are present + assert "observation.state" in observation + assert "observation.images.laptop" in observation + assert "observation.images.phone" in observation + + # Check state processing + state = observation["observation.state"] + assert isinstance(state, torch.Tensor) + assert state.device.type == device + assert state.shape == (1, 4) # Batched + + # Check image processing + laptop_img = observation["observation.images.laptop"] + phone_img = observation["observation.images.phone"] + + # Images should have batch dimension: (B, C, H, W) + assert laptop_img.shape == (1, 3, 224, 224) + assert phone_img.shape == (1, 3, 160, 160) + + # Check device placement + assert laptop_img.device.type == device + assert phone_img.device.type == device + + # Check image dtype and range (should be float32 in [0, 1]) + assert laptop_img.dtype == torch.float32 + assert phone_img.dtype == torch.float32 + assert laptop_img.min() >= 0.0 and laptop_img.max() <= 1.0 + assert phone_img.min() >= 0.0 and phone_img.max() <= 1.0 + + +def test_raw_observation_to_observation_with_non_tensor_data(): + """Test that non-tensor data (like task strings) is preserved.""" + robot_obs = _create_mock_robot_observation() + robot_obs["task"] = "pick up the red cube" # Add string instruction + + lerobot_features = _create_mock_lerobot_features() + policy_image_features = _create_mock_policy_image_features() + device = "cpu" + + observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) + + # Check that task string is preserved + assert "task" in observation + assert observation["task"] == "pick up the red cube" + assert isinstance(observation["task"], str) + + +@torch.no_grad() +def test_raw_observation_to_observation_device_handling(): + """Test that tensors are properly moved to the specified device.""" + device = "mps" if torch.backends.mps.is_available() else "cpu" + + robot_obs = _create_mock_robot_observation() + lerobot_features = _create_mock_lerobot_features() + policy_image_features = _create_mock_policy_image_features() + + observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) + + # Check that all tensors are on the correct device + for key, value in observation.items(): + if isinstance(value, torch.Tensor): + assert value.device.type == device, f"Tensor {key} not on {device}" + + +def test_raw_observation_to_observation_deterministic(): + """Test that the function produces consistent results for the same input.""" + robot_obs = _create_mock_robot_observation() + lerobot_features = _create_mock_lerobot_features() + policy_image_features = _create_mock_policy_image_features() + device = "cpu" + + # Run twice with same input + obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) + obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) + + # Results should be identical + assert set(obs1.keys()) == set(obs2.keys()) + + for key in obs1: + if isinstance(obs1[key], torch.Tensor): + torch.testing.assert_close(obs1[key], obs2[key]) + else: + assert obs1[key] == obs2[key] + + +def test_image_processing_pipeline_preserves_content(): + """Test that the image processing pipeline preserves recognizable patterns.""" + # Create an image with a specific pattern + original_img = np.zeros((100, 100, 3), dtype=np.uint8) + original_img[25:75, 25:75, :] = 255 # White square in center + + robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img} + lerobot_features = { + "observation.state": { + "dtype": "float32", + "shape": [4], + "names": ["shoulder", "elbow", "wrist", "gripper"], + }, + "observation.images.laptop": { + "dtype": "image", + "shape": [100, 100, 3], + "names": ["height", "width", "channels"], + }, + } + policy_image_features = { + "observation.images.laptop": PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 50, 50), # Downsamples from 100x100 + ) + } + + observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu") + + processed_img = observation["observation.images.laptop"].squeeze(0) # Remove batch dim + + # Check that the center region has higher values than corners + # Due to bilinear interpolation, exact values will change but pattern should remain + center_val = processed_img[:, 25, 25].mean() # Center of 50x50 image + corner_val = processed_img[:, 5, 5].mean() # Corner + + assert center_val > corner_val, "Image processing should preserve recognizable patterns" diff --git a/tests/async_inference/test_policy_server.py b/tests/async_inference/test_policy_server.py new file mode 100644 index 000000000..5c795e7ec --- /dev/null +++ b/tests/async_inference/test_policy_server.py @@ -0,0 +1,215 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit-tests for the `PolicyServer` core logic. +Monkey-patch the `policy` attribute with a stub so that no real model inference is performed. +""" + +from __future__ import annotations + +import time + +import pytest +import torch + +from lerobot.configs.types import PolicyFeature +from tests.utils import require_package + +# ----------------------------------------------------------------------------- +# Test fixtures +# ----------------------------------------------------------------------------- + + +class MockPolicy: + """A minimal mock for an actual policy, returning zeros. + Refer to tests/policies for tests of the individual policies supported.""" + + class _Config: + robot_type = "dummy_robot" + + @property + def image_features(self) -> dict[str, PolicyFeature]: + """Empty image features since this test doesn't use images.""" + return {} + + def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor: + """Return a chunk of 20 dummy actions.""" + batch_size = len(observation["observation.state"]) + return torch.zeros(batch_size, 20, 6) + + def __init__(self): + self.config = self._Config() + + def to(self, *args, **kwargs): + # The server calls `policy.to(device)`. This stub ignores it. + return self + + def model(self, batch: dict) -> torch.Tensor: + # Return a chunk of 20 dummy actions. + batch_size = len(batch["robot_type"]) + return torch.zeros(batch_size, 20, 6) + + +@pytest.fixture +@require_package("grpc") +def policy_server(): + """Fresh `PolicyServer` instance with a stubbed-out policy model.""" + # Import only when the test actually runs (after decorator check) + from lerobot.scripts.server.configs import PolicyServerConfig + from lerobot.scripts.server.policy_server import PolicyServer + + test_config = PolicyServerConfig(host="localhost", port=9999) + server = PolicyServer(test_config) + # Replace the real policy with our fast, deterministic stub. + server.policy = MockPolicy() + server.actions_per_chunk = 20 + server.device = "cpu" + + # Add mock lerobot_features that the observation similarity functions need + server.lerobot_features = { + "observation.state": { + "dtype": "float32", + "shape": [6], + "names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"], + } + } + + return server + + +# ----------------------------------------------------------------------------- +# Helper utilities for tests +# ----------------------------------------------------------------------------- + + +def _make_obs(state: torch.Tensor, timestep: int = 0, must_go: bool = False): + """Create a TimedObservation with a given state vector.""" + # Import only when needed + from lerobot.scripts.server.helpers import TimedObservation + + return TimedObservation( + observation={ + "joint1": state[0].item() if len(state) > 0 else 0.0, + "joint2": state[1].item() if len(state) > 1 else 0.0, + "joint3": state[2].item() if len(state) > 2 else 0.0, + "joint4": state[3].item() if len(state) > 3 else 0.0, + "joint5": state[4].item() if len(state) > 4 else 0.0, + "joint6": state[5].item() if len(state) > 5 else 0.0, + }, + timestamp=time.time(), + timestep=timestep, + must_go=must_go, + ) + + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- + + +def test_time_action_chunk(policy_server): + """Verify that `_time_action_chunk` assigns correct timestamps and timesteps.""" + start_ts = time.time() + start_t = 10 + # A chunk of 3 action tensors. + action_tensors = [torch.randn(6) for _ in range(3)] + + timed_actions = policy_server._time_action_chunk(start_ts, action_tensors, start_t) + + assert len(timed_actions) == 3 + # Check timesteps + assert [ta.get_timestep() for ta in timed_actions] == [10, 11, 12] + # Check timestamps + expected_timestamps = [ + start_ts, + start_ts + policy_server.config.environment_dt, + start_ts + 2 * policy_server.config.environment_dt, + ] + for ta, expected_ts in zip(timed_actions, expected_timestamps, strict=True): + assert abs(ta.get_timestamp() - expected_ts) < 1e-6 + + +def test_maybe_enqueue_observation_must_go(policy_server): + """An observation with `must_go=True` is always enqueued.""" + obs = _make_obs(torch.zeros(6), must_go=True) + assert policy_server._enqueue_observation(obs) is True + assert policy_server.observation_queue.qsize() == 1 + assert policy_server.observation_queue.get_nowait() is obs + + +def test_maybe_enqueue_observation_dissimilar(policy_server): + """A dissimilar observation (not `must_go`) is enqueued.""" + # Set a last predicted observation. + policy_server.last_processed_obs = _make_obs(torch.zeros(6)) + # Create a new, dissimilar observation. + new_obs = _make_obs(torch.ones(6) * 5) # High norm difference + + assert policy_server._enqueue_observation(new_obs) is True + assert policy_server.observation_queue.qsize() == 1 + + +def test_maybe_enqueue_observation_is_skipped(policy_server): + """A similar observation (not `must_go`) is skipped.""" + # Set a last predicted observation. + policy_server.last_processed_obs = _make_obs(torch.zeros(6)) + # Create a new, very similar observation. + new_obs = _make_obs(torch.zeros(6) + 1e-4) + + assert policy_server._enqueue_observation(new_obs) is False + assert policy_server.observation_queue.empty() is True + + +def test_obs_sanity_checks(policy_server): + """Unit-test the private `_obs_sanity_checks` helper.""" + prev = _make_obs(torch.zeros(6), timestep=0) + + # Case 1 – timestep already predicted + policy_server._predicted_timesteps.add(1) + obs_same_ts = _make_obs(torch.ones(6), timestep=1) + assert policy_server._obs_sanity_checks(obs_same_ts, prev) is False + + # Case 2 – observation too similar + policy_server._predicted_timesteps.clear() + obs_similar = _make_obs(torch.zeros(6) + 1e-4, timestep=2) + assert policy_server._obs_sanity_checks(obs_similar, prev) is False + + # Case 3 – genuinely new & dissimilar observation passes + obs_ok = _make_obs(torch.ones(6) * 5, timestep=3) + assert policy_server._obs_sanity_checks(obs_ok, prev) is True + + +def test_predict_action_chunk(monkeypatch, policy_server): + """End-to-end test of `_predict_action_chunk` with a stubbed _get_action_chunk.""" + # Import only when needed + from lerobot.scripts.server.policy_server import PolicyServer + + # Force server to act-style policy; patch method to return deterministic tensor + policy_server.policy_type = "act" + action_dim = 6 + batch_size = 1 + actions_per_chunk = policy_server.actions_per_chunk + + def _fake_get_action_chunk(_self, _obs, _type="act"): + return torch.zeros(batch_size, actions_per_chunk, action_dim) + + monkeypatch.setattr(PolicyServer, "_get_action_chunk", _fake_get_action_chunk, raising=True) + + obs = _make_obs(torch.zeros(6), timestep=5) + timed_actions = policy_server._predict_action_chunk(obs) + + assert len(timed_actions) == actions_per_chunk + assert [ta.get_timestep() for ta in timed_actions] == list(range(5, 5 + actions_per_chunk)) + + for i, ta in enumerate(timed_actions): + expected_ts = obs.get_timestamp() + i * policy_server.config.environment_dt + assert abs(ta.get_timestamp() - expected_ts) < 1e-6 diff --git a/tests/async_inference/test_robot_client.py b/tests/async_inference/test_robot_client.py new file mode 100644 index 000000000..d1273ae63 --- /dev/null +++ b/tests/async_inference/test_robot_client.py @@ -0,0 +1,234 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit-tests for the `RobotClient` action-queue logic (pure Python, no gRPC). + +We monkey-patch `lerobot.common.robot_devices.robots.utils.make_robot` so that +no real hardware is accessed. Only the queue-update mechanism is verified. +""" + +from __future__ import annotations + +import time +from queue import Queue + +import pytest +import torch + +# Skip entire module if grpc is not available +pytest.importorskip("grpc") + +# ----------------------------------------------------------------------------- +# Test fixtures +# ----------------------------------------------------------------------------- + + +@pytest.fixture() +def robot_client(): + """Fresh `RobotClient` instance for each test case (no threads started). + Uses DummyRobot.""" + # Import only when the test actually runs (after decorator check) + from lerobot.scripts.server.configs import RobotClientConfig + from lerobot.scripts.server.robot_client import RobotClient + from tests.mocks.mock_robot import MockRobotConfig + + test_config = MockRobotConfig() + + # gRPC channel is not actually used in tests, so using a dummy address + test_config = RobotClientConfig( + robot=test_config, + server_address="localhost:9999", + policy_type="test", + pretrained_name_or_path="test", + actions_per_chunk=20, + verify_robot_cameras=False, + ) + + client = RobotClient(test_config) + + # Initialize attributes that are normally set in start() method + client.chunks_received = 0 + client.available_actions_size = [] + + yield client + + if client.robot.is_connected: + client.stop() + + +# ----------------------------------------------------------------------------- +# Helper utilities for tests +# ----------------------------------------------------------------------------- + + +def _make_actions(start_ts: float, start_t: int, count: int): + """Generate `count` consecutive TimedAction objects starting at timestep `start_t`.""" + from lerobot.scripts.server.helpers import TimedAction + + fps = 30 # emulates most common frame-rate + actions = [] + for i in range(count): + timestep = start_t + i + timestamp = start_ts + i * (1 / fps) + action_tensor = torch.full((6,), timestep, dtype=torch.float32) + actions.append(TimedAction(action=action_tensor, timestep=timestep, timestamp=timestamp)) + return actions + + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- + + +def test_update_action_queue_discards_stale(robot_client): + """`_update_action_queue` must drop actions with `timestep` <= `latest_action`.""" + + # Pretend we already executed up to action #4 + robot_client.latest_action = 4 + + # Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept. + incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7 + + robot_client._aggregate_action_queues(incoming) + + # Extract timesteps from queue + resulting_timesteps = [a.get_timestep() for a in robot_client.action_queue.queue] + + assert resulting_timesteps == [5, 6, 7] + + +@pytest.mark.parametrize( + "weight_old, weight_new", + [ + (1.0, 0.0), + (0.0, 1.0), + (0.5, 0.5), + (0.2, 0.8), + (0.8, 0.2), + (0.1, 0.9), + (0.9, 0.1), + ], +) +def test_aggregate_action_queues_combines_actions_in_overlap( + robot_client, weight_old: float, weight_new: float +): + """`_aggregate_action_queues` must combine actions on overlapping timesteps according + to the provided aggregate_fn, here tested with multiple coefficients.""" + from lerobot.scripts.server.helpers import TimedAction + + robot_client.chunks_received = 0 + + # Pretend we already executed up to action #4, and queue contains actions for timesteps 5..6 + robot_client.latest_action = 4 + current_actions = _make_actions( + start_ts=time.time(), start_t=5, count=2 + ) # actions are [torch.ones(6), torch.ones(6), ...] + current_actions = [ + TimedAction(action=10 * a.get_action(), timestep=a.get_timestep(), timestamp=a.get_timestamp()) + for a in current_actions + ] + + for a in current_actions: + robot_client.action_queue.put(a) + + # Incoming chunk contains timesteps 3..7 -> expect 5,6,7 kept. + incoming = _make_actions(start_ts=time.time(), start_t=3, count=5) # 3,4,5,6,7 + + overlap_timesteps = [5, 6] # properly tested in test_aggregate_action_queues_discards_stale + nonoverlap_timesteps = [7] + + robot_client._aggregate_action_queues( + incoming, aggregate_fn=lambda x1, x2: weight_old * x1 + weight_new * x2 + ) + + queue_overlap_actions = [] + queue_non_overlap_actions = [] + for a in robot_client.action_queue.queue: + if a.get_timestep() in overlap_timesteps: + queue_overlap_actions.append(a) + elif a.get_timestep() in nonoverlap_timesteps: + queue_non_overlap_actions.append(a) + + queue_overlap_actions = sorted(queue_overlap_actions, key=lambda x: x.get_timestep()) + queue_non_overlap_actions = sorted(queue_non_overlap_actions, key=lambda x: x.get_timestep()) + + assert torch.allclose( + queue_overlap_actions[0].get_action(), + weight_old * current_actions[0].get_action() + weight_new * incoming[-3].get_action(), + ) + assert torch.allclose( + queue_overlap_actions[1].get_action(), + weight_old * current_actions[1].get_action() + weight_new * incoming[-2].get_action(), + ) + assert torch.allclose(queue_non_overlap_actions[0].get_action(), incoming[-1].get_action()) + + +@pytest.mark.parametrize( + "chunk_size, queue_len, expected", + [ + (20, 12, False), # 12 / 20 = 0.6 > g=0.5 threshold, not ready to send + (20, 8, True), # 8 / 20 = 0.4 <= g=0.5, ready to send + (10, 5, True), + (10, 6, False), + ], +) +def test_ready_to_send_observation(robot_client, chunk_size: int, queue_len: int, expected: bool): + """Validate `_ready_to_send_observation` ratio logic for various sizes.""" + + robot_client.action_chunk_size = chunk_size + + # Clear any existing actions then fill with `queue_len` dummy entries ---- + robot_client.action_queue = Queue() + + dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len) + for act in dummy_actions: + robot_client.action_queue.put(act) + + assert robot_client._ready_to_send_observation() is expected + + +@pytest.mark.parametrize( + "g_threshold, expected", + [ + # The condition is `queue_size / chunk_size <= g`. + # Here, ratio = 6 / 10 = 0.6. + (0.0, False), # 0.6 <= 0.0 is False + (0.1, False), + (0.2, False), + (0.3, False), + (0.4, False), + (0.5, False), + (0.6, True), # 0.6 <= 0.6 is True + (0.7, True), + (0.8, True), + (0.9, True), + (1.0, True), + ], +) +def test_ready_to_send_observation_with_varying_threshold(robot_client, g_threshold: float, expected: bool): + """Validate `_ready_to_send_observation` with fixed sizes and varying `g`.""" + # Fixed sizes for this test: ratio = 6 / 10 = 0.6 + chunk_size = 10 + queue_len = 6 + + robot_client.action_chunk_size = chunk_size + # This is the parameter we are testing + robot_client._chunk_size_threshold = g_threshold + + # Fill queue with dummy actions + robot_client.action_queue = Queue() + dummy_actions = _make_actions(start_ts=time.time(), start_t=0, count=queue_len) + for act in dummy_actions: + robot_client.action_queue.put(act) + + assert robot_client._ready_to_send_observation() is expected From abe51eeba3ad7fde864147c187fcf71ad7ae469c Mon Sep 17 00:00:00 2001 From: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Date: Thu, 10 Jul 2025 12:24:40 +0200 Subject: [PATCH 006/158] Update async docs with blogpost (#1479) Co-authored-by: Michel Aractingi --- docs/source/async.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/async.mdx b/docs/source/async.mdx index 0a0823cf0..6ff05a88a 100644 --- a/docs/source/async.mdx +++ b/docs/source/async.mdx @@ -18,7 +18,7 @@ This is fundamentally different from synchronous inference (sync), where the rob --- ## Getting started with async inference -You can read more information on asynchronous inference in our [blogpost](NOTE:blogpost). Here, we report a getting started guide meant to help you setup and run asynchronous inference in your setup. +You can read more information on asynchronous inference in our [blogpost](https://huggingface.co/blog/async-robot-inference). This guide is designed to help you quickly set up and run asynchronous inference in your environment. First, install `lerobot` with the `async` tag, to install the extra dependencies required to run async inference. From d2645cb19fc521e5b117fe03d90a84f698d3d3f6 Mon Sep 17 00:00:00 2001 From: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Date: Thu, 10 Jul 2025 20:13:56 +0200 Subject: [PATCH 007/158] fix(docs): Record-Upload failed? Don't panic! (#1478) * fix: add instruction to manually upload dataset Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> * fix: repo type is explicited --------- Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Co-authored-by: Michel Aractingi --- docs/source/il_robots.mdx | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index cfa0a2809..2e8ac3619 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -282,6 +282,12 @@ Your dataset will be automatically tagged with `LeRobot` for the community to fi You can look for other LeRobot datasets on the hub by searching for `LeRobot` [tags](https://huggingface.co/datasets?other=LeRobot). +You can also push your local dataset to the Hub manually, running: +```bash +huggingface-cli upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset +``` + + #### Record function The `record` function provides a suite of tools for capturing and managing data during robot operation: From 519b76110efeea55a4f919895d0029dc0df41e8b Mon Sep 17 00:00:00 2001 From: Ben Zhang <5977478+ben-z@users.noreply.github.com> Date: Sun, 13 Jul 2025 12:58:05 -0700 Subject: [PATCH 008/158] Remove random noise injected by policy server (#1496) --- src/lerobot/scripts/server/policy_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/scripts/server/policy_server.py b/src/lerobot/scripts/server/policy_server.py index 669ccc58e..13ba976e2 100644 --- a/src/lerobot/scripts/server/policy_server.py +++ b/src/lerobot/scripts/server/policy_server.py @@ -323,7 +323,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): if chunk.ndim != 3: chunk = chunk.unsqueeze(0) # adding batch dimension, now shape is (B, chunk_size, action_dim) - return chunk[:, : self.actions_per_chunk, :] + torch.randn_like(chunk[:, : self.actions_per_chunk, :]) + return chunk[:, : self.actions_per_chunk, :] def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]: """Predict an action chunk based on an observation""" From 91b110d8063afdf7f5086aad0be8eea0ac939892 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Tue, 15 Jul 2025 10:28:19 +0200 Subject: [PATCH 009/158] fix(mps): gradient exploding and nan loss issues with ACT (#1490) Co-authored-by: Michel Aractingi --- src/lerobot/policies/act/modeling_act.py | 15 ++++++--------- src/lerobot/scripts/train.py | 4 ++-- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index f66c8ae82..aa81d3cd2 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -485,12 +485,10 @@ class ACT(nn.Module): self.encoder_env_state_input_proj(batch["observation.environment_state"]) ) - # Camera observation features and positional embeddings. if self.config.image_features: - all_cam_features = [] - all_cam_pos_embeds = [] - # For a list of images, the H and W may vary but H*W is constant. + # NOTE: If modifying this section, verify on MPS devices that + # gradients remain stable (no explosions or NaNs). for img in batch["observation.images"]: cam_features = self.backbone(img)["feature_map"] cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) @@ -500,11 +498,10 @@ class ACT(nn.Module): cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c") cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c") - all_cam_features.append(cam_features) - all_cam_pos_embeds.append(cam_pos_embed) - - encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0)) - encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0)) + # Extend immediately instead of accumulating and concatenating + # Convert to list to extend properly + encoder_in_tokens.extend(list(cam_features)) + encoder_in_pos_embed.extend(list(cam_pos_embed)) # Stack all tokens along the sequence dimension. encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0) diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 2f2e88de6..f09d231a8 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -180,7 +180,7 @@ def train(cfg: TrainPipelineConfig): batch_size=cfg.batch_size, shuffle=shuffle, sampler=sampler, - pin_memory=device.type != "cpu", + pin_memory=device.type == "cuda", drop_last=False, ) dl_iter = cycle(dataloader) @@ -207,7 +207,7 @@ def train(cfg: TrainPipelineConfig): for key in batch: if isinstance(batch[key], torch.Tensor): - batch[key] = batch[key].to(device, non_blocking=True) + batch[key] = batch[key].to(device, non_blocking=device.type == "cuda") train_tracker, output_dict = update_policy( train_tracker, From 724874e063ecfb892bbcbc4a5e16fde5860cb28c Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Tue, 15 Jul 2025 11:27:01 +0200 Subject: [PATCH 010/158] Fix tests (#1510) --- tests/motors/test_feetech.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index 7f9e5dd7b..c5a170dd9 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -219,7 +219,7 @@ def test__write(addr, length, id_, value, mock_motors, dummy_motors): comm, error = bus._write(addr, length, id_, value) - assert mock_motors.stubs[stub].called + assert mock_motors.stubs[stub].wait_called() assert comm == scs.COMM_SUCCESS assert error == 0 @@ -371,9 +371,9 @@ def test_reset_calibration(mock_motors, dummy_motors): bus.reset_calibration() - assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs) - assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs) - assert all(mock_motors.stubs[stub].called for stub in write_maxes_stubs) + assert all(mock_motors.stubs[stub].wait_called() for stub in write_homing_stubs) + assert all(mock_motors.stubs[stub].wait_called() for stub in write_mins_stubs) + assert all(mock_motors.stubs[stub].wait_called() for stub in write_maxes_stubs) def test_set_half_turn_homings(mock_motors, dummy_motors): @@ -410,7 +410,7 @@ def test_set_half_turn_homings(mock_motors, dummy_motors): bus.reset_calibration.assert_called_once() assert mock_motors.stubs[read_pos_stub].called - assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs) + assert all(mock_motors.stubs[stub].wait_called() for stub in write_homing_stubs) def test_record_ranges_of_motion(mock_motors, dummy_motors): From 1b878c9155d7ff9783929d55035c942dd8bb7933 Mon Sep 17 00:00:00 2001 From: aka <47563398+todateman@users.noreply.github.com> Date: Tue, 15 Jul 2025 18:33:02 +0900 Subject: [PATCH 011/158] fix(record): Improve OpenCV backend handling for Windows systems (#1495) * fix(record): Improve OpenCV backend handling for Windows systems * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolved ruff's E402 error (import statements not at the beginning of the file): - Moved all import statements to the beginning of the file - Defined _fix_opencv_backend() as a function - Adjusted the timing of the fix call - Code structure conforming to ruff * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(record): Correct OpenCV backend for Windows systems * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(opencv): Set OpenCV environment variable for Windows systems * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(opencv): Refactor MSMF hardware transform environment variable setting for Windows * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/lerobot/cameras/opencv/camera_opencv.py | 4 ++++ src/lerobot/cameras/utils.py | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index fd99922a4..1d7a1645d 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -18,12 +18,16 @@ Provides the OpenCVCamera class for capturing frames from cameras using OpenCV. import logging import math +import os import platform import time from pathlib import Path from threading import Event, Lock, Thread from typing import Any, Dict, List +# Fix MSMF hardware transform compatibility for Windows before importing cv2 +if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ: + os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0" import cv2 import numpy as np diff --git a/src/lerobot/cameras/utils.py b/src/lerobot/cameras/utils.py index f8bbd6e70..1eb69840b 100644 --- a/src/lerobot/cameras/utils.py +++ b/src/lerobot/cameras/utils.py @@ -60,6 +60,8 @@ def get_cv2_backend() -> int: import cv2 if platform.system() == "Windows": - return cv2.CAP_AVFOUNDATION - else: + return cv2.CAP_MSMF # Use MSMF for Windows instead of AVFOUNDATION + # elif platform.system() == "Darwin": # macOS + # return cv2.CAP_AVFOUNDATION + else: # Linux and others return cv2.CAP_ANY From c4c0105a474008e6bfd0bfd4e6b4c8fa94684e82 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Jul 2025 12:28:22 +0200 Subject: [PATCH 012/158] [pre-commit.ci] pre-commit autoupdate (#1327) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [pre-commit.ci] pre-commit autoupdate updates: - [github.com/adhtruong/mirrors-typos: v1.33.1 → v1.34.0](https://github.com/adhtruong/mirrors-typos/compare/v1.33.1...v1.34.0) - [github.com/astral-sh/ruff-pre-commit: v0.11.13 → v0.12.3](https://github.com/astral-sh/ruff-pre-commit/compare/v0.11.13...v0.12.3) - [github.com/woodruffw/zizmor-pre-commit: v1.9.0 → v1.11.0](https://github.com/woodruffw/zizmor-pre-commit/compare/v1.9.0...v1.11.0) - [github.com/PyCQA/bandit: 1.8.3 → 1.8.6](https://github.com/PyCQA/bandit/compare/1.8.3...1.8.6) * Ignore B615 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> --- .pre-commit-config.yaml | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e1f971d39..e25f33ee0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,7 +37,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/adhtruong/mirrors-typos - rev: v1.33.1 + rev: v1.34.0 hooks: - id: typos args: [--force-exclude] @@ -48,7 +48,7 @@ repos: - id: pyupgrade - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.13 + rev: v0.12.3 hooks: - id: ruff args: [--fix] @@ -62,12 +62,12 @@ repos: - id: gitleaks - repo: https://github.com/woodruffw/zizmor-pre-commit - rev: v1.9.0 + rev: v1.11.0 hooks: - id: zizmor - repo: https://github.com/PyCQA/bandit - rev: 1.8.3 + rev: 1.8.6 hooks: - id: bandit args: ["-c", "pyproject.toml"] diff --git a/pyproject.toml b/pyproject.toml index 81cb22a21..878a36dbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,7 +133,7 @@ exclude_dirs = [ "src/lerobot/policies/pi0/conversion_scripts", "src/lerobot/scripts/push_dataset_to_hub.py", ] -skips = ["B101", "B311", "B404", "B603"] +skips = ["B101", "B311", "B404", "B603", "B615"] [tool.typos] default.extend-ignore-re = [ From 1c0ac8e3415adff7411846c73c6e9dfb94941eb1 Mon Sep 17 00:00:00 2001 From: Ben Zhang <5977478+ben-z@users.noreply.github.com> Date: Tue, 15 Jul 2025 03:29:07 -0700 Subject: [PATCH 013/158] Parse draccus subclass overrides when using `--policy.path` (#1501) * Parse draccus subclass overrides when using --policy.path * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> --- src/lerobot/configs/policies.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 36e6ea2e5..05f3296b8 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import json import logging import os +import tempfile from dataclasses import dataclass, field from pathlib import Path from typing import Type, TypeVar @@ -183,8 +185,22 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}" ) from e - # HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus + # HACK: Parse the original config to get the config subclass, so that we can + # apply cli overrides. + # This is very ugly, ideally we'd like to be able to do that natively with draccus # something like --policy.path (in addition to --policy.type) - cli_overrides = policy_kwargs.pop("cli_overrides", []) with draccus.config_type("json"): - return draccus.parse(cls, config_file, args=cli_overrides) + orig_config = draccus.parse(cls, config_file, args=[]) + + with open(config_file) as f: + config = json.load(f) + + config.pop("type") + with tempfile.NamedTemporaryFile("w+") as f: + json.dump(config, f) + config_file = f.name + f.flush() + + cli_overrides = policy_kwargs.pop("cli_overrides", []) + with draccus.config_type("json"): + return draccus.parse(orig_config.__class__, config_file, args=cli_overrides) From 3efb4410f10c0b21f6e07456b35a9a3b07a1399d Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Wed, 16 Jul 2025 02:23:00 +0700 Subject: [PATCH 014/158] Fix logging for mps in auto_select_torch_device (#1513) --- src/lerobot/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 2e94a9c93..7a9717dce 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -48,7 +48,7 @@ def auto_select_torch_device() -> torch.device: logging.info("Cuda backend detected, using cuda.") return torch.device("cuda") elif torch.backends.mps.is_available(): - logging.info("Metal backend detected, using cuda.") + logging.info("Metal backend detected, using mps.") return torch.device("mps") else: logging.warning("No accelerated backend detected. Using default cpu, this will be slow.") From dfb1571bcf3a84d19ac84855942e72ac2fe04431 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Wed, 16 Jul 2025 11:31:25 +0200 Subject: [PATCH 015/158] Added missing licenses (#1517) * Added missing liscenses --- src/lerobot/motors/dynamixel/__init__.py | 16 ++++++++++++++++ src/lerobot/motors/feetech/__init__.py | 16 ++++++++++++++++ src/lerobot/robots/hope_jr/__init__.py | 16 ++++++++++++++++ src/lerobot/robots/koch_follower/__init__.py | 16 ++++++++++++++++ src/lerobot/robots/lekiwi/__init__.py | 16 ++++++++++++++++ src/lerobot/robots/so100_follower/__init__.py | 16 ++++++++++++++++ src/lerobot/robots/so101_follower/__init__.py | 16 ++++++++++++++++ src/lerobot/robots/stretch3/__init__.py | 16 ++++++++++++++++ src/lerobot/robots/viperx/__init__.py | 16 ++++++++++++++++ src/lerobot/teleoperators/__init__.py | 16 ++++++++++++++++ src/lerobot/teleoperators/homunculus/__init__.py | 16 ++++++++++++++++ src/lerobot/teleoperators/keyboard/__init__.py | 16 ++++++++++++++++ .../teleoperators/koch_leader/__init__.py | 16 ++++++++++++++++ .../teleoperators/so100_leader/__init__.py | 16 ++++++++++++++++ .../teleoperators/so101_leader/__init__.py | 16 ++++++++++++++++ .../teleoperators/stretch3_gamepad/__init__.py | 16 ++++++++++++++++ src/lerobot/teleoperators/widowx/__init__.py | 16 ++++++++++++++++ tests/configs/test_plugin_loading.py | 16 ++++++++++++++++ tests/mocks/mock_dynamixel.py | 16 ++++++++++++++++ tests/mocks/mock_feetech.py | 16 ++++++++++++++++ tests/mocks/mock_motors_bus.py | 14 ++++++++++++++ tests/mocks/mock_robot.py | 16 ++++++++++++++++ tests/mocks/mock_serial_patch.py | 16 ++++++++++++++++ tests/mocks/mock_teleop.py | 16 ++++++++++++++++ tests/motors/test_dynamixel.py | 16 ++++++++++++++++ tests/motors/test_feetech.py | 16 ++++++++++++++++ tests/motors/test_motors_bus.py | 16 ++++++++++++++++ tests/robots/test_so100_follower.py | 16 ++++++++++++++++ tests/test_control_robot.py | 16 ++++++++++++++++ tests/utils/test_encoding_utils.py | 16 ++++++++++++++++ 30 files changed, 478 insertions(+) diff --git a/src/lerobot/motors/dynamixel/__init__.py b/src/lerobot/motors/dynamixel/__init__.py index 3e414557e..425f8538a 100644 --- a/src/lerobot/motors/dynamixel/__init__.py +++ b/src/lerobot/motors/dynamixel/__init__.py @@ -1,2 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode from .tables import * diff --git a/src/lerobot/motors/feetech/__init__.py b/src/lerobot/motors/feetech/__init__.py index 911d1d19f..75da2d221 100644 --- a/src/lerobot/motors/feetech/__init__.py +++ b/src/lerobot/motors/feetech/__init__.py @@ -1,2 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .feetech import DriveMode, FeetechMotorsBus, OperatingMode, TorqueMode from .tables import * diff --git a/src/lerobot/robots/hope_jr/__init__.py b/src/lerobot/robots/hope_jr/__init__.py index 324e3c8e8..26603ebb0 100644 --- a/src/lerobot/robots/hope_jr/__init__.py +++ b/src/lerobot/robots/hope_jr/__init__.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .config_hope_jr import HopeJrArmConfig, HopeJrHandConfig from .hope_jr_arm import HopeJrArm from .hope_jr_hand import HopeJrHand diff --git a/src/lerobot/robots/koch_follower/__init__.py b/src/lerobot/robots/koch_follower/__init__.py index ae98a2c38..6271c4e55 100644 --- a/src/lerobot/robots/koch_follower/__init__.py +++ b/src/lerobot/robots/koch_follower/__init__.py @@ -1,2 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .config_koch_follower import KochFollowerConfig from .koch_follower import KochFollower diff --git a/src/lerobot/robots/lekiwi/__init__.py b/src/lerobot/robots/lekiwi/__init__.py index e3d10c5c1..ada2ff368 100644 --- a/src/lerobot/robots/lekiwi/__init__.py +++ b/src/lerobot/robots/lekiwi/__init__.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .config_lekiwi import LeKiwiClientConfig, LeKiwiConfig from .lekiwi import LeKiwi from .lekiwi_client import LeKiwiClient diff --git a/src/lerobot/robots/so100_follower/__init__.py b/src/lerobot/robots/so100_follower/__init__.py index 63c3e1c17..b995aab13 100644 --- a/src/lerobot/robots/so100_follower/__init__.py +++ b/src/lerobot/robots/so100_follower/__init__.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig from .so100_follower import SO100Follower from .so100_follower_end_effector import SO100FollowerEndEffector diff --git a/src/lerobot/robots/so101_follower/__init__.py b/src/lerobot/robots/so101_follower/__init__.py index f6615b15b..9ff2baf45 100644 --- a/src/lerobot/robots/so101_follower/__init__.py +++ b/src/lerobot/robots/so101_follower/__init__.py @@ -1,2 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .config_so101_follower import SO101FollowerConfig from .so101_follower import SO101Follower diff --git a/src/lerobot/robots/stretch3/__init__.py b/src/lerobot/robots/stretch3/__init__.py index e2a859cde..b3070bbd6 100644 --- a/src/lerobot/robots/stretch3/__init__.py +++ b/src/lerobot/robots/stretch3/__init__.py @@ -1,2 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .configuration_stretch3 import Stretch3RobotConfig from .robot_stretch3 import Stretch3Robot diff --git a/src/lerobot/robots/viperx/__init__.py b/src/lerobot/robots/viperx/__init__.py index 522d02f1c..bfba07fc7 100644 --- a/src/lerobot/robots/viperx/__init__.py +++ b/src/lerobot/robots/viperx/__init__.py @@ -1,2 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .config_viperx import ViperXConfig from .viperx import ViperX diff --git a/src/lerobot/teleoperators/__init__.py b/src/lerobot/teleoperators/__init__.py index ec93547f7..56f48af7e 100644 --- a/src/lerobot/teleoperators/__init__.py +++ b/src/lerobot/teleoperators/__init__.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .config import TeleoperatorConfig from .teleoperator import Teleoperator from .utils import make_teleoperator_from_config diff --git a/src/lerobot/teleoperators/homunculus/__init__.py b/src/lerobot/teleoperators/homunculus/__init__.py index 04b5c0f2b..b3c6c0bf5 100644 --- a/src/lerobot/teleoperators/homunculus/__init__.py +++ b/src/lerobot/teleoperators/homunculus/__init__.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .config_homunculus import HomunculusArmConfig, HomunculusGloveConfig from .homunculus_arm import HomunculusArm from .homunculus_glove import HomunculusGlove diff --git a/src/lerobot/teleoperators/keyboard/__init__.py b/src/lerobot/teleoperators/keyboard/__init__.py index 5761bf788..72d01003a 100644 --- a/src/lerobot/teleoperators/keyboard/__init__.py +++ b/src/lerobot/teleoperators/keyboard/__init__.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .configuration_keyboard import KeyboardEndEffectorTeleopConfig, KeyboardTeleopConfig from .teleop_keyboard import KeyboardEndEffectorTeleop, KeyboardTeleop diff --git a/src/lerobot/teleoperators/koch_leader/__init__.py b/src/lerobot/teleoperators/koch_leader/__init__.py index ad2d6a0e4..1bf9d51db 100644 --- a/src/lerobot/teleoperators/koch_leader/__init__.py +++ b/src/lerobot/teleoperators/koch_leader/__init__.py @@ -1,2 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .config_koch_leader import KochLeaderConfig from .koch_leader import KochLeader diff --git a/src/lerobot/teleoperators/so100_leader/__init__.py b/src/lerobot/teleoperators/so100_leader/__init__.py index 63c877e60..747416be2 100644 --- a/src/lerobot/teleoperators/so100_leader/__init__.py +++ b/src/lerobot/teleoperators/so100_leader/__init__.py @@ -1,2 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .config_so100_leader import SO100LeaderConfig from .so100_leader import SO100Leader diff --git a/src/lerobot/teleoperators/so101_leader/__init__.py b/src/lerobot/teleoperators/so101_leader/__init__.py index 1f45170e9..11e277c91 100644 --- a/src/lerobot/teleoperators/so101_leader/__init__.py +++ b/src/lerobot/teleoperators/so101_leader/__init__.py @@ -1,2 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .config_so101_leader import SO101LeaderConfig from .so101_leader import SO101Leader diff --git a/src/lerobot/teleoperators/stretch3_gamepad/__init__.py b/src/lerobot/teleoperators/stretch3_gamepad/__init__.py index ac45b6dd4..fa5a19974 100644 --- a/src/lerobot/teleoperators/stretch3_gamepad/__init__.py +++ b/src/lerobot/teleoperators/stretch3_gamepad/__init__.py @@ -1,2 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .configuration_stretch3 import Stretch3GamePadConfig from .stretch3_gamepad import Stretch3GamePad diff --git a/src/lerobot/teleoperators/widowx/__init__.py b/src/lerobot/teleoperators/widowx/__init__.py index 122ee3290..42e312f49 100644 --- a/src/lerobot/teleoperators/widowx/__init__.py +++ b/src/lerobot/teleoperators/widowx/__init__.py @@ -1,2 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .config_widowx import WidowXConfig from .widowx import WidowX diff --git a/tests/configs/test_plugin_loading.py b/tests/configs/test_plugin_loading.py index 957574eb4..e81057c95 100644 --- a/tests/configs/test_plugin_loading.py +++ b/tests/configs/test_plugin_loading.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import sys from dataclasses import dataclass from pathlib import Path diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py index 64592439f..00403d146 100644 --- a/tests/mocks/mock_dynamixel.py +++ b/tests/mocks/mock_dynamixel.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import abc from typing import Callable diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index e0b677d57..041c09421 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import abc from typing import Callable diff --git a/tests/mocks/mock_motors_bus.py b/tests/mocks/mock_motors_bus.py index 91e33473d..a499dbfee 100644 --- a/tests/mocks/mock_motors_bus.py +++ b/tests/mocks/mock_motors_bus.py @@ -1,3 +1,17 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # ruff: noqa: N802 from lerobot.motors.motors_bus import ( diff --git a/tests/mocks/mock_robot.py b/tests/mocks/mock_robot.py index 971fc00ad..8108c7c25 100644 --- a/tests/mocks/mock_robot.py +++ b/tests/mocks/mock_robot.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import random from dataclasses import dataclass, field from functools import cached_property diff --git a/tests/mocks/mock_serial_patch.py b/tests/mocks/mock_serial_patch.py index e39923188..bde0efae2 100644 --- a/tests/mocks/mock_serial_patch.py +++ b/tests/mocks/mock_serial_patch.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import threading import time diff --git a/tests/mocks/mock_teleop.py b/tests/mocks/mock_teleop.py index c29cc9219..e37d4a2c5 100644 --- a/tests/mocks/mock_teleop.py +++ b/tests/mocks/mock_teleop.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import random from dataclasses import dataclass from functools import cached_property diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py index a54e49056..d990b5b0f 100644 --- a/tests/motors/test_dynamixel.py +++ b/tests/motors/test_dynamixel.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import re import sys from typing import Generator diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index c5a170dd9..d6ea1db20 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import re import sys from typing import Generator diff --git a/tests/motors/test_motors_bus.py b/tests/motors/test_motors_bus.py index 966af3fb0..27650ef1b 100644 --- a/tests/motors/test_motors_bus.py +++ b/tests/motors/test_motors_bus.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import re from unittest.mock import patch diff --git a/tests/robots/test_so100_follower.py b/tests/robots/test_so100_follower.py index 498eec94b..d76b9591a 100644 --- a/tests/robots/test_so100_follower.py +++ b/tests/robots/test_so100_follower.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from contextlib import contextmanager from unittest.mock import MagicMock, patch diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 2ec6a2905..e45688c14 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from lerobot.calibrate import CalibrateConfig, calibrate from lerobot.record import DatasetRecordConfig, RecordConfig, record from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay diff --git a/tests/utils/test_encoding_utils.py b/tests/utils/test_encoding_utils.py index 9c9796762..813942862 100644 --- a/tests/utils/test_encoding_utils.py +++ b/tests/utils/test_encoding_utils.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pytest from lerobot.utils.encoding_utils import ( From 816034948af0fcd2cdfb2ae51ea41eb42d98101a Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Wed, 16 Jul 2025 21:13:01 +0700 Subject: [PATCH 016/158] [Async Inference] Add gRPC retry mechanism to Async client (#1485) Co-authored-by: Michel Aractingi --- src/lerobot/scripts/rl/actor.py | 31 ++-------------- src/lerobot/scripts/rl/learner.py | 5 +-- src/lerobot/scripts/rl/learner_service.py | 1 - src/lerobot/scripts/server/robot_client.py | 5 ++- src/lerobot/transport/utils.py | 41 ++++++++++++++++++++++ 5 files changed, 50 insertions(+), 33 deletions(-) diff --git a/src/lerobot/scripts/rl/actor.py b/src/lerobot/scripts/rl/actor.py index cd5e286c0..1c8f9286b 100644 --- a/src/lerobot/scripts/rl/actor.py +++ b/src/lerobot/scripts/rl/actor.py @@ -63,12 +63,12 @@ from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.policies.factory import make_policy from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.robots import so100_follower # noqa: F401 -from lerobot.scripts.rl import learner_service from lerobot.scripts.rl.gym_manipulator import make_robot_env from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 from lerobot.transport import services_pb2, services_pb2_grpc from lerobot.transport.utils import ( bytes_to_state_dict, + grpc_channel_options, python_object_to_bytes, receive_bytes_in_chunks, send_bytes_in_chunks, @@ -399,8 +399,6 @@ def learner_service_client( host: str = "127.0.0.1", port: int = 50051, ) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: - import json - """ Returns a client for the learner service. @@ -408,34 +406,9 @@ def learner_service_client( So we need to create only one client and reuse it. """ - service_config = { - "methodConfig": [ - { - "name": [{}], # Applies to ALL methods in ALL services - "retryPolicy": { - "maxAttempts": 5, # Max retries (total attempts = 5) - "initialBackoff": "0.1s", # First retry after 0.1s - "maxBackoff": "2s", # Max wait time between retries - "backoffMultiplier": 2, # Exponential backoff factor - "retryableStatusCodes": [ - "UNAVAILABLE", - "DEADLINE_EXCEEDED", - ], # Retries on network failures - }, - } - ] - } - - service_config_json = json.dumps(service_config) - channel = grpc.insecure_channel( f"{host}:{port}", - options=[ - ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE), - ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE), - ("grpc.enable_retries", 1), - ("grpc.service_config", service_config_json), - ], + grpc_channel_options(), ) stub = services_pb2_grpc.LearnerServiceStub(channel) logging.info("[ACTOR] Learner service client created") diff --git a/src/lerobot/scripts/rl/learner.py b/src/lerobot/scripts/rl/learner.py index edd2363b1..cb88895cf 100644 --- a/src/lerobot/scripts/rl/learner.py +++ b/src/lerobot/scripts/rl/learner.py @@ -77,6 +77,7 @@ from lerobot.scripts.rl import learner_service from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 from lerobot.transport import services_pb2_grpc from lerobot.transport.utils import ( + MAX_MESSAGE_SIZE, bytes_to_python_object, bytes_to_transitions, state_to_bytes, @@ -658,8 +659,8 @@ def start_learner( server = grpc.server( ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS), options=[ - ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE), - ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE), + ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), + ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), ], ) diff --git a/src/lerobot/scripts/rl/learner_service.py b/src/lerobot/scripts/rl/learner_service.py index 198e52945..b07c296e6 100644 --- a/src/lerobot/scripts/rl/learner_service.py +++ b/src/lerobot/scripts/rl/learner_service.py @@ -23,7 +23,6 @@ from lerobot.transport import services_pb2, services_pb2_grpc from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks from lerobot.utils.queue import get_last_item_from_queue -MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB MAX_WORKERS = 3 # Stream parameters, send transitions and interactions SHUTDOWN_TIMEOUT = 10 diff --git a/src/lerobot/scripts/server/robot_client.py b/src/lerobot/scripts/server/robot_client.py index a6d7b7242..44d9cdf77 100644 --- a/src/lerobot/scripts/server/robot_client.py +++ b/src/lerobot/scripts/server/robot_client.py @@ -76,6 +76,7 @@ from lerobot.transport import ( async_inference_pb2, # type: ignore async_inference_pb2_grpc, # type: ignore ) +from lerobot.transport.utils import grpc_channel_options class RobotClient: @@ -113,7 +114,9 @@ class RobotClient: config.actions_per_chunk, config.policy_device, ) - self.channel = grpc.insecure_channel(self.server_address) + self.channel = grpc.insecure_channel( + self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s") + ) self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel) self.logger.info(f"Initializing client to connect to server at {self.server_address}") diff --git a/src/lerobot/transport/utils.py b/src/lerobot/transport/utils.py index 1c6683262..bf1aab755 100644 --- a/src/lerobot/transport/utils.py +++ b/src/lerobot/transport/utils.py @@ -16,6 +16,7 @@ # limitations under the License. import io +import json import logging import pickle # nosec B403: Safe usage for internal serialization only from multiprocessing import Event, Queue @@ -27,6 +28,7 @@ from lerobot.transport import services_pb2 from lerobot.utils.transition import Transition CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB +MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB def bytes_buffer_size(buffer: io.BytesIO) -> int: @@ -139,3 +141,42 @@ def transitions_to_bytes(transitions: list[Transition]) -> bytes: buffer = io.BytesIO() torch.save(transitions, buffer) return buffer.getvalue() + + +def grpc_channel_options( + max_receive_message_length: int = MAX_MESSAGE_SIZE, + max_send_message_length: int = MAX_MESSAGE_SIZE, + enable_retries: bool = True, + initial_backoff: str = "0.1s", + max_attempts: int = 5, + backoff_multiplier: float = 2, + max_backoff: str = "2s", +): + service_config = { + "methodConfig": [ + { + "name": [{}], # Applies to ALL methods in ALL services + "retryPolicy": { + "maxAttempts": max_attempts, # Max retries (total attempts = 5) + "initialBackoff": initial_backoff, # First retry after 0.1s + "maxBackoff": max_backoff, # Max wait time between retries + "backoffMultiplier": backoff_multiplier, # Exponential backoff factor + "retryableStatusCodes": [ + "UNAVAILABLE", + "DEADLINE_EXCEEDED", + ], # Retries on network failures + }, + } + ] + } + + service_config_json = json.dumps(service_config) + + retries_option = 1 if enable_retries else 0 + + return [ + ("grpc.max_receive_message_length", max_receive_message_length), + ("grpc.max_send_message_length", max_send_message_length), + ("grpc.enable_retries", retries_option), + ("grpc.service_config", service_config_json), + ] From 0938a1d816c70d689ab6df3298a71572798c15a6 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Wed, 16 Jul 2025 17:50:36 +0200 Subject: [PATCH 017/158] Feat/add bimanual so100 robot (#1509) --- src/lerobot/record.py | 24 +++ src/lerobot/replay.py | 15 +- .../robots/bi_so100_follower/__init__.py | 18 ++ .../bi_so100_follower/bi_so100_follower.py | 163 ++++++++++++++++++ .../config_bi_so100_follower.py | 39 +++++ .../so100_follower/config_so100_follower.py | 4 +- src/lerobot/robots/utils.py | 4 + src/lerobot/teleoperate.py | 23 +++ .../teleoperators/bi_so100_leader/__init__.py | 18 ++ .../bi_so100_leader/bi_so100_leader.py | 121 +++++++++++++ .../bi_so100_leader/config_bi_so100_leader.py | 26 +++ src/lerobot/teleoperators/utils.py | 4 + 12 files changed, 457 insertions(+), 2 deletions(-) create mode 100644 src/lerobot/robots/bi_so100_follower/__init__.py create mode 100644 src/lerobot/robots/bi_so100_follower/bi_so100_follower.py create mode 100644 src/lerobot/robots/bi_so100_follower/config_bi_so100_follower.py create mode 100644 src/lerobot/teleoperators/bi_so100_leader/__init__.py create mode 100644 src/lerobot/teleoperators/bi_so100_leader/bi_so100_leader.py create mode 100644 src/lerobot/teleoperators/bi_so100_leader/config_bi_so100_leader.py diff --git a/src/lerobot/record.py b/src/lerobot/record.py index 9fc0dc7ed..c8184d40b 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -33,6 +33,28 @@ python -m lerobot.record \ # <- Policy optional if you want to record with a policy \ # --policy.path=${HF_USER}/my_policy \ ``` + +Example recording with bimanual so100: +```shell +python -m lerobot.record \ + --robot.type=bi_so100_follower \ + --robot.left_arm_port=/dev/tty.usbmodem5A460851411 \ + --robot.right_arm_port=/dev/tty.usbmodem5A460812391 \ + --robot.id=bimanual_follower \ + --robot.cameras='{ + left: {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}, + top: {"type": "opencv", "index_or_path": 1, "width": 640, "height": 480, "fps": 30}, + right: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30} + }' \ + --teleop.type=bi_so100_leader \ + --teleop.left_arm_port=/dev/tty.usbmodem5A460828611 \ + --teleop.right_arm_port=/dev/tty.usbmodem5A460826981 \ + --teleop.id=bimanual_leader \ + --display_data=true \ + --dataset.repo_id=${HF_USER}/bimanual-so100-handover-cube \ + --dataset.num_episodes=25 \ + --dataset.single_task="Grab and handover the red cube to the other arm" +``` """ import logging @@ -57,6 +79,7 @@ from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_so100_follower, hope_jr, koch_follower, make_robot_from_config, @@ -66,6 +89,7 @@ from lerobot.robots import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, + bi_so100_leader, homunculus, koch_leader, make_teleoperator_from_config, diff --git a/src/lerobot/replay.py b/src/lerobot/replay.py index c51c55cee..afe54d90f 100644 --- a/src/lerobot/replay.py +++ b/src/lerobot/replay.py @@ -15,7 +15,7 @@ """ Replays the actions of an episode from a dataset on a robot. -Example: +Examples: ```shell python -m lerobot.replay \ @@ -25,6 +25,18 @@ python -m lerobot.replay \ --dataset.repo_id=aliberts/record-test \ --dataset.episode=2 ``` + +Example replay with bimanual so100: +```shell +python -m lerobot.replay \ + --robot.type=bi_so100_follower \ + --robot.left_arm_port=/dev/tty.usbmodem5A460851411 \ + --robot.right_arm_port=/dev/tty.usbmodem5A460812391 \ + --robot.id=bimanual_follower \ + --dataset.repo_id=${HF_USER}/bimanual-so100-handover-cube \ + --dataset.episode=0 +``` + """ import logging @@ -39,6 +51,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_so100_follower, hope_jr, koch_follower, make_robot_from_config, diff --git a/src/lerobot/robots/bi_so100_follower/__init__.py b/src/lerobot/robots/bi_so100_follower/__init__.py new file mode 100644 index 000000000..90f56516b --- /dev/null +++ b/src/lerobot/robots/bi_so100_follower/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .bi_so100_follower import BiSO100Follower +from .config_bi_so100_follower import BiSO100FollowerConfig diff --git a/src/lerobot/robots/bi_so100_follower/bi_so100_follower.py b/src/lerobot/robots/bi_so100_follower/bi_so100_follower.py new file mode 100644 index 000000000..7992b79fd --- /dev/null +++ b/src/lerobot/robots/bi_so100_follower/bi_so100_follower.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from functools import cached_property +from typing import Any + +from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.robots.so100_follower import SO100Follower +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig + +from ..robot import Robot +from .config_bi_so100_follower import BiSO100FollowerConfig + +logger = logging.getLogger(__name__) + + +class BiSO100Follower(Robot): + """ + [Bimanual SO-100 Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio + This bimanual robot can also be easily adapted to use SO-101 follower arms, just replace the SO100Follower class with SO101Follower and SO100FollowerConfig with SO101FollowerConfig. + """ + + config_class = BiSO100FollowerConfig + name = "bi_so100_follower" + + def __init__(self, config: BiSO100FollowerConfig): + super().__init__(config) + self.config = config + + left_arm_config = SO100FollowerConfig( + id=f"{config.id}_left" if config.id else None, + calibration_dir=config.calibration_dir, + port=config.left_arm_port, + disable_torque_on_disconnect=config.left_arm_disable_torque_on_disconnect, + max_relative_target=config.left_arm_max_relative_target, + use_degrees=config.left_arm_use_degrees, + cameras={}, + ) + + right_arm_config = SO100FollowerConfig( + id=f"{config.id}_right" if config.id else None, + calibration_dir=config.calibration_dir, + port=config.right_arm_port, + disable_torque_on_disconnect=config.right_arm_disable_torque_on_disconnect, + max_relative_target=config.right_arm_max_relative_target, + use_degrees=config.right_arm_use_degrees, + cameras={}, + ) + + self.left_arm = SO100Follower(left_arm_config) + self.right_arm = SO100Follower(right_arm_config) + self.cameras = make_cameras_from_configs(config.cameras) + + @property + def _motors_ft(self) -> dict[str, type]: + return {f"left_{motor}.pos": float for motor in self.left_arm.bus.motors} | { + f"right_{motor}.pos": float for motor in self.right_arm.bus.motors + } + + @property + def _cameras_ft(self) -> dict[str, tuple]: + return { + cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._motors_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + return self._motors_ft + + @property + def is_connected(self) -> bool: + return ( + self.left_arm.bus.is_connected + and self.right_arm.bus.is_connected + and all(cam.is_connected for cam in self.cameras.values()) + ) + + def connect(self, calibrate: bool = True) -> None: + self.left_arm.connect(calibrate) + self.right_arm.connect(calibrate) + + for cam in self.cameras.values(): + cam.connect() + + @property + def is_calibrated(self) -> bool: + return self.left_arm.is_calibrated and self.right_arm.is_calibrated + + def calibrate(self) -> None: + self.left_arm.calibrate() + self.right_arm.calibrate() + + def configure(self) -> None: + self.left_arm.configure() + self.right_arm.configure() + + def setup_motors(self) -> None: + self.left_arm.setup_motors() + self.right_arm.setup_motors() + + def get_observation(self) -> dict[str, Any]: + obs_dict = {} + + # Add "left_" prefix + left_obs = self.left_arm.get_observation() + obs_dict.update({f"left_{key}": value for key, value in left_obs.items()}) + + # Add "right_" prefix + right_obs = self.right_arm.get_observation() + obs_dict.update({f"right_{key}": value for key, value in right_obs.items()}) + + for cam_key, cam in self.cameras.items(): + start = time.perf_counter() + obs_dict[cam_key] = cam.async_read() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + + return obs_dict + + def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + # Remove "left_" prefix + left_action = { + key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_") + } + # Remove "right_" prefix + right_action = { + key.removeprefix("right_"): value for key, value in action.items() if key.startswith("right_") + } + + send_action_left = self.left_arm.send_action(left_action) + send_action_right = self.right_arm.send_action(right_action) + + # Add prefixes back + prefixed_send_action_left = {f"left_{key}": value for key, value in send_action_left.items()} + prefixed_send_action_right = {f"right_{key}": value for key, value in send_action_right.items()} + + return {**prefixed_send_action_left, **prefixed_send_action_right} + + def disconnect(self): + self.left_arm.disconnect() + self.right_arm.disconnect() + + for cam in self.cameras.values(): + cam.disconnect() diff --git a/src/lerobot/robots/bi_so100_follower/config_bi_so100_follower.py b/src/lerobot/robots/bi_so100_follower/config_bi_so100_follower.py new file mode 100644 index 000000000..00643b85f --- /dev/null +++ b/src/lerobot/robots/bi_so100_follower/config_bi_so100_follower.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.cameras import CameraConfig + +from ..config import RobotConfig + + +@RobotConfig.register_subclass("bi_so100_follower") +@dataclass +class BiSO100FollowerConfig(RobotConfig): + left_arm_port: str + right_arm_port: str + + # Optional + left_arm_disable_torque_on_disconnect: bool = True + left_arm_max_relative_target: int | None = None + left_arm_use_degrees: bool = False + right_arm_disable_torque_on_disconnect: bool = True + right_arm_max_relative_target: int | None = None + right_arm_use_degrees: bool = False + + # cameras (shared between both arms) + cameras: dict[str, CameraConfig] = field(default_factory=dict) diff --git a/src/lerobot/robots/so100_follower/config_so100_follower.py b/src/lerobot/robots/so100_follower/config_so100_follower.py index 7cd23d340..ea8b9f1c2 100644 --- a/src/lerobot/robots/so100_follower/config_so100_follower.py +++ b/src/lerobot/robots/so100_follower/config_so100_follower.py @@ -1,4 +1,6 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index 911d40465..7486ee499 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -57,6 +57,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .hope_jr import HopeJrArm return HopeJrArm(config) + elif config.type == "bi_so100_follower": + from .bi_so100_follower import BiSO100Follower + + return BiSO100Follower(config) elif config.type == "mock_robot": from tests.mocks.mock_robot import MockRobot diff --git a/src/lerobot/teleoperate.py b/src/lerobot/teleoperate.py index 168f898c4..9836f1393 100644 --- a/src/lerobot/teleoperate.py +++ b/src/lerobot/teleoperate.py @@ -28,6 +28,27 @@ python -m lerobot.teleoperate \ --teleop.id=blue \ --display_data=true ``` + +Example teleoperation with bimanual so100: + +```shell +python -m lerobot.teleoperate \ + --robot.type=bi_so100_follower \ + --robot.left_arm_port=/dev/tty.usbmodem5A460851411 \ + --robot.right_arm_port=/dev/tty.usbmodem5A460812391 \ + --robot.id=bimanual_follower \ + --robot.cameras='{ + left: {"type": "opencv", "index_or_path": 0, "width": 1920, "height": 1080, "fps": 30}, + top: {"type": "opencv", "index_or_path": 1, "width": 1920, "height": 1080, "fps": 30}, + right: {"type": "opencv", "index_or_path": 2, "width": 1920, "height": 1080, "fps": 30} + }' \ + --teleop.type=bi_so100_leader \ + --teleop.left_arm_port=/dev/tty.usbmodem5A460828611 \ + --teleop.right_arm_port=/dev/tty.usbmodem5A460826981 \ + --teleop.id=bimanual_leader \ + --display_data=true +``` + """ import logging @@ -43,6 +64,7 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_so100_follower, hope_jr, koch_follower, make_robot_from_config, @@ -52,6 +74,7 @@ from lerobot.robots import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, + bi_so100_leader, gamepad, homunculus, koch_leader, diff --git a/src/lerobot/teleoperators/bi_so100_leader/__init__.py b/src/lerobot/teleoperators/bi_so100_leader/__init__.py new file mode 100644 index 000000000..34313a61e --- /dev/null +++ b/src/lerobot/teleoperators/bi_so100_leader/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .bi_so100_leader import BiSO100Leader +from .config_bi_so100_leader import BiSO100LeaderConfig diff --git a/src/lerobot/teleoperators/bi_so100_leader/bi_so100_leader.py b/src/lerobot/teleoperators/bi_so100_leader/bi_so100_leader.py new file mode 100644 index 000000000..769669655 --- /dev/null +++ b/src/lerobot/teleoperators/bi_so100_leader/bi_so100_leader.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from functools import cached_property + +from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig +from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader + +from ..teleoperator import Teleoperator +from .config_bi_so100_leader import BiSO100LeaderConfig + +logger = logging.getLogger(__name__) + + +class BiSO100Leader(Teleoperator): + """ + [Bimanual SO-100 Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio + This bimanual leader arm can also be easily adapted to use SO-101 leader arms, just replace the SO100Leader class with SO101Leader and SO100LeaderConfig with SO101LeaderConfig. + """ + + config_class = BiSO100LeaderConfig + name = "bi_so100_leader" + + def __init__(self, config: BiSO100LeaderConfig): + super().__init__(config) + self.config = config + + left_arm_config = SO100LeaderConfig( + id=f"{config.id}_left" if config.id else None, + calibration_dir=config.calibration_dir, + port=config.left_arm_port, + ) + + right_arm_config = SO100LeaderConfig( + id=f"{config.id}_right" if config.id else None, + calibration_dir=config.calibration_dir, + port=config.right_arm_port, + ) + + self.left_arm = SO100Leader(left_arm_config) + self.right_arm = SO100Leader(right_arm_config) + + @cached_property + def action_features(self) -> dict[str, type]: + return {f"left_{motor}.pos": float for motor in self.left_arm.bus.motors} | { + f"right_{motor}.pos": float for motor in self.right_arm.bus.motors + } + + @cached_property + def feedback_features(self) -> dict[str, type]: + return {} + + @property + def is_connected(self) -> bool: + return self.left_arm.is_connected and self.right_arm.is_connected + + def connect(self, calibrate: bool = True) -> None: + self.left_arm.connect(calibrate) + self.right_arm.connect(calibrate) + + @property + def is_calibrated(self) -> bool: + return self.left_arm.is_calibrated and self.right_arm.is_calibrated + + def calibrate(self) -> None: + self.left_arm.calibrate() + self.right_arm.calibrate() + + def configure(self) -> None: + self.left_arm.configure() + self.right_arm.configure() + + def setup_motors(self) -> None: + self.left_arm.setup_motors() + self.right_arm.setup_motors() + + def get_action(self) -> dict[str, float]: + action_dict = {} + + # Add "left_" prefix + left_action = self.left_arm.get_action() + action_dict.update({f"left_{key}": value for key, value in left_action.items()}) + + # Add "right_" prefix + right_action = self.right_arm.get_action() + action_dict.update({f"right_{key}": value for key, value in right_action.items()}) + + return action_dict + + def send_feedback(self, feedback: dict[str, float]) -> None: + # Remove "left_" prefix + left_feedback = { + key.removeprefix("left_"): value for key, value in feedback.items() if key.startswith("left_") + } + # Remove "right_" prefix + right_feedback = { + key.removeprefix("right_"): value for key, value in feedback.items() if key.startswith("right_") + } + + if left_feedback: + self.left_arm.send_feedback(left_feedback) + if right_feedback: + self.right_arm.send_feedback(right_feedback) + + def disconnect(self) -> None: + self.left_arm.disconnect() + self.right_arm.disconnect() diff --git a/src/lerobot/teleoperators/bi_so100_leader/config_bi_so100_leader.py b/src/lerobot/teleoperators/bi_so100_leader/config_bi_so100_leader.py new file mode 100644 index 000000000..117e09913 --- /dev/null +++ b/src/lerobot/teleoperators/bi_so100_leader/config_bi_so100_leader.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("bi_so100_leader") +@dataclass +class BiSO100LeaderConfig(TeleoperatorConfig): + left_arm_port: str + right_arm_port: str diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index 8a667fd41..344a95d72 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -61,5 +61,9 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: from .homunculus import HomunculusArm return HomunculusArm(config) + elif config.type == "bi_so100_leader": + from .bi_so100_leader import BiSO100Leader + + return BiSO100Leader(config) else: raise ValueError(config.type) From 378e1f0338749bae1d2d1d50d01e944b8a2d9f55 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 17 Jul 2025 14:30:20 +0200 Subject: [PATCH 018/158] Update pre-commit-config.yaml + pyproject.toml + ceil rerun & transformer dependencies version (#1520) * chore: update .gitignore * chore: update pre-commit * chore(deps): update pyproject * fix(ci): multiple fixes * chore: pre-commit apply * chore: address review comments * Update pyproject.toml Co-authored-by: Ben Zhang <5977478+ben-z@users.noreply.github.com> Signed-off-by: Steven Palma * chore(deps): add todo --------- Signed-off-by: Steven Palma Co-authored-by: Ben Zhang <5977478+ben-z@users.noreply.github.com> --- .github/PULL_REQUEST_TEMPLATE.md | 15 +- .gitignore | 282 +++++++++--------- .pre-commit-config.yaml | 53 +++- CODE_OF_CONDUCT.md | 21 +- CONTRIBUTING.md | 55 ++-- README.md | 51 +++- benchmarks/video/README.md | 45 ++- docs/README.md | 4 +- docs/source/async.mdx | 78 +++-- docs/source/backwardcomp.mdx | 29 +- docs/source/cameras.mdx | 59 +++- docs/source/hilserl.mdx | 115 +++++-- docs/source/hilserl_sim.mdx | 24 +- docs/source/il_robots.mdx | 56 +++- docs/source/il_sim.mdx | 38 ++- docs/source/index.mdx | 6 +- docs/source/installation.mdx | 27 +- docs/source/integrate_hardware.mdx | 44 ++- docs/source/notebooks.mdx | 4 +- docs/source/smolvla.mdx | 35 ++- examples/4_train_policy_with_script.md | 43 ++- pyproject.toml | 202 +++++++++---- src/lerobot/cameras/camera.py | 4 +- src/lerobot/cameras/opencv/camera_opencv.py | 4 +- .../cameras/realsense/camera_realsense.py | 4 +- .../realsense/configuration_realsense.py | 6 +- src/lerobot/configs/parser.py | 9 +- src/lerobot/configs/policies.py | 6 +- src/lerobot/configs/train.py | 4 +- src/lerobot/datasets/card_template.md | 3 +- src/lerobot/datasets/lerobot_dataset.py | 2 +- .../datasets/push_dataset_to_hub/utils.py | 3 +- src/lerobot/datasets/sampler.py | 4 +- src/lerobot/datasets/transforms.py | 9 +- src/lerobot/envs/configs.py | 36 +-- src/lerobot/find_cameras.py | 20 +- src/lerobot/motors/motors_bus.py | 2 +- src/lerobot/policies/act/modeling_act.py | 4 +- .../policies/diffusion/modeling_diffusion.py | 2 +- .../policies/pi0/paligemma_with_expert.py | 13 +- src/lerobot/policies/pretrained.py | 10 +- src/lerobot/policies/sac/modeling_sac.py | 3 +- .../policies/smolvla/smolvlm_with_expert.py | 13 +- src/lerobot/policies/tdmpc/modeling_tdmpc.py | 2 +- src/lerobot/policies/vqbet/modeling_vqbet.py | 4 +- src/lerobot/policies/vqbet/vqbet_utils.py | 6 +- src/lerobot/record.py | 3 +- src/lerobot/robots/hope_jr/hope_jr.mdx | 29 +- src/lerobot/robots/koch_follower/koch.mdx | 43 ++- src/lerobot/robots/lekiwi/lekiwi.mdx | 53 +++- src/lerobot/robots/lekiwi/lekiwi_client.py | 24 +- src/lerobot/robots/robot.py | 5 +- src/lerobot/robots/so100_follower/so100.mdx | 233 ++++++++++++--- src/lerobot/robots/so101_follower/so101.mdx | 103 +++++-- src/lerobot/robots/stretch3/README.md | 18 +- src/lerobot/robots/viperx/README.md | 28 +- src/lerobot/scripts/eval.py | 2 +- src/lerobot/scripts/rl/crop_dataset_roi.py | 5 +- src/lerobot/scripts/rl/gym_manipulator.py | 3 +- src/lerobot/scripts/rl/learner.py | 4 +- src/lerobot/scripts/server/configs.py | 2 +- src/lerobot/scripts/server/robot_client.py | 5 +- src/lerobot/scripts/visualize_dataset.py | 2 +- .../homunculus/homunculus_arm.py | 10 +- .../homunculus/homunculus_glove.py | 10 +- src/lerobot/teleoperators/teleoperator.py | 5 +- .../templates/lerobot_modelcard_template.md | 7 +- src/lerobot/utils/benchmark.py | 2 + src/lerobot/utils/buffer.py | 3 +- src/lerobot/utils/hub.py | 5 +- src/lerobot/utils/random_utils.py | 3 +- src/lerobot/utils/utils.py | 8 +- tests/configs/test_plugin_loading.py | 2 +- tests/mocks/mock_dynamixel.py | 2 +- tests/mocks/mock_feetech.py | 2 +- tests/motors/test_dynamixel.py | 2 +- tests/motors/test_feetech.py | 2 +- tests/utils/test_replay_buffer.py | 2 +- 78 files changed, 1450 insertions(+), 636 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index df2e2db29..22f1ee3d7 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,33 +1,40 @@ ## What this does + Explain what this PR does. Feel free to tag your PR with the appropriate label(s). Examples: -| Title | Label | +| Title | Label | |----------------------|-----------------| -| Fixes #[issue] | (🐛 Bug) | -| Adds new dataset | (🗃️ Dataset) | -| Optimizes something | (⚡️ Performance) | +| Fixes #[issue] | (🐛 Bug) | +| Adds new dataset | (🗃️ Dataset) | +| Optimizes something | (⚡️ Performance) | ## How it was tested + Explain/show how you tested your changes. Examples: + - Added `test_something` in `tests/test_stuff.py`. - Added `new_feature` and checked that training converges with policy X on dataset/environment Y. - Optimized `some_function`, it now runs X times faster than previously. ## How to checkout & try? (for the reviewer) + Provide a simple way for the reviewer to try out your changes. Examples: + ```bash pytest -sx tests/test_stuff.py::test_something ``` + ```bash python -m lerobot.scripts.train --some.option=true ``` ## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR + **Note**: Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. Try to avoid tagging more than 3 people. diff --git a/.gitignore b/.gitignore index 4ab886933..c4d1f769f 100644 --- a/.gitignore +++ b/.gitignore @@ -12,164 +12,164 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Dev scripts -.dev - -# Logging -logs -tmp -wandb - -# Data -data -outputs - -# Apple -.DS_Store - -# VS Code -.vscode -.devcontainer - -# HPC -nautilus/*.yaml -*.key - -# Slurm -sbatch*.sh - -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -pip-wheel-metadata/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# uv/poetry lock files -poetry.lock -uv.lock - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -!tests/artifacts -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ - -# Ignore .cache -.cache/* - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -.python-version - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments +### Environments & Dependencies ### .env .venv env/ venv/ env.bak/ venv.bak/ +.python-version +__pypackages__/ +node_modules/ -# Spyder project settings +# Lock files +poetry.lock +uv.lock +Pipfile.lock + +### Build & Distribution ### +build/ +dist/ +sdist/ +wheels/ +downloads/ +eggs/ +.eggs/ +parts/ +var/ +pip-wheel-metadata/ +share/python-wheels/ +develop-eggs/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +lib/ +lib64/ + +# PyInstaller +*.manifest +*.spec + +### Compiled & Cached Files ### +__pycache__/ +*.py[cod] +*$py.class +*.so +*.sage.py +.cache/ +.ruff_cache/ +.mypy_cache/ +.pyre/ +.pytype/ +cython_debug/ + +### Testing & Coverage ### +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.pytest_cache/ +.hypothesis/ +nosetests.xml +coverage.xml +*.cover +*.py,cover +!tests/artifacts + +### Logs & Temporary Files ### +logs/ +tmp/ +*.log +pip-log.txt +pip-delete-this-directory.txt +celerybeat-schedule +celerybeat.pid + +### IDE & Editor Config ### +# VS Code +.vscode/ +.devcontainer/ + +# JetBrains / PyCharm +.idea/ + +# Spyder .spyderproject .spyproject -# Rope project settings +# Rope .ropeproject -# mkdocs documentation +# Vim +*.swp + +# Other +*~ + +### OS Specific ### +# macOS +.DS_Store + +# Windows +Thumbs.db + +### Framework & Tool Specific ### + +.Python + +# Django +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask +instance/ +.webassets-cache + +# Scrapy +.scrapy + +# Jupyter +.ipynb_checkpoints/ +profile_default/ +ipython_config.py + +# Sphinx +docs/_build/ + +# MkDocs /site +# PyBuilder +.pybuilder/ +target/ + # mypy -.mypy_cache/ .dmypy.json dmypy.json -# Pyre type checker -.pyre/ +### HPC & Slurm ### +nautilus/*.yaml +*.key +sbatch*.sh -# pytype static type analyzer -.pytype/ +### Miscellaneous ### +# W&B +wandb/ -# Cython debug symbols -cython_debug/ +# Dev scripts +.dev/ + +# Data folders +data/ +outputs/ + +# Translations +*.mo +*.pot + +# Dev folders +.cache/* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e25f33ee0..e509d6d88 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -exclude: "tests/artifacts/.*\\.safetensors$" default_language_version: python: python3.10 + +exclude: "tests/artifacts/.*\\.safetensors$" + repos: ##### Meta ##### - repo: meta @@ -22,12 +24,12 @@ repos: - id: check-useless-excludes - id: check-hooks-apply - - ##### Style / Misc. ##### + ##### General Code Quality & Formatting ##### - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 hooks: - id: check-added-large-files + args: ['--maxkb=1024'] - id: debug-statements - id: check-merge-conflict - id: check-case-conflict @@ -36,7 +38,14 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/adhtruong/mirrors-typos + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.13 + hooks: + - id: ruff-format + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + + - repo: https://github.com/crate-ci/typos rev: v1.34.0 hooks: - id: typos @@ -46,14 +55,16 @@ repos: rev: v3.20.0 hooks: - id: pyupgrade + args: [--py310-plus] - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.3 + ##### Markdown Quality ##### + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v4.0.0-alpha.8 hooks: - - id: ruff - args: [--fix] - - id: ruff-format - + - id: prettier + name: Format Markdown with Prettier + types_or: [markdown, mdx] + args: [--prose-wrap=preserve] ##### Security ##### - repo: https://github.com/gitleaks/gitleaks @@ -72,3 +83,25 @@ repos: - id: bandit args: ["-c", "pyproject.toml"] additional_dependencies: ["bandit[toml]"] + + # TODO(Steven): Uncomment when ready to use + ##### Static Analysis & Typing ##### + # - repo: https://github.com/pre-commit/mirrors-mypy + # rev: v1.16.0 + # hooks: + # - id: mypy + # args: [--python-version=3.10] + + ##### Docstring Checks ##### + # - repo: https://github.com/akaihola/darglint2 + # rev: v1.8.2 + # hooks: + # - id: darglint2 + # args: ["--docstring-style", "google", "-v", "2"] + # exclude: ^tests/.*$ + + # - repo: https://github.com/econchick/interrogate + # rev: 1.7.0 + # hooks: + # - id: interrogate + # args: ["-vv", "--config=pyproject.toml"] diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 04a052753..c0fdac843 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,4 +1,3 @@ - # Contributor Covenant Code of Conduct ## Our Pledge @@ -18,23 +17,23 @@ diverse, inclusive, and healthy community. Examples of behavior that contributes to a positive environment for our community include: -* Demonstrating empathy and kindness toward other people -* Being respectful of differing opinions, viewpoints, and experiences -* Giving and gracefully accepting constructive feedback -* Accepting responsibility and apologizing to those affected by our mistakes, +- Demonstrating empathy and kindness toward other people +- Being respectful of differing opinions, viewpoints, and experiences +- Giving and gracefully accepting constructive feedback +- Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience -* Focusing on what is best not just for us as individuals, but for the overall +- Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: -* The use of sexualized language or imagery, and sexual attention or advances of +- The use of sexualized language or imagery, and sexual attention or advances of any kind -* Trolling, insulting or derogatory comments, and personal or political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or email address, +- Trolling, insulting or derogatory comments, and personal or political attacks +- Public or private harassment +- Publishing others' private information, such as a physical or email address, without their explicit permission -* Other conduct which could reasonably be considered inappropriate in a +- Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a354e1346..369af602b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -15,10 +15,11 @@ Whichever way you choose to contribute, please be mindful to respect our ## You can contribute in so many ways! Some of the ways you can contribute to 🤗 LeRobot: -* Fixing outstanding issues with the existing code. -* Implementing new models, datasets or simulation environments. -* Contributing to the examples or to the documentation. -* Submitting issues related to bugs or desired new features. + +- Fixing outstanding issues with the existing code. +- Implementing new models, datasets or simulation environments. +- Contributing to the examples or to the documentation. +- Submitting issues related to bugs or desired new features. Following the guides below, feel free to open issues and PRs and to coordinate your efforts with the community on our [Discord Channel](https://discord.gg/VjFz58wn3R). For specific inquiries, reach out to [Remi Cadene](mailto:remi.cadene@huggingface.co). @@ -40,24 +41,26 @@ already reported** (use the search bar on Github under Issues). Did not find it? :( So we can act quickly on it, please follow these steps: -* Include your **OS type and version**, the versions of **Python** and **PyTorch**. -* A short, self-contained, code snippet that allows us to reproduce the bug in +- Include your **OS type and version**, the versions of **Python** and **PyTorch**. +- A short, self-contained, code snippet that allows us to reproduce the bug in less than 30s. -* The full traceback if an exception is raised. -* Attach any other additional information, like screenshots, you think may help. +- The full traceback if an exception is raised. +- Attach any other additional information, like screenshots, you think may help. ### Do you want a new feature? A good feature request addresses the following points: 1. Motivation first: -* Is it related to a problem/frustration with the library? If so, please explain + +- Is it related to a problem/frustration with the library? If so, please explain why. Providing a code snippet that demonstrates the problem is best. -* Is it related to something you would need for a project? We'd love to hear +- Is it related to something you would need for a project? We'd love to hear about it! -* Is it something you worked on and think could benefit the community? +- Is it something you worked on and think could benefit the community? Awesome! Tell us what problem it solved for you. -2. Write a *paragraph* describing the feature. + +2. Write a _paragraph_ describing the feature. 3. Provide a **code snippet** that demonstrates its future use. 4. In case this is related to a paper, please attach a link. 5. Attach any additional information (drawings, screenshots, etc.) you think may help. @@ -74,12 +77,15 @@ environments ([aloha](https://github.com/huggingface/gym-aloha), and follow the same api design. When implementing a new dataset loadable with LeRobotDataset follow these steps: + - Update `available_datasets_per_env` in `lerobot/__init__.py` When implementing a new environment (e.g. `gym_aloha`), follow these steps: + - Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py` When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps: + - Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py` - Set the required `name` class attribute. - Update variables in `tests/test_available.py` by importing your new Policy class @@ -133,11 +139,13 @@ Follow these steps to start contributing: Follow the instructions to [install poetry](https://python-poetry.org/docs/#installation) (use a version >=2.1.0) or to [install uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) if you don't have one of them already. Set up a development environment with conda or miniconda: + ```bash conda create -y -n lerobot-dev python=3.10 && conda activate lerobot-dev ``` If you're using `uv`, it can manage python versions so you can instead do: + ```bash uv venv --python 3.10 && source .venv/bin/activate ``` @@ -145,11 +153,13 @@ Follow these steps to start contributing: To develop on 🤗 LeRobot, you will at least need to install the `dev` and `test` extras dependencies along with the core library: using `poetry` + ```bash poetry sync --extras "dev test" ``` using `uv` + ```bash uv sync --extra dev --extra test ``` @@ -157,43 +167,48 @@ Follow these steps to start contributing: You can also install the project with all its dependencies (including environments): using `poetry` + ```bash poetry sync --all-extras ``` using `uv` + ```bash uv sync --all-extras ``` - > **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they *will* be tested in the CI. In general, we advise you to install everything and test locally before pushing. + > **Note:** If you don't install simulation environments with `--all-extras`, the tests that require them will be skipped when running the pytest suite locally. However, they _will_ be tested in the CI. In general, we advise you to install everything and test locally before pushing. Whichever command you chose to install the project (e.g. `poetry sync --all-extras`), you should run it again when pulling code with an updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the new dependencies. The equivalent of `pip install some-package`, would just be: using `poetry` + ```bash poetry add some-package ``` using `uv` + ```bash uv add some-package ``` When making changes to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies. using `poetry` + ```bash poetry lock ``` using `uv` + ```bash uv lock ``` - 5. Develop the features on your branch. As you work on the features, you should make sure that the test suite @@ -211,11 +226,13 @@ Follow these steps to start contributing: automatically as Git commit hooks. Install `pre-commit` hooks: + ```bash pre-commit install ``` You can run these hooks whenever you need on staged files with: + ```bash pre-commit ``` @@ -229,6 +246,7 @@ Follow these steps to start contributing: ``` Note, if you already committed some changes that have a wrong formatting, you can use: + ```bash pre-commit run --all-files ``` @@ -249,16 +267,15 @@ Follow these steps to start contributing: git push -u origin a-descriptive-name-for-my-changes ``` -6. Once you are satisfied (**and the checklist below is happy too**), go to the +7. Once you are satisfied (**and the checklist below is happy too**), go to the webpage of your fork on GitHub. Click on 'Pull request' to send your changes to the project maintainers for review. -7. It's ok if maintainers ask you for changes. It happens to core contributors +8. It's ok if maintainers ask you for changes. It happens to core contributors too! So everyone can see the changes in the Pull request, work in your local branch and push the changes to your fork. They will automatically appear in the pull request. - ### Checklist 1. The title of your pull request should be a summary of its contribution; @@ -277,18 +294,21 @@ An extensive test suite is included to test the library behavior and several exa Install [git lfs](https://git-lfs.com/) to retrieve test artifacts (if you don't have it already). On Mac: + ```bash brew install git-lfs git lfs install ``` On Ubuntu: + ```bash sudo apt-get install git-lfs git lfs install ``` Pull artifacts if they're not in [tests/artifacts](tests/artifacts) + ```bash git lfs pull ``` @@ -300,6 +320,5 @@ repository, here's how to run tests with `pytest` for the library: python -m pytest -sv ./tests ``` - You can specify a smaller set of tests in order to test only the feature you're working on. diff --git a/README.md b/README.md index ff7a92384..1d7cbcad4 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,6 @@ /> -

Meet the updated SO100, the SO-101 – Just €114 per arm!

Train it in minutes with a few simple moves on your laptop.

Then sit back and watch your creation act autonomously! 🤯

@@ -120,52 +119,61 @@ - Thanks to Antonio Loquercio and Ashish Kumar for their early support. - Thanks to [Seungjae (Jay) Lee](https://sjlee.cc/), [Mahi Shafiullah](https://mahis.life/) and colleagues for open sourcing [VQ-BeT](https://sjlee.cc/vq-bet/) policy and helping us adapt the codebase to our repository. The policy is adapted from [VQ-BeT repo](https://github.com/jayLEE0301/vq_bet_official). - ## Installation Download our source code: + ```bash git clone https://github.com/huggingface/lerobot.git cd lerobot ``` Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html): + ```bash conda create -y -n lerobot python=3.10 conda activate lerobot ``` When using `miniconda`, install `ffmpeg` in your environment: + ```bash conda install ffmpeg -c conda-forge ``` > **NOTE:** This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can: -> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using: -> ```bash -> conda install ffmpeg=7.1.1 -c conda-forge -> ``` -> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. +> +> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using: +> +> ```bash +> conda install ffmpeg=7.1.1 -c conda-forge +> ``` +> +> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. Install 🤗 LeRobot: + ```bash pip install -e . ``` > **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run: -`sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg) +> `sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg) For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras: + - [aloha](https://github.com/huggingface/gym-aloha) - [xarm](https://github.com/huggingface/gym-xarm) - [pusht](https://github.com/huggingface/gym-pusht) For instance, to install 🤗 LeRobot with aloha and pusht, use: + ```bash pip install -e ".[aloha, pusht]" ``` To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with + ```bash wandb login ``` @@ -177,6 +185,7 @@ wandb login Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub. You can also locally visualize episodes from a dataset on the hub by executing our script from the command line: + ```bash python -m lerobot.scripts.visualize_dataset \ --repo-id lerobot/pusht \ @@ -184,6 +193,7 @@ python -m lerobot.scripts.visualize_dataset \ ``` or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`) + ```bash python -m lerobot.scripts.visualize_dataset \ --repo-id lerobot/pusht \ @@ -192,19 +202,17 @@ python -m lerobot.scripts.visualize_dataset \ --episode-index 0 ``` - It will open `rerun.io` and display the camera streams, robot states and actions, like this: https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144 - Our script can also visualize datasets stored on a distant server. See `python -m lerobot.scripts.visualize_dataset --help` for more instructions. ### The `LeRobotDataset` format A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and PyTorch dataset. For instance `dataset[0]` will retrieve a single temporal frame from the dataset containing observation(s) and an action as PyTorch tensors ready to be fed to a model. -A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`. +A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`. Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states but easily extended to other types of sensory inputs as long as they can be represented by a tensor. @@ -239,6 +247,7 @@ dataset attributes: ``` A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely: + - hf_dataset stored using Hugging Face datasets library serialization to parquet - videos are stored in mp4 format to save space - metadata are stored in plain json/jsonl files @@ -250,6 +259,7 @@ Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment. We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht): + ```bash python -m lerobot.scripts.eval \ --policy.path=lerobot/diffusion_pusht \ @@ -284,9 +294,11 @@ Note: For efficiency, during training every checkpoint is evaluated on a low num We provide some pretrained policies on our [hub page](https://huggingface.co/lerobot) that can achieve state-of-the-art performances. You can reproduce their training by loading the config from their run. Simply running: + ```bash python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht ``` + reproduces SOTA results for Diffusion Policy on the PushT task. ## Contribute @@ -313,27 +325,29 @@ See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions If your dataset format is not supported, implement your own in `lerobot/datasets/push_dataset_to_hub/${raw_format}_format.py` by copying examples like [pusht_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/datasets/push_dataset_to_hub/pusht_zarr_format.py), [umi_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/datasets/push_dataset_to_hub/umi_zarr_format.py), [aloha_hdf5](https://github.com/huggingface/lerobot/blob/main/lerobot/datasets/push_dataset_to_hub/aloha_hdf5_format.py), or [xarm_pkl](https://github.com/huggingface/lerobot/blob/main/lerobot/datasets/push_dataset_to_hub/xarm_pkl_format.py). --> - ### Add a pretrained policy Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)). You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain: + - `config.json`: A serialized version of the policy configuration (following the policy's dataclass config). - `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format. - `train_config.json`: A consolidated configuration containing all parameters used for training. The policy configuration should match `config.json` exactly. This is useful for anyone who wants to evaluate your policy or for reproducibility. To upload these to the hub, run the following: + ```bash huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model ``` See [eval.py](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/eval.py) for an example of how other people may use your policy. - ### Improve your code with profiling An example of a code snippet to profile the evaluation of a policy: + + ```python from torch.profiler import profile, record_function, ProfilerActivity @@ -354,10 +368,12 @@ with profile( prof.step() # insert code to profile, potentially whole body of eval_policy function ``` + ## Citation If you want, you can cite this work with: + ```bibtex @misc{cadene2024lerobot, author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascale, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas}, @@ -368,7 +384,9 @@ If you want, you can cite this work with: ``` Additionally, if you are using any of the particular policy architecture, pretrained models, or datasets, it is recommended to cite the original authors of the work as they appear below: + - [SmolVLA](https://arxiv.org/abs/2506.01844) + ```bibtex @article{shukor2025smolvla, title={SmolVLA: A Vision-Language-Action Model for Affordable and Efficient Robotics}, @@ -379,6 +397,7 @@ Additionally, if you are using any of the particular policy architecture, pretra ``` - [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) + ```bibtex @article{chi2024diffusionpolicy, author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song}, @@ -387,7 +406,9 @@ Additionally, if you are using any of the particular policy architecture, pretra year = {2024}, } ``` + - [ACT or ALOHA](https://tonyzhaozh.github.io/aloha) + ```bibtex @article{zhao2023learning, title={Learning fine-grained bimanual manipulation with low-cost hardware}, @@ -409,6 +430,7 @@ Additionally, if you are using any of the particular policy architecture, pretra ``` - [VQ-BeT](https://sjlee.cc/vq-bet/) + ```bibtex @article{lee2024behavior, title={Behavior generation with latent actions}, @@ -418,8 +440,8 @@ Additionally, if you are using any of the particular policy architecture, pretra } ``` - - [HIL-SERL](https://hil-serl.github.io/) + ```bibtex @Article{luo2024hilserl, title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning}, @@ -430,6 +452,7 @@ archivePrefix={arXiv}, primaryClass={cs.RO} } ``` + ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=huggingface/lerobot&type=Timeline)](https://star-history.com/#huggingface/lerobot&Timeline) diff --git a/benchmarks/video/README.md b/benchmarks/video/README.md index daa3e1f48..490a4b495 100644 --- a/benchmarks/video/README.md +++ b/benchmarks/video/README.md @@ -1,28 +1,32 @@ # Video benchmark - ## Questions + What is the optimal trade-off between: + - maximizing loading time with random access, - minimizing memory space on disk, - maximizing success rate of policies, - compatibility across devices/platforms for decoding videos (e.g. video players, web browsers). How to encode videos? + - Which video codec (`-vcodec`) to use? h264, h265, AV1? - What pixel format to use (`-pix_fmt`)? `yuv444p` or `yuv420p`? - How much compression (`-crf`)? No compression with `0`, intermediate compression with `25` or extreme with `50+`? - Which frequency to chose for key frames (`-g`)? A key frame every `10` frames? How to decode videos? + - Which `decoder`? `torchvision`, `torchaudio`, `ffmpegio`, `decord`, or `nvc`? - What scenarios to use for the requesting timestamps during benchmark? (`timestamps_mode`) - ## Variables + **Image content & size** We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution). For these reasons, we run this benchmark on four representative datasets: + - `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera. - `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera. - `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera. @@ -34,8 +38,9 @@ Note: The datasets used for this benchmark need to be image datasets, not video We might revisit this benchmark and find better settings if we train our policies with various data augmentations to make them more robust (e.g. robust to color changes, compression, etc.). ### Encoding parameters + | parameter | values | -|-------------|--------------------------------------------------------------| +| ----------- | ------------------------------------------------------------ | | **vcodec** | `libx264`, `libx265`, `libsvtav1` | | **pix_fmt** | `yuv444p`, `yuv420p` | | **g** | `1`, `2`, `3`, `4`, `5`, `6`, `10`, `15`, `20`, `40`, `None` | @@ -44,19 +49,23 @@ We might revisit this benchmark and find better settings if we train our policie Note that `crf` value might be interpreted differently by various video codecs. In other words, the same value used with one codec doesn't necessarily translate into the same compression level with another codec. In fact, the default value (`None`) isn't the same amongst the different video codecs. Importantly, it is also the case for many other ffmpeg arguments like `g` which specifies the frequency of the key frames. For a comprehensive list and documentation of these parameters, see the ffmpeg documentation depending on the video codec used: + - h264: https://trac.ffmpeg.org/wiki/Encode/H.264 - h265: https://trac.ffmpeg.org/wiki/Encode/H.265 - AV1: https://trac.ffmpeg.org/wiki/Encode/AV1 ### Decoding parameters + **Decoder** We tested two video decoding backends from torchvision: + - `pyav` - `video_reader` (requires to build torchvision from source) **Requested timestamps** Given the way video decoding works, once a keyframe has been loaded, the decoding of subsequent frames is fast. This of course is affected by the `-g` parameter during encoding, which specifies the frequency of the keyframes. Given our typical use cases in robotics policies which might request a few timestamps in different random places, we want to replicate these use cases with the following scenarios: + - `1_frame`: 1 frame, - `2_frames`: 2 consecutive frames (e.g. `[t, t + 1 / fps]`), - `6_frames`: 6 consecutive frames (e.g. `[t + i / fps for i in range(6)]`) @@ -64,12 +73,13 @@ This of course is affected by the `-g` parameter during encoding, which specifie Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`. Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario: + - `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`), However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded. - ## Metrics + **Data compression ratio (lower is better)** `video_images_size_ratio` is the ratio of the memory space on disk taken by the encoded video over the memory space taken by the original images. For instance, `video_images_size_ratio=25%` means that the video takes 4 times less memory space on disk compared to the original images. @@ -87,18 +97,18 @@ However, due to how video decoding is implemented with `pyav`, we don't have acc One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes. h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility: + - `yuv420p` is more widely supported across various platforms, including web browsers. - `yuv444p` offers higher color fidelity but might not be supported as broadly. - - ## How the benchmark works + The benchmark evaluates both encoding and decoding of video frames on the first episode of each dataset. **Encoding:** for each `vcodec` and `pix_fmt` pair, we use a default value for `g` and `crf` upon which we change a single value (either `g` or `crf`) to one of the specified values (we don't test every combination of those as this would be computationally too heavy). @@ -110,15 +120,18 @@ Intermediate results saved for each `vcodec` and `pix_fmt` combination in csv ta These are then all concatenated to a single table ready for analysis. ## Caveats + We tried to measure the most impactful parameters for both encoding and decoding. However, for computational reasons we can't test out every combination. Additional encoding parameters exist that are not included in this benchmark. In particular: + - `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1. - `-tune` which allows to optimize the encoding for certain aspects (e.g. film quality, fast decoding, etc.). See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters. Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few: + - `torchaudio` - `ffmpegio` - `decord` @@ -127,16 +140,17 @@ Similarly on the decoding side, other decoders exist but are not implemented in Note as well that since we are mostly interested in the performance at decoding time (also because encoding is done only once before uploading a dataset), we did not measure encoding times nor have any metrics regarding encoding. However, besides the necessity to build ffmpeg from source, encoding did not pose any issue and it didn't take a significant amount of time during this benchmark. - ## Install + Building ffmpeg from source is required to include libx265 and libaom/libsvtav1 (av1) video codecs ([compilation guide](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu)). **Note:** While you still need to build torchvision with a conda-installed `ffmpeg<4.3` to use the `video_reader` decoder (as described in [#220](https://github.com/huggingface/lerobot/pull/220)), you also need another version which is custom-built with all the video codecs for encoding. For the script to then use that version, you can prepend the command above with `PATH="$HOME/bin:$PATH"`, which is where ffmpeg should be built. - ## Adding a video decoder + Right now, we're only benchmarking the two video decoder available with torchvision: `pyav` and `video_reader`. You can easily add a new decoder to benchmark by adding it to this function in the script: + ```diff def decode_video_frames( video_path: str, @@ -156,9 +170,10 @@ def decode_video_frames( raise NotImplementedError(backend) ``` - ## Example + For a quick run, you can try these parameters: + ```bash python benchmark/video/run_video_benchmark.py \ --output-dir outputs/video_benchmark \ @@ -176,11 +191,12 @@ python benchmark/video/run_video_benchmark.py \ --save-frames 0 ``` - ## Results ### Reproduce + We ran the benchmark with the following parameters: + ```bash # h264 and h265 encodings python benchmark/video/run_video_benchmark.py \ @@ -221,9 +237,10 @@ python benchmark/video/run_video_benchmark.py \ The full results are available [here](https://docs.google.com/spreadsheets/d/1OYJB43Qu8fC26k_OyoMFgGBBKfQRCi4BIuYitQnq3sw/edit?usp=sharing) - ### Parameters selected for LeRobotDataset + Considering these results, we chose what we think is the best set of encoding parameter: + - vcodec: `libsvtav1` - pix-fmt: `yuv420p` - g: `2` @@ -236,7 +253,7 @@ Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_read These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav` | video_images_size_ratio | vcodec | pix_fmt | | | | -|------------------------------------|------------|---------|-----------|-----------|-----------| +| ---------------------------------- | ---------- | ------- | --------- | --------- | --------- | | | libx264 | | libx265 | | libsvtav1 | | repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | | lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% | @@ -245,7 +262,7 @@ These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_ | aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% | | video_images_load_time_ratio | vcodec | pix_fmt | | | | -|------------------------------------|---------|---------|----------|---------|-----------| +| ---------------------------------- | ------- | ------- | -------- | ------- | --------- | | | libx264 | | libx265 | | libsvtav1 | | repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | | lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 | @@ -254,7 +271,7 @@ These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_ | aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** | | | | vcodec | pix_fmt | | | | -|------------------------------------|----------|----------|--------------|----------|-----------|--------------| +| ---------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ | | | | libx264 | | libx265 | | libsvtav1 | | repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | | lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 | diff --git a/docs/README.md b/docs/README.md index 275fee46b..967de7b84 100644 --- a/docs/README.md +++ b/docs/README.md @@ -26,6 +26,7 @@ pip install -e ".[docs]" You will also need `nodejs`. Please refer to their [installation page](https://nodejs.org/en/download) --- + **NOTE** You only need to generate the documentation to inspect it locally (if you're planning changes and want to @@ -63,6 +64,7 @@ doc-builder preview lerobot docs/source/ The docs will be viewable at [http://localhost:3000](http://localhost:3000). You can also preview the docs once you have opened a PR. You will see a bot add a comment to a link where the documentation with your changes lives. --- + **NOTE** The `preview` command only works with existing doc files. When you add a completely new file, you need to update `_toctree.yml` & restart `preview` command (`ctrl-c` to stop it & call `doc-builder preview ...` again). @@ -89,6 +91,7 @@ Sections that were moved: [ Section A ] ``` + and of course, if you moved it to another file, then: ``` @@ -119,7 +122,6 @@ and objects like True, None or any strings should usually be put in `code`. Multi-line code blocks can be useful for displaying examples. They are done between two lines of three backticks as usual in Markdown: - ```` ``` # first line of code diff --git a/docs/source/async.mdx b/docs/source/async.mdx index 6ff05a88a..397c513cf 100644 --- a/docs/source/async.mdx +++ b/docs/source/async.mdx @@ -5,17 +5,18 @@ In this tutorial, we'll show how to use asynchronous inference (_async inference **Try async inference with all the policies** supported by LeRobot! **What you'll learn:** + 1. Why asynchronous inference matters and how it compares to, more traditional, sequential inference. 2. How to spin-up a `PolicyServer` and connect a `RobotClient` from the same machine, and even over the network. 3. How to tune key parameters (`actions_per_chunk`, `chunk_size_threshold`) for your robot and policy. If you get stuck, hop into our [Discord community](https://discord.gg/s3KuuzsPFb)! - -In a nutshell: with *async inference*, your robot keeps acting while the policy server is already busy computing the next chunk of actions---eliminating "wait-for-inference" lags and unlocking smoother, more reactive behaviours. +In a nutshell: with _async inference_, your robot keeps acting while the policy server is already busy computing the next chunk of actions---eliminating "wait-for-inference" lags and unlocking smoother, more reactive behaviours. This is fundamentally different from synchronous inference (sync), where the robot stays idle while the policy computes the next chunk of actions. --- + ## Getting started with async inference You can read more information on asynchronous inference in our [blogpost](https://huggingface.co/blog/async-robot-inference). This guide is designed to help you quickly set up and run asynchronous inference in your environment. @@ -53,40 +54,53 @@ python src/lerobot/scripts/server/robot_client.py \ --aggregate_fn_name=weighted_average \ # CLIENT: the function to aggregate actions on overlapping portions --debug_visualize_queue_size=True # CLIENT: whether to visualize the queue size at runtime ``` + In summary, you need to specify instructions for: + - `SERVER`: the address and port of the policy server - `ROBOT`: the type of robot to connect to, the port to connect to, and the local `id` of the robot - `POLICY`: the type of policy to run, and the model name/path on server to the checkpoint to run. You also need to specify which device should the sever be using, and how many actions to output at once (capped at the policy max actions value). - `CLIENT`: the threshold for the chunk size before sending a new observation to the server, and the function to aggregate actions on overlapping portions. Optionally, you can also visualize the queue size at runtime, to help you tune the `CLIENT` parameters. Importantly, + - `actions_per_chunk` and `chunk_size_threshold` are key parameters to tune for your setup. - `aggregate_fn_name` is the function to aggregate actions on overlapping portions. You can either add a new one to a registry of functions, or add your own in `robot_client.py` (see [here](NOTE:addlinktoLOC)) - `debug_visualize_queue_size` is a useful tool to tune the `CLIENT` parameters. -Done! You should see your robot moving around by now 😉 ---- +## Done! You should see your robot moving around by now 😉 ## Async vs. synchronous inference -Synchronous inference relies on interleaving action chunk prediction and action execution. This inherently results in *idle frames*, frames where the robot awaits idle the policy's output: a new action chunk. +Synchronous inference relies on interleaving action chunk prediction and action execution. This inherently results in _idle frames_, frames where the robot awaits idle the policy's output: a new action chunk. In turn, inference is plagued by evident real-time lags, where the robot simply stops acting due to the lack of available actions. With robotics models increasing in size, this problem risks becoming only more severe.

- + +

+

+ Synchronous inference makes the robot idle while the policy is + computing the next chunk of actions.

-

Synchronous inference makes the robot idle while the policy is computing the next chunk of actions.

To overcome this, we design async inference, a paradigm where action planning and execution are decoupled, resulting in (1) higher adaptability and, most importantly, (2) no idle frames. -Crucially, with async inference, the next action chunk is computed *before* the current one is exhausted, resulting in no idleness. +Crucially, with async inference, the next action chunk is computed _before_ the current one is exhausted, resulting in no idleness. Higher adaptability is ensured by aggregating the different action chunks on overlapping portions, obtaining an up-to-date plan and a tighter control loop.

- + +

+

+ Asynchronous inference results in no idleness because the next chunk is + computed before the current chunk is exhausted.

-

Asynchronous inference results in no idleness because the next chunk is computed before the current chunk is exhausted.

- --- @@ -105,6 +119,8 @@ python -m lerobot.scripts.server.policy_server \ ``` + + ```python from lerobot.scripts.server.configs import PolicyServerConfig from lerobot.scripts.server.policy_server import serve @@ -115,6 +131,8 @@ config = PolicyServerConfig( ) serve(config) ``` + + @@ -147,6 +165,8 @@ python src/lerobot/scripts/server/robot_client.py \ ``` + + ```python import threading from lerobot.robots.so100_follower import SO100FollowerConfig @@ -201,6 +221,8 @@ if client.start(): # (Optionally) plot the action queue size visualize_action_queue_size(client.action_queue_size) ``` + + @@ -216,20 +238,30 @@ The following two parameters are key in every setup: - actions_per_chunk + + actions_per_chunk + 50 - How many actions the policy outputs at once. Typical values: 10-50. + + How many actions the policy outputs at once. Typical values: 10-50. + - chunk_size_threshold + + chunk_size_threshold + 0.7 - When the queue is ≤ 50% full, the client sends a fresh observation. Value in [0, 1]. + + When the queue is ≤ 50% full, the client sends a fresh observation. + Value in [0, 1]. + -Different values of `actions_per_chunk` and `chunk_size_threshold` do result in different behaviours. + Different values of `actions_per_chunk` and `chunk_size_threshold` do result + in different behaviours. On the one hand, increasing the value of `actions_per_chunk` will result in reducing the likelihood of ending up with no actions to execute, as more actions will be available when the new chunk is computed. @@ -249,10 +281,18 @@ We found the default values of `actions_per_chunk` and `chunk_size_threshold` to - We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug-visualize-queue-size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.

- + +

+

+ + The action queue size is plotted at runtime when the + `--debug-visualize-queue-size` flag is passed, for various levels of + `chunk_size_threshold` (`g` in the SmolVLA paper). +

-

The action queue size is plotted at runtime when the `--debug-visualize-queue-size` flag is passed, for various levels of `chunk_size_threshold` (`g` in the SmolVLA paper).

- --- diff --git a/docs/source/backwardcomp.mdx b/docs/source/backwardcomp.mdx index 555239170..0e1d01636 100644 --- a/docs/source/backwardcomp.mdx +++ b/docs/source/backwardcomp.mdx @@ -6,21 +6,22 @@ PR [#777](https://github.com/huggingface/lerobot/pull/777) improves the LeRobot ### What changed? -| | Before PR #777 | After PR #777 | -| --------------------------------- | ------------------------------------------------- | --------------------------------------------------------------------------- | -| **Joint range** | Degrees `-180...180°` | **Normalised range** Joints: `–100...100` Gripper: `0...100` | -| **Zero position (SO100 / SO101)** | Arm fully extended horizontally | **In middle of the range for each joint** | -| **Boundary handling** | Software safeguards to detect ±180 ° wrap-arounds | No wrap-around logic needed due to mid-range zero | +| | Before PR #777 | After PR #777 | +| --------------------------------- | ------------------------------------------------- | ------------------------------------------------------------ | +| **Joint range** | Degrees `-180...180°` | **Normalised range** Joints: `–100...100` Gripper: `0...100` | +| **Zero position (SO100 / SO101)** | Arm fully extended horizontally | **In middle of the range for each joint** | +| **Boundary handling** | Software safeguards to detect ±180 ° wrap-arounds | No wrap-around logic needed due to mid-range zero | --- ### Impact on existing datasets -* Recorded trajectories created **before** PR #777 will replay incorrectly if loaded directly: - * Joint angles are offset and incorrectly normalized. -* Any models directly finetuned or trained on the old data will need their inputs and outputs converted. +- Recorded trajectories created **before** PR #777 will replay incorrectly if loaded directly: + - Joint angles are offset and incorrectly normalized. +- Any models directly finetuned or trained on the old data will need their inputs and outputs converted. ### Using datasets made with the previous calibration system + We provide a migration example script for replaying an episode recorded with the previous calibration here: `examples/backward_compatibility/replay.py`. Below we take you through the modifications that are done in the example script to make the previous calibration datasets work. @@ -33,20 +34,31 @@ Below we take you through the modifications that are done in the example script Let's break this down. New codebase uses `.pos` suffix for the position observations and we have removed `main_` prefix: + + ```python key = f"{name.removeprefix('main_')}.pos" ``` + For `"shoulder_lift"` (id = 2), the 0 position is changed by -90 degrees and the direction is reversed compared to old calibration/code. + + ```python action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90) ``` + + For `"elbow_flex"` (id = 3), the 0 position is changed by -90 degrees compared to old calibration/code. + + ```python action["elbow_flex.pos"] -= 90 ``` + To use degrees normalization we then set the `--robot.use_degrees` option to `true`. + ```diff python examples/backward_compatibility/replay.py \ --robot.type=so101_follower \ @@ -63,6 +75,7 @@ Policies output actions in the same format as the datasets (`torch.Tensors`). Th To find these transformations, we recommend to first try and and replay an episode of the dataset your policy was trained on using the section above. Then, add these same transformations on your inference script (shown here in the `record.py` script): + ```diff action_values = predict_action( observation_frame, diff --git a/docs/source/cameras.mdx b/docs/source/cameras.mdx index 313d5a7cd..604863d74 100644 --- a/docs/source/cameras.mdx +++ b/docs/source/cameras.mdx @@ -7,11 +7,13 @@ LeRobot offers multiple options for video capture, including phone cameras, buil To instantiate a camera, you need a camera identifier. This identifier might change if you reboot your computer or re-plug your camera, a behavior mostly dependant on your operating system. To find the camera indices of the cameras plugged into your system, run the following script: + ```bash python -m lerobot.find_cameras opencv # or realsense for Intel Realsense cameras ``` The output will look something like this if you have two cameras connected: + ``` --- Detected Cameras --- Camera #0: @@ -31,7 +33,6 @@ Camera #0: > [!WARNING] > When using Intel RealSense cameras in `macOS`, you could get this [error](https://github.com/IntelRealSense/librealsense/issues/12307): `Error finding RealSense cameras: failed to set power state`, this can be solved by running the same command with `sudo` permissions. Note that using RealSense cameras in `macOS` is unstable. - ## Use Cameras Below are two examples, demonstrating how to work with the API. @@ -39,10 +40,10 @@ Below are two examples, demonstrating how to work with the API. - **Asynchronous frame capture** using an OpenCV-based camera - **Color and depth capture** using an Intel RealSense camera - + ```python from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.cameras.opencv.camera_opencv import OpenCVCamera @@ -70,10 +71,12 @@ try: finally: camera.disconnect() ``` + + ```python from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig from lerobot.cameras.realsense.camera_realsense import RealSenseCamera @@ -103,15 +106,18 @@ try: finally: camera.disconnect() ``` + + - ## Use your phone + To use your iPhone as a camera on macOS, enable the Continuity Camera feature: + - Ensure your Mac is running macOS 13 or later, and your iPhone is on iOS 16 or later. - Sign in both devices with the same Apple ID. - Connect your devices with a USB cable or turn on Wi-Fi and Bluetooth for a wireless connection. @@ -125,40 +131,67 @@ Your iPhone should be detected automatically when running the camera setup scrip If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera -1. *Install `v4l2loopback-dkms` and `v4l-utils`*. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using: +1. _Install `v4l2loopback-dkms` and `v4l-utils`_. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using: + + ```python sudo apt install v4l2loopback-dkms v4l-utils ``` -2. *Install [DroidCam](https://droidcam.app) on your phone*. This app is available for both iOS and Android. -3. *Install [OBS Studio](https://obsproject.com)*. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org): + + +2. _Install [DroidCam](https://droidcam.app) on your phone_. This app is available for both iOS and Android. +3. _Install [OBS Studio](https://obsproject.com)_. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org): + + ```python flatpak install flathub com.obsproject.Studio ``` -4. *Install the DroidCam OBS plugin*. This plugin integrates DroidCam with OBS Studio. Install it with: + + +4. _Install the DroidCam OBS plugin_. This plugin integrates DroidCam with OBS Studio. Install it with: + + ```python flatpak install flathub com.obsproject.Studio.Plugin.DroidCam ``` -5. *Start OBS Studio*. Launch with: + + +5. _Start OBS Studio_. Launch with: + + ```python flatpak run com.obsproject.Studio ``` -6. *Add your phone as a source*. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`. -7. *Adjust resolution settings*. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in. -8. *Start virtual camera*. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide). -9. *Verify the virtual camera setup*. Use `v4l2-ctl` to list the devices: + + +6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`. +7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in. +8. _Start virtual camera_. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide). +9. _Verify the virtual camera setup_. Use `v4l2-ctl` to list the devices: + + ```python v4l2-ctl --list-devices ``` + + You should see an entry like: + ``` VirtualCam (platform:v4l2loopback-000): /dev/video1 ``` -10. *Check the camera resolution*. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`. + +10. _Check the camera resolution_. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`. + + ```python v4l2-ctl -d /dev/video1 --get-fmt-video ``` + + You should see an entry like: + ``` >>> Format Video Capture: >>> Width/Height : 640/480 diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx index b3ab40c89..c647a58d5 100644 --- a/docs/source/hilserl.mdx +++ b/docs/source/hilserl.mdx @@ -4,18 +4,22 @@ In this tutorial you will go through the full Human-in-the-Loop Sample-Efficient HIL-SERL is a sample-efficient reinforcement learning algorithm that combines human demonstrations with online learning and human interventions. The approach starts from a small set of human demonstrations, uses them to train a reward classifier, and then employs an actor-learner architecture where humans can intervene during policy execution to guide exploration and correct unsafe behaviors. In this tutorial, you'll use a gamepad to provide interventions and control the robot during the learning process. -It combines three key ingredients: - 1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. - 2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. - 3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe. +It combines three key ingredients: 1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. 2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. 3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe. Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines.

- HIL-SERL workflow + HIL-SERL workflow

-

HIL-SERL workflow, Luo et al. 2024

+

+ HIL-SERL workflow, Luo et al. 2024 +

This guide provides step-by-step instructions for training a robot policy using LeRobot's HilSerl implementation to train on a real robot. @@ -29,6 +33,7 @@ This guide provides step-by-step instructions for training a robot policy using ## What kind of tasks can I train? One can use HIL-SERL to train on a variety of manipulation tasks. Some recommendations: + - Start with a simple task to understand how the system works. - Push cube to a goal region - Pick and lift cube with the gripper @@ -53,6 +58,7 @@ pip install -e ".[hilserl]" The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/envs/configs.py`. Which is defined as: + ```python class HILSerlRobotEnvConfig(EnvConfig): robot: RobotConfig | None = None # Main robot agent (defined in `lerobot/robots`) @@ -72,7 +78,7 @@ class HILSerlRobotEnvConfig(EnvConfig): reward_classifier_pretrained_path: str | None = None # For reward model number_of_steps_after_success: int = 0 # For reward classifier, collect more positive examples after a success to train a classifier ``` - + ### Finding Robot Workspace Bounds @@ -131,6 +137,7 @@ Create a configuration file for recording demonstrations (or edit an existing on 5. Configure `robot`, `cameras`, and other hardware settings Example configuration section: + ```json "mode": "record", "repo_id": "username/pick_lift_cube", @@ -150,6 +157,7 @@ HIL-Serl learns actions in the end-effector space of the robot. Therefore, the t For that we need to define a version of the robot that takes actions in the end-effector space. Check the robot class `SO100FollowerEndEffector` and its configuration `SO100FollowerEndEffectorConfig` for the default parameters related to the end-effector space. + ```python class SO100FollowerEndEffectorConfig(SO100FollowerConfig): """Configuration for the SO100FollowerEndEffector robot.""" @@ -172,6 +180,7 @@ class SO100FollowerEndEffectorConfig(SO100FollowerConfig): } ) ``` + The `Teleoperator` defines the teleoperation device. You can check the list of available teleoperators in `lerobot/teleoperators`. @@ -189,9 +198,16 @@ To setup the gamepad, you need to set the `control_mode` to `"gamepad"` and defi ```

- Figure shows the control mappings on a Logitech gamepad. + Figure shows the control mappings on a Logitech gamepad. +

+

+ Gamepad button mapping for robot control and episode management

-

Gamepad button mapping for robot control and episode management

**Setting up the SO101 leader** @@ -215,7 +231,10 @@ During the online training, press `space` to take over the policy and `space` ag
@@ -231,6 +250,7 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/e ``` During recording: + 1. The robot will reset to the initial position defined in the configuration file `fixed_reset_joint_positions` 2. Complete the task successfully 3. The episode ends with a reward of 1 when you press the "success" button @@ -239,13 +259,13 @@ During recording: 6. The process automatically continues to the next episode 7. After recording all episodes, the dataset is pushed to the Hugging Face Hub (optional) and saved locally - ### Processing the Dataset After collecting demonstrations, process them to determine optimal camera crops. Reinforcement learning is sensitive to background distractions, so it is important to crop the images to the relevant workspace area. Visual RL algorithms learn directly from pixel inputs, making them vulnerable to irrelevant visual information. Background elements like changing lighting, shadows, people moving, or objects outside the workspace can confuse the learning process. Good ROI selection should: + - Include only the essential workspace where the task happens - Capture the robot's end-effector and all objects involved in the task - Exclude unnecessary background elements and distractions @@ -267,6 +287,7 @@ python -m lerobot.scripts.rl.crop_dataset_roi --repo-id username/pick_lift_cube 5. The script outputs cropping parameters and creates a new cropped dataset Example output: + ``` Selected Rectangular Regions of Interest (top, left, height, width): observation.images.side: [180, 207, 180, 200] @@ -274,11 +295,15 @@ observation.images.front: [180, 250, 120, 150] ```

- +

-

Interactive cropping tool for selecting regions of interest

- +

+ Interactive cropping tool for selecting regions of interest +

**Updating Configuration** @@ -294,8 +319,7 @@ Add these crop parameters to your training configuration: **Recommended image resolution** -Most vision-based policies have been validated on square inputs of either **128×128** (default) or **64×64** pixels. We therefore advise setting the resize_size parameter to [128, 128] – or [64, 64] if you need to save GPU memory and bandwidth. Other resolutions are possible but have not been extensively tested. - +Most vision-based policies have been validated on square inputs of either **128×128** (default) or **64×64** pixels. We therefore advise setting the resize_size parameter to [128, 128] – or [64, 64] if you need to save GPU memory and bandwidth. Other resolutions are possible but have not been extensively tested. ### Training a Reward Classifier @@ -332,13 +356,13 @@ Example configuration section for data collection: ```json { - "mode": "record", - "repo_id": "hf_username/dataset_name", - "dataset_root": "data/your_dataset", - "num_episodes": 20, - "push_to_hub": true, - "fps": 10, - "number_of_steps_after_success": 15 + "mode": "record", + "repo_id": "hf_username/dataset_name", + "dataset_root": "data/your_dataset", + "num_episodes": 20, + "push_to_hub": true, + "fps": 10, + "number_of_steps_after_success": 15 } ``` @@ -395,21 +419,25 @@ python -m lerobot.scripts.train --config_path path/to/reward_classifier_train_co To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to use your model: + ```python env_config = HILSerlRobotEnvConfig( reward_classifier_pretrained_path="path_to_your_pretrained_trained_model", # Other environment parameters ) ``` + + or set the argument in the json config file. ```json { - "reward_classifier_pretrained_path": "path_to_your_pretrained_model" + "reward_classifier_pretrained_path": "path_to_your_pretrained_model" } ``` Run `gym_manipulator.py` to test the model. + ```bash python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/env_config.json ``` @@ -422,11 +450,13 @@ The reward classifier will automatically provide rewards based on the visual inp Create the necessary json configuration files for the reward classifier and the environment. Check the examples [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/tree/main). 2. **Collect a dataset**: + ```bash python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/env_config.json ``` 3. **Train the classifier**: + ```bash python -m lerobot.scripts.train --config_path src/lerobot/configs/reward_classifier_train_config.json ``` @@ -459,6 +489,7 @@ python -m lerobot.scripts.rl.learner --config_path src/lerobot/configs/train_con ``` The learner: + - Initializes the policy network - Prepares replay buffers - Opens a `gRPC` server to communicate with actors @@ -473,6 +504,7 @@ python -m lerobot.scripts.rl.actor --config_path src/lerobot/configs/train_confi ``` The actor: + - Connects to the learner via `gRPC` - Initializes the environment - Execute rollouts of the policy to collect experience @@ -496,10 +528,19 @@ The training proceeds automatically: - A successful experiment is one where the human has to intervene at the start but then reduces the amount of interventions as the policy improves. You can monitor the intervention rate in the `wandb` dashboard.

- Figure shows the control mappings on a Logitech gamepad. + Figure shows the control mappings on a Logitech gamepad.

-

Example showing how human interventions help guide policy learning over time

+

+ + Example showing how human interventions help guide policy learning over time + +

- The figure shows the plot of the episodic reward over interaction step. The figure shows the effect of human interventions on the policy learning. - The orange curve is an experiment without any human interventions. While the pink and blue curves are experiments with human interventions. @@ -510,7 +551,9 @@ The training proceeds automatically: If you have `wandb.enable` set to `true` in your configuration, you can monitor training progress in real-time through the [Weights & Biases](https://wandb.ai/site/) dashboard. ### Guide to Human Interventions + The learning process is very sensitive to the intervention strategy. It will takes a few runs to understand how to intervene effectively. Some tips and hints: + - Allow the policy to explore for a few episodes at the start of training. - Avoid intervening for long periods of time. Try to intervene in situation to correct the robot's behaviour when it goes off track. - Once the policy starts achieving the task, even if its not perfect, you can limit your interventions to simple quick actions like a simple grasping commands. @@ -518,26 +561,36 @@ The learning process is very sensitive to the intervention strategy. It will tak The ideal behaviour is that your intervention rate should drop gradually during training as shown in the figure below.

- Intervention rate + Intervention rate

-

Plot of the intervention rate during a training run on a pick and lift cube task

+

+ + Plot of the intervention rate during a training run on a pick and lift cube + task + +

### Key hyperparameters to tune Some configuration values have a disproportionate impact on training stability and speed: - **`temperature_init`** (`policy.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning. -- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in *seconds* between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency. +- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in _seconds_ between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency. - **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second. - Congrats 🎉, you have finished this tutorial! > [!TIP] -> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). Paper citation: + ``` @article{luo2024precise, title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning}, diff --git a/docs/source/hilserl_sim.mdx b/docs/source/hilserl_sim.mdx index ad7a9584a..c739be835 100644 --- a/docs/source/hilserl_sim.mdx +++ b/docs/source/hilserl_sim.mdx @@ -11,7 +11,6 @@ This guide explains how to use the `gym_hil` simulation environments as an alter Currently, the main environment is a Franka Panda robot simulation based on MuJoCo, with tasks like picking up a cube. - ## Installation First, install the `gym_hil` package within the LeRobot environment: @@ -25,8 +24,6 @@ pip install -e ".[hilserl]" - A gamepad or keyboard to control the robot - A Nvidia GPU - - ## Configuration To use `gym_hil` with LeRobot, you need to create a configuration file. An example is provided [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/gym_hil_env.json). Key configuration sections include: @@ -35,14 +32,15 @@ To use `gym_hil` with LeRobot, you need to create a configuration file. An examp ```json { - "type": "hil", - "name": "franka_sim", - "task": "PandaPickCubeGamepad-v0", - "device": "cuda" + "type": "hil", + "name": "franka_sim", + "task": "PandaPickCubeGamepad-v0", + "device": "cuda" } ``` Available tasks: + - `PandaPickCubeBase-v0`: Basic environment - `PandaPickCubeGamepad-v0`: With gamepad control - `PandaPickCubeKeyboard-v0`: With keyboard control @@ -65,6 +63,7 @@ Available tasks: ``` Important parameters: + - `gripper_penalty`: Penalty for excessive gripper movement - `use_gripper`: Whether to enable gripper control - `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector @@ -76,40 +75,49 @@ Important parameters: To run the environment, set mode to null: + ```python python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json ``` + ### Recording a Dataset To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record: + ```python python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json ``` + ### Training a Policy To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_gym_hil_env.json) and run the actor and learner servers: + ```python python -m lerobot.scripts.rl.actor --config_path path/to/train_gym_hil_env.json ``` + In a different terminal, run the learner server: + ```python python -m lerobot.scripts.rl.learner --config_path path/to/train_gym_hil_env.json ``` + The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots. Congrats 🎉, you have finished this tutorial! > [!TIP] -> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). Paper citation: + ``` @article{luo2024precise, title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning}, diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 2e8ac3619..b18adb8f4 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -3,6 +3,7 @@ This tutorial will explain how to train a neural network to control a real robot autonomously. **You'll learn:** + 1. How to record and visualize your dataset. 2. How to train a policy using your data and prepare it for evaluation. 3. How to evaluate your policy and visualize the results. @@ -14,7 +15,10 @@ By following these steps, you'll be able to replicate tasks, such as picking up
@@ -51,6 +55,8 @@ python -m lerobot.teleoperate \ ```
+ + ```python from lerobot.teleoperators.so101_leader import SO101LeaderConfig, SO101Leader from lerobot.robots.so101_follower import SO101FollowerConfig, SO101Follower @@ -74,10 +80,13 @@ while True: action = teleop_device.get_action() robot.send_action(action) ``` + +
The teleoperate command will automatically: + 1. Identify any missing calibrations and initiate the calibration procedure. 2. Connect the robot and teleop device and start teleoperation. @@ -104,6 +113,8 @@ python -m lerobot.teleoperate \ ``` + + ```python from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.teleoperators.koch_leader import KochLeaderConfig, KochLeader @@ -134,6 +145,8 @@ while True: action = teleop_device.get_action() robot.send_action(action) ``` + + @@ -144,11 +157,13 @@ Once you're familiar with teleoperation, you can record your first dataset. We use the Hugging Face hub features for uploading your dataset. If you haven't previously used the Hub, make sure you can login via the cli using a write-access token, this token can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens). Add your token to the CLI by running this command: + ```bash huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential ``` Then store your Hugging Face repository name in a variable: + ```bash HF_USER=$(huggingface-cli whoami | head -n 1) echo $HF_USER @@ -174,6 +189,8 @@ python -m lerobot.record \ ``` + + ```python from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.datasets.lerobot_dataset import LeRobotDataset @@ -270,40 +287,49 @@ robot.disconnect() teleop.disconnect() dataset.push_to_hub() ``` + + #### Dataset upload + Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running: + ```bash echo https://huggingface.co/datasets/${HF_USER}/so101_test ``` + Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example). You can look for other LeRobot datasets on the hub by searching for `LeRobot` [tags](https://huggingface.co/datasets?other=LeRobot). You can also push your local dataset to the Hub manually, running: + ```bash huggingface-cli upload ${HF_USER}/record-test ~/.cache/huggingface/lerobot/{repo-id} --repo-type dataset ``` - #### Record function The `record` function provides a suite of tools for capturing and managing data during robot operation: ##### 1. Data Storage + - Data is stored using the `LeRobotDataset` format and is stored on disk during recording. - By default, the dataset is pushed to your Hugging Face page after recording. - To disable uploading, use `--dataset.push_to_hub=False`. ##### 2. Checkpointing and Resuming + - Checkpoints are automatically created during recording. - If an issue occurs, you can resume by re-running the same command with `--resume=true`. - To start recording from scratch, **manually delete** the dataset directory. ##### 3. Recording Parameters + Set the flow of data recording using command-line arguments: + - `--dataset.episode_time_s=60` Duration of each data recording episode (default: **60 seconds**). - `--dataset.reset_time_s=60` @@ -312,7 +338,9 @@ Set the flow of data recording using command-line arguments: Total number of episodes to record (default: **50**). ##### 4. Keyboard Controls During Recording + Control the data recording flow using keyboard shortcuts: + - Press **Right Arrow (`→`)**: Early stop the current episode or reset time and move to the next. - Press **Left Arrow (`←`)**: Cancel the current episode and re-record it. - Press **Escape (`ESC`)**: Immediately stop the session, encode videos, and upload the dataset. @@ -327,13 +355,14 @@ Avoid adding too much variation too quickly, as it may hinder your results. If you want to dive deeper into this important topic, you can check out the [blog post](https://huggingface.co/blog/lerobot-datasets#what-makes-a-good-dataset) we wrote on what makes a good dataset. - #### Troubleshooting: + - On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux). ## Visualize a dataset If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by: + ```bash echo ${HF_USER}/so101_test ``` @@ -356,6 +385,8 @@ python -m lerobot.replay \ ``` + + ```python import time @@ -388,6 +419,8 @@ for idx in range(dataset.num_frames): robot.disconnect() ``` + + @@ -396,6 +429,7 @@ Your robot should replicate movements similar to those you recorded. For example ## Train a policy To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: + ```bash python -m lerobot.scripts.train \ --dataset.repo_id=${HF_USER}/so101_test \ @@ -408,14 +442,16 @@ python -m lerobot.scripts.train \ ``` Let's explain the command: + 1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so101_test`. 2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. -4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. -5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. +3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. +4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. Training should take several hours. You will find checkpoints in `outputs/train/act_so101_test/checkpoints`. To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so101_test` policy: + ```bash python -m lerobot.scripts.train \ --config_path=outputs/train/act_so101_test/checkpoints/last/pretrained_model/train_config.json \ @@ -427,17 +463,20 @@ If you do not want to push your model to the hub after training use `--policy.pu Additionally you can provide extra `tags` or specify a `license` for your model or make the model repo `private` by adding this: `--policy.private=true --policy.tags=\[ppo,rl\] --policy.license=mit` #### Train using Collab + If your local computer doesn't have a powerful GPU you could utilize Google Collab to train your model by following the [ACT training notebook](./notebooks#training-act). #### Upload policy checkpoints Once training is done, upload the latest checkpoint with: + ```bash huggingface-cli upload ${HF_USER}/act_so101_test \ outputs/train/act_so101_test/checkpoints/last/pretrained_model ``` You can also upload intermediate checkpoints with: + ```bash CKPT=010000 huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \ @@ -467,6 +506,8 @@ python -m lerobot.record \ ``` + + ```python from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.datasets.lerobot_dataset import LeRobotDataset @@ -539,9 +580,12 @@ for episode_idx in range(NUM_EPISODES): robot.disconnect() dataset.push_to_hub() ``` + + As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: -1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`). + +1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`). 2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so101_test`). diff --git a/docs/source/il_sim.mdx b/docs/source/il_sim.mdx index 048d3147e..193b09b1b 100644 --- a/docs/source/il_sim.mdx +++ b/docs/source/il_sim.mdx @@ -3,6 +3,7 @@ This tutorial will explain how to train a neural network to control a robot in simulation with imitation learning. **You'll learn:** + 1. How to record a dataset in simulation with [gym-hil](https://github.com/huggingface/gym-hil) and visualize the dataset. 2. How to train a policy using your data. 3. How to evaluate your policy in simulation and visualize the results. @@ -55,13 +56,21 @@ Note that to teleoperate the robot you have to hold the "Human Take Over Pause P **Gamepad Controls**

- Figure shows the control mappings on a Logitech gamepad. + Figure shows the control mappings on a Logitech gamepad. +

+

+ Gamepad button mapping for robot control and episode management

-

Gamepad button mapping for robot control and episode management

**Keyboard controls** For keyboard controls use the `spacebar` to enable control and the following keys to move the robot: + ```bash Arrow keys: Move in X-Y plane Shift and Shift_R: Move in Z axis @@ -74,14 +83,21 @@ For keyboard controls use the `spacebar` to enable control and the following key If you uploaded your dataset to the hub you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id.

- Figure shows the dataset visualizer + Figure shows the dataset visualizer +

+

+ Dataset visualizer

-

Dataset visualizer

- ## Train a policy To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: + ```bash python -m lerobot.scripts.train \ --dataset.repo_id=${HF_USER}/il_gym \ @@ -93,25 +109,29 @@ python -m lerobot.scripts.train \ ``` Let's explain the command: + 1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/il_gym`. 2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. -4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. -5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. +3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. +4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. Training should take several hours, 100k steps (which is the default) will take about 1h on Nvidia A100. You will find checkpoints in `outputs/train/il_sim_test/checkpoints`. #### Train using Collab + If your local computer doesn't have a powerful GPU you could utilize Google Collab to train your model by following the [ACT training notebook](./notebooks#training-act). #### Upload policy checkpoints Once training is done, upload the latest checkpoint with: + ```bash huggingface-cli upload ${HF_USER}/il_sim_test \ outputs/train/il_sim_test/checkpoints/last/pretrained_model ``` You can also upload intermediate checkpoints with: + ```bash CKPT=010000 huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \ @@ -144,9 +164,9 @@ mjpython -m lerobot.scripts.rl.eval_policy --config_path=path/to/eval_config_gym > [!WARNING] -> While the main workflow of training ACT in simulation is straightforward, there is significant room for exploring how to set up the task, define the initial state of the environment, and determine the type of data required during collection to learn the most effective policy. If your trained policy doesn't perform well, investigate the quality of the dataset it was trained on using our visualizers, as well as the action values and various hyperparameters related to ACT and the simulation. +> While the main workflow of training ACT in simulation is straightforward, there is significant room for exploring how to set up the task, define the initial state of the environment, and determine the type of data required during collection to learn the most effective policy. If your trained policy doesn't perform well, investigate the quality of the dataset it was trained on using our visualizers, as well as the action values and various hyperparameters related to ACT and the simulation. Congrats 🎉, you have finished this tutorial. If you want to continue with using LeRobot in simulation follow this [Tutorial on reinforcement learning in sim with HIL-SERL](https://huggingface.co/docs/lerobot/hilserl_sim) > [!TIP] -> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/docs/source/index.mdx b/docs/source/index.mdx index b8ff56ea7..a2f919e7d 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -1,6 +1,10 @@ diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 51474d8f7..13c3600b4 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -5,45 +5,56 @@ Currently only available from source. Download our source code: + ```bash git clone https://github.com/huggingface/lerobot.git cd lerobot ``` Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install) + ```bash conda create -y -n lerobot python=3.10 ``` Then activate your conda environment, you have to do this each time you open a shell to use lerobot: + ```bash conda activate lerobot ``` When using `miniconda`, install `ffmpeg` in your environment: + ```bash conda install ffmpeg -c conda-forge ``` > [!TIP] > This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can: -> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using: -> ```bash -> conda install ffmpeg=7.1.1 -c conda-forge -> ``` -> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. +> +> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using: +> +> ```bash +> conda install ffmpeg=7.1.1 -c conda-forge +> ``` +> +> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. Install 🤗 LeRobot: + ```bash pip install -e . ``` ### Troubleshooting + If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`. To install these for linux run: + ```bash sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config ``` + For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg) ## Optional dependencies @@ -51,20 +62,26 @@ For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/ LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`. ### Simulations + Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), `xarm` ([gym-xarm](https://github.com/huggingface/gym-xarm)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht)) Example: + ```bash pip install -e ".[aloha]" # or "[pusht]" for example ``` ### Motor Control + For Koch v1.1 install the Dynamixel SDK, for SO100/SO101/Moss install the Feetech SDK. + ```bash pip install -e ".[feetech]" # or "[dynamixel]" for example ``` ### Experiment Tracking + To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with + ```bash wandb login ``` diff --git a/docs/source/integrate_hardware.mdx b/docs/source/integrate_hardware.mdx index 18d73d3cd..089126fcb 100644 --- a/docs/source/integrate_hardware.mdx +++ b/docs/source/integrate_hardware.mdx @@ -21,16 +21,13 @@ Please refer to the [`MotorsBus`](https://github.com/huggingface/lerobot/blob/ma For a good example of how it can be used, you can have a look at our own [SO101 follower implementation](https://github.com/huggingface/lerobot/blob/main/lerobot/robots/so101_follower/so101_follower.py) Use these if compatible. Otherwise, you'll need to find or write a Python interface (not covered in this tutorial): + - Find an existing SDK in Python (or use bindings to C/C++) - Or implement a basic communication wrapper (e.g., via pyserial, socket, or CANopen) You're not alone—many community contributions use custom boards or firmware! -For Feetech and Dynamixel, we currently support these servos: - - Feetech: - - STS & SMS series (protocol 0): `sts3215`, `sts3250`, `sm8512bl` - - SCS series (protocol 1): `scs0009` - - Dynamixel (protocol 2.0 only): `xl330-m077`, `xl330-m288`, `xl430-w250`, `xm430-w350`, `xm540-w270`, `xc430-w150` +For Feetech and Dynamixel, we currently support these servos: - Feetech: - STS & SMS series (protocol 0): `sts3215`, `sts3250`, `sm8512bl` - SCS series (protocol 1): `scs0009` - Dynamixel (protocol 2.0 only): `xl330-m077`, `xl330-m288`, `xl430-w250`, `xm430-w350`, `xm540-w270`, `xc430-w150` If you are using Feetech or Dynamixel servos that are not in this list, you can add those in the [Feetech table](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/feetech/tables.py) or [Dynamixel table](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/dynamixel/tables.py). Depending on the model, this will require you to add model-specific information. In most cases though, there shouldn't be a lot of additions to do. @@ -41,6 +38,8 @@ In the next sections, we'll use a `FeetechMotorsBus` as the motors interface for You’ll first need to specify the config class and a string identifier (`name`) for your robot. If your robot has special needs that you'd like to be able to change easily, it should go here (e.g. port/address, baudrate). Here, we'll add the port name and one camera by default for our robot: + + ```python from dataclasses import dataclass, field @@ -64,6 +63,7 @@ class MyCoolRobotConfig(RobotConfig): } ) ``` + Have a look at our [Cameras tutorial](./cameras) to understand how to detect and add your camera. @@ -71,6 +71,7 @@ Next, we'll create our actual robot class which inherits from `Robot`. This abst Here we'll create a simple 5-DoF robot with one camera. It could be a simple arm but notice that the `Robot` abstract class does not assume anything on your robot's form factor. You can let you imagination run wild when designing new robots! + ```python from lerobot.cameras import make_cameras_from_configs from lerobot.motors import Motor, MotorNormMode @@ -96,10 +97,11 @@ class MyCoolRobot(Robot): ) self.cameras = make_cameras_from_configs(config.cameras) ``` + ## Step 2: Define Observation and Action Features -These two properties define the *interface contract* between your robot and tools that consume it (such as data collection or learning pipelines). +These two properties define the _interface contract_ between your robot and tools that consume it (such as data collection or learning pipelines). > [!WARNING] > Note that these properties must be callable even if the robot is not yet connected, so avoid relying on runtime hardware state to define them. @@ -109,6 +111,8 @@ These two properties define the *interface contract* between your robot and tool This property should return a dictionary describing the structure of sensor outputs from your robot. The keys match what `get_observation()` returns, and the values describe either the shape (for arrays/images) or the type (for simple values). Example for our 5-DoF arm with one camera: + + ```python @property def _motors_ft(self) -> dict[str, type]: @@ -130,6 +134,8 @@ def _cameras_ft(self) -> dict[str, tuple]: def observation_features(self) -> dict: return {**self._motors_ft, **self._cameras_ft} ``` + + In this case, observations consist of a simple dict storing each motor's position and a camera image. ### `action_features` @@ -137,10 +143,13 @@ In this case, observations consist of a simple dict storing each motor's positio This property describes the commands your robot expects via `send_action()`. Again, keys must match the expected input format, and values define the shape/type of each command. Here, we simply use the same joints proprioceptive features (`self._motors_ft`) as with `observation_features`: the action sent will simply the goal position for each motor. + + ```python def action_features(self) -> dict: return self._motors_ft ``` + ## Step 3: Handle Connection and Disconnection @@ -150,16 +159,19 @@ These methods should handle opening and closing communication with your hardware This property should simply reflect that communication with the robot's hardware is established. When this property is `True`, it should be possible to read and write to the hardware using `get_observation()` and `send_action()`. + ```python @property def is_connected(self) -> bool: return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) ``` + ### `connect()` This method should establish communication with the hardware. Moreover, if your robot needs calibration and is not calibrated, it should start a calibration procedure by default. If your robot needs some specific configuration, this should also be called here. + ```python def connect(self, calibrate: bool = True) -> None: self.bus.connect() @@ -171,25 +183,31 @@ def connect(self, calibrate: bool = True) -> None: self.configure() ``` + ### `disconnect()` This method should gracefully terminate communication with the hardware: free any related resources (threads or processes), close ports, etc. Here, we already handle this in our `MotorsBus` and `Camera` classes so we just need to call their own `disconnect()` methods: + + ```python def disconnect(self) -> None: self.bus.disconnect() for cam in self.cameras.values(): cam.disconnect() ``` + ## Step 4: Support Calibration and Configuration LeRobot supports saving and loading calibration data automatically. This is useful for joint offsets, zero positions, or sensor alignment. > Note that depending on your hardware, this may not apply. If that's the case, you can simply leave these methods as no-ops: -> ```python + + +```python > @property > def is_calibrated(self) -> bool: > return True @@ -202,7 +220,8 @@ LeRobot supports saving and loading calibration data automatically. This is usef This should reflect whether your robot has the required calibration loaded. -```python +``` +python @property def is_calibrated(self) -> bool: return self.bus.is_calibrated @@ -216,6 +235,8 @@ The goal of the calibration is twofold: It should implement the logic for calibration (if relevant) and update the `self.calibration` dictionary. If you are using Feetech or Dynamixel motors, our bus interfaces already include methods to help with this. + + ```python def calibrate(self) -> None: self.bus.disable_torque() @@ -245,11 +266,13 @@ def calibrate(self) -> None: self._save_calibration() print("Calibration saved to", self.calibration_fpath) ``` + ### `configure()` Use this to set up any configuration for your hardware (servos control modes, controller gains, etc.). This should usually be run at connection time and be idempotent. + ```python def configure(self) -> None: with self.bus.torque_disabled(): @@ -260,6 +283,7 @@ def configure(self) -> None: self.bus.write("I_Coefficient", motor, 0) self.bus.write("D_Coefficient", motor, 32) ``` + ## Step 5: Implement Sensors Reading and Action Sending @@ -269,6 +293,7 @@ These are the most important runtime functions: the core I/O loop. Returns a dictionary of sensor values from the robot. These typically include motor states, camera frames, various sensors, etc. In the LeRobot framework, these observations are what will be fed to a policy in order to predict the actions to take. The dictionary keys and structure must match `observation_features`. + ```python def get_observation(self) -> dict[str, Any]: if not self.is_connected: @@ -284,6 +309,7 @@ def get_observation(self) -> dict[str, Any]: return obs_dict ``` + ### `send_action()` @@ -291,6 +317,7 @@ Takes a dictionary that matches `action_features`, and sends it to your hardware For simplicity, we won't be adding any modification of the actions in our example here. + ```python def send_action(self, action: dict[str, Any]) -> dict[str, Any]: goal_pos = {key.removesuffix(".pos"): val for key, val in action.items()} @@ -300,6 +327,7 @@ def send_action(self, action: dict[str, Any]) -> dict[str, Any]: return action ``` + ## Adding a Teleoperator diff --git a/docs/source/notebooks.mdx b/docs/source/notebooks.mdx index 729b31a99..6a9c3b103 100644 --- a/docs/source/notebooks.mdx +++ b/docs/source/notebooks.mdx @@ -10,8 +10,8 @@ This repository contains example notebooks for using LeRobot. These notebooks de We provide a ready-to-run Google Colab notebook to help you train ACT policies using datasets from the Hugging Face Hub, with optional logging to Weights & Biases. -| Notebook | Colab | -|:---------|:------| +| Notebook | Colab | +| :------------------------------------------------------------------------------------------------------ | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | [Train ACT with LeRobot](https://github.com/huggingface/notebooks/blob/main/lerobot/training-act.ipynb) | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/lerobot/training-act.ipynb) | Expected training time for 100k steps: ~1.5 hours on an NVIDIA A100 GPU with batch size of `64`. diff --git a/docs/source/smolvla.mdx b/docs/source/smolvla.mdx index 17a2bdf18..880beaa1a 100644 --- a/docs/source/smolvla.mdx +++ b/docs/source/smolvla.mdx @@ -3,9 +3,18 @@ SmolVLA is Hugging Face’s lightweight foundation model for robotics. Designed for easy fine-tuning on LeRobot datasets, it helps accelerate your development!

- SmolVLA architecture. -
- Figure 1. SmolVLA takes as input (i) multiple cameras views, (ii) the robot’s current sensorimotor state, and (iii) a natural language instruction, encoded into contextual features used to condition the action expert when generating an action chunk. + SmolVLA architecture. +
+ + Figure 1. SmolVLA takes as input (i) multiple cameras views, (ii) the + robot’s current sensorimotor state, and (iii) a natural language + instruction, encoded into contextual features used to condition the action + expert when generating an action chunk. +

## Set Up Your Environment @@ -32,6 +41,7 @@ We recommend checking out the dataset linked below for reference that was used i In this dataset, we recorded 50 episodes across 5 distinct cube positions. For each position, we collected 10 episodes of pick-and-place interactions. This structure, repeating each variation several times, helped the model generalize better. We tried similar dataset with 25 episodes, and it was not enough leading to a bad performance. So, the data quality and quantity is definitely a key. After you have your dataset available on the Hub, you are good to go to use our finetuning script to adapt SmolVLA to your application. + ## Finetune SmolVLA on your data @@ -56,7 +66,8 @@ cd lerobot && python -m lerobot.scripts.train \ ``` -You can start with a small batch size and increase it incrementally, if the GPU allows it, as long as loading times remain short. + You can start with a small batch size and increase it incrementally, if the + GPU allows it, as long as loading times remain short. Fine-tuning is an art. For a complete overview of the options for finetuning, run @@ -66,12 +77,20 @@ python -m lerobot.scripts.train --help ```

- Comparison of SmolVLA across task variations. -
- Figure 2: Comparison of SmolVLA across task variations. From left to right: (1) pick-place cube counting, (2) pick-place cube counting, (3) pick-place cube counting under perturbations, and (4) generalization on pick-and-place of the lego block with real-world SO101. + Comparison of SmolVLA across task variations. +
+ + Figure 2: Comparison of SmolVLA across task variations. From left to right: + (1) pick-place cube counting, (2) pick-place cube counting, (3) pick-place + cube counting under perturbations, and (4) generalization on pick-and-place + of the lego block with real-world SO101. +

- ## Evaluate the finetuned model and run it in real-time Similarly for when recording an episode, it is recommended that you are logged in to the HuggingFace Hub. You can follow the corresponding steps: [Record a dataset](./getting_started_real_world_robot#record-a-dataset). diff --git a/examples/4_train_policy_with_script.md b/examples/4_train_policy_with_script.md index f17411b75..d6cd6cc23 100644 --- a/examples/4_train_policy_with_script.md +++ b/examples/4_train_policy_with_script.md @@ -1,6 +1,6 @@ This tutorial will explain the training script, how to use it, and particularly how to configure everything needed for the training run. -> **Note:** The following assumes you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--policy.device=cpu` (`--policy.device=mps` respectively). However, be advised that the code executes much slower on cpu. +> **Note:** The following assumes you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--policy.device=cpu` (`--policy.device=mps` respectively). However, be advised that the code executes much slower on cpu. ## The training script @@ -15,17 +15,22 @@ LeRobot offers a training script at [`lerobot/scripts/train.py`](../src/lerobot/ ## Overview of the configuration system In the training script, the main function `train` expects a `TrainPipelineConfig` object: + + ```python # train.py @parser.wrap() def train(cfg: TrainPipelineConfig): ``` + You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../src/lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option) When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated to this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.) Let's have a look at a simplified example. Amongst other attributes, the training config has the following attributes: + + ```python @dataclass class TrainPipelineConfig: @@ -33,7 +38,11 @@ class TrainPipelineConfig: env: envs.EnvConfig | None = None policy: PreTrainedConfig | None = None ``` + + in which `DatasetConfig` for example is defined as such: + + ```python @dataclass class DatasetConfig: @@ -41,16 +50,17 @@ class DatasetConfig: episodes: list[int] | None = None video_backend: str = "pyav" ``` + This creates a hierarchical relationship where, for example assuming we have a `cfg` instance of `TrainPipelineConfig`, we can access the `repo_id` value with `cfg.dataset.repo_id`. From the command line, we can specify this value by using a very similar syntax `--dataset.repo_id=repo/id`. By default, every field takes its default value specified in the dataclass. If a field doesn't have a default value, it needs to be specified either from the command line or from a config file – which path is also given in the command line (more in this below). In the example above, the `dataset` field doesn't have a default value which means it must be specified. - ## Specifying values from the CLI Let's say that we want to train [Diffusion Policy](../src/lerobot/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this: + ```bash python -m lerobot.scripts.train \ --dataset.repo_id=lerobot/pusht \ @@ -59,11 +69,13 @@ python -m lerobot.scripts.train \ ``` Let's break this down: + - To specify the dataset, we just need to specify its `repo_id` on the hub which is the only required argument in the `DatasetConfig`. The rest of the fields have default values and in this case we are fine with those so we can just add the option `--dataset.repo_id=lerobot/pusht`. - To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/policies](../src/lerobot/policies) - Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/envs/configs.py`](../src/lerobot/envs/configs.py) Let's see another example. Let's say you've been training [ACT](../src/lerobot/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with: + ```bash python -m lerobot.scripts.train \ --policy.type=act \ @@ -71,10 +83,12 @@ python -m lerobot.scripts.train \ --env.type=aloha \ --output_dir=outputs/train/act_aloha_insertion ``` + > Notice we added `--output_dir` to explicitly tell where to write outputs from this run (checkpoints, training state, configs etc.). This is not mandatory and if you don't specify it, a default directory will be created from the current date and time, env.type and policy.type. This will typically look like `outputs/train/2025-01-24/16-10-05_aloha_act`. We now want to train a different policy for aloha on another task. We'll change the dataset and use [lerobot/aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) instead. Of course, we also need to change the task of the environment as well to match this other task. Looking at the [`AlohaEnv`](../src/lerobot/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using: + ```bash python -m lerobot.scripts.train \ --policy.type=act \ @@ -87,6 +101,7 @@ python -m lerobot.scripts.train \ ## Loading from a config file Now, let's assume that we want to reproduce the run just above. That run has produced a `train_config.json` file in its checkpoints, which serializes the `TrainPipelineConfig` instance it used: + ```json { "dataset": { @@ -110,34 +125,40 @@ Now, let's assume that we want to reproduce the run just above. That run has pro ``` We can then simply load the config values from this file using: + ```bash python -m lerobot.scripts.train \ --config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \ --output_dir=outputs/train/act_aloha_transfer_2 ``` + `--config_path` is also a special argument which allows to initialize the config from a local config file. It can point to a directory that contains `train_config.json` or to the config file itself directly. Similarly to Hydra, we can still override some parameters in the CLI if we want to, e.g.: + ```bash python -m lerobot.scripts.train \ --config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \ --output_dir=outputs/train/act_aloha_transfer_2 --policy.n_action_steps=80 ``` + > Note: While `--output_dir` is not required in general, in this case we need to specify it since it will otherwise take the value from the `train_config.json` (which is `outputs/train/act_aloha_transfer`). In order to prevent accidental deletion of previous run checkpoints, we raise an error if you're trying to write in an existing directory. This is not the case when resuming a run, which is what you'll learn next. `--config_path` can also accept the repo_id of a repo on the hub that contains a `train_config.json` file, e.g. running: + ```bash python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht ``` -will start a training run with the same configuration used for training [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht) +will start a training run with the same configuration used for training [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht) ## Resume training Being able to resume a training run is important in case it crashed or aborted for any reason. We'll demonstrate how to do that here. Let's reuse the command from the previous run and add a few more options: + ```bash python -m lerobot.scripts.train \ --policy.type=act \ @@ -150,19 +171,24 @@ python -m lerobot.scripts.train \ ``` Here we've taken care to set up the log frequency and checkpointing frequency to low numbers so we can showcase resumption. You should be able to see some logging and have a first checkpoint within 1 minute (depending on hardware). Wait for the first checkpoint to happen, you should see a line that looks like this in your terminal: + ``` INFO 2025-01-24 16:10:56 ts/train.py:263 Checkpoint policy after step 100 ``` + Now let's simulate a crash by killing the process (hit `ctrl`+`c`). We can then simply resume this run from the last checkpoint available with: + ```bash python -m lerobot.scripts.train \ --config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \ --resume=true ``` + You should see from the logging that your training picks up from where it left off. Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--steps`, which is 100 000 by default. You could double the number of steps of the previous run with: + ```bash python -m lerobot.scripts.train \ --config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \ @@ -171,7 +197,9 @@ python -m lerobot.scripts.train \ ``` ## Outputs of a run + In the output directory, there will be a folder called `checkpoints` with the following structure: + ```bash outputs/train/run_resumption/checkpoints ├── 000100 # checkpoint_dir for training step 100 @@ -194,6 +222,7 @@ outputs/train/run_resumption/checkpoints In addition to the features currently in Draccus, we've added a special `.path` argument for the policy, which allows to load a policy as you would with `PreTrainedPolicy.from_pretrained()`. In that case, `path` can be a local directory that contains a checkpoint or a repo_id pointing to a pretrained policy on the hub. For example, we could fine-tune a [policy pre-trained on the aloha transfer task](https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human) on the aloha insertion task. We can achieve this with: + ```bash python -m lerobot.scripts.train \ --policy.path=lerobot/act_aloha_sim_transfer_cube_human \ @@ -209,15 +238,19 @@ When doing so, keep in mind that the features of the fine-tuning dataset would h When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you configured your run correctly. The final configuration will also be saved with the checkpoint. After that, you will see training log like this one: + ``` INFO 2024-08-14 13:35:12 ts/train.py:192 step:0 smpl:64 ep:1 epch:0.00 loss:1.112 grdn:15.387 lr:2.0e-07 updt_s:1.738 data_s:4.774 ``` + or evaluation log: + ``` INFO 2024-08-14 13:38:45 ts/train.py:226 step:100 smpl:6K ep:52 epch:0.25 ∑rwrd:20.693 success:0.0% eval_s:120.266 ``` These logs will also be saved in wandb if `wandb.enable` is set to `true`. Here are the meaning of some abbreviations: + - `smpl`: number of samples seen during training. - `ep`: number of episodes seen during training. An episode contains multiple samples in a complete manipulation task. - `epch`: number of time all unique samples are seen (epoch). @@ -235,6 +268,7 @@ Some metrics are useful for initial performance profiling. For example, if you f We'll summarize here the main use cases to remember from this tutorial. #### Train a policy from scratch – CLI + ```bash python -m lerobot.scripts.train \ --policy.type=act \ # <- select 'act' policy @@ -243,6 +277,7 @@ python -m lerobot.scripts.train \ ``` #### Train a policy from scratch - config file + CLI + ```bash python -m lerobot.scripts.train \ --config_path=path/to/pretrained_model \ # <- can also be a repo_id @@ -250,6 +285,7 @@ python -m lerobot.scripts.train \ ``` #### Resume/continue a training run + ```bash python -m lerobot.scripts.train \ --config_path=checkpoint/pretrained_model/ \ @@ -258,6 +294,7 @@ python -m lerobot.scripts.train \ ``` #### Fine-tuning + ```bash python -m lerobot.scripts.train \ --policy.path=lerobot/act_aloha_sim_transfer_cube_human \ # <- can also be a local path to a checkpoint diff --git a/pyproject.toml b/pyproject.toml index 878a36dbe..e9539037b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,8 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + [project.urls] -homepage = "https://github.com/huggingface/lerobot" +homepage = "https://huggingface.co/lerobot" +documentation = "https://huggingface.co/docs/lerobot/index" +source = "https://github.com/huggingface/lerobot" issues = "https://github.com/huggingface/lerobot/issues" discord = "https://discord.gg/s3KuuzsPFb" @@ -21,109 +27,165 @@ discord = "https://discord.gg/s3KuuzsPFb" name = "lerobot" version = "0.1.0" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" +readme = "README.md" +license = { text = "Apache-2.0" } +requires-python = ">=3.10" authors = [ { name = "Rémi Cadène", email = "re.cadene@gmail.com" }, { name = "Simon Alibert", email = "alibert.sim@gmail.com" }, { name = "Alexander Soare", email = "alexander.soare159@gmail.com" }, { name = "Quentin Gallouédec", email = "quentin.gallouedec@ec-lyon.fr" }, - { name = "Adil Zouitine", email = "adilzouitinegm@gmail.com" }, - { name = "Thomas Wolf", email = "thomaswolfcontact@gmail.com" }, { name = "Steven Palma", email = "imstevenpmwork@ieee.org" }, + { name = "Pepijn Kooijmans", email = "pepijnkooijmans@outlook.com"}, + { name = "Michel Aractingi", email = "michel.aractingi@gmail.com"}, + { name = "Adil Zouitine", email = "adilzouitinegm@gmail.com" }, + { name = "Dana Aubakirova", email = "danaaubakirova17@gmail.com"}, + { name = "Caroline Pascal", email = "caroline8.pascal@gmail.com"}, + { name = "Martino Russi", email = "nopyeps@gmail.com"}, + { name = "Thomas Wolf", email = "thomaswolfcontact@gmail.com" }, ] -readme = "README.md" -license = { text = "Apache-2.0" } -requires-python = ">=3.10" -keywords = ["robotics", "deep learning", "pytorch"] classifiers = [ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Intended Audience :: Education", "Intended Audience :: Science/Research", - "Topic :: Software Development :: Build Tools", - "Topic :: Scientific/Engineering :: Artificial Intelligence", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3.10", + "Topic :: Software Development :: Build Tools", + "Topic :: Scientific/Engineering :: Artificial Intelligence", ] +keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artificial intelligence"] + dependencies = [ - "cmake>=3.29.0.1", - "datasets>=2.19.0,<=3.6.0", - "deepdiff>=7.0.1", + + # Hugging Face dependencies + "datasets>=2.19.0,<=3.6.0", # TODO: Bumb dependency "diffusers>=0.27.2", - "draccus==0.10.0", + "huggingface-hub[hf-transfer,cli]>=0.27.1", + + # Core dependencies + "cmake>=3.29.0.1", "einops>=0.8.0", - "flask>=3.0.3", - "gdown>=5.1.0", - "gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work - "h5py>=3.10.0", - "huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'", - "imageio[ffmpeg]>=2.34.0", - "jsonlines>=4.0.0", - "numba>=0.59.0", - "omegaconf>=2.3.0", "opencv-python-headless>=4.9.0", - "packaging>=24.2", "av>=14.2.0", - "pymunk>=6.6.0,<7.0.0", - "pynput>=1.7.7", - "pyserial>=3.5", - "pyzmq>=26.2.1", - "rerun-sdk>=0.21.0", - "termcolor>=2.4.0", "torch>=2.2.1", "torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", "torchvision>=0.21.0", + "jsonlines>=4.0.0", + "packaging>=24.2", + "pynput>=1.7.7", + "pyserial>=3.5", "wandb>=0.16.3", - "zarr>=2.17.0", + + "draccus==0.10.0", # TODO: Remove == + "gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency + "rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency + + # Support dependencies + "deepdiff>=7.0.1,<9.0.0", + "flask>=3.0.3,<4.0.0", + "imageio[ffmpeg]>=2.34.0,<3.0.0", + "termcolor>=2.4.0,<4.0.0", ] +# Optional dependencies [project.optional-dependencies] -aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"] -docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"] -dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "grpcio-tools==1.71.0"] -dora = [ - "gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'", -] -dynamixel = ["dynamixel-sdk>=3.7.31"] + +# Common +pygame-dep = ["pygame>=2.5.1"] +placo-dep = ["placo>=0.9.6"] +transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency +grpcio-dep = ["grpcio==1.71.0"] + +# Motors feetech = ["feetech-servo-sdk>=1.0.0"] -gamepad = ["pygame>=2.5.1", "hidapi>=0.14.0"] -hopejr = ["feetech-servo-sdk>=1.0.0", "pygame>=2.5.1"] -kinematics = ["placo>=0.9.6"] +dynamixel = ["dynamixel-sdk>=3.7.31"] + +# Robots +gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0"] +hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"] +lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1"] +kinematics = ["lerobot[placo-dep]"] intelrealsense = [ "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'", "pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'", ] -pi0 = ["transformers>=4.50.3"] -smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"] -pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"] stretch = [ - "hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'", + "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'", "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'", "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'" -] -test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "pyserial>=3.5", "mock-serial>=0.0.1 ; sys_platform != 'win32'"] -hilserl = ["transformers>=4.50.3", "gym-hil>=0.1.9", "protobuf>=5.29.3", "grpcio==1.71.0", "placo>=0.9.6"] -umi = ["imagecodecs>=2024.1.1"] -video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] -xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"] -async = ["grpcio==1.71.0", "matplotlib>=3.10.3"] +] # TODO: Currently not supported -[tool.poetry] -requires-poetry = ">=2.1" -packages = [ - { include = "lerobot", from = "src" } -] +# Policies +pi0 = ["lerobot[transformers-dep]"] +smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"] +hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "protobuf>=5.29.3", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] + +# Features +async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"] + +# Development +docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"] +dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "grpcio-tools==1.71.0"] +test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"] +video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] + +# Simulation +aloha = ["gym-aloha>=0.1.1"] +pusht = ["gym-pusht>=0.1.5", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead +xarm = ["gym-xarm>=0.1.1"] + +# ---------------- Tool Configurations ---------------- +[tool.setuptools.packages.find] +where = ["src"] [tool.ruff] -line-length = 110 target-version = "py310" +line-length = 110 exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"] [tool.ruff.lint] -select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"] +# E, W: pycodestyle errors and warnings +# F: PyFlakes +# I: isort +# UP: pyupgrade +# B: flake8-bugbear (good practices, potential bugs) +# C4: flake8-comprehensions (more concise comprehensions) +# A: flake8-builtins (shadowing builtins) +# SIM: flake8-simplify +# RUF: Ruff-specific rules +# D: pydocstyle (for docstring style/formatting) +# S: flake8-bandit (some security checks, complements Bandit) +# T20: flake8-print (discourage print statements in production code) +# N: pep8-naming +# TODO: Uncomment rules when ready to use +select = [ + "E", "W", "F", "I", "B", "C4", "T20", "N" # "SIM", "A", "S", "D", "RUF", "UP" +] +ignore = [ + "E501", # Line too long + "T201", # Print statement found + "T203", # Pprint statement found + "B008", # Perform function call in argument defaults +] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401", "F403"] +[tool.ruff.lint.isort] +combine-as-imports = true +known-first-party = ["lerobot"] + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" +docstring-code-format = true + [tool.bandit] exclude_dirs = [ "tests", @@ -148,6 +210,24 @@ default.extend-ignore-identifiers-re = [ "ein", ] -[build-system] -requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" +# TODO: Uncomment when ready to use +# [tool.interrogate] +# ignore-init-module = true +# ignore-init-method = true +# ignore-nested-functions = false +# ignore-magic = false +# ignore-semiprivate = false +# ignore-private = false +# ignore-property-decorators = false +# ignore-module = false +# ignore-setters = false +# fail-under = 80 +# output-format = "term-missing" +# color = true +# paths = ["src/lerobot"] + +# [tool.mypy] +# python_version = "3.10" +# warn_return_any = true +# warn_unused_configs = true +# ignore_missing_imports = false diff --git a/src/lerobot/cameras/camera.py b/src/lerobot/cameras/camera.py index 1937205b1..e435c7309 100644 --- a/src/lerobot/cameras/camera.py +++ b/src/lerobot/cameras/camera.py @@ -15,7 +15,7 @@ # limitations under the License. import abc -from typing import Any, Dict, List +from typing import Any import numpy as np @@ -69,7 +69,7 @@ class Camera(abc.ABC): @staticmethod @abc.abstractmethod - def find_cameras() -> List[Dict[str, Any]]: + def find_cameras() -> list[dict[str, Any]]: """Detects available cameras connected to the system. Returns: List[Dict[str, Any]]: A list of dictionaries, diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index 1d7a1645d..7ad9988cc 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -23,7 +23,7 @@ import platform import time from pathlib import Path from threading import Event, Lock, Thread -from typing import Any, Dict, List +from typing import Any # Fix MSMF hardware transform compatibility for Windows before importing cv2 if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ: @@ -245,7 +245,7 @@ class OpenCVCamera(Camera): ) @staticmethod - def find_cameras() -> List[Dict[str, Any]]: + def find_cameras() -> list[dict[str, Any]]: """ Detects available OpenCV cameras connected to the system. diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index 96531b694..74b055fa4 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -19,7 +19,7 @@ Provides the RealSenseCamera class for capturing frames from Intel RealSense cam import logging import time from threading import Event, Lock, Thread -from typing import Any, Dict, List +from typing import Any import cv2 import numpy as np @@ -194,7 +194,7 @@ class RealSenseCamera(Camera): logger.info(f"{self} connected.") @staticmethod - def find_cameras() -> List[Dict[str, Any]]: + def find_cameras() -> list[dict[str, Any]]: """ Detects available Intel RealSense cameras connected to the system. diff --git a/src/lerobot/cameras/realsense/configuration_realsense.py b/src/lerobot/cameras/realsense/configuration_realsense.py index 82e7c0d36..36a86876d 100644 --- a/src/lerobot/cameras/realsense/configuration_realsense.py +++ b/src/lerobot/cameras/realsense/configuration_realsense.py @@ -28,12 +28,12 @@ class RealSenseCameraConfig(CameraConfig): Example configurations for Intel RealSense D405: ```python # Basic configurations - RealSenseCameraConfig("0123456789", 30, 1280, 720) # 1280x720 @ 30FPS - RealSenseCameraConfig("0123456789", 60, 640, 480) # 640x480 @ 60FPS + RealSenseCameraConfig("0123456789", 30, 1280, 720) # 1280x720 @ 30FPS + RealSenseCameraConfig("0123456789", 60, 640, 480) # 640x480 @ 60FPS # Advanced configurations RealSenseCameraConfig("0123456789", 30, 640, 480, use_depth=True) # With depth sensing - RealSenseCameraConfig("0123456789", 30, 640, 480, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation + RealSenseCameraConfig("0123456789", 30, 640, 480, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation ``` Attributes: diff --git a/src/lerobot/configs/parser.py b/src/lerobot/configs/parser.py index 1da7ad83f..2296eaa20 100644 --- a/src/lerobot/configs/parser.py +++ b/src/lerobot/configs/parser.py @@ -16,9 +16,9 @@ import inspect import pkgutil import sys from argparse import ArgumentError +from collections.abc import Sequence from functools import wraps from pathlib import Path -from typing import Sequence import draccus @@ -76,9 +76,8 @@ def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict: - Values are the corresponding argument values Example: - >>> args = ['--env.discover_packages_path=my_package', - ... '--other_arg=value'] - >>> parse_plugin_args('discover_packages_path', args) + >>> args = ["--env.discover_packages_path=my_package", "--other_arg=value"] + >>> parse_plugin_args("discover_packages_path", args) {'env.discover_packages_path': 'my_package'} """ plugin_args = {} @@ -111,7 +110,7 @@ def load_plugin(plugin_path: str) -> None: PluginLoadError: If the plugin cannot be loaded due to import errors or if the package path is invalid. Examples: - >>> load_plugin("external_plugin.core") # Loads plugin from external package + >>> load_plugin("external_plugin.core") # Loads plugin from external package Notes: - The plugin package should handle its own registration during import diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 05f3296b8..c5b2fa09e 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import builtins import json import logging import os import tempfile from dataclasses import dataclass, field from pathlib import Path -from typing import Type, TypeVar +from typing import TypeVar import draccus from huggingface_hub import hf_hub_download @@ -31,7 +32,6 @@ from lerobot.optim.schedulers import LRSchedulerConfig from lerobot.utils.hub import HubMixin from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available -# Generic variable that is either PreTrainedConfig or a subclass thereof T = TypeVar("T", bound="PreTrainedConfig") @@ -148,7 +148,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): @classmethod def from_pretrained( - cls: Type[T], + cls: builtins.type[T], pretrained_name_or_path: str | Path, *, force_download: bool = False, diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index c088a5fa1..60a4d81d5 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import builtins import datetime as dt import os from dataclasses import dataclass, field from pathlib import Path -from typing import Type import draccus from huggingface_hub import hf_hub_download @@ -135,7 +135,7 @@ class TrainPipelineConfig(HubMixin): @classmethod def from_pretrained( - cls: Type["TrainPipelineConfig"], + cls: builtins.type["TrainPipelineConfig"], pretrained_name_or_path: str | Path, *, force_download: bool = False, diff --git a/src/lerobot/datasets/card_template.md b/src/lerobot/datasets/card_template.md index 7ee27df95..ee26a78f5 100644 --- a/src/lerobot/datasets/card_template.md +++ b/src/lerobot/datasets/card_template.md @@ -1,7 +1,8 @@ --- # For reference on dataset card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/datasetcard.md?plain=1 # Doc / guide: https://huggingface.co/docs/hub/datasets-cards -{{ card_data }} +# prettier-ignore +{{card_data}} --- This dataset was created using [LeRobot](https://github.com/huggingface/lerobot). diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 1a3dd1e1b..46feed2bf 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -16,8 +16,8 @@ import contextlib import logging import shutil +from collections.abc import Callable from pathlib import Path -from typing import Callable import datasets import numpy as np diff --git a/src/lerobot/datasets/push_dataset_to_hub/utils.py b/src/lerobot/datasets/push_dataset_to_hub/utils.py index 6aca7b03b..5f6363a77 100644 --- a/src/lerobot/datasets/push_dataset_to_hub/utils.py +++ b/src/lerobot/datasets/push_dataset_to_hub/utils.py @@ -16,7 +16,6 @@ import inspect from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import Dict import datasets import numpy @@ -77,7 +76,7 @@ def check_repo_id(repo_id: str) -> None: # TODO(aliberts): remove -def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]: +def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]: """ Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index 2f6c15c15..79ac7a4b2 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterator, Union +from collections.abc import Iterator import torch @@ -22,7 +22,7 @@ class EpisodeAwareSampler: def __init__( self, episode_data_index: dict, - episode_indices_to_use: Union[list, None] = None, + episode_indices_to_use: list | None = None, drop_n_first_frames: int = 0, drop_n_last_frames: int = 0, shuffle: bool = False, diff --git a/src/lerobot/datasets/transforms.py b/src/lerobot/datasets/transforms.py index 3ac1d5771..f992275b7 100644 --- a/src/lerobot/datasets/transforms.py +++ b/src/lerobot/datasets/transforms.py @@ -14,13 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections +from collections.abc import Callable, Sequence from dataclasses import dataclass, field -from typing import Any, Callable, Sequence +from typing import Any import torch from torchvision.transforms import v2 -from torchvision.transforms.v2 import Transform -from torchvision.transforms.v2 import functional as F # noqa: N812 +from torchvision.transforms.v2 import ( + Transform, + functional as F, # noqa: N812 +) class RandomSubsetApply(Transform): diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index de969d618..ef381e9e7 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -14,7 +14,7 @@ import abc from dataclasses import dataclass, field -from typing import Any, Optional +from typing import Any import draccus @@ -179,10 +179,10 @@ class EnvTransformConfig: add_joint_velocity_to_observation: bool = False add_current_to_observation: bool = False add_ee_pose_to_observation: bool = False - crop_params_dict: Optional[dict[str, tuple[int, int, int, int]]] = None - resize_size: Optional[tuple[int, int]] = None + crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None + resize_size: tuple[int, int] | None = None control_time_s: float = 20.0 - fixed_reset_joint_positions: Optional[Any] = None + fixed_reset_joint_positions: Any | None = None reset_time_s: float = 5.0 use_gripper: bool = True gripper_quantization_threshold: float | None = 0.8 @@ -195,21 +195,21 @@ class EnvTransformConfig: class HILSerlRobotEnvConfig(EnvConfig): """Configuration for the HILSerlRobotEnv environment.""" - robot: Optional[RobotConfig] = None - teleop: Optional[TeleoperatorConfig] = None - wrapper: Optional[EnvTransformConfig] = None + robot: RobotConfig | None = None + teleop: TeleoperatorConfig | None = None + wrapper: EnvTransformConfig | None = None fps: int = 10 name: str = "real_robot" mode: str = None # Either "record", "replay", None - repo_id: Optional[str] = None - dataset_root: Optional[str] = None + repo_id: str | None = None + dataset_root: str | None = None task: str = "" num_episodes: int = 10 # only for record mode episode: int = 0 device: str = "cuda" push_to_hub: bool = True - pretrained_policy_name_or_path: Optional[str] = None - reward_classifier_pretrained_path: Optional[str] = None + pretrained_policy_name_or_path: str | None = None + reward_classifier_pretrained_path: str | None = None # For the reward classifier, to record more positive examples after a success number_of_steps_after_success: int = 0 @@ -248,18 +248,18 @@ class HILEnvConfig(EnvConfig): } ) ################# args from hilserlrobotenv - reward_classifier_pretrained_path: Optional[str] = None - robot_config: Optional[RobotConfig] = None - teleop_config: Optional[TeleoperatorConfig] = None - wrapper: Optional[EnvTransformConfig] = None + reward_classifier_pretrained_path: str | None = None + robot_config: RobotConfig | None = None + teleop_config: TeleoperatorConfig | None = None + wrapper: EnvTransformConfig | None = None mode: str = None # Either "record", "replay", None - repo_id: Optional[str] = None - dataset_root: Optional[str] = None + repo_id: str | None = None + dataset_root: str | None = None num_episodes: int = 10 # only for record mode episode: int = 0 device: str = "cuda" push_to_hub: bool = True - pretrained_policy_name_or_path: Optional[str] = None + pretrained_policy_name_or_path: str | None = None # For the reward classifier, to record more positive examples after a success number_of_steps_after_success: int = 0 ############################ diff --git a/src/lerobot/find_cameras.py b/src/lerobot/find_cameras.py index aff2f8c19..be8f272ee 100644 --- a/src/lerobot/find_cameras.py +++ b/src/lerobot/find_cameras.py @@ -32,7 +32,7 @@ import concurrent.futures import logging import time from pathlib import Path -from typing import Any, Dict, List +from typing import Any import numpy as np from PIL import Image @@ -46,14 +46,14 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon logger = logging.getLogger(__name__) -def find_all_opencv_cameras() -> List[Dict[str, Any]]: +def find_all_opencv_cameras() -> list[dict[str, Any]]: """ Finds all available OpenCV cameras plugged into the system. Returns: A list of all available OpenCV cameras with their metadata. """ - all_opencv_cameras_info: List[Dict[str, Any]] = [] + all_opencv_cameras_info: list[dict[str, Any]] = [] logger.info("Searching for OpenCV cameras...") try: opencv_cameras = OpenCVCamera.find_cameras() @@ -66,14 +66,14 @@ def find_all_opencv_cameras() -> List[Dict[str, Any]]: return all_opencv_cameras_info -def find_all_realsense_cameras() -> List[Dict[str, Any]]: +def find_all_realsense_cameras() -> list[dict[str, Any]]: """ Finds all available RealSense cameras plugged into the system. Returns: A list of all available RealSense cameras with their metadata. """ - all_realsense_cameras_info: List[Dict[str, Any]] = [] + all_realsense_cameras_info: list[dict[str, Any]] = [] logger.info("Searching for RealSense cameras...") try: realsense_cameras = RealSenseCamera.find_cameras() @@ -88,7 +88,7 @@ def find_all_realsense_cameras() -> List[Dict[str, Any]]: return all_realsense_cameras_info -def find_and_print_cameras(camera_type_filter: str | None = None) -> List[Dict[str, Any]]: +def find_and_print_cameras(camera_type_filter: str | None = None) -> list[dict[str, Any]]: """ Finds available cameras based on an optional filter and prints their information. @@ -99,7 +99,7 @@ def find_and_print_cameras(camera_type_filter: str | None = None) -> List[Dict[s Returns: A list of all available cameras matching the filter, with their metadata. """ - all_cameras_info: List[Dict[str, Any]] = [] + all_cameras_info: list[dict[str, Any]] = [] if camera_type_filter: camera_type_filter = camera_type_filter.lower() @@ -153,7 +153,7 @@ def save_image( logger.error(f"Failed to save image for camera {camera_identifier} (type {camera_type}): {e}") -def create_camera_instance(cam_meta: Dict[str, Any]) -> Dict[str, Any] | None: +def create_camera_instance(cam_meta: dict[str, Any]) -> dict[str, Any] | None: """Create and connect to a camera instance based on metadata.""" cam_type = cam_meta.get("type") cam_id = cam_meta.get("id") @@ -190,7 +190,7 @@ def create_camera_instance(cam_meta: Dict[str, Any]) -> Dict[str, Any] | None: def process_camera_image( - cam_dict: Dict[str, Any], output_dir: Path, current_time: float + cam_dict: dict[str, Any], output_dir: Path, current_time: float ) -> concurrent.futures.Future | None: """Capture and process an image from a single camera.""" cam = cam_dict["instance"] @@ -216,7 +216,7 @@ def process_camera_image( return None -def cleanup_cameras(cameras_to_use: List[Dict[str, Any]]): +def cleanup_cameras(cameras_to_use: list[dict[str, Any]]): """Disconnect all cameras.""" logger.info(f"Disconnecting {len(cameras_to_use)} cameras...") for cam_dict in cameras_to_use: diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index 26522c7c9..597bcd3c4 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -224,7 +224,7 @@ class MotorsBus(abc.ABC): ```bash python -m lerobot.find_port.py >>> Finding all available ports for the MotorsBus. - >>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] + >>> ["/dev/tty.usbmodem575E0032081", "/dev/tty.usbmodem575E0031751"] >>> Remove the usb cable from your MotorsBus and press Enter when done. >>> The port of this MotorsBus is /dev/tty.usbmodem575E0031751. >>> Reconnect the usb cable. diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index aa81d3cd2..4a048e63d 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -21,8 +21,8 @@ The majority of changes here involve removing unused code, unifying naming, and import math from collections import deque +from collections.abc import Callable from itertools import chain -from typing import Callable import einops import numpy as np @@ -216,7 +216,7 @@ class ACTTemporalEnsembler: continue avg *= exp_weights[:i].sum() avg += item * exp_weights[i] - avg /= exp_weights[:i+1].sum() + avg /= exp_weights[: i + 1].sum() print("online", avg) ``` """ diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 6dad8fb89..24b273967 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -22,7 +22,7 @@ TODO(alexander-soare): import math from collections import deque -from typing import Callable +from collections.abc import Callable import einops import numpy as np diff --git a/src/lerobot/policies/pi0/paligemma_with_expert.py b/src/lerobot/policies/pi0/paligemma_with_expert.py index f0f5713e5..edc34b7c5 100644 --- a/src/lerobot/policies/pi0/paligemma_with_expert.py +++ b/src/lerobot/policies/pi0/paligemma_with_expert.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union import torch import torch.version @@ -228,12 +227,12 @@ class PaliGemmaWithExpertModel(PreTrainedModel): # TODO: break down this huge forward into modules or functions def forward( self, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, - inputs_embeds: List[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - fill_kv_cache: Optional[bool] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | Cache | None = None, + inputs_embeds: list[torch.FloatTensor] = None, + use_cache: bool | None = None, + fill_kv_cache: bool | None = None, ): models = [self.paligemma.language_model, self.gemma_expert.model] diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index d18b798a8..d745c901c 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -12,20 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import builtins import logging import os from importlib.resources import files from pathlib import Path from tempfile import TemporaryDirectory -from typing import List, Type, TypeVar +from typing import TypeVar import packaging import safetensors from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.errors import HfHubHTTPError -from safetensors.torch import load_model as load_model_as_safetensor -from safetensors.torch import save_model as save_model_as_safetensor +from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor from torch import Tensor, nn from lerobot.configs.policies import PreTrainedConfig @@ -67,7 +67,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): @classmethod def from_pretrained( - cls: Type[T], + cls: builtins.type[T], pretrained_name_or_path: str | Path, *, config: PreTrainedConfig | None = None, @@ -223,7 +223,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): logging.info(f"Model pushed to {commit_info.repo_url.url}") def generate_model_card( - self, dataset_repo_id: str, model_type: str, license: str | None, tags: List[str] | None + self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None ) -> ModelCard: base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index 93cfe6c93..878f3cdd8 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -16,8 +16,9 @@ # limitations under the License. import math +from collections.abc import Callable from dataclasses import asdict -from typing import Callable, Literal +from typing import Literal import einops import numpy as np diff --git a/src/lerobot/policies/smolvla/smolvlm_with_expert.py b/src/lerobot/policies/smolvla/smolvlm_with_expert.py index 07eae8089..f3d1a693a 100644 --- a/src/lerobot/policies/smolvla/smolvlm_with_expert.py +++ b/src/lerobot/policies/smolvla/smolvlm_with_expert.py @@ -13,7 +13,6 @@ # limitations under the License. import copy -from typing import List, Optional import torch from torch import nn @@ -403,12 +402,12 @@ class SmolVLMWithExpertModel(nn.Module): def forward( self, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: List[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - fill_kv_cache: Optional[bool] = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: list[torch.FloatTensor] = None, + use_cache: bool | None = None, + fill_kv_cache: bool | None = None, ): models = [self.get_vlm_model().text_model, self.lm_expert] model_layers = self.get_model_layers(models) diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index c27689387..664fe863d 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -24,9 +24,9 @@ The comments in this code may sometimes refer to these references: # ruff: noqa: N806 from collections import deque +from collections.abc import Callable from copy import deepcopy from functools import partial -from typing import Callable import einops import numpy as np diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index 59c820a96..b271298a3 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -18,7 +18,7 @@ import warnings from collections import deque -from typing import Callable, List +from collections.abc import Callable import einops import numpy as np @@ -901,7 +901,7 @@ class MLP(torch.nn.Sequential): def __init__( self, in_channels: int, - hidden_channels: List[int], + hidden_channels: list[int], ): layers = [] in_dim = in_channels diff --git a/src/lerobot/policies/vqbet/vqbet_utils.py b/src/lerobot/policies/vqbet/vqbet_utils.py index 03b02a280..e0afe5585 100644 --- a/src/lerobot/policies/vqbet/vqbet_utils.py +++ b/src/lerobot/policies/vqbet/vqbet_utils.py @@ -17,10 +17,10 @@ # limitations under the License. import math +from collections.abc import Callable from functools import partial from math import ceil from random import randrange -from typing import Callable import torch import torch.distributed as distributed @@ -198,7 +198,7 @@ class GPT(nn.Module): # report number of parameters n_params = sum(p.numel() for p in self.parameters()) - print("number of parameters: {:.2f}M".format(n_params / 1e6)) + print(f"number of parameters: {n_params / 1e6:.2f}M") def forward(self, input, targets=None): device = input.device @@ -255,7 +255,7 @@ class GPT(nn.Module): blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.named_modules(): for pn, _p in m.named_parameters(): - fpn = "{}.{}".format(mn, pn) if mn else pn # full param name + fpn = f"{mn}.{pn}" if mn else pn # full param name if pn.endswith("bias"): # all biases will not be decayed no_decay.add(fpn) diff --git a/src/lerobot/record.py b/src/lerobot/record.py index c8184d40b..0b1af192e 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -62,7 +62,6 @@ import time from dataclasses import asdict, dataclass from pathlib import Path from pprint import pformat -from typing import List from lerobot.cameras import ( # noqa: F401 CameraConfig, # noqa: F401 @@ -190,7 +189,7 @@ def record_loop( events: dict, fps: int, dataset: LeRobotDataset | None = None, - teleop: Teleoperator | List[Teleoperator] | None = None, + teleop: Teleoperator | list[Teleoperator] | None = None, policy: PreTrainedPolicy | None = None, control_time_s: int | None = None, single_task: str | None = None, diff --git a/src/lerobot/robots/hope_jr/hope_jr.mdx b/src/lerobot/robots/hope_jr/hope_jr.mdx index 2f9ec9d89..72aa8f923 100644 --- a/src/lerobot/robots/hope_jr/hope_jr.mdx +++ b/src/lerobot/robots/hope_jr/hope_jr.mdx @@ -9,6 +9,7 @@ Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot. Install LeRobot with HopeJR dependencies: + ```bash pip install -e ".[hopejr]" ``` @@ -40,35 +41,39 @@ python -m lerobot.calibrate \ When running the calibration script, a calibration GUI will pop up. Finger joints are named as follows: **Thumb**: + - **CMC**: base joint connecting thumb to hand - **MCP**: knuckle joint - **PIP**: first finger joint - **DIP** : fingertip joint **Index, Middle, Ring, and Pinky fingers**: + - **Radial flexor**: Moves base of finger towards the thumb - **Ulnar flexor**: Moves base of finger towards the pinky - **PIP/DIP**: Flexes the distal and proximal phalanx of the finger Each one of these will need to be calibrated individually via the GUI. - Note that ulnar and radial flexors should have ranges of the same size (but with different offsets) in order to get symmetric movement. +Note that ulnar and radial flexors should have ranges of the same size (but with different offsets) in order to get symmetric movement.

- Setting boundaries in the hand calibration GUI - + width="100%" + >

Use the calibration interface to set the range boundaries for each joint as shown above.

- Saving calibration values - + width="100%" + >

Once you have set the appropriate boundaries for all joints, click "Save" to save the calibration values to the motors. @@ -122,16 +127,18 @@ python -m lerobot.calibrate \ ``` This will open a calibration GUI where you can set the range limits for each motor. The arm motions are organized as follows: + - **Shoulder**: pitch, yaw, and roll - **Elbow**: flex - **Wrist**: pitch, yaw, and roll

- Setting boundaries in the arm calibration GUI - + width="100%" + >

Use the calibration interface to set the range boundaries for each joint. Move each joint through its full range of motion and adjust the minimum and maximum values accordingly. Once you have set the appropriate boundaries for all joints, save the calibration. @@ -169,6 +176,7 @@ Calibration saved to /Users/your_username/.cache/huggingface/lerobot/calibration Due to global variable conflicts in the Feetech middleware, teleoperation for arm and hand must run in separate shell sessions: ### Hand + ```bash python -m lerobot.teleoperate \ --robot.type=hope_jr_hand \ @@ -184,6 +192,7 @@ python -m lerobot.teleoperate \ ``` ### Arm + ```bash python -m lerobot.teleoperate \ --robot.type=hope_jr_arm \ diff --git a/src/lerobot/robots/koch_follower/koch.mdx b/src/lerobot/robots/koch_follower/koch.mdx index f70a1802c..d0b991e74 100644 --- a/src/lerobot/robots/koch_follower/koch.mdx +++ b/src/lerobot/robots/koch_follower/koch.mdx @@ -10,15 +10,16 @@ For a visual walkthrough of the assembly process, you can refer to [this video t > [!WARNING] > Since the production of this video, we simplified the configuration phase. Because of this, two things differ from the instructions in that video: +> > - Don't plug in all the motor cables right away and wait to be instructed to do so in [Configure the motors](#configure-the-motors). > - Don't screw in the controller board (PCB) to the base right away and wait for being instructed to do so in [Configure the motors](#configure-the-motors). - ## Install LeRobot 🤗 To install LeRobot follow, our [Installation Guide](./installation) In addition to these instructions, you need to install the Dynamixel SDK: + ```bash pip install -e ".[dynamixel]" ``` @@ -28,6 +29,7 @@ pip install -e ".[dynamixel]" ### 1. Find the USB ports associated with each arm To find the port for each bus servo adapter, run this script: + ```bash python -m lerobot.find_port ``` @@ -54,6 +56,7 @@ Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your le On Linux, you might need to give access to the USB ports by running: + ```bash sudo chmod 666 /dev/ttyACM0 sudo chmod 666 /dev/ttyACM1 @@ -99,9 +102,11 @@ python -m lerobot.setup_motors \ --robot.type=koch_follower \ --robot.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step ``` + + ```python from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig @@ -112,10 +117,13 @@ config = KochFollowerConfig( follower = KochFollower(config) follower.setup_motors() ``` + + You should see the following instruction. + ``` Connect the controller board to the 'gripper' motor only and press enter. ``` @@ -125,22 +133,26 @@ As instructed, plug the gripper's motor. Make sure it's the only motor connected
Troubleshooting - If you get an error at that point, check your cables and make sure they are plugged in properly: -
    -
  • Power supply
  • -
  • USB cable between your computer and the controller board
  • -
  • The 3-pin cable from the controller board to the motor
  • -
+If you get an error at that point, check your cables and make sure they are plugged in properly: + +
    +
  • Power supply
  • +
  • USB cable between your computer and the controller board
  • +
  • The 3-pin cable from the controller board to the motor
  • +
+ +If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB). - If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB).
You should then see the following message: + ``` 'gripper' motor id set to 6 ``` Followed by the next instruction: + ``` Connect the controller board to the 'wrist_roll' motor only and press enter. ``` @@ -155,6 +167,7 @@ Repeat the operation for each motor as instructed. When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm. #### Leader + Do the same steps for the leader arm but modify the command or script accordingly. @@ -165,9 +178,11 @@ python -m lerobot.setup_motors \ --teleop.type=koch_leader \ --teleop.port=/dev/tty.usbmodem575E0031751 \ # <- paste here the port found at previous step ``` + + ```python from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig @@ -178,6 +193,8 @@ config = KochLeaderConfig( leader = KochLeader(config) leader.setup_motors() ``` + + @@ -199,9 +216,11 @@ python -m lerobot.calibrate \ --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --robot.id=my_awesome_follower_arm # <- Give the robot a unique name ``` + + ```python from lerobot.robots.koch_follower import KochFollowerConfig, KochFollower @@ -215,6 +234,8 @@ follower.connect(calibrate=False) follower.calibrate() follower.disconnect() ``` + + @@ -233,9 +254,11 @@ python -m lerobot.calibrate \ --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name ``` + + ```python from lerobot.teleoperators.koch_leader import KochLeaderConfig, KochLeader @@ -249,10 +272,12 @@ leader.connect(calibrate=False) leader.calibrate() leader.disconnect() ``` + + Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) > [!TIP] -> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/src/lerobot/robots/lekiwi/lekiwi.mdx b/src/lerobot/robots/lekiwi/lekiwi.mdx index 61b1c05c1..bb70fd26b 100644 --- a/src/lerobot/robots/lekiwi/lekiwi.mdx +++ b/src/lerobot/robots/lekiwi/lekiwi.mdx @@ -8,31 +8,43 @@ Follow this [README](https://github.com/SIGRobotics-UIUC/LeKiwi). It contains th And advise if it's your first time printing or if you don't own a 3D printer. ### Wired version + If you have the **wired** LeKiwi version, you can skip the installation of the Raspberry Pi and setting up SSH. You can also run all commands directly on your PC for both the LeKiwi scripts and the leader arm scripts for teleoperating. ## Install software on Pi + Now we have to set up the remote PC that will run on the LeKiwi Robot. This is normally a Raspberry Pi, but can be any PC that can run on 5V and has enough usb ports (2 or more) for the cameras and motor control board. ### Install OS + For setting up the Raspberry Pi and its SD-card see: [Setup PI](https://www.raspberrypi.com/documentation/computers/getting-started.html). Here is explained how to download the [Imager](https://www.raspberrypi.com/software/) to install Raspberry Pi OS or Ubuntu. ### Setup SSH + After setting up your Pi, you should enable and set up [SSH](https://www.raspberrypi.com/news/coding-on-raspberry-pi-remotely-with-visual-studio-code/) (Secure Shell Protocol) so you can log in to the Pi from your laptop without requiring a screen, keyboard, and mouse on the Pi. A great tutorial on how to do this can be found [here](https://www.raspberrypi.com/documentation/computers/remote-access.html#ssh). Logging into your Pi can be done in your Command Prompt (cmd) or, if you use VSCode you can use [this](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-ssh) extension. ### Install LeRobot on Pi 🤗 On your Raspberry Pi install LeRobot using our [Installation Guide](./installation) -In addition to these instructions, you need to install the Feetech sdk on your Pi: +In addition to these instructions, you need to install the Feetech SDK & ZeroMQ on your Pi: + ```bash -pip install -e ".[feetech]" +pip install -e ".[lekiwi]" ``` ## Install LeRobot locally + If you already have installed LeRobot on your laptop/pc you can skip this step; otherwise, please follow along as we do the same steps we did on the Pi. Follow our [Installation Guide](./installation) +In addition to these instructions, you need to install the Feetech SDK & ZeroMQ on your laptop/pc: + +```bash +pip install -e ".[lekiwi]" +``` + Great :hugs:! You are now done installing LeRobot, and we can begin assembling the SO100/SO101 arms and the mobile base :robot:. Every time you now want to use LeRobot, you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands. @@ -46,6 +58,7 @@ First, we will assemble the two SO100/SO101 arms. One to attach to the mobile ba ### Find the USB ports associated with motor board To find the port for each bus servo adapter, run this script: + ```bash python -m lerobot.find_port ``` @@ -72,6 +85,7 @@ Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your bo On Linux, you might need to give access to the USB ports by running: + ```bash sudo chmod 666 /dev/ttyACM0 sudo chmod 666 /dev/ttyACM1 @@ -96,6 +110,7 @@ Where the found port is: `/dev/ttyACM0` corresponding to your board. ### Configure motors + The instructions for configuring the motors can be found in the SO101 [docs](./so101#configure-the-motors). Besides the ids for the arm motors, we also need to set the motor ids for the mobile base. These need to be in a specific order to work. Below an image of the motor ids and motor mounting positions for the mobile base. Note that we only use one Motor Control board on LeKiwi. This means the motor ids for the wheels are 7, 8 and 9. You can run this command to setup motors for LeKiwi. It will first setup the motors for arm (id 6..1) and then setup motors for wheels (9,8,7) @@ -113,27 +128,36 @@ python -m lerobot.setup_motors \ If you are having trouble connecting to the Mobile SO100, follow these steps to diagnose and resolve the issue. #### 1. Verify IP Address Configuration + Make sure that the correct IP for the Pi is used in the commands or in your code. To check the Raspberry Pi's IP address, run (on the Pi command line): + ```bash hostname -I ``` #### 2. Check if Pi is reachable from laptop/pc + Try pinging the Raspberry Pi from your laptop: + ```bach ping ``` If the ping fails: + - Ensure the Pi is powered on and connected to the same network. - Check if SSH is enabled on the Pi. #### 3. Try SSH connection + If you can't SSH into the Pi, it might not be properly connected. Use: + ```bash ssh @ ``` + If you get a connection error: + - Ensure SSH is enabled on the Pi by running: ```bash sudo raspi-config @@ -158,10 +182,13 @@ python -m lerobot.calibrate \ We unified the calibration method for most robots, thus, the calibration steps for this SO100 arm are the same as the steps for the Koch and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video). ### Wired version + If you have the **wired** LeKiwi version, please run all commands on your laptop. ### Calibrate leader arm + Then, to calibrate the leader arm (which is attached to the laptop/pc). Run the following command of API example on your laptop: + @@ -171,9 +198,11 @@ python -m lerobot.calibrate \ --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name ``` + + ```python from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader @@ -187,6 +216,8 @@ leader.connect(calibrate=False) leader.calibrate() leader.disconnect() ``` + + @@ -196,6 +227,7 @@ leader.disconnect() > If you're using a Mac, you might need to give Terminal permission to access your keyboard for teleoperation. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal. To teleoperate, SSH into your Raspberry Pi, and run `conda activate lerobot` and this command: + ```bash python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi ``` @@ -206,7 +238,7 @@ Then on your laptop, also run `conda activate lerobot` and run the API example, python examples/lekiwi/teleoperate.py ``` -You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below: +You should see on your laptop something like this: `[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below: | Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) | | ---------- | ------------------ | ---------------------- | @@ -214,7 +246,6 @@ You should see on your laptop something like this: ```[INFO] Connected to remote | Medium | 0.25 | 60 | | Slow | 0.1 | 30 | - | Key | Action | | --- | -------------- | | W | Move forward | @@ -227,9 +258,10 @@ You should see on your laptop something like this: ```[INFO] Connected to remote | F | Decrease speed | > [!TIP] -> If you use a different keyboard, you can change the keys for each command in the [`LeKiwiConfig`](../src/lerobot/robot_devices/robots/configs.py). +> If you use a different keyboard, you can change the keys for each command in the [`LeKiwiConfig`](../src/lerobot/robot_devices/robots/configs.py). ### Wired version + If you have the **wired** LeKiwi version, please run all commands on your laptop. ## Record a dataset @@ -239,26 +271,32 @@ Once you're familiar with teleoperation, you can record your first dataset. We use the Hugging Face hub features for uploading your dataset. If you haven't previously used the Hub, make sure you can login via the cli using a write-access token, this token can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens). Add your token to the CLI by running this command: + ```bash huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential ``` Then store your Hugging Face repository name in a variable: + ```bash HF_USER=$(huggingface-cli whoami | head -n 1) echo $HF_USER ``` Now you can record a dataset. To record episodes and upload your dataset to the hub, execute this API example tailored for LeKiwi. Make sure to first adapt the `remote_ip`, `repo_id`, `port` and `task` in the script. If you would like to run the script for longer you can increase `NB_CYCLES_CLIENT_CONNECTION`. + ```bash python examples/lekiwi/record.py ``` #### Dataset upload + Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running: + ```bash echo https://huggingface.co/datasets/${HF_USER}/so101_test ``` + Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example). You can look for other LeRobot datasets on the hub by searching for `LeRobot` [tags](https://huggingface.co/datasets?other=LeRobot). @@ -274,14 +312,13 @@ Avoid adding too much variation too quickly, as it may hinder your results. If you want to dive deeper into this important topic, you can check out the [blog post](https://huggingface.co/blog/lerobot-datasets#what-makes-a-good-dataset) we wrote on what makes a good dataset. #### Troubleshooting: -- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux). +- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux). ## Replay an episode To replay an episode run the API example below, make sure to change `remote_ip`, `port`, LeRobotDatasetId and episode index. - ```bash python examples/lekiwi/replay.py ``` @@ -297,4 +334,4 @@ python examples/lekiwi/evaluate.py ``` > [!TIP] -> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py index 0ce259bb6..9a8001401 100644 --- a/src/lerobot/robots/lekiwi/lekiwi_client.py +++ b/src/lerobot/robots/lekiwi/lekiwi_client.py @@ -18,11 +18,10 @@ import base64 import json import logging from functools import cached_property -from typing import Any, Dict, Optional, Tuple +from typing import Any import cv2 import numpy as np -import zmq from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError @@ -35,6 +34,9 @@ class LeKiwiClient(Robot): name = "lekiwi_client" def __init__(self, config: LeKiwiClientConfig): + import zmq + + self._zmq = zmq super().__init__(config) self.config = config self.id = config.id @@ -117,6 +119,7 @@ class LeKiwiClient(Robot): "LeKiwi Daemon is already connected. Do not run `robot.connect()` twice." ) + zmq = self._zmq self.zmq_context = zmq.Context() self.zmq_cmd_socket = self.zmq_context.socket(zmq.PUSH) zmq_cmd_locator = f"tcp://{self.remote_ip}:{self.port_zmq_cmd}" @@ -139,8 +142,9 @@ class LeKiwiClient(Robot): def calibrate(self) -> None: pass - def _poll_and_get_latest_message(self) -> Optional[str]: + def _poll_and_get_latest_message(self) -> str | None: """Polls the ZMQ socket for a limited time and returns the latest message string.""" + zmq = self._zmq poller = zmq.Poller() poller.register(self.zmq_observation_socket, zmq.POLLIN) @@ -167,7 +171,7 @@ class LeKiwiClient(Robot): return last_msg - def _parse_observation_json(self, obs_string: str) -> Optional[Dict[str, Any]]: + def _parse_observation_json(self, obs_string: str) -> dict[str, Any] | None: """Parses the JSON observation string.""" try: return json.loads(obs_string) @@ -175,7 +179,7 @@ class LeKiwiClient(Robot): logging.error(f"Error decoding JSON observation: {e}") return None - def _decode_image_from_b64(self, image_b64: str) -> Optional[np.ndarray]: + def _decode_image_from_b64(self, image_b64: str) -> np.ndarray | None: """Decodes a base64 encoded image string to an OpenCV image.""" if not image_b64: return None @@ -191,18 +195,18 @@ class LeKiwiClient(Robot): return None def _remote_state_from_obs( - self, observation: Dict[str, Any] - ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: + self, observation: dict[str, Any] + ) -> tuple[dict[str, np.ndarray], dict[str, Any]]: """Extracts frames, and state from the parsed observation.""" flat_state = {key: observation.get(key, 0.0) for key in self._state_order} state_vec = np.array([flat_state[key] for key in self._state_order], dtype=np.float32) - obs_dict: Dict[str, Any] = {**flat_state, "observation.state": state_vec} + obs_dict: dict[str, Any] = {**flat_state, "observation.state": state_vec} # Decode images - current_frames: Dict[str, np.ndarray] = {} + current_frames: dict[str, np.ndarray] = {} for cam_name, image_b64 in observation.items(): if cam_name not in self._cameras_ft: continue @@ -212,7 +216,7 @@ class LeKiwiClient(Robot): return current_frames, obs_dict - def _get_data(self) -> Tuple[Dict[str, np.ndarray], Dict[str, Any], Dict[str, Any]]: + def _get_data(self) -> tuple[dict[str, np.ndarray], dict[str, Any], dict[str, Any]]: """ Polls the video socket for the latest observation data. diff --git a/src/lerobot/robots/robot.py b/src/lerobot/robots/robot.py index 6820645cc..2a9004380 100644 --- a/src/lerobot/robots/robot.py +++ b/src/lerobot/robots/robot.py @@ -13,8 +13,9 @@ # limitations under the License. import abc +import builtins from pathlib import Path -from typing import Any, Type +from typing import Any import draccus @@ -39,7 +40,7 @@ class Robot(abc.ABC): """ # Set these in ALL subclasses - config_class: Type[RobotConfig] + config_class: builtins.type[RobotConfig] name: str def __init__(self, config: RobotConfig): diff --git a/src/lerobot/robots/so100_follower/so100.mdx b/src/lerobot/robots/so100_follower/so100.mdx index f5eea6aef..d9ff922c5 100644 --- a/src/lerobot/robots/so100_follower/so100.mdx +++ b/src/lerobot/robots/so100_follower/so100.mdx @@ -11,6 +11,7 @@ Follow this [README](https://github.com/TheRobotStudio/SO-ARM100/blob/main/SO100 To install LeRobot, follow our [Installation Guide](./installation) In addition to these instructions, you need to install the Feetech SDK: + ```bash pip install -e ".[feetech]" ``` @@ -23,6 +24,7 @@ Unlike the SO-101, the motor connectors are not easily accessible once the arm i ### 1. Find the USB ports associated with each arm To find the port for each bus servo adapter, run this script: + ```bash python -m lerobot.find_port ``` @@ -49,6 +51,7 @@ Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your le On Linux, you might need to give access to the USB ports by running: + ```bash sudo chmod 666 /dev/ttyACM0 sudo chmod 666 /dev/ttyACM1 @@ -94,9 +97,11 @@ python -m lerobot.setup_motors \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step ``` + + ```python from lerobot.robots.so100_follower import SO100Follower, SO100FollowerConfig @@ -107,10 +112,13 @@ config = SO100FollowerConfig( follower = SO100Follower(config) follower.setup_motors() ``` + + You should see the following instruction + ``` Connect the controller board to the 'gripper' motor only and press enter. ``` @@ -120,22 +128,26 @@ As instructed, plug the gripper's motor. Make sure it's the only motor connected
Troubleshooting - If you get an error at that point, check your cables and make sure they are plugged in properly: -
    -
  • Power supply
  • -
  • USB cable between your computer and the controller board
  • -
  • The 3-pin cable from the controller board to the motor
  • -
+If you get an error at that point, check your cables and make sure they are plugged in properly: + +
    +
  • Power supply
  • +
  • USB cable between your computer and the controller board
  • +
  • The 3-pin cable from the controller board to the motor
  • +
If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB). +
You should then see the following message: + ``` 'gripper' motor id set to 6 ``` Followed by the next instruction: + ``` Connect the controller board to the 'wrist_roll' motor only and press enter. ``` @@ -150,6 +162,7 @@ Repeat the operation for each motor as instructed. When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm. #### Leader + Do the same steps for the leader arm. @@ -162,6 +175,7 @@ python -m lerobot.setup_motors \
+ ```python from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig @@ -172,6 +186,8 @@ config = SO100LeaderConfig( leader = SO100Leader(config) leader.setup_motors() ``` + + @@ -184,7 +200,10 @@ leader.setup_motors()
@@ -193,6 +212,7 @@ leader.setup_motors() Follow the video for removing gears. You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm. ### Clean Parts + Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material. ### Additional Guidance @@ -202,7 +222,10 @@ Remove all support material from the 3D-printed parts. The easiest way to do thi
@@ -216,75 +239,117 @@ This video provides visual guidance for assembling the arms, but it doesn't spec ### First Motor **Step 2: Insert Wires** + - Insert two wires into the first motor. - + **Step 3: Install in Base** + - Place the first motor into the base. - + **Step 4: Secure Motor** + - Fasten the motor with 4 screws. Two from the bottom and two from top. **Step 5: Attach Motor Holder** + - Slide over the first motor holder and fasten it using two screws (one on each side). - + **Step 6: Attach Motor Horns** + - Install both motor horns, securing the top horn with a screw. Try not to move the motor position when attaching the motor horn, especially for the leader arms, where we removed the gears. - +
- Video adding motor horn + + Video adding motor horn +
**Step 7: Attach Shoulder Part** + - Route one wire to the back of the robot and the other to the left or towards you (see photo). - Attach the shoulder part. - + **Step 8: Secure Shoulder** + - Tighten the shoulder part with 4 screws on top and 4 on the bottom -*(access bottom holes by turning the shoulder).* + _(access bottom holes by turning the shoulder)._ --- ### Second Motor Assembly **Step 9: Install Motor 2** + - Slide the second motor in from the top and link the wire from motor 1 to motor 2. - + **Step 10: Attach Shoulder Holder** + - Add the shoulder motor holder. - Ensure the wire from motor 1 to motor 2 goes behind the holder while the other wire is routed upward (see photo). - This part can be tight to assemble, you can use a workbench like the image or a similar setup to push the part around the motor.
- - - + + +
**Step 11: Secure Motor 2** + - Fasten the second motor with 4 screws. **Step 12: Attach Motor Horn** + - Attach both motor horns to motor 2, again use the horn screw. **Step 13: Attach Base** + - Install the base attachment using 2 screws. **Step 14: Attach Upper Arm** + - Attach the upper arm with 4 screws on each side. @@ -294,89 +359,144 @@ This video provides visual guidance for assembling the arms, but it doesn't spec ### Third Motor Assembly **Step 15: Install Motor 3** + - Route the motor cable from motor 2 through the cable holder to motor 3, then secure motor 3 with 4 screws. **Step 16: Attach Motor Horn** + - Attach both motor horns to motor 3 and secure one again with a horn screw. - + **Step 17: Attach Forearm** + - Connect the forearm to motor 3 using 4 screws on each side. - + --- ### Fourth Motor Assembly **Step 18: Install Motor 4** + - Slide in motor 4, attach the cable from motor 3, and secure the cable in its holder with a screw.
- - + +
**Step 19: Attach Motor Holder 4** + - Install the fourth motor holder (a tight fit). Ensure one wire is routed upward and the wire from motor 3 is routed downward (see photo). - + **Step 20: Secure Motor 4 & Attach Horn** + - Fasten motor 4 with 4 screws and attach its motor horns, use for one a horn screw. - + --- ### Wrist Assembly **Step 21: Install Motor 5** + - Insert motor 5 into the wrist holder and secure it with 2 front screws. - + **Step 22: Attach Wrist** + - Connect the wire from motor 4 to motor 5. And already insert the other wire for the gripper. - Secure the wrist to motor 4 using 4 screws on both sides. - + **Step 23: Attach Wrist Horn** + - Install only one motor horn on the wrist motor and secure it with a horn screw. - + --- ### Follower Configuration **Step 24: Attach Gripper** + - Attach the gripper to motor 5. - + **Step 25: Install Gripper Motor** + - Insert the gripper motor, connect the motor wire from motor 5 to motor 6, and secure it with 3 screws on each side. - + **Step 26: Attach Gripper Horn & Claw** + - Attach the motor horns and again use a horn screw. - Install the gripper claw and secure it with 4 screws on both sides. - + **Step 27: Mount Controller** + - Attach the motor controller to the back of the robot.
- - + +
-*Assembly complete – proceed to Leader arm assembly.* +_Assembly complete – proceed to Leader arm assembly._ --- @@ -385,31 +505,54 @@ This video provides visual guidance for assembling the arms, but it doesn't spec For the leader configuration, perform **Steps 1–23**. Make sure that you removed the motor gears from the motors. **Step 24: Attach Leader Holder** + - Mount the leader holder onto the wrist and secure it with a screw. - + **Step 25: Attach Handle** + - Attach the handle to motor 5 using 4 screws. - + **Step 26: Install Gripper Motor** + - Insert the gripper motor, secure it with 3 screws on each side, attach a motor horn using a horn screw, and connect the motor wire. - + **Step 27: Attach Trigger** + - Attach the follower trigger with 4 screws. - + **Step 28: Mount Controller** + - Attach the motor controller to the back of the robot.
- - + +
## Calibrate @@ -430,9 +573,11 @@ python -m lerobot.calibrate \ --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --robot.id=my_awesome_follower_arm # <- Give the robot a unique name ``` + + ```python from lerobot.robots.so100_follower import SO100FollowerConfig, SO100Follower @@ -446,6 +591,8 @@ follower.connect(calibrate=False) follower.calibrate() follower.disconnect() ``` + + @@ -464,9 +611,11 @@ python -m lerobot.calibrate \ --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name ``` + + ```python from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader @@ -480,10 +629,12 @@ leader.connect(calibrate=False) leader.calibrate() leader.disconnect() ``` + + Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) > [!TIP] -> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/src/lerobot/robots/so101_follower/so101.mdx b/src/lerobot/robots/so101_follower/so101.mdx index c49807d93..e84336e17 100644 --- a/src/lerobot/robots/so101_follower/so101.mdx +++ b/src/lerobot/robots/so101_follower/so101.mdx @@ -12,6 +12,7 @@ And advise if it's your first time printing or if you don't own a 3D printer. To install LeRobot, follow our [Installation Guide](./installation) In addition to these instructions, you need to install the Feetech SDK: + ```bash pip install -e ".[feetech]" ``` @@ -20,16 +21,17 @@ pip install -e ".[feetech]" The follower arm uses 6x STS3215 motors with 1/345 gearing. The leader, however, uses three differently geared motors to make sure it can both sustain its own weight and it can be moved without requiring much force. Which motor is needed for which joint is shown in the table below. -| Leader-Arm Axis | Motor | Gear Ratio | -|-----------------|:-------:|:----------:| -| Base / Shoulder Pan | 1 | 1 / 191 | -| Shoulder Lift | 2 | 1 / 345 | -| Elbow Flex | 3 | 1 / 191 | -| Wrist Flex | 4 | 1 / 147 | -| Wrist Roll | 5 | 1 / 147 | -| Gripper | 6 | 1 / 147 | +| Leader-Arm Axis | Motor | Gear Ratio | +| ------------------- | :---: | :--------: | +| Base / Shoulder Pan | 1 | 1 / 191 | +| Shoulder Lift | 2 | 1 / 345 | +| Elbow Flex | 3 | 1 / 191 | +| Wrist Flex | 4 | 1 / 147 | +| Wrist Roll | 5 | 1 / 147 | +| Gripper | 6 | 1 / 147 | ### Clean Parts + Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material. ### Joint 1 @@ -44,7 +46,10 @@ Remove all support material from the 3D-printed parts. The easiest way to do thi
@@ -57,7 +62,10 @@ Remove all support material from the 3D-printed parts. The easiest way to do thi
@@ -69,7 +77,10 @@ Remove all support material from the 3D-printed parts. The easiest way to do thi
@@ -81,7 +92,10 @@ Remove all support material from the 3D-printed parts. The easiest way to do thi
@@ -93,7 +107,10 @@ Remove all support material from the 3D-printed parts. The easiest way to do thi
@@ -109,7 +126,10 @@ Remove all support material from the 3D-printed parts. The easiest way to do thi
@@ -123,7 +143,10 @@ Remove all support material from the 3D-printed parts. The easiest way to do thi
@@ -135,6 +158,7 @@ Remove all support material from the 3D-printed parts. The easiest way to do thi ### 1. Find the USB ports associated with each arm To find the port for each bus servo adapter, run this script: + ```bash python -m lerobot.find_port ``` @@ -161,6 +185,7 @@ Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your le On Linux, you might need to give access to the USB ports by running: + ```bash sudo chmod 666 /dev/ttyACM0 sudo chmod 666 /dev/ttyACM1 @@ -198,7 +223,10 @@ The video below shows the sequence of steps for setting the motor ids.
@@ -214,9 +242,11 @@ python -m lerobot.setup_motors \ --robot.type=so101_follower \ --robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step ``` +
+ ```python from lerobot.robots.so101_follower import SO101Follower, SO101FollowerConfig @@ -227,10 +257,13 @@ config = SO101FollowerConfig( follower = SO101Follower(config) follower.setup_motors() ``` + + You should see the following instruction + ```bash Connect the controller board to the 'gripper' motor only and press enter. ``` @@ -240,22 +273,26 @@ As instructed, plug the gripper's motor. Make sure it's the only motor connected
Troubleshooting - If you get an error at that point, check your cables and make sure they are plugged in properly: -
    -
  • Power supply
  • -
  • USB cable between your computer and the controller board
  • -
  • The 3-pin cable from the controller board to the motor
  • -
+If you get an error at that point, check your cables and make sure they are plugged in properly: + +
    +
  • Power supply
  • +
  • USB cable between your computer and the controller board
  • +
  • The 3-pin cable from the controller board to the motor
  • +
+ +If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB). - If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB).
You should then see the following message: + ```bash 'gripper' motor id set to 6 ``` Followed by the next instruction: + ```bash Connect the controller board to the 'wrist_roll' motor only and press enter. ``` @@ -270,6 +307,7 @@ Repeat the operation for each motor as instructed. When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm. #### Leader + Do the same steps for the leader arm. @@ -280,9 +318,11 @@ python -m lerobot.setup_motors \ --teleop.type=so101_leader \ --teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step ``` + + ```python from lerobot.teleoperators.so101_leader import SO101Leader, SO101LeaderConfig @@ -293,6 +333,8 @@ config = SO101LeaderConfig( leader = SO101Leader(config) leader.setup_motors() ``` + + @@ -314,9 +356,11 @@ python -m lerobot.calibrate \ --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --robot.id=my_awesome_follower_arm # <- Give the robot a unique name ``` + + ```python from lerobot.robots.so101_follower import SO101FollowerConfig, SO101Follower @@ -330,6 +374,8 @@ follower.connect(calibrate=False) follower.calibrate() follower.disconnect() ``` + + @@ -339,7 +385,10 @@ The video below shows how to perform the calibration. First you need to move the
@@ -356,9 +405,11 @@ python -m lerobot.calibrate \ --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name ``` + + ```python from lerobot.teleoperators.so101_leader import SO101LeaderConfig, SO101Leader @@ -372,10 +423,12 @@ leader.connect(calibrate=False) leader.calibrate() leader.disconnect() ``` + + Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) > [!TIP] -> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/src/lerobot/robots/stretch3/README.md b/src/lerobot/robots/stretch3/README.md index 982e72571..724732286 100644 --- a/src/lerobot/robots/stretch3/README.md +++ b/src/lerobot/robots/stretch3/README.md @@ -5,16 +5,17 @@ This tutorial explains how to use [Stretch 3](https://hello-robot.com/stretch-3- Familiarize yourself with Stretch by following its [tutorials](https://docs.hello-robot.com/0.3/getting_started/hello_robot/) (recommended). To use LeRobot on Stretch, 3 options are available: + - [tethered setup](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#tethered-setup) - [untethered setup](https://docs.hello-robot.com/0.3/getting_started/connecting_to_stretch/#untethered-setup) - ssh directly into Stretch (you will first need to install and configure openssh-server on stretch using one of the two above setups) - ## Install LeRobot On Stretch's CLI, follow these steps: 1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install): + ```bash mkdir -p ~/miniconda3 wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh @@ -24,6 +25,7 @@ rm ~/miniconda3/miniconda.sh ``` 2. Comment out these lines in `~/.profile` (this can mess up paths used by conda and ~/.local/bin should already be in your PATH) + ``` # set PATH so it includes user's private bin if it exists if [ -d "$HOME/.local/bin" ] ; then @@ -34,21 +36,25 @@ fi 3. Restart shell or `source ~/.bashrc` 4. Create and activate a fresh conda environment for lerobot + ```bash conda create -y -n lerobot python=3.10 && conda activate lerobot ``` 5. Clone LeRobot: + ```bash git clone https://github.com/huggingface/lerobot.git ~/lerobot ``` 6. When using `miniconda`, install `ffmpeg` in your environment: + ```bash conda install ffmpeg -c conda-forge ``` 7. Install LeRobot with stretch dependencies: + ```bash cd ~/lerobot && pip install -e ".[stretch]" ``` @@ -56,6 +62,7 @@ cd ~/lerobot && pip install -e ".[stretch]" > **Note:** If you get this message, you can ignore it: `ERROR: pip's dependency resolver does not currently take into account all the packages that are installed.` 8. Run a [system check](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#system-check) to make sure your robot is ready: + ```bash stretch_system_check.py ``` @@ -63,6 +70,7 @@ stretch_system_check.py > **Note:** You may need to free the "robot process" after booting Stretch by running `stretch_free_robot_process.py`. For more info this Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#turning-off-gamepad-teleoperation). You should get something like this: + ```bash For use with S T R E T C H (R) from Hello Robot Inc. --------------------------------------------------------------------- @@ -89,11 +97,13 @@ Serial Number = stretch-se3-3054 **Calibrate (Optional)** Before operating Stretch, you need to [home](https://docs.hello-robot.com/0.3/getting_started/stretch_hardware_overview/#homing) it first. Be mindful about giving Stretch some space as this procedure will move the robot's arm and gripper. Now run this command: + ```bash python lerobot/scripts/control_robot.py \ --robot.type=stretch \ --control.type=calibrate ``` + This is equivalent to running `stretch_robot_home.py` > **Note:** If you run any of the LeRobot scripts below and Stretch is not properly homed, it will automatically home/calibrate first. @@ -104,28 +114,33 @@ Before trying teleoperation, you need to activate the gamepad controller by pres Now try out teleoperation (see above documentation to learn about the gamepad controls): > **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. + ```bash python lerobot/scripts/control_robot.py \ --robot.type=stretch \ --control.type=teleoperate ``` + This is essentially the same as running `stretch_gamepad_teleop.py` **Record a dataset** Once you're familiar with the gamepad controls and after a bit of practice, you can try to record your first dataset with Stretch. If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens): + ```bash huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential ``` Store your Hugging Face repository name in a variable to run these commands: + ```bash HF_USER=$(huggingface-cli whoami | head -n 1) echo $HF_USER ``` Record one episode: + ```bash python lerobot/scripts/control_robot.py \ --robot.type=stretch \ @@ -145,6 +160,7 @@ python lerobot/scripts/control_robot.py \ **Replay an episode** Now try to replay this episode (make sure the robot's initial position is the same): + ```bash python lerobot/scripts/control_robot.py \ --robot.type=stretch \ diff --git a/src/lerobot/robots/viperx/README.md b/src/lerobot/robots/viperx/README.md index 445368e7a..4e90c99c7 100644 --- a/src/lerobot/robots/viperx/README.md +++ b/src/lerobot/robots/viperx/README.md @@ -4,12 +4,12 @@ This tutorial explains how to use [Aloha and Aloha 2 stationary](https://www.tro Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/2.0/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer. - ## Install LeRobot On your computer: 1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install): + ```bash mkdir -p ~/miniconda3 wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh @@ -21,29 +21,34 @@ rm ~/miniconda3/miniconda.sh 2. Restart shell or `source ~/.bashrc` 3. Create and activate a fresh conda environment for lerobot + ```bash conda create -y -n lerobot python=3.10 && conda activate lerobot ``` 4. Clone LeRobot: + ```bash git clone https://github.com/huggingface/lerobot.git ~/lerobot ``` 5. When using `miniconda`, install `ffmpeg` in your environment: + ```bash conda install ffmpeg -c conda-forge ``` 6. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense): + ```bash cd ~/lerobot && pip install -e ".[dynamixel, intelrealsense]" ``` ## Teleoperate -**/!\ FOR SAFETY, READ THIS /!\** +\*\*/!\ FOR SAFETY, READ THIS /!\*\* Teleoperation consists in manually operating the leader arms to move the follower arms. Importantly: + 1. Make sure your leader arms are in the same position as the follower arms, so that the follower arms don't move too fast to match the leader arms, 2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics. @@ -59,6 +64,7 @@ python lerobot/scripts/control_robot.py \ ``` By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`AlohaRobotConfig`](lerobot/robot_devices/robots/configs.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line: + ```bash python lerobot/scripts/control_robot.py \ --robot.type=aloha \ @@ -71,17 +77,20 @@ python lerobot/scripts/control_robot.py \ Once you're familiar with teleoperation, you can record your first dataset with Aloha. If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens): + ```bash huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential ``` Store your Hugging Face repository name in a variable to run these commands: + ```bash HF_USER=$(huggingface-cli whoami | head -n 1) echo $HF_USER ``` Record 2 episodes and upload your dataset to the hub: + ```bash python lerobot/scripts/control_robot.py \ --robot.type=aloha \ @@ -101,11 +110,13 @@ python lerobot/scripts/control_robot.py \ ## Visualize a dataset If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by: + ```bash echo ${HF_USER}/aloha_test ``` If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with: + ```bash python -m lerobot.scripts.visualize_dataset_html \ --repo-id ${HF_USER}/aloha_test @@ -113,10 +124,11 @@ python -m lerobot.scripts.visualize_dataset_html \ ## Replay an episode -**/!\ FOR SAFETY, READ THIS /!\** +\*\*/!\ FOR SAFETY, READ THIS /!\*\* Replay consists in automatically replaying the sequence of actions (i.e. goal positions for your motors) recorded in a given dataset episode. Make sure the current initial position of your robot is similar to the one in your episode, so that your follower arms don't move too fast to go to the first goal positions. For safety, you might want to add `--robot.max_relative_target=5` to your command line as explained above. Now try to replay the first episode on your robot: + ```bash python lerobot/scripts/control_robot.py \ --robot.type=aloha \ @@ -130,6 +142,7 @@ python lerobot/scripts/control_robot.py \ ## Train a policy To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: + ```bash python -m lerobot.scripts.train \ --dataset.repo_id=${HF_USER}/aloha_test \ @@ -141,10 +154,11 @@ python -m lerobot.scripts.train \ ``` Let's explain it: + 1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`. 2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. -4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. -5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. +3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. +4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md) @@ -153,6 +167,7 @@ Training should take several hours. You will find checkpoints in `outputs/train/ ## Evaluate your policy You can use the `record` function from [`lerobot/scripts/control_robot.py`](../src/lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes: + ```bash python lerobot/scripts/control_robot.py \ --robot.type=aloha \ @@ -171,7 +186,8 @@ python lerobot/scripts/control_robot.py \ ``` As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: -1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`). + +1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`). 2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_aloha_test`). 3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constant 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`. diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index d85ac27b3..7c5aec48a 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -50,12 +50,12 @@ import json import logging import threading import time +from collections.abc import Callable from contextlib import nullcontext from copy import deepcopy from dataclasses import asdict from pathlib import Path from pprint import pformat -from typing import Callable import einops import gymnasium as gym diff --git a/src/lerobot/scripts/rl/crop_dataset_roi.py b/src/lerobot/scripts/rl/crop_dataset_roi.py index 0b71b5363..69904b740 100644 --- a/src/lerobot/scripts/rl/crop_dataset_roi.py +++ b/src/lerobot/scripts/rl/crop_dataset_roi.py @@ -18,7 +18,6 @@ import argparse import json from copy import deepcopy from pathlib import Path -from typing import Dict, Tuple import cv2 import torch @@ -162,10 +161,10 @@ def get_image_from_lerobot_dataset(dataset: LeRobotDataset): def convert_lerobot_dataset_to_cropper_lerobot_dataset( original_dataset: LeRobotDataset, - crop_params_dict: Dict[str, Tuple[int, int, int, int]], + crop_params_dict: dict[str, tuple[int, int, int, int]], new_repo_id: str, new_dataset_root: str, - resize_size: Tuple[int, int] = (128, 128), + resize_size: tuple[int, int] = (128, 128), push_to_hub: bool = False, task: str = "", ) -> LeRobotDataset: diff --git a/src/lerobot/scripts/rl/gym_manipulator.py b/src/lerobot/scripts/rl/gym_manipulator.py index 673043b6e..c8be6b7dd 100644 --- a/src/lerobot/scripts/rl/gym_manipulator.py +++ b/src/lerobot/scripts/rl/gym_manipulator.py @@ -39,8 +39,9 @@ Example: import logging import time from collections import deque +from collections.abc import Sequence from threading import Lock -from typing import Annotated, Any, Sequence +from typing import Annotated, Any import gymnasium as gym import numpy as np diff --git a/src/lerobot/scripts/rl/learner.py b/src/lerobot/scripts/rl/learner.py index cb88895cf..f9f3901ce 100644 --- a/src/lerobot/scripts/rl/learner.py +++ b/src/lerobot/scripts/rl/learner.py @@ -87,12 +87,10 @@ from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.random_utils import set_seed from lerobot.utils.train_utils import ( get_step_checkpoint_dir, + load_training_state as utils_load_training_state, save_checkpoint, update_last_checkpoint, ) -from lerobot.utils.train_utils import ( - load_training_state as utils_load_training_state, -) from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device from lerobot.utils.utils import ( format_big_number, diff --git a/src/lerobot/scripts/server/configs.py b/src/lerobot/scripts/server/configs.py index 7058842ae..5be46485e 100644 --- a/src/lerobot/scripts/server/configs.py +++ b/src/lerobot/scripts/server/configs.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable from dataclasses import dataclass, field -from typing import Callable import torch diff --git a/src/lerobot/scripts/server/robot_client.py b/src/lerobot/scripts/server/robot_client.py index 44d9cdf77..68166de6f 100644 --- a/src/lerobot/scripts/server/robot_client.py +++ b/src/lerobot/scripts/server/robot_client.py @@ -36,10 +36,11 @@ import logging import pickle # nosec import threading import time +from collections.abc import Callable from dataclasses import asdict from pprint import pformat from queue import Queue -from typing import Any, Callable, Optional +from typing import Any import draccus import grpc @@ -231,7 +232,7 @@ class RobotClient: def _aggregate_action_queues( self, incoming_actions: list[TimedAction], - aggregate_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + aggregate_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, ): """Finds the same timestep actions in the queue and aggregates them using the aggregate_fn""" if aggregate_fn is None: diff --git a/src/lerobot/scripts/visualize_dataset.py b/src/lerobot/scripts/visualize_dataset.py index 37db66ddf..51ead0dd1 100644 --- a/src/lerobot/scripts/visualize_dataset.py +++ b/src/lerobot/scripts/visualize_dataset.py @@ -65,8 +65,8 @@ import argparse import gc import logging import time +from collections.abc import Iterator from pathlib import Path -from typing import Iterator import numpy as np import rerun as rr diff --git a/src/lerobot/teleoperators/homunculus/homunculus_arm.py b/src/lerobot/teleoperators/homunculus/homunculus_arm.py index dfce0c88e..6f5137af9 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_arm.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_arm.py @@ -18,7 +18,7 @@ import logging import threading from collections import deque from pprint import pformat -from typing import Deque, Dict, Optional +from typing import Deque import serial @@ -60,7 +60,7 @@ class HomunculusArm(Teleoperator): self.n: int = n self.alpha: float = 2 / (n + 1) # one deque *per joint* so we can inspect raw history if needed - self._buffers: Dict[str, Deque[int]] = { + self._buffers: dict[str, Deque[int]] = { joint: deque(maxlen=n) for joint in ( "shoulder_pitch", @@ -73,7 +73,7 @@ class HomunculusArm(Teleoperator): ) } # running EMA value per joint – lazily initialised on first read - self._ema: Dict[str, Optional[float]] = dict.fromkeys(self._buffers) + self._ema: dict[str, float | None] = dict.fromkeys(self._buffers) self._state: dict[str, float] | None = None self.new_state_event = threading.Event() @@ -217,9 +217,9 @@ class HomunculusArm(Teleoperator): return normalized_values - def _apply_ema(self, raw: Dict[str, int]) -> Dict[str, float]: + def _apply_ema(self, raw: dict[str, int]) -> dict[str, float]: """Update buffers & running EMA values; return smoothed dict.""" - smoothed: Dict[str, float] = {} + smoothed: dict[str, float] = {} for joint, value in raw.items(): # maintain raw history self._buffers[joint].append(value) diff --git a/src/lerobot/teleoperators/homunculus/homunculus_glove.py b/src/lerobot/teleoperators/homunculus/homunculus_glove.py index d367a2a7c..7b0ced9f6 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_glove.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_glove.py @@ -18,7 +18,7 @@ import logging import threading from collections import deque from pprint import pformat -from typing import Deque, Dict, Optional +from typing import Deque import serial @@ -97,9 +97,9 @@ class HomunculusGlove(Teleoperator): self.n: int = n self.alpha: float = 2 / (n + 1) # one deque *per joint* so we can inspect raw history if needed - self._buffers: Dict[str, Deque[int]] = {joint: deque(maxlen=n) for joint in self.joints} + self._buffers: dict[str, Deque[int]] = {joint: deque(maxlen=n) for joint in self.joints} # running EMA value per joint – lazily initialised on first read - self._ema: Dict[str, Optional[float]] = dict.fromkeys(self._buffers) + self._ema: dict[str, float | None] = dict.fromkeys(self._buffers) self._state: dict[str, float] | None = None self.new_state_event = threading.Event() @@ -248,9 +248,9 @@ class HomunculusGlove(Teleoperator): return normalized_values - def _apply_ema(self, raw: Dict[str, int]) -> Dict[str, int]: + def _apply_ema(self, raw: dict[str, int]) -> dict[str, int]: """Update buffers & running EMA values; return smoothed dict as integers.""" - smoothed: Dict[str, int] = {} + smoothed: dict[str, int] = {} for joint, value in raw.items(): # maintain raw history self._buffers[joint].append(value) diff --git a/src/lerobot/teleoperators/teleoperator.py b/src/lerobot/teleoperators/teleoperator.py index 49f259c17..c360ee7bb 100644 --- a/src/lerobot/teleoperators/teleoperator.py +++ b/src/lerobot/teleoperators/teleoperator.py @@ -13,8 +13,9 @@ # limitations under the License. import abc +import builtins from pathlib import Path -from typing import Any, Type +from typing import Any import draccus @@ -37,7 +38,7 @@ class Teleoperator(abc.ABC): """ # Set these in ALL subclasses - config_class: Type[TeleoperatorConfig] + config_class: builtins.type[TeleoperatorConfig] name: str def __init__(self, config: TeleoperatorConfig): diff --git a/src/lerobot/templates/lerobot_modelcard_template.md b/src/lerobot/templates/lerobot_modelcard_template.md index 64ad7196c..7b7aaa84a 100644 --- a/src/lerobot/templates/lerobot_modelcard_template.md +++ b/src/lerobot/templates/lerobot_modelcard_template.md @@ -1,7 +1,8 @@ --- # For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 # Doc / guide: https://huggingface.co/docs/hub/model-cards -{{ card_data }} +# prettier-ignore +{{card_data}} --- # Model Card for {{ model_name | default("Model ID", true) }} @@ -53,7 +54,7 @@ python -m lerobot.scripts.train \ --wandb.enable=true ``` -*Writes checkpoints to `outputs/train//checkpoints/`.* +_Writes checkpoints to `outputs/train//checkpoints/`._ ### Evaluate the policy/run inference @@ -71,4 +72,4 @@ Prefix the dataset repo with **eval\_** and supply `--policy.path` pointing to a ## Model Details -* **License:** {{ license | default("\[More Information Needed]", true) }} +- **License:** {{ license | default("\[More Information Needed]", true) }} diff --git a/src/lerobot/utils/benchmark.py b/src/lerobot/utils/benchmark.py index 4b08e6f6d..d9e5e62bb 100644 --- a/src/lerobot/utils/benchmark.py +++ b/src/lerobot/utils/benchmark.py @@ -46,11 +46,13 @@ class TimeBenchmark(ContextDecorator): benchmark = TimeBenchmark() + def context_manager_example(): with benchmark: time.sleep(0.01) print(f"Block took {benchmark.result_ms:.2f} milliseconds") + threads = [] for _ in range(3): t1 = threading.Thread(target=context_manager_example) diff --git a/src/lerobot/utils/buffer.py b/src/lerobot/utils/buffer.py index 7f8d989dd..d9ffa899c 100644 --- a/src/lerobot/utils/buffer.py +++ b/src/lerobot/utils/buffer.py @@ -15,8 +15,9 @@ # limitations under the License. import functools +from collections.abc import Callable, Sequence from contextlib import suppress -from typing import Callable, Sequence, TypedDict +from typing import TypedDict import torch import torch.nn.functional as F # noqa: N812 diff --git a/src/lerobot/utils/hub.py b/src/lerobot/utils/hub.py index df7435c0f..566701b31 100644 --- a/src/lerobot/utils/hub.py +++ b/src/lerobot/utils/hub.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import builtins from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Type, TypeVar +from typing import Any, TypeVar from huggingface_hub import HfApi from huggingface_hub.utils import validate_hf_hub_args @@ -85,7 +86,7 @@ class HubMixin: @classmethod @validate_hf_hub_args def from_pretrained( - cls: Type[T], + cls: builtins.type[T], pretrained_name_or_path: str | Path, *, force_download: bool = False, diff --git a/src/lerobot/utils/random_utils.py b/src/lerobot/utils/random_utils.py index 31fed1da6..da3ecf37f 100644 --- a/src/lerobot/utils/random_utils.py +++ b/src/lerobot/utils/random_utils.py @@ -14,9 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import random +from collections.abc import Generator from contextlib import contextmanager from pathlib import Path -from typing import Any, Generator +from typing import Any import numpy as np import torch diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 7a9717dce..6e13646b0 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -185,10 +185,10 @@ def print_cuda_memory_usage(): gc.collect() # Also clear the cache if you want to fully release the memory torch.cuda.empty_cache() - print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2)) - print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2)) - print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2)) - print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2)) + print(f"Current GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB") + print(f"Maximum GPU Memory Allocated: {torch.cuda.max_memory_allocated(0) / 1024**2:.2f} MB") + print(f"Current GPU Memory Reserved: {torch.cuda.memory_reserved(0) / 1024**2:.2f} MB") + print(f"Maximum GPU Memory Reserved: {torch.cuda.max_memory_reserved(0) / 1024**2:.2f} MB") def capture_timestamp_utc(): diff --git a/tests/configs/test_plugin_loading.py b/tests/configs/test_plugin_loading.py index e81057c95..3ec60a485 100644 --- a/tests/configs/test_plugin_loading.py +++ b/tests/configs/test_plugin_loading.py @@ -15,9 +15,9 @@ # limitations under the License. import sys +from collections.abc import Generator from dataclasses import dataclass from pathlib import Path -from typing import Generator import pytest diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py index 00403d146..84026fc34 100644 --- a/tests/mocks/mock_dynamixel.py +++ b/tests/mocks/mock_dynamixel.py @@ -15,7 +15,7 @@ # limitations under the License. import abc -from typing import Callable +from collections.abc import Callable import dynamixel_sdk as dxl import serial diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py index 041c09421..33cbc41d6 100644 --- a/tests/mocks/mock_feetech.py +++ b/tests/mocks/mock_feetech.py @@ -15,7 +15,7 @@ # limitations under the License. import abc -from typing import Callable +from collections.abc import Callable import scservo_sdk as scs import serial diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py index d990b5b0f..e0dbe713a 100644 --- a/tests/motors/test_dynamixel.py +++ b/tests/motors/test_dynamixel.py @@ -16,7 +16,7 @@ import re import sys -from typing import Generator +from collections.abc import Generator from unittest.mock import MagicMock, patch import pytest diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index d6ea1db20..31e4a9018 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -16,7 +16,7 @@ import re import sys -from typing import Generator +from collections.abc import Generator from unittest.mock import MagicMock, patch import pytest diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index 260276032..a53d7ba8c 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -15,7 +15,7 @@ # limitations under the License. import sys -from typing import Callable +from collections.abc import Callable import pytest import torch From 7e9f955b4023887439eba681f14bbb81b40c1357 Mon Sep 17 00:00:00 2001 From: Juan Pizarro Date: Thu, 17 Jul 2025 17:01:48 +0200 Subject: [PATCH 019/158] fix(hil-serl): drain queue on get_last_item_from_queue (#1524) * fix(hil-serl): drain queue on get_last_item_from_queue * parametrize queue tests * revert changes for Darwin * revert parametrize queue tests * add test_get_last_item_multiple_items_with_torch_queue * update test_get_last_item_multiple_items_with_torch_queue * update test_get_last_item_multiple_items_with_torch_queue --- src/lerobot/utils/queue.py | 21 +++++++++++++++++---- tests/utils/test_queue.py | 16 ++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/lerobot/utils/queue.py b/src/lerobot/utils/queue.py index ceb30e2bf..864d798ac 100644 --- a/src/lerobot/utils/queue.py +++ b/src/lerobot/utils/queue.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import platform +from contextlib import suppress from queue import Empty from typing import Any @@ -30,10 +32,21 @@ def get_last_item_from_queue(queue: Queue, block=True, timeout: float = 0.1) -> item = None # Drain queue and keep only the most recent parameters - try: - while True: + if platform.system() == "Darwin": + # On Mac, avoid using `qsize` due to unreliable implementation. + # There is a comment on `qsize` code in the Python source: + # Raises NotImplementedError on Mac OSX because of broken sem_getvalue() + try: + while True: + item = queue.get_nowait() + except Empty: + pass + + return item + + # Details about using qsize in https://github.com/huggingface/lerobot/issues/1523 + while queue.qsize() > 0: + with suppress(Empty): item = queue.get_nowait() - except Empty: - pass return item diff --git a/tests/utils/test_queue.py b/tests/utils/test_queue.py index 0a0d21770..6e42acdb7 100644 --- a/tests/utils/test_queue.py +++ b/tests/utils/test_queue.py @@ -18,6 +18,8 @@ import threading import time from queue import Queue +from torch.multiprocessing import Queue as TorchMPQueue + from lerobot.utils.queue import get_last_item_from_queue @@ -46,6 +48,20 @@ def test_get_last_item_multiple_items(): assert queue.empty() +def test_get_last_item_multiple_items_with_torch_queue(): + """Test getting the last item when queue has multiple items.""" + queue = TorchMPQueue() + items = ["first", "second", "third", "fourth", "last"] + + for item in items: + queue.put(item) + + result = get_last_item_from_queue(queue) + + assert result == "last" + assert queue.empty() + + def test_get_last_item_different_types(): """Test with different data types in the queue.""" queue = Queue() From 38d3737f09dc3cc8f777a57c9f5a89dd541b464c Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 17 Jul 2025 18:07:07 +0200 Subject: [PATCH 020/158] feat(ci): add new & clean dockerfiles (#1525) --- .github/workflows/test-docker-build.yml | 4 +- docker/Dockerfile.internal | 60 +++++++++++++++++++++++++ docker/Dockerfile.user | 50 +++++++++++++++++++++ pyproject.toml | 31 ++++++++++--- 4 files changed, 139 insertions(+), 6 deletions(-) create mode 100644 docker/Dockerfile.internal create mode 100644 docker/Dockerfile.user diff --git a/.github/workflows/test-docker-build.yml b/.github/workflows/test-docker-build.yml index 7a1e93274..c33813418 100644 --- a/.github/workflows/test-docker-build.yml +++ b/.github/workflows/test-docker-build.yml @@ -20,7 +20,9 @@ on: pull_request: paths: # Run only when DockerFile files are modified - - "docker/**" + - "docker/lerobot-cpu/**" + - "docker/lerobot-gpu/**" + - "docker/lerobot-gpu-dev/**" permissions: {} diff --git a/docker/Dockerfile.internal b/docker/Dockerfile.internal new file mode 100644 index 000000000..051606449 --- /dev/null +++ b/docker/Dockerfile.internal @@ -0,0 +1,60 @@ +# Dockerfile.internal +# This Dockerfile is designed for HuggingFace internal CI environments +# that require GPU access. It starts from an NVIDIA CUDA base image. + +# docker build -f docker/Dockerfile.internal -t lerobot-ci . + +# Configure the base image for CI with GPU access +ARG CUDA_VERSION=12.9.1 +ARG OS_VERSION=24.04 +FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION} + +# Define Python version argument +ARG PYTHON_VERSION=3.10 + +# Configure environment variables +ENV DEBIAN_FRONTEND=noninteractive \ + MUJOCO_GL="egl" \ + PATH="/lerobot/.venv/bin:$PATH" + +# Install Python, system dependencies, and uv (as root) +RUN apt-get update && apt-get install -y --no-install-recommends \ + software-properties-common \ + build-essential git curl \ + libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \ + libusb-1.0-0-dev \ + speech-dispatcher libgeos-dev \ + && add-apt-repository -y ppa:deadsnakes/ppa \ + && apt-get update \ + && apt-get install -y --no-install-recommends \ + python${PYTHON_VERSION} \ + python${PYTHON_VERSION}-venv \ + python${PYTHON_VERSION}-dev \ + && curl -LsSf https://astral.sh/uv/install.sh | sh \ + && mv /root/.local/bin/uv /usr/local/bin/uv \ + && useradd --create-home --shell /bin/bash user_lerobot \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +# Create application directory and set permissions +WORKDIR /lerobot +RUN chown -R user_lerobot:user_lerobot /lerobot + +# Switch to the non-root user +USER user_lerobot + +# Create the virtual environment +# We use a virtual environment inside the container—even though the container itself \ +# provides isolation—to ensure compatibility with the cluster and to prevent \ +# issues with MuJoCo and OpenGL drivers. +RUN uv venv --python python${PYTHON_VERSION} + +# Install Python dependencies for caching +COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md ./ +COPY --chown=user_lerobot:user_lerobot src/ src/ +RUN uv pip install --no-cache ".[all]" + +# Copy the rest of the application source code +COPY --chown=user_lerobot:user_lerobot . . + +# Set the default command +CMD ["/bin/bash"] diff --git a/docker/Dockerfile.user b/docker/Dockerfile.user new file mode 100644 index 000000000..ce63f5530 --- /dev/null +++ b/docker/Dockerfile.user @@ -0,0 +1,50 @@ +# Dockerfile.user +# This Dockerfile is designed for a lerobot user who wants to +# experiment with the project. It starts from an Python Slim base image. + +# docker build -f docker/Dockerfile.user -t lerobot-user . +# docker run -it --rm lerobot-user + +# Configure the base image +ARG PYTHON_VERSION=3.10 +FROM python:${PYTHON_VERSION}-slim + +# Configure environment variables +ENV DEBIAN_FRONTEND=noninteractive \ + MUJOCO_GL="egl" \ + PATH="/lerobot/.venv/bin:$PATH" + +# Install system dependencies and uv (as root) +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential git curl \ + libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \ + libusb-1.0-0-dev \ + speech-dispatcher libgeos-dev \ + && curl -LsSf https://astral.sh/uv/install.sh | sh \ + && mv /root/.local/bin/uv /usr/local/bin/uv \ + && useradd --create-home --shell /bin/bash user_lerobot \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +# Create application directory and set permissions +WORKDIR /lerobot +RUN chown -R user_lerobot:user_lerobot /lerobot + +# Switch to the non-root user +USER user_lerobot + +# Create the virtual environment +# We use a virtual environment inside the container—even though the container itself \ +# provides isolation—to closely resemble local development and allow users to \ +# run other Python projects in the same container without dependency conflicts. +RUN uv venv + +# Install Python dependencies for caching +COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md ./ +COPY --chown=user_lerobot:user_lerobot src/ src/ +RUN uv pip install --no-cache ".[all]" + +# Copy the rest of the application code +COPY --chown=user_lerobot:user_lerobot . . + +# Set the default command +CMD ["/bin/bash"] diff --git a/pyproject.toml b/pyproject.toml index e9539037b..e0d754f53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,11 +110,11 @@ intelrealsense = [ "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'", "pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'", ] -stretch = [ - "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'", - "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'", - "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'" -] # TODO: Currently not supported +# stretch = [ +# "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'", +# "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'", +# "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'" +# ] # TODO: Currently not supported # Policies pi0 = ["lerobot[transformers-dep]"] @@ -135,6 +135,27 @@ aloha = ["gym-aloha>=0.1.1"] pusht = ["gym-pusht>=0.1.5", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead xarm = ["gym-xarm>=0.1.1"] +# All +all = [ + "lerobot[dynamixel]", + "lerobot[gamepad]", + "lerobot[hopejr]", + "lerobot[lekiwi]", + "lerobot[kinematics]", + "lerobot[intelrealsense]", + "lerobot[pi0]", + "lerobot[smolvla]", + "lerobot[hilserl]", + "lerobot[async]", + "lerobot[docs]", + "lerobot[dev]", + "lerobot[test]", + "lerobot[video_benchmark]", + "lerobot[aloha]", + "lerobot[pusht]", + "lerobot[xarm]" +] + # ---------------- Tool Configurations ---------------- [tool.setuptools.packages.find] where = ["src"] From 862a4439ea4ce4671b84b713f28e705d0b6e172b Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 17 Jul 2025 18:09:16 +0200 Subject: [PATCH 021/158] chore(examples): remove outdated examples (#1526) --- examples/advanced/1_add_image_transforms.py | 67 ----------- .../advanced/2_calculate_validation_loss.py | 104 ------------------ 2 files changed, 171 deletions(-) delete mode 100644 examples/advanced/1_add_image_transforms.py delete mode 100644 examples/advanced/2_calculate_validation_loss.py diff --git a/examples/advanced/1_add_image_transforms.py b/examples/advanced/1_add_image_transforms.py deleted file mode 100644 index 3760feabb..000000000 --- a/examples/advanced/1_add_image_transforms.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data -augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and -transforms are applied to the observation images before they are returned in the dataset's __getitem__. -""" - -from pathlib import Path - -from torchvision.transforms import ToPILImage, v2 - -from lerobot.datasets.lerobot_dataset import LeRobotDataset - -dataset_repo_id = "lerobot/aloha_static_screw_driver" - -# Create a LeRobotDataset with no transformations -dataset = LeRobotDataset(dataset_repo_id, episodes=[0]) -# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)` - -# Get the index of the first observation in the first episode -first_idx = dataset.episode_data_index["from"][0].item() - -# Get the frame corresponding to the first camera -frame = dataset[first_idx][dataset.meta.camera_keys[0]] - - -# Define the transformations -transforms = v2.Compose( - [ - v2.ColorJitter(brightness=(0.5, 1.5)), - v2.ColorJitter(contrast=(0.5, 1.5)), - v2.ColorJitter(hue=(-0.1, 0.1)), - v2.RandomAdjustSharpness(sharpness_factor=2, p=1), - ] -) - -# Create another LeRobotDataset with the defined transformations -transformed_dataset = LeRobotDataset(dataset_repo_id, episodes=[0], image_transforms=transforms) - -# Get a frame from the transformed dataset -transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]] - -# Create a directory to store output images -output_dir = Path("outputs/image_transforms") -output_dir.mkdir(parents=True, exist_ok=True) - -# Save the original frame -to_pil = ToPILImage() -to_pil(frame).save(output_dir / "original_frame.png", quality=100) -print(f"Original frame saved to {output_dir / 'original_frame.png'}.") - -# Save the transformed frame -to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100) -print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.") diff --git a/examples/advanced/2_calculate_validation_loss.py b/examples/advanced/2_calculate_validation_loss.py deleted file mode 100644 index 9eeb1a2d9..000000000 --- a/examples/advanced/2_calculate_validation_loss.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This script demonstrates how to slice a dataset and calculate the loss on a subset of the data. - -This technique can be useful for debugging and testing purposes, as well as identifying whether a policy -is learning effectively. - -Furthermore, relying on validation loss to evaluate performance is generally not considered a good practice, -especially in the context of imitation learning. The most reliable approach is to evaluate the policy directly -on the target environment, whether that be in simulation or the real world. -""" - -import math - -import torch - -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata -from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy - - -def main(): - device = torch.device("cuda") - - # Download the diffusion policy for pusht environment - pretrained_policy_path = "lerobot/diffusion_pusht" - # OR uncomment the following to evaluate a policy from the local outputs/train folder. - # pretrained_policy_path = Path("outputs/train/example_pusht_diffusion") - - policy = DiffusionPolicy.from_pretrained(pretrained_policy_path) - policy.eval() - policy.to(device) - - # Set up the dataset. - delta_timestamps = { - # Load the previous image and state at -0.1 seconds before current frame, - # then load current image and state corresponding to 0.0 second. - "observation.image": [-0.1, 0.0], - "observation.state": [-0.1, 0.0], - # Load the previous action (-0.1), the next action to be executed (0.0), - # and 14 future actions with a 0.1 seconds spacing. All these actions will be - # used to calculate the loss. - "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], - } - - # Load the last 10% of episodes of the dataset as a validation set. - # - Load dataset metadata - dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht") - # - Calculate train and val episodes - total_episodes = dataset_metadata.total_episodes - episodes = list(range(dataset_metadata.total_episodes)) - num_train_episodes = math.floor(total_episodes * 90 / 100) - train_episodes = episodes[:num_train_episodes] - val_episodes = episodes[num_train_episodes:] - print(f"Number of episodes in full dataset: {total_episodes}") - print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}") - print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}") - # - Load train and val datasets - train_dataset = LeRobotDataset( - "lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps - ) - val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps) - print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}") - print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}") - - # Create dataloader for evaluation. - val_dataloader = torch.utils.data.DataLoader( - val_dataset, - num_workers=4, - batch_size=64, - shuffle=False, - pin_memory=device != torch.device("cpu"), - drop_last=False, - ) - - # Run validation loop. - loss_cumsum = 0 - n_examples_evaluated = 0 - for batch in val_dataloader: - batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} - loss, _ = policy.forward(batch) - - loss_cumsum += loss.item() - n_examples_evaluated += batch["index"].shape[0] - - # Calculate the average loss over the validation set. - average_loss = loss_cumsum / n_examples_evaluated - - print(f"Average loss on validation set: {average_loss:.4f}") - - -if __name__ == "__main__": - main() From e6e1f085d4f66f34e9e6bd7c7f00893e9e425f9e Mon Sep 17 00:00:00 2001 From: Xingdong Zuo Date: Fri, 18 Jul 2025 19:18:52 +0900 Subject: [PATCH 022/158] Feat: Add Batched Video Encoding for Faster Dataset Recording (#1390) * LeRobotDataset video encoding: updated `save_episode` method and added `batch_encode_videos` method to handle video encoding based on `batch_encoding_size`, allowing for both immediate and batched encoding. * LeRobotDataset video cleanup: Enabled individual episode cleanup and check for remaining PNG files before removing the `images` directory. * LeRobotDataset - VideoEncodingManager: added proper handling of pending episodes (encoding, cleaning) on exit or recording failures. * LeRobotDatasetMetadata: removed `update_video_info` to only update video info at episode index 0 encoding. * Adjusted the `record` function to utilize the new encoding management logic. * Removed `encode_videos` method from `LeRobotDataset` and `encode_episode_videos` outputs as they are nowhere used. --------- Signed-off-by: Xingdong Zuo Co-authored-by: Xingdong Zuo Co-authored-by: Caroline Pascal --- src/lerobot/datasets/lerobot_dataset.py | 98 ++++++++++++++++++------- src/lerobot/datasets/video_utils.py | 64 ++++++++++++++++ src/lerobot/record.py | 67 +++++++++-------- 3 files changed, 172 insertions(+), 57 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 46feed2bf..72d1a722c 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -260,8 +260,6 @@ class LeRobotDatasetMetadata: self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} self.info["total_videos"] += len(self.video_keys) - if len(self.video_keys) > 0: - self.update_video_info() write_info(self.info, self.root) @@ -342,6 +340,7 @@ class LeRobotDataset(torch.utils.data.Dataset): force_cache_sync: bool = False, download_videos: bool = True, video_backend: str | None = None, + batch_encoding_size: int = 1, ): """ 2 modes are available for instantiating this class, depending on 2 different use cases: @@ -443,6 +442,8 @@ class LeRobotDataset(torch.utils.data.Dataset): True. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. + batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos. + Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1. """ super().__init__() self.repo_id = repo_id @@ -454,6 +455,8 @@ class LeRobotDataset(torch.utils.data.Dataset): self.revision = revision if revision else CODEBASE_VERSION self.video_backend = video_backend if video_backend else get_safe_default_codec() self.delta_indices = None + self.batch_encoding_size = batch_encoding_size + self.episodes_since_last_encoding = 0 # Unused attributes self.image_writer = None @@ -811,6 +814,10 @@ class LeRobotDataset(torch.utils.data.Dataset): """ This will save to disk the current episode in self.episode_buffer. + Video encoding is handled automatically based on batch_encoding_size: + - If batch_encoding_size == 1: Videos are encoded immediately after each episode + - If batch_encoding_size > 1: Videos are encoded in batches. + Args: episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to @@ -850,14 +857,28 @@ class LeRobotDataset(torch.utils.data.Dataset): self._save_episode_table(episode_buffer, episode_index) ep_stats = compute_episode_stats(episode_buffer, self.features) - if len(self.meta.video_keys) > 0: - video_paths = self.encode_episode_videos(episode_index) - for key in self.meta.video_keys: - episode_buffer[key] = video_paths[key] + has_video_keys = len(self.meta.video_keys) > 0 + use_batched_encoding = self.batch_encoding_size > 1 - # `meta.save_episode` be executed after encoding the videos + if has_video_keys and not use_batched_encoding: + self.encode_episode_videos(episode_index) + + # `meta.save_episode` should be executed after encoding the videos self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) + # Check if we should trigger batch encoding + if has_video_keys and use_batched_encoding: + self.episodes_since_last_encoding += 1 + if self.episodes_since_last_encoding == self.batch_encoding_size: + start_ep = self.num_episodes - self.batch_encoding_size + end_ep = self.num_episodes + logging.info( + f"Batch encoding {self.batch_encoding_size} videos for episodes {start_ep} to {end_ep - 1}" + ) + self.batch_encode_videos(start_ep, end_ep) + self.episodes_since_last_encoding = 0 + + # Episode data index and timestamp checking ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index]) ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()} check_timestamps_sync( @@ -868,16 +889,13 @@ class LeRobotDataset(torch.utils.data.Dataset): self.tolerance_s, ) - video_files = list(self.root.rglob("*.mp4")) - assert len(video_files) == self.num_episodes * len(self.meta.video_keys) - + # Verify that we have one parquet file per episode and the number of video files matches the number of encoded episodes parquet_files = list(self.root.rglob("*.parquet")) assert len(parquet_files) == self.num_episodes - - # delete images - img_dir = self.root / "images" - if img_dir.is_dir(): - shutil.rmtree(self.root / "images") + video_files = list(self.root.rglob("*.mp4")) + assert len(video_files) == (self.num_episodes - self.episodes_since_last_encoding) * len( + self.meta.video_keys + ) if not episode_data: # Reset the buffer self.episode_buffer = self.create_episode_buffer() @@ -894,6 +912,8 @@ class LeRobotDataset(torch.utils.data.Dataset): def clear_episode_buffer(self) -> None: episode_index = self.episode_buffer["episode_index"] + + # Clean up image files for the current episode buffer if self.image_writer is not None: for cam_key in self.meta.camera_keys: img_dir = self._get_image_file_path( @@ -930,25 +950,22 @@ class LeRobotDataset(torch.utils.data.Dataset): if self.image_writer is not None: self.image_writer.wait_until_done() - def encode_videos(self) -> None: + def encode_episode_videos(self, episode_index: int) -> None: """ Use ffmpeg to convert frames stored as png into mp4 videos. Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, since video encoding with ffmpeg is already using multithreading. - """ - for ep_idx in range(self.meta.total_episodes): - self.encode_episode_videos(ep_idx) - def encode_episode_videos(self, episode_index: int) -> dict: + This method handles video encoding steps: + - Video encoding via ffmpeg + - Video info updating in metadata + - Raw image cleanup + + Args: + episode_index (int): Index of the episode to encode. """ - Use ffmpeg to convert frames stored as png into mp4 videos. - Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, - since video encoding with ffmpeg is already using multithreading. - """ - video_paths = {} for key in self.meta.video_keys: video_path = self.root / self.meta.get_video_file_path(episode_index, key) - video_paths[key] = str(video_path) if video_path.is_file(): # Skip if video is already encoded. Could be the case when resuming data recording. continue @@ -956,8 +973,32 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_index=episode_index, image_key=key, frame_index=0 ).parent encode_video_frames(img_dir, video_path, self.fps, overwrite=True) + shutil.rmtree(img_dir) - return video_paths + # Update video info (only needed when first episode is encoded since it reads from episode 0) + if len(self.meta.video_keys) > 0 and episode_index == 0: + self.meta.update_video_info() + write_info(self.meta.info, self.meta.root) # ensure video info always written properly + + def batch_encode_videos(self, start_episode: int = 0, end_episode: int | None = None) -> None: + """ + Batch encode videos for multiple episodes. + + Args: + start_episode: Starting episode index (inclusive) + end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode + """ + if end_episode is None: + end_episode = self.meta.total_episodes + + logging.info(f"Starting batch video encoding for episodes {start_episode} to {end_episode - 1}") + + # Encode all episodes with cleanup enabled for individual episodes + for ep_idx in range(start_episode, end_episode): + logging.info(f"Encoding videos for episode {ep_idx}") + self.encode_episode_videos(ep_idx) + + logging.info("Batch video encoding completed") @classmethod def create( @@ -972,6 +1013,7 @@ class LeRobotDataset(torch.utils.data.Dataset): image_writer_processes: int = 0, image_writer_threads: int = 0, video_backend: str | None = None, + batch_encoding_size: int = 1, ) -> "LeRobotDataset": """Create a LeRobot Dataset from scratch in order to record data.""" obj = cls.__new__(cls) @@ -988,6 +1030,8 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.revision = None obj.tolerance_s = tolerance_s obj.image_writer = None + obj.batch_encoding_size = batch_encoding_size + obj.episodes_since_last_encoding = 0 if image_writer_processes or image_writer_threads: obj.start_image_writer(image_writer_processes, image_writer_threads) diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 3a77f36e4..b05edf6bd 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -16,6 +16,7 @@ import glob import importlib import logging +import shutil import warnings from dataclasses import dataclass, field from pathlib import Path @@ -451,3 +452,66 @@ def get_image_pixel_channels(image: Image): return 4 # RGBA else: raise ValueError("Unknown format") + + +class VideoEncodingManager: + """ + Context manager that ensures proper video encoding and data cleanup even if exceptions occur. + + This manager handles: + - Batch encoding for any remaining episodes when recording interrupted + - Cleaning up temporary image files from interrupted episodes + - Removing empty image directories + + Args: + dataset: The LeRobotDataset instance + """ + + def __init__(self, dataset): + self.dataset = dataset + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Handle any remaining episodes that haven't been batch encoded + if self.dataset.episodes_since_last_encoding > 0: + if exc_type is not None: + logging.info("Exception occurred. Encoding remaining episodes before exit...") + else: + logging.info("Recording stopped. Encoding remaining episodes...") + + start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding + end_ep = self.dataset.num_episodes + logging.info( + f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, " + f"from episode {start_ep} to {end_ep - 1}" + ) + self.dataset.batch_encode_videos(start_ep, end_ep) + + # Clean up episode images if recording was interrupted + if exc_type is not None: + interrupted_episode_index = self.dataset.num_episodes + for key in self.dataset.meta.video_keys: + img_dir = self.dataset._get_image_file_path( + episode_index=interrupted_episode_index, image_key=key, frame_index=0 + ).parent + if img_dir.exists(): + logging.debug( + f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}" + ) + shutil.rmtree(img_dir) + + # Clean up any remaining images directory if it's empty + img_dir = self.dataset.root / "images" + # Check for any remaining PNG files + png_files = list(img_dir.rglob("*.png")) + if len(png_files) == 0: + # Only remove the images directory if no PNG files remain + if img_dir.exists(): + shutil.rmtree(img_dir) + logging.debug("Cleaned up empty images directory") + else: + logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files") + + return False # Don't suppress the original exception diff --git a/src/lerobot/record.py b/src/lerobot/record.py index 0b1af192e..d662efcab 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -73,6 +73,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.datasets.image_writer import safe_stop_image_writer from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features +from lerobot.datasets.video_utils import VideoEncodingManager from lerobot.policies.factory import make_policy from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.robots import ( # noqa: F401 @@ -145,6 +146,9 @@ class DatasetRecordConfig: # Too many threads might cause unstable teleoperation fps due to main thread being blocked. # Not enough threads might cause low camera fps. num_image_writer_threads_per_camera: int = 4 + # Number of episodes to record before batch encoding videos + # Set to 1 for immediate encoding (default behavior), or higher for batched encoding + video_encoding_batch_size: int = 1 def __post_init__(self): if self.single_task is None: @@ -298,6 +302,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: dataset = LeRobotDataset( cfg.dataset.repo_id, root=cfg.dataset.root, + batch_encoding_size=cfg.dataset.video_encoding_batch_size, ) if hasattr(robot, "cameras") and len(robot.cameras) > 0: @@ -318,6 +323,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: use_videos=cfg.dataset.video, image_writer_processes=cfg.dataset.num_image_writer_processes, image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), + batch_encoding_size=cfg.dataset.video_encoding_batch_size, ) # Load pretrained policy @@ -329,46 +335,47 @@ def record(cfg: RecordConfig) -> LeRobotDataset: listener, events = init_keyboard_listener() - recorded_episodes = 0 - while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]: - log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) - record_loop( - robot=robot, - events=events, - fps=cfg.dataset.fps, - teleop=teleop, - policy=policy, - dataset=dataset, - control_time_s=cfg.dataset.episode_time_s, - single_task=cfg.dataset.single_task, - display_data=cfg.display_data, - ) - - # Execute a few seconds without recording to give time to manually reset the environment - # Skip reset for the last episode to be recorded - if not events["stop_recording"] and ( - (recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"] - ): - log_say("Reset the environment", cfg.play_sounds) + with VideoEncodingManager(dataset): + recorded_episodes = 0 + while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]: + log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) record_loop( robot=robot, events=events, fps=cfg.dataset.fps, teleop=teleop, - control_time_s=cfg.dataset.reset_time_s, + policy=policy, + dataset=dataset, + control_time_s=cfg.dataset.episode_time_s, single_task=cfg.dataset.single_task, display_data=cfg.display_data, ) - if events["rerecord_episode"]: - log_say("Re-record episode", cfg.play_sounds) - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Execute a few seconds without recording to give time to manually reset the environment + # Skip reset for the last episode to be recorded + if not events["stop_recording"] and ( + (recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment", cfg.play_sounds) + record_loop( + robot=robot, + events=events, + fps=cfg.dataset.fps, + teleop=teleop, + control_time_s=cfg.dataset.reset_time_s, + single_task=cfg.dataset.single_task, + display_data=cfg.display_data, + ) - dataset.save_episode() - recorded_episodes += 1 + if events["rerecord_episode"]: + log_say("Re-record episode", cfg.play_sounds) + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + dataset.save_episode() + recorded_episodes += 1 log_say("Stop recording", cfg.play_sounds, blocking=True) From 89f59b070368f789a8b3aa558d9f0c88796cc324 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sat, 19 Jul 2025 20:09:12 +0200 Subject: [PATCH 023/158] refactor(ci): workflows improvements (#1535) * refactor(ci): consolidate documentation workflows * refactor(ci): improve quality workflow * refactor(ci): edit security workflow * refactor(ci): improve testing workflows * fix(ci): several fixes * chore(ci): renaming + permissions * chore(ci): remove now unused dockerfiles * chore(docs): add license headers to dockerfiles * chore(ci): add cache-binary false to setup-buildx actions * fix(ci): several fixes * dgb(ci): explicit env in the workflow * fix(ci): more explicit env vars for writing * fix(ci): nightly gpu tag --- .github/workflows/build-docker-images.yml | 135 ---------- .github/workflows/build_documentation.yml | 23 -- .github/workflows/build_pr_documentation.yml | 19 -- .github/workflows/documentation-upload-pr.yml | 40 +++ .github/workflows/documentation.yml | 70 +++++ .github/workflows/nightly-tests.yml | 93 ------- .github/workflows/nightly.yml | 160 ++++++++++++ .github/workflows/quality.yml | 64 ++--- .github/workflows/security.yml | 54 ++++ .github/workflows/test-docker-build.yml | 84 ------ .github/workflows/test.yml | 150 ----------- .github/workflows/tests.yml | 243 ++++++++++++++++++ .github/workflows/trufflehog.yml | 35 --- .github/workflows/upload_pr_documentation.yml | 16 -- docker/Dockerfile.internal | 44 +++- docker/Dockerfile.user | 36 ++- docker/lerobot-cpu/Dockerfile | 29 --- docker/lerobot-gpu-dev/Dockerfile | 68 ----- docker/lerobot-gpu/Dockerfile | 24 -- 19 files changed, 654 insertions(+), 733 deletions(-) delete mode 100644 .github/workflows/build-docker-images.yml delete mode 100644 .github/workflows/build_documentation.yml delete mode 100644 .github/workflows/build_pr_documentation.yml create mode 100644 .github/workflows/documentation-upload-pr.yml create mode 100644 .github/workflows/documentation.yml delete mode 100644 .github/workflows/nightly-tests.yml create mode 100644 .github/workflows/nightly.yml create mode 100644 .github/workflows/security.yml delete mode 100644 .github/workflows/test-docker-build.yml delete mode 100644 .github/workflows/test.yml create mode 100644 .github/workflows/tests.yml delete mode 100644 .github/workflows/trufflehog.yml delete mode 100644 .github/workflows/upload_pr_documentation.yml delete mode 100644 docker/lerobot-cpu/Dockerfile delete mode 100644 docker/lerobot-gpu-dev/Dockerfile delete mode 100644 docker/lerobot-gpu/Dockerfile diff --git a/.github/workflows/build-docker-images.yml b/.github/workflows/build-docker-images.yml deleted file mode 100644 index 20974b85a..000000000 --- a/.github/workflows/build-docker-images.yml +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Inspired by -# https://github.com/huggingface/peft/blob/main/.github/workflows/build_docker_images.yml -name: Builds - -on: - workflow_dispatch: - workflow_call: - schedule: - - cron: "0 1 * * *" - -permissions: {} - -env: - PYTHON_VERSION: "3.10" - -jobs: - latest-cpu: - name: CPU - runs-on: - group: aws-general-8-plus - steps: - - name: Install Git LFS - run: | - sudo apt-get update - sudo apt-get install git-lfs - git lfs install - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0 - with: - cache-binary: false - - - name: Check out code - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - lfs: true - persist-credentials: false - - - name: Login to DockerHub - uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_PASSWORD }} - - - name: Build and Push CPU - uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0 - with: - context: . - file: ./docker/lerobot-cpu/Dockerfile - push: true - tags: huggingface/lerobot-cpu - build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }} - - - latest-cuda: - name: GPU - runs-on: - group: aws-general-8-plus - steps: - - name: Install Git LFS - run: | - sudo apt-get update - sudo apt-get install git-lfs - git lfs install - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0 - with: - cache-binary: false - - - name: Check out code - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - lfs: true - persist-credentials: false - - - name: Login to DockerHub - uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_PASSWORD }} - - - name: Build and Push GPU - uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0 - with: - context: . - file: ./docker/lerobot-gpu/Dockerfile - push: true - tags: huggingface/lerobot-gpu - build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }} - - - latest-cuda-dev: - name: GPU Dev - runs-on: - group: aws-general-8-plus - steps: - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0 - with: - cache-binary: false - - - name: Check out code - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - persist-credentials: false - - - name: Login to DockerHub - uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_PASSWORD }} - - - name: Build and Push GPU dev - uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0 - with: - context: . - file: ./docker/lerobot-gpu-dev/Dockerfile - push: true - tags: huggingface/lerobot-gpu:dev - build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }} diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml deleted file mode 100644 index 884e2e4b5..000000000 --- a/.github/workflows/build_documentation.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Build documentation - -on: - workflow_dispatch: - push: - paths: - - "docs/**" - branches: - - main - - doc-builder* - - v*-release - - -jobs: - build: # zizmor: ignore[excessive-permissions] We follow the same pattern as in Transformers - uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main - with: - commit_sha: ${{ github.sha }} - package: lerobot - additional_args: --not_python_module - secrets: - token: ${{ secrets.HUGGINGFACE_PUSH }} - hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml deleted file mode 100644 index 51bab10d5..000000000 --- a/.github/workflows/build_pr_documentation.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Build PR Documentation - -on: - pull_request: - paths: - - "docs/**" - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - build: # zizmor: ignore[excessive-permissions] We follow the same pattern as in Transformers - uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main - with: - commit_sha: ${{ github.event.pull_request.head.sha }} - pr_number: ${{ github.event.number }} - package: lerobot - additional_args: --not_python_module diff --git a/.github/workflows/documentation-upload-pr.yml b/.github/workflows/documentation-upload-pr.yml new file mode 100644 index 000000000..22ba11cbb --- /dev/null +++ b/.github/workflows/documentation-upload-pr.yml @@ -0,0 +1,40 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This workflow uploads the documentation preview built for a PR and comments the link on the PR. +name: Documentation PR Upload +permissions: + contents: read + pull-requests: write + +on: + # Triggered by the completion of the main 'Documentation' workflow. + workflow_run: # zizmor: ignore[dangerous-triggers] We follow the same pattern as in Transformers + workflows: ["Documentation"] + types: + - completed + +jobs: + # This job uploads a preview of the documentation for a pull request. + upload_and_comment: + name: Upload Preview and Comment + if: > + github.event.workflow_run.event == 'pull_request' && + github.event.workflow_run.conclusion == 'success' + uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main + with: + package_name: lerobot + secrets: + hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} + comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml new file mode 100644 index 000000000..96005af3f --- /dev/null +++ b/.github/workflows/documentation.yml @@ -0,0 +1,70 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This workflow handles building documentation for both main branches and PRs. +name: Documentation + +on: + # Allows running this workflow manually from the Actions tab + workflow_dispatch: + + # Triggers the workflow on push events to main for the docs folder + push: + branches: + - main + paths: + - "docs/**" + + # Triggers the workflow on pull request events targeting main for the docs folder + pull_request: + branches: + - main + paths: + - "docs/**" + +# Ensures that only the latest commit for a PR or branch is built, canceling older runs. +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + # This job builds and deploys the official documentation. + build_main_docs: + name: Build Main Docs + if: github.event_name == 'push' || github.event_name == 'workflow_dispatch' + permissions: + contents: read + uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main + with: + commit_sha: ${{ github.sha }} + package: lerobot + additional_args: --not_python_module + secrets: + token: ${{ secrets.HUGGINGFACE_PUSH }} + hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} + + # This job builds a preview of the documentation for a pull request. + # The result of this job triggers the 'Upload PR Documentation' workflow. + build_pr_docs: + name: Build PR Docs + if: github.event_name == 'pull_request' + permissions: + contents: read + pull-requests: write + uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main + with: + commit_sha: ${{ github.event.pull_request.head.sha }} + pr_number: ${{ github.event.number }} + package: lerobot + additional_args: --not_python_module diff --git a/.github/workflows/nightly-tests.yml b/.github/workflows/nightly-tests.yml deleted file mode 100644 index 728016915..000000000 --- a/.github/workflows/nightly-tests.yml +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Inspired by -# https://github.com/huggingface/peft/blob/main/.github/workflows/nightly.yml -name: Nightly - -on: - workflow_dispatch: - schedule: - - cron: "0 2 * * *" - -permissions: {} - -# env: - # SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }} -jobs: - run_all_tests_cpu: - name: CPU - strategy: - fail-fast: false - runs-on: - group: aws-general-8-plus - container: - image: huggingface/lerobot-cpu:latest # zizmor: ignore[unpinned-images] - options: --shm-size "16gb" - credentials: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_PASSWORD }} - defaults: - run: - shell: bash - working-directory: /lerobot - steps: - - name: Tests - run: pytest -v --cov=./src/lerobot --disable-warnings tests - - - name: Tests end-to-end - run: make test-end-to-end - - - run_all_tests_single_gpu: - name: GPU - strategy: - fail-fast: false - runs-on: - group: aws-g6-4xlarge-plus - env: - CUDA_VISIBLE_DEVICES: "0" - TEST_TYPE: "single_gpu" - container: - image: huggingface/lerobot-gpu:latest # zizmor: ignore[unpinned-images] - options: --gpus all --shm-size "16gb" - credentials: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_PASSWORD }} - defaults: - run: - shell: bash - working-directory: /lerobot - steps: - - name: Nvidia-smi - run: nvidia-smi - - - name: Test - run: pytest -v --cov=./src/lerobot --cov-report=xml --disable-warnings tests - # TODO(aliberts): Link with HF Codecov account - # - name: Upload coverage reports to Codecov with GitHub Action - # uses: codecov/codecov-action@v4 - # with: - # files: ./coverage.xml - # verbose: true - - name: Tests end-to-end - env: - DEVICE: cuda - run: make test-end-to-end - - # - name: Generate Report - # if: always() - # run: | - # pip install slack_sdk tabulate - # python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml new file mode 100644 index 000000000..149f55c89 --- /dev/null +++ b/.github/workflows/nightly.yml @@ -0,0 +1,160 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This workflow handles nightly testing & docker images publishing. +name: Nightly +permissions: + contents: read + +on: + # Allows running this workflow manually from the Actions tab + workflow_dispatch: + + # Runs at 02:00 + schedule: + - cron: "0 2 * * *" + +# Sets up the environment variables +env: + UV_VERSION: "0.8.0" + PYTHON_VERSION: "3.10" + DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-gpu:latest + DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-cpu:latest + +# Ensures that only the latest commit is built, canceling older runs. +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build-docker-cpu-nightly: + # This job builds a CPU image for testing & distribution + name: Build CPU Docker for Nightly + runs-on: + group: aws-general-8-plus + outputs: + image_tag: ${{ env.DOCKER_IMAGE_NAME_CPU }} + steps: + - name: Install Git LFS + run: | + sudo apt-get update + sudo apt-get install git-lfs + git lfs install + - uses: actions/checkout@v4 + with: + lfs: true + persist-credentials: false + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses] + with: + cache-binary: false + - name: Login to Docker Hub + uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses] + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_PASSWORD }} + - name: Build and push Docker image CPU + uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses] + with: + context: . + file: ./docker/Dockerfile.user + push: true + tags: ${{ env.DOCKER_IMAGE_NAME_CPU }} + + build-docker-gpu-nightly: + # This job builds a GPU image for testing & distribution + name: Build GPU Docker for Nightly + runs-on: + group: aws-general-8-plus + outputs: + image_tag: ${{ env.DOCKER_IMAGE_NAME_GPU }} + steps: + - name: Install Git LFS + run: | + sudo apt-get update + sudo apt-get install git-lfs + git lfs install + - uses: actions/checkout@v4 + with: + lfs: true + persist-credentials: false + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses] + with: + cache-binary: false + - name: Login to Docker Hub + uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses] + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_PASSWORD }} + - name: Build and push Docker image GPU + uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses] + with: + context: . + file: ./docker/Dockerfile.internal + push: true + tags: ${{ env.DOCKER_IMAGE_NAME_GPU }} + + nightly-cpu-tests: + # This job runs the E2E tests + pytest with all extras in the CPU image + name: Nightly CPU Tests + needs: [build-docker-cpu-nightly] + runs-on: + group: aws-g6-4xlarge-plus + env: + HF_HOME: /home/user_lerobot/.cache/huggingface + HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot + TORCH_HOME: /home/user_lerobot/.cache/torch + TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton + container: + image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images] + credentials: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_PASSWORD }} + defaults: + run: + shell: bash + working-directory: /lerobot + steps: + - name: Run pytest on CPU + run: pytest tests -vv --maxfail=10 + - name: Run end-to-end tests + run: make test-end-to-end + + nightly-gpu-tests: + # This job runs the E2E tests + pytest with all extras in the GPU image + name: Nightly GPU Tests + needs: [build-docker-gpu-nightly] + runs-on: + group: aws-g6-4xlarge-plus + env: + HF_HOME: /home/user_lerobot/.cache/huggingface + HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot + TORCH_HOME: /home/user_lerobot/.cache/torch + TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton + container: + image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images] + options: --gpus all --shm-size "16gb" + credentials: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_PASSWORD }} + defaults: + run: + shell: bash + working-directory: /lerobot + steps: + - name: Run pytest on GPU + run: pytest tests -vv --maxfail=10 + - name: Run end-to-end tests + run: make test-end-to-end diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index 1c048c4fe..e9f73ed23 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,61 +12,47 @@ # See the License for the specific language governing permissions and # limitations under the License. +# This workflow handles linting, formatting, and static analysis checks for the codebase. name: Quality +permissions: + contents: read on: + # Allows running this workflow manually from the Actions tab workflow_dispatch: - workflow_call: - pull_request: + + # Triggers the workflow on push events to main push: branches: - main -permissions: {} + # Triggers the workflow on pull request events targeting main + pull_request: + branches: + - main -env: - PYTHON_VERSION: "3.10" +# Ensures that only the latest commit for a PR or branch is built, canceling older runs. +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true jobs: - style: - name: Style + # This job runs pre-commit hooks to check code style and formatting. + pre-commit-checks: + name: Run Pre-commit Hooks (Lint, Format & Static Analysis) runs-on: ubuntu-latest steps: - - name: Checkout Repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Checkout code + uses: actions/checkout@v4 with: persist-credentials: false - name: Set up Python - uses: actions/setup-python@7f4fc3e22c37d6ff65e88745f38bd3157c663f7c # v4.9.1 + uses: actions/setup-python@v5 with: - python-version: ${{ env.PYTHON_VERSION }} + python-version: '3.10' - - name: Get Ruff Version from pre-commit-config.yaml - id: get-ruff-version - run: | - RUFF_VERSION=$(awk '/repo: https:\/\/github.com\/astral-sh\/ruff-pre-commit/{flag=1;next}/rev:/{if(flag){print $2;exit}}' .pre-commit-config.yaml) - echo "ruff_version=${RUFF_VERSION}" >> $GITHUB_OUTPUT - - - name: Install Ruff - env: - RUFF_VERSION: ${{ steps.get-ruff-version.outputs.ruff_version }} - run: python -m pip install "ruff==${RUFF_VERSION}" - - - name: Ruff check - run: ruff check --output-format=github - - - name: Ruff format - run: ruff format --diff - - typos: - name: Typos - runs-on: ubuntu-latest - steps: - - name: Checkout Repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Run pre-commit hooks + uses: pre-commit/action@v3.0.1 # zizmor: ignore[unpinned-uses] with: - persist-credentials: false - - - name: typos-action - uses: crate-ci/typos@db35ee91e80fbb447f33b0e5fbddb24d2a1a884f # v1.29.10 + extra_args: --all-files --show-diff-on-failure --color=always diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml new file mode 100644 index 000000000..04497307b --- /dev/null +++ b/.github/workflows/security.yml @@ -0,0 +1,54 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This workflow handles secret scanning using TruffleHog to detect sensitive information in the codebase. +name: Security +permissions: + contents: read + +on: + # Allows running this workflow manually from the Actions tab + workflow_dispatch: + + # Triggers the workflow on push events to main + push: + branches: + - main + + # Triggers the workflow on pull request events targeting main + pull_request: + branches: + - main + +# Ensures that only the latest commit for a PR or branch is built, canceling older runs. +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + # This job runs TruffleHog to scan the full history of the repository for secrets. + trufflehog: + name: Secret Leaks Scan + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 # zizmor: ignore[unpinned-uses] + with: + fetch-depth: 0 + persist-credentials: false + + - name: Secret Scanning + uses: trufflesecurity/trufflehog@v3.90.0 # zizmor: ignore[unpinned-uses] + with: + extra_args: --only-verified diff --git a/.github/workflows/test-docker-build.yml b/.github/workflows/test-docker-build.yml deleted file mode 100644 index c33813418..000000000 --- a/.github/workflows/test-docker-build.yml +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Inspired by -# https://github.com/huggingface/peft/blob/main/.github/workflows/test-docker-build.yml -name: Test Dockerfiles - -on: - pull_request: - paths: - # Run only when DockerFile files are modified - - "docker/lerobot-cpu/**" - - "docker/lerobot-gpu/**" - - "docker/lerobot-gpu-dev/**" - -permissions: {} - -env: - PYTHON_VERSION: "3.10" - -jobs: - get_changed_files: - name: Detect modified Dockerfiles - runs-on: ubuntu-latest - outputs: - matrix: ${{ steps.set-matrix.outputs.matrix }} - steps: - - name: Check out code - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - persist-credentials: false - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@3f54ebb830831fc121d3263c1857cfbdc310cdb9 #v42 - with: - files: docker/** - json: "true" - - - name: Run step if only the files listed above change # zizmor: ignore[template-injection] - if: steps.changed-files.outputs.any_changed == 'true' - id: set-matrix - run: | - echo "matrix=${{ steps.changed-files.outputs.all_changed_files}}" >> $GITHUB_OUTPUT - - build_modified_dockerfiles: - name: Build modified Docker images - needs: get_changed_files - runs-on: - group: aws-general-8-plus - if: needs.get_changed_files.outputs.matrix != '' - strategy: - fail-fast: false - matrix: - docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }} - steps: - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0 - with: - cache-binary: false - - - name: Check out code - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - persist-credentials: false - - - name: Build Docker image - uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0 - with: - file: ${{ matrix.docker-file }} - context: . - push: False - build-args: PYTHON_VERSION=${{ env.PYTHON_VERSION }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml deleted file mode 100644 index d6ea1d404..000000000 --- a/.github/workflows/test.yml +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: Tests - -on: - pull_request: - paths: - - "src/**" - - "tests/**" - - "examples/**" - - ".github/**" - - "pyproject.toml" - - ".pre-commit-config.yaml" - - "Makefile" - - ".cache/**" - push: - branches: - - main - paths: - - "src/**" - - "tests/**" - - "examples/**" - - ".github/**" - - "pyproject.toml" - - ".pre-commit-config.yaml" - - "Makefile" - - ".cache/**" - -permissions: {} - -env: - UV_VERSION: "0.6.0" - -jobs: - pytest: - name: Pytest - runs-on: ubuntu-latest - env: - MUJOCO_GL: egl - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - lfs: true # Ensure LFS files are pulled - persist-credentials: false - - - name: Install apt dependencies - # portaudio19-dev is needed to install pyaudio - run: | - sudo apt-get update && \ - sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev - - - name: Install uv and python - uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2 - with: - enable-cache: true - version: ${{ env.UV_VERSION }} - python-version: "3.10" - - - name: Install lerobot (all extras) - run: uv sync --all-extras - - - name: Test with pytest - run: | - uv run pytest tests -v --cov=./src/lerobot --durations=0 \ - -W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \ - -W ignore::UserWarning:torch.utils.data.dataloader:558 \ - -W ignore::UserWarning:gymnasium.utils.env_checker:247 \ - && rm -rf tests/outputs outputs - - pytest-minimal: - name: Pytest (minimal install) - runs-on: ubuntu-latest - env: - MUJOCO_GL: egl - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - lfs: true # Ensure LFS files are pulled - persist-credentials: false - - - name: Install apt dependencies - run: sudo apt-get update && sudo apt-get install -y ffmpeg - - - name: Install uv and python - uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2 - with: - enable-cache: true - version: ${{ env.UV_VERSION }} - python-version: "3.10" - - - name: Install lerobot - run: uv sync --extra "test" - - - name: Test with pytest - run: | - uv run pytest tests -v --cov=./src/lerobot --durations=0 \ - -W ignore::DeprecationWarning:imageio_ffmpeg._utils:7 \ - -W ignore::UserWarning:torch.utils.data.dataloader:558 \ - -W ignore::UserWarning:gymnasium.utils.env_checker:247 \ - && rm -rf tests/outputs outputs - - end-to-end: - name: End-to-end - runs-on: ubuntu-latest - env: - MUJOCO_GL: egl - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - lfs: true # Ensure LFS files are pulled - persist-credentials: false - - - name: Install apt dependencies - # portaudio19-dev is needed to install pyaudio - run: | - sudo apt-get update && \ - sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev - - - name: Install uv and python - uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2 - with: - enable-cache: true - version: ${{ env.UV_VERSION }} - python-version: "3.10" - - - name: Install lerobot (all extras) - run: | - uv venv - uv sync --all-extras - - - name: venv - run: | - echo "PYTHON_PATH=${{ github.workspace }}/.venv/bin/python" >> $GITHUB_ENV - - - name: Test end-to-end - run: | - make test-end-to-end \ - && rm -rf outputs diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 000000000..443c849dc --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,243 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This workflow handles testing. +name: Tests + +on: + # Allows running this workflow manually from the Actions tab + workflow_dispatch: + + pull_request: + branches: + - main + paths: + - "src/**" + - "tests/**" + - ".github/workflows/**" + - "pyproject.toml" + - "Makefile" + pull_request_review: + types: [submitted] + push: + branches: + - main + paths: + - "src/**" + - "tests/**" + - ".github/workflows/**" + - "pyproject.toml" + - "Makefile" + +permissions: + contents: read + +# Sets up the environment variables +env: + UV_VERSION: "0.8.0" + PYTHON_VERSION: "3.10" + DOCKER_IMAGE_NAME: huggingface/lerobot-gpu + +# Ensures that only the latest commit for a PR or branch is built, canceling older runs. +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + # This job runs pytests with the default dependencies. + # It runs everytime we commit to a PR or push to main + fast-pytest-tests: + name: Fast Pytest Tests + if: | + github.event_name == 'pull_request' || + github.event_name == 'push' || + github.event_name == 'workflow_dispatch' + runs-on: ubuntu-latest + env: + MUJOCO_GL: egl + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + lfs: true + + # TODO(Steven): Evaluate the need of these dependencies + - name: Install apt dependencies + run: | + sudo apt-get update && sudo apt-get install -y build-essential git \ + curl libglib2.0-0 libegl1-mesa-dev ffmpeg \ + libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev + + - name: Setup uv and Python + uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses] + with: + enable-cache: true + version: ${{ env.UV_VERSION }} + python-version: ${{ env.PYTHON_VERSION }} + + - name: Install lerobot with test extras + run: uv sync --extra "test" + + - name: Run pytest + run: uv run pytest tests -vv --maxfail=10 + + full-tests-gate: + # This job evaluates the need to run the full tests suite. + name: Full Tests Gate + runs-on: ubuntu-latest + if: | + (github.event_name == 'pull_request_review' && github.event.review.state == 'approved') || + github.event_name == 'push' || + github.event_name == 'workflow_dispatch' + steps: + - name: Gate check + run: echo "Full tests will run." + + full-tests: + # This job runs the E2E tests + pytest with all extras + # It runs everytime a PR is approved or a push to main + name: Full Tests + needs: full-tests-gate + runs-on: ubuntu-latest + env: + MUJOCO_GL: egl + steps: + - uses: actions/checkout@v4 + with: + lfs: true + persist-credentials: false + + - name: Install apt dependencies + run: | + sudo apt-get update && sudo apt-get install -y build-essential \ + git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \ + speech-dispatcher libgeos-dev portaudio19-dev + + - name: Setup uv and Python + uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses] + with: + enable-cache: true + version: ${{ env.UV_VERSION }} + python-version: ${{ env.PYTHON_VERSION }} + + - name: Install lerobot with all extras + run: uv sync --all-extras + + - name: Run pytest (all extras) + run: uv run pytest tests -vv --maxfail=10 + + - name: Run end-to-end tests + run: uv run make test-end-to-end + + build-and-push-docker: + # This job builds a GPU enabled image for testing + # It runs everytime a PR is approved or a push to main + name: Build and Push Docker + needs: full-tests-gate + runs-on: + group: aws-general-8-plus + outputs: + image_tag: ${{ env.DOCKER_IMAGE_NAME }}:pr-${{ github.event.pull_request.number }} + steps: + - name: Install Git LFS + run: | + sudo apt-get update + sudo apt-get install git-lfs + git lfs install + - uses: actions/checkout@v4 + with: + lfs: true + persist-credentials: false + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses] + with: + cache-binary: false + - name: Login to Docker Hub + uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses] + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_PASSWORD }} + - name: Build and push Docker image + uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses] + with: + context: . + file: ./docker/Dockerfile.internal + push: true + tags: ${{ env.DOCKER_IMAGE_NAME }}:pr-${{ github.event.pull_request.number }} + + gpu-tests: + # This job runs pytest with all extras in a GPU enabled host + # It runs everytime a test image is created + name: GPU Tests + needs: [build-and-push-docker] + runs-on: + group: aws-g6-4xlarge-plus + env: + HF_HOME: /home/user_lerobot/.cache/huggingface + HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot + TORCH_HOME: /home/user_lerobot/.cache/torch + TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton + container: + image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images] + options: --gpus all --shm-size "16gb" + credentials: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_PASSWORD }} + defaults: + run: + shell: bash + working-directory: /lerobot + steps: + - name: Run pytest on GPU + run: pytest tests -vv --maxfail=10 + - name: Run end-to-end tests + run: make test-end-to-end + + delete-pr-image: + # This job deletes the test image recently created + # It runs everytime after the gpu-tests have finished + name: Delete PR Image + needs: [gpu-tests, build-and-push-docker] + if: always() && github.event.review.state == 'approved' && needs.build-and-push-docker.result == 'success' + runs-on: ubuntu-latest + steps: + - name: Get Docker Hub Token and Delete Image + # zizmor: ignore[template-injection] + run: | + IMAGE_NAME=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f1) + IMAGE_TAG=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f2) + + echo "Attempting to delete image: $IMAGE_NAME:$IMAGE_TAG" + + TOKEN=$(curl -s -H "Content-Type: application/json" \ + -X POST \ + -d '{"username": "${{ secrets.DOCKERHUB_USERNAME }}", "password": "${{ secrets.DOCKERHUB_PASSWORD }}"}' \ + https://hub.docker.com/v2/users/login/ | jq -r .token) + + if [ "$TOKEN" == "null" ] || [ -z "$TOKEN" ]; then + echo "::error::Failed to get Docker Hub token." + exit 1 + fi + + HTTP_RESPONSE=$(curl -s -o /dev/null -w "%{http_code}" \ + -H "Authorization: JWT ${TOKEN}" \ + -X DELETE \ + https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/${IMAGE_TAG}/) + + if [ "$HTTP_RESPONSE" -eq 204 ]; then + echo "Successfully deleted Docker image tag: $IMAGE_NAME:$IMAGE_TAG" + else + echo "::error::Failed to delete Docker image. HTTP status: $HTTP_RESPONSE" + exit 1 + fi diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml deleted file mode 100644 index 704a3baaa..000000000 --- a/.github/workflows/trufflehog.yml +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -on: - push: - -name: Secret Leaks - -permissions: {} - -jobs: - trufflehog: - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - persist-credentials: false - - - name: Secret Scanning - uses: trufflesecurity/trufflehog@90694bf9af66e7536abc5824e7a87246dbf933cb # v3.88.35 - with: - extra_args: --only-verified diff --git a/.github/workflows/upload_pr_documentation.yml b/.github/workflows/upload_pr_documentation.yml deleted file mode 100644 index 32665930b..000000000 --- a/.github/workflows/upload_pr_documentation.yml +++ /dev/null @@ -1,16 +0,0 @@ -name: Upload PR Documentation - -on: # zizmor: ignore[dangerous-triggers] We follow the same pattern as in Transformers - workflow_run: - workflows: [ "Build PR Documentation" ] - types: - - completed - -jobs: - build: # zizmor: ignore[excessive-permissions] We follow the same pattern as in Transformers - uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main - with: - package_name: lerobot - secrets: - hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} - comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} diff --git a/docker/Dockerfile.internal b/docker/Dockerfile.internal index 051606449..c799a006d 100644 --- a/docker/Dockerfile.internal +++ b/docker/Dockerfile.internal @@ -1,12 +1,26 @@ -# Dockerfile.internal +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # This Dockerfile is designed for HuggingFace internal CI environments # that require GPU access. It starts from an NVIDIA CUDA base image. # docker build -f docker/Dockerfile.internal -t lerobot-ci . # Configure the base image for CI with GPU access -ARG CUDA_VERSION=12.9.1 -ARG OS_VERSION=24.04 +# TODO(Steven): Bump these versions +ARG CUDA_VERSION=12.4.1 +ARG OS_VERSION=22.04 FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION} # Define Python version argument @@ -14,16 +28,17 @@ ARG PYTHON_VERSION=3.10 # Configure environment variables ENV DEBIAN_FRONTEND=noninteractive \ - MUJOCO_GL="egl" \ - PATH="/lerobot/.venv/bin:$PATH" + MUJOCO_GL=egl \ + PATH=/lerobot/.venv/bin:$PATH \ + CUDA_VISIBLE_DEVICES=0 \ + TEST_TYPE=single_gpu \ + DEVICE=cuda # Install Python, system dependencies, and uv (as root) RUN apt-get update && apt-get install -y --no-install-recommends \ - software-properties-common \ - build-essential git curl \ + software-properties-common build-essential git curl \ libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \ - libusb-1.0-0-dev \ - speech-dispatcher libgeos-dev \ + libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \ && add-apt-repository -y ppa:deadsnakes/ppa \ && apt-get update \ && apt-get install -y --no-install-recommends \ @@ -33,6 +48,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ && curl -LsSf https://astral.sh/uv/install.sh | sh \ && mv /root/.local/bin/uv /usr/local/bin/uv \ && useradd --create-home --shell /bin/bash user_lerobot \ + && usermod -aG sudo user_lerobot \ && apt-get clean && rm -rf /var/lib/apt/lists/* # Create application directory and set permissions @@ -42,6 +58,13 @@ RUN chown -R user_lerobot:user_lerobot /lerobot # Switch to the non-root user USER user_lerobot +# Environment variables for the testing +ENV HOME=/home/user_lerobot \ + HF_HOME=/home/user_lerobot/.cache/huggingface \ + HF_LEROBOT_HOME=/home/user_lerobot/.cache/huggingface/lerobot \ + TORCH_HOME=/home/user_lerobot/.cache/torch \ + TRITON_CACHE_DIR=/home/user_lerobot/.cache/triton + # Create the virtual environment # We use a virtual environment inside the container—even though the container itself \ # provides isolation—to ensure compatibility with the cluster and to prevent \ @@ -49,11 +72,12 @@ USER user_lerobot RUN uv venv --python python${PYTHON_VERSION} # Install Python dependencies for caching -COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md ./ +COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./ COPY --chown=user_lerobot:user_lerobot src/ src/ RUN uv pip install --no-cache ".[all]" # Copy the rest of the application source code +# Make sure to have the git-LFS files for testing COPY --chown=user_lerobot:user_lerobot . . # Set the default command diff --git a/docker/Dockerfile.user b/docker/Dockerfile.user index ce63f5530..4cfbb437a 100644 --- a/docker/Dockerfile.user +++ b/docker/Dockerfile.user @@ -1,4 +1,17 @@ -# Dockerfile.user +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # This Dockerfile is designed for a lerobot user who wants to # experiment with the project. It starts from an Python Slim base image. @@ -11,18 +24,17 @@ FROM python:${PYTHON_VERSION}-slim # Configure environment variables ENV DEBIAN_FRONTEND=noninteractive \ - MUJOCO_GL="egl" \ - PATH="/lerobot/.venv/bin:$PATH" + MUJOCO_GL=egl \ + PATH=/lerobot/.venv/bin:$PATH # Install system dependencies and uv (as root) RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential git curl \ - libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \ - libusb-1.0-0-dev \ - speech-dispatcher libgeos-dev \ + build-essential git curl libglib2.0-0 libegl1-mesa ffmpeg \ + libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \ && curl -LsSf https://astral.sh/uv/install.sh | sh \ && mv /root/.local/bin/uv /usr/local/bin/uv \ && useradd --create-home --shell /bin/bash user_lerobot \ + && usermod -aG sudo user_lerobot \ && apt-get clean && rm -rf /var/lib/apt/lists/* # Create application directory and set permissions @@ -32,6 +44,13 @@ RUN chown -R user_lerobot:user_lerobot /lerobot # Switch to the non-root user USER user_lerobot +# Environment variables for the testing +ENV HOME=/home/user_lerobot \ + HF_HOME=/home/user_lerobot/.cache/huggingface \ + HF_LEROBOT_HOME=/home/user_lerobot/.cache/huggingface/lerobot \ + TORCH_HOME=/home/user_lerobot/.cache/torch \ + TRITON_CACHE_DIR=/home/user_lerobot/.cache/triton + # Create the virtual environment # We use a virtual environment inside the container—even though the container itself \ # provides isolation—to closely resemble local development and allow users to \ @@ -39,11 +58,12 @@ USER user_lerobot RUN uv venv # Install Python dependencies for caching -COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md ./ +COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./ COPY --chown=user_lerobot:user_lerobot src/ src/ RUN uv pip install --no-cache ".[all]" # Copy the rest of the application code +# Make sure to have the git-LFS files for testing COPY --chown=user_lerobot:user_lerobot . . # Set the default command diff --git a/docker/lerobot-cpu/Dockerfile b/docker/lerobot-cpu/Dockerfile deleted file mode 100644 index 85c31ac1a..000000000 --- a/docker/lerobot-cpu/Dockerfile +++ /dev/null @@ -1,29 +0,0 @@ -# Configure image -ARG PYTHON_VERSION=3.10 -FROM python:${PYTHON_VERSION}-slim - -# Configure environment variables -ARG PYTHON_VERSION -ENV DEBIAN_FRONTEND=noninteractive -ENV MUJOCO_GL="egl" -ENV PATH="/opt/venv/bin:$PATH" - -# Install dependencies and set up Python in a single layer -RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential cmake git \ - libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \ - speech-dispatcher libgeos-dev \ - && ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \ - && python -m venv /opt/venv \ - && apt-get clean && rm -rf /var/lib/apt/lists/* \ - && echo "source /opt/venv/bin/activate" >> /root/.bashrc - -# Clone repository and install LeRobot in a single layer -COPY . /lerobot -WORKDIR /lerobot -RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \ - && /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, smolvla]" \ - --extra-index-url https://download.pytorch.org/whl/cpu - -# Execute in bash shell rather than python -CMD ["/bin/bash"] diff --git a/docker/lerobot-gpu-dev/Dockerfile b/docker/lerobot-gpu-dev/Dockerfile deleted file mode 100644 index 4d25b2550..000000000 --- a/docker/lerobot-gpu-dev/Dockerfile +++ /dev/null @@ -1,68 +0,0 @@ -FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 - -# Configure image -ARG PYTHON_VERSION=3.10 -ARG DEBIAN_FRONTEND=noninteractive - -# Install apt dependencies -RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential cmake \ - git git-lfs openssh-client \ - nano vim less util-linux tree \ - htop atop nvtop \ - sed gawk grep curl wget zip unzip \ - tcpdump sysstat screen tmux \ - libglib2.0-0 libgl1-mesa-glx libegl1-mesa \ - speech-dispatcher portaudio19-dev libgeos-dev \ - python${PYTHON_VERSION} python${PYTHON_VERSION}-venv python${PYTHON_VERSION}-dev \ - && apt-get clean && rm -rf /var/lib/apt/lists/* - -# Install ffmpeg build dependencies. See: -# https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu -# TODO(aliberts): create image to build dependencies from source instead -RUN apt-get update && apt-get install -y --no-install-recommends \ - autoconf automake yasm \ - libass-dev \ - libfreetype6-dev \ - libgnutls28-dev \ - libunistring-dev \ - libmp3lame-dev \ - libtool \ - libvorbis-dev \ - meson \ - ninja-build \ - pkg-config \ - texinfo \ - yasm \ - zlib1g-dev \ - nasm \ - libx264-dev \ - libx265-dev libnuma-dev \ - libvpx-dev \ - libfdk-aac-dev \ - libopus-dev \ - libsvtav1-dev libsvtav1enc-dev libsvtav1dec-dev \ - libdav1d-dev - -# Install gh cli tool -RUN (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \ - && mkdir -p -m 755 /etc/apt/keyrings \ - && wget -qO- https://cli.github.com/packages/githubcli-archive-keyring.gpg | tee /etc/apt/keyrings/githubcli-archive-keyring.gpg > /dev/null \ - && chmod go+r /etc/apt/keyrings/githubcli-archive-keyring.gpg \ - && echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \ - && apt update \ - && apt install gh -y \ - && apt clean && rm -rf /var/lib/apt/lists/* - -# Setup `python` -RUN ln -s /usr/bin/python3 /usr/bin/python - -# Install poetry -RUN curl -sSL https://install.python-poetry.org | python - -ENV PATH="/root/.local/bin:$PATH" -RUN echo 'if [ "$HOME" != "/root" ]; then ln -sf /root/.local/bin/poetry $HOME/.local/bin/poetry; fi' >> /root/.bashrc -RUN poetry config virtualenvs.create false -RUN poetry config virtualenvs.in-project true - -# Set EGL as the rendering backend for MuJoCo -ENV MUJOCO_GL="egl" diff --git a/docker/lerobot-gpu/Dockerfile b/docker/lerobot-gpu/Dockerfile deleted file mode 100644 index 746ea29b7..000000000 --- a/docker/lerobot-gpu/Dockerfile +++ /dev/null @@ -1,24 +0,0 @@ -FROM nvidia/cuda:12.4.1-base-ubuntu22.04 - -# Configure environment variables -ARG PYTHON_VERSION=3.10 -ENV DEBIAN_FRONTEND=noninteractive -ENV MUJOCO_GL="egl" -ENV PATH="/opt/venv/bin:$PATH" - -# Install dependencies and set up Python in a single layer -RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential cmake git \ - libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \ - speech-dispatcher libgeos-dev \ - python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ - && ln -s /usr/bin/python${PYTHON_VERSION} /usr/bin/python \ - && python -m venv /opt/venv \ - && apt-get clean && rm -rf /var/lib/apt/lists/* \ - && echo "source /opt/venv/bin/activate" >> /root/.bashrc - -# Clone repository and install LeRobot in a single layer -COPY . /lerobot -WORKDIR /lerobot -RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \ - && /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel, smolvla]" From 9229f21b23085abbf64cd695ffaebc392fabb099 Mon Sep 17 00:00:00 2001 From: Jakob Frick Date: Sun, 20 Jul 2025 01:33:51 -0700 Subject: [PATCH 024/158] Advise placement of cable during assembly, clarify USB instructions (#1545) * Update so101.mdx Signed-off-by: Jakob Frick * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update so101.mdx Signed-off-by: Jakob Frick --------- Signed-off-by: Jakob Frick Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/lerobot/robots/so101_follower/so101.mdx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lerobot/robots/so101_follower/so101.mdx b/src/lerobot/robots/so101_follower/so101.mdx index e84336e17..a20a3fa9f 100644 --- a/src/lerobot/robots/so101_follower/so101.mdx +++ b/src/lerobot/robots/so101_follower/so101.mdx @@ -34,6 +34,8 @@ The follower arm uses 6x STS3215 motors with 1/345 gearing. The leader, however, Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material. +It is advisable to install one 3-pin cable in the motor after placing them before continuing assembly. + ### Joint 1 - Place the first motor into the base. @@ -157,7 +159,7 @@ Remove all support material from the 3D-printed parts. The easiest way to do thi ### 1. Find the USB ports associated with each arm -To find the port for each bus servo adapter, run this script: +To find the port for each bus servo adapter, connect MotorBus to your computer via USB and power. Run the following script and disconnect the MotorBus when prompted: ```bash python -m lerobot.find_port From e88b30e6ccfb142ccb9ccf0dbc406f92b2b8f00c Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sun, 20 Jul 2025 23:09:35 +0200 Subject: [PATCH 025/158] fix(ci): multiple fixes (#1549) * fix(ci): tag of image when pushing to main * fix(docs): remove symlink in docs folder * chore(docs): move .mdx files to docs/ folder * chore(docs): create symlink to docs files * chore(ci): de-couple fast and full test pipeline * fix(ci): skip GPU Tests for community PRs --- .github/workflows/fast_tests.yml | 87 +++ .../workflows/{tests.yml => full_tests.yml} | 107 +-- .github/workflows/nightly.yml | 8 +- docs/source/hope_jr.mdx | 278 +++++++- docs/source/koch.mdx | 284 +++++++- docs/source/lekiwi.mdx | 338 ++++++++- docs/source/so100.mdx | 641 +++++++++++++++++- docs/source/so101.mdx | 437 +++++++++++- src/lerobot/robots/hope_jr/hope_jr.mdx | 278 +------- src/lerobot/robots/koch_follower/koch.mdx | 284 +------- src/lerobot/robots/lekiwi/lekiwi.mdx | 338 +-------- src/lerobot/robots/so100_follower/so100.mdx | 641 +----------------- src/lerobot/robots/so101_follower/so101.mdx | 437 +----------- 13 files changed, 2105 insertions(+), 2053 deletions(-) create mode 100644 .github/workflows/fast_tests.yml rename .github/workflows/{tests.yml => full_tests.yml} (70%) mode change 120000 => 100644 docs/source/hope_jr.mdx mode change 120000 => 100644 docs/source/koch.mdx mode change 120000 => 100644 docs/source/lekiwi.mdx mode change 120000 => 100644 docs/source/so100.mdx mode change 120000 => 100644 docs/source/so101.mdx mode change 100644 => 120000 src/lerobot/robots/hope_jr/hope_jr.mdx mode change 100644 => 120000 src/lerobot/robots/koch_follower/koch.mdx mode change 100644 => 120000 src/lerobot/robots/lekiwi/lekiwi.mdx mode change 100644 => 120000 src/lerobot/robots/so100_follower/so100.mdx mode change 100644 => 120000 src/lerobot/robots/so101_follower/so101.mdx diff --git a/.github/workflows/fast_tests.yml b/.github/workflows/fast_tests.yml new file mode 100644 index 000000000..ad4938970 --- /dev/null +++ b/.github/workflows/fast_tests.yml @@ -0,0 +1,87 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This workflow handles fast testing. +name: Fast Tests + +on: + # Allows running this workflow manually from the Actions tab + workflow_dispatch: + + pull_request: + branches: + - main + paths: + - "src/**" + - "tests/**" + - ".github/workflows/**" + - "pyproject.toml" + - "Makefile" + push: + branches: + - main + paths: + - "src/**" + - "tests/**" + - ".github/workflows/**" + - "pyproject.toml" + - "Makefile" + +permissions: + contents: read + +# Sets up the environment variables +env: + UV_VERSION: "0.8.0" + PYTHON_VERSION: "3.10" + DOCKER_IMAGE_NAME: huggingface/lerobot-gpu + +# Ensures that only the latest commit for a PR or branch is built, canceling older runs. +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + # This job runs pytests with the default dependencies. + # It runs everytime we commit to a PR or push to main + fast-pytest-tests: + name: Fast Pytest Tests + runs-on: ubuntu-latest + env: + MUJOCO_GL: egl + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + lfs: true + + # TODO(Steven): Evaluate the need of these dependencies + - name: Install apt dependencies + run: | + sudo apt-get update && sudo apt-get install -y build-essential git \ + curl libglib2.0-0 libegl1-mesa-dev ffmpeg \ + libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev + + - name: Setup uv and Python + uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses] + with: + enable-cache: true + version: ${{ env.UV_VERSION }} + python-version: ${{ env.PYTHON_VERSION }} + + - name: Install lerobot with test extras + run: uv sync --extra "test" + + - name: Run pytest + run: uv run pytest tests -vv --maxfail=10 diff --git a/.github/workflows/tests.yml b/.github/workflows/full_tests.yml similarity index 70% rename from .github/workflows/tests.yml rename to .github/workflows/full_tests.yml index 443c849dc..f044dc484 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/full_tests.yml @@ -12,22 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -# This workflow handles testing. -name: Tests +# This workflow handles full testing. +name: Full Tests on: # Allows running this workflow manually from the Actions tab workflow_dispatch: - pull_request: - branches: - - main - paths: - - "src/**" - - "tests/**" - - ".github/workflows/**" - - "pyproject.toml" - - "Makefile" pull_request_review: types: [submitted] push: @@ -49,67 +40,22 @@ env: PYTHON_VERSION: "3.10" DOCKER_IMAGE_NAME: huggingface/lerobot-gpu -# Ensures that only the latest commit for a PR or branch is built, canceling older runs. +# Ensures that only the latest action is built, canceling older runs. concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true jobs: - # This job runs pytests with the default dependencies. - # It runs everytime we commit to a PR or push to main - fast-pytest-tests: - name: Fast Pytest Tests - if: | - github.event_name == 'pull_request' || - github.event_name == 'push' || - github.event_name == 'workflow_dispatch' - runs-on: ubuntu-latest - env: - MUJOCO_GL: egl - steps: - - uses: actions/checkout@v4 - with: - persist-credentials: false - lfs: true - # TODO(Steven): Evaluate the need of these dependencies - - name: Install apt dependencies - run: | - sudo apt-get update && sudo apt-get install -y build-essential git \ - curl libglib2.0-0 libegl1-mesa-dev ffmpeg \ - libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev - - - name: Setup uv and Python - uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses] - with: - enable-cache: true - version: ${{ env.UV_VERSION }} - python-version: ${{ env.PYTHON_VERSION }} - - - name: Install lerobot with test extras - run: uv sync --extra "test" - - - name: Run pytest - run: uv run pytest tests -vv --maxfail=10 - - full-tests-gate: - # This job evaluates the need to run the full tests suite. - name: Full Tests Gate + # This job runs the E2E tests + pytest with all extras + # It runs everytime a PR is approved or a push to main + full-tests: + name: Full Tests runs-on: ubuntu-latest if: | (github.event_name == 'pull_request_review' && github.event.review.state == 'approved') || github.event_name == 'push' || github.event_name == 'workflow_dispatch' - steps: - - name: Gate check - run: echo "Full tests will run." - - full-tests: - # This job runs the E2E tests + pytest with all extras - # It runs everytime a PR is approved or a push to main - name: Full Tests - needs: full-tests-gate - runs-on: ubuntu-latest env: MUJOCO_GL: egl steps: @@ -140,16 +86,35 @@ jobs: - name: Run end-to-end tests run: uv run make test-end-to-end + # This job builds a GPU enabled image for testing + # It runs everytime a PR is approved or a push to main + # TODO(Steven): For now we skip this job for community PRs build-and-push-docker: - # This job builds a GPU enabled image for testing - # It runs everytime a PR is approved or a push to main name: Build and Push Docker - needs: full-tests-gate runs-on: group: aws-general-8-plus + if: | + (github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) || + github.event_name == 'push' || + github.event_name == 'workflow_dispatch' outputs: - image_tag: ${{ env.DOCKER_IMAGE_NAME }}:pr-${{ github.event.pull_request.number }} + image_tag: ${{ steps.set_tag.outputs.image_tag }} + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_REF: ${{ github.ref }} + GITHUB_PR_NUMBER: ${{ github.event.pull_request.number }} steps: + - name: Set Docker image tag + id: set_tag + run: | + if [[ "${GITHUB_EVENT_NAME}" == "push" ]]; then + TAG="${DOCKER_IMAGE_NAME}:latest" + elif [[ -n "${GITHUB_PR_NUMBER}" ]]; then + TAG="${DOCKER_IMAGE_NAME}:pr-${GITHUB_PR_NUMBER}" + else + TAG="${DOCKER_IMAGE_NAME}:pr-${GITHUB_REF##*/}" + fi + echo "image_tag=$TAG" >> $GITHUB_OUTPUT - name: Install Git LFS run: | sudo apt-get update @@ -174,11 +139,11 @@ jobs: context: . file: ./docker/Dockerfile.internal push: true - tags: ${{ env.DOCKER_IMAGE_NAME }}:pr-${{ github.event.pull_request.number }} + tags: ${{ steps.set_tag.outputs.image_tag }} + # This job runs pytest with all extras in a GPU enabled host + # It runs everytime a test image is created gpu-tests: - # This job runs pytest with all extras in a GPU enabled host - # It runs everytime a test image is created name: GPU Tests needs: [build-and-push-docker] runs-on: @@ -204,12 +169,12 @@ jobs: - name: Run end-to-end tests run: make test-end-to-end + # This job deletes the test image recently created + # It runs everytime after the gpu-tests have finished delete-pr-image: - # This job deletes the test image recently created - # It runs everytime after the gpu-tests have finished name: Delete PR Image needs: [gpu-tests, build-and-push-docker] - if: always() && github.event.review.state == 'approved' && needs.build-and-push-docker.result == 'success' + if: always() && ((github.event.review.state == 'approved') || (github.event_name == 'workflow_dispatch')) && needs.build-and-push-docker.result == 'success' runs-on: ubuntu-latest steps: - name: Get Docker Hub Token and Delete Image diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 149f55c89..66755d9df 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -38,8 +38,8 @@ concurrency: cancel-in-progress: true jobs: + # This job builds a CPU image for testing & distribution build-docker-cpu-nightly: - # This job builds a CPU image for testing & distribution name: Build CPU Docker for Nightly runs-on: group: aws-general-8-plus @@ -72,8 +72,8 @@ jobs: push: true tags: ${{ env.DOCKER_IMAGE_NAME_CPU }} + # This job builds a GPU image for testing & distribution build-docker-gpu-nightly: - # This job builds a GPU image for testing & distribution name: Build GPU Docker for Nightly runs-on: group: aws-general-8-plus @@ -106,8 +106,8 @@ jobs: push: true tags: ${{ env.DOCKER_IMAGE_NAME_GPU }} + # This job runs the E2E tests + pytest with all extras in the CPU image nightly-cpu-tests: - # This job runs the E2E tests + pytest with all extras in the CPU image name: Nightly CPU Tests needs: [build-docker-cpu-nightly] runs-on: @@ -132,8 +132,8 @@ jobs: - name: Run end-to-end tests run: make test-end-to-end + # This job runs the E2E tests + pytest with all extras in the GPU image nightly-gpu-tests: - # This job runs the E2E tests + pytest with all extras in the GPU image name: Nightly GPU Tests needs: [build-docker-gpu-nightly] runs-on: diff --git a/docs/source/hope_jr.mdx b/docs/source/hope_jr.mdx deleted file mode 120000 index 402422634..000000000 --- a/docs/source/hope_jr.mdx +++ /dev/null @@ -1 +0,0 @@ -../../src/lerobot/robots/hope_jr/hope_jr.mdx \ No newline at end of file diff --git a/docs/source/hope_jr.mdx b/docs/source/hope_jr.mdx new file mode 100644 index 000000000..72aa8f923 --- /dev/null +++ b/docs/source/hope_jr.mdx @@ -0,0 +1,277 @@ +# HopeJR + +## Prerequisites + +- [Hardware Setup](https://github.com/TheRobotStudio/HOPEJr) + +## Install LeRobot + +Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot. + +Install LeRobot with HopeJR dependencies: + +```bash +pip install -e ".[hopejr]" +``` + +## Device Configuration + +Before starting calibration and operation, you need to identify the USB ports for each HopeJR component. Run this script to find the USB ports for the arm, hand, glove, and exoskeleton: + +```bash +python -m lerobot.find_port +``` + +This will display the available USB ports and their associated devices. Make note of the port paths (e.g., `/dev/tty.usbmodem58760433331`, `/dev/tty.usbmodem11301`) as you'll need to specify them in the `--robot.port` and `--teleop.port` parameters when recording data, replaying episodes, or running teleoperation scripts. + +## Step 1: Calibration + +Before performing teleoperation, HopeJR's limbs need to be calibrated. Calibration files will be saved in `~/.cache/huggingface/lerobot/calibration` + +### 1.1 Calibrate Robot Hand + +```bash +python -m lerobot.calibrate \ + --robot.type=hope_jr_hand \ + --robot.port=/dev/tty.usbmodem58760432281 \ + --robot.id=blue \ + --robot.side=right +``` + +When running the calibration script, a calibration GUI will pop up. Finger joints are named as follows: + +**Thumb**: + +- **CMC**: base joint connecting thumb to hand +- **MCP**: knuckle joint +- **PIP**: first finger joint +- **DIP** : fingertip joint + +**Index, Middle, Ring, and Pinky fingers**: + +- **Radial flexor**: Moves base of finger towards the thumb +- **Ulnar flexor**: Moves base of finger towards the pinky +- **PIP/DIP**: Flexes the distal and proximal phalanx of the finger + +Each one of these will need to be calibrated individually via the GUI. +Note that ulnar and radial flexors should have ranges of the same size (but with different offsets) in order to get symmetric movement. + +

+ Setting boundaries in the hand calibration GUI +

+ +Use the calibration interface to set the range boundaries for each joint as shown above. + +

+ Saving calibration values +

+ +Once you have set the appropriate boundaries for all joints, click "Save" to save the calibration values to the motors. + +### 1.2 Calibrate Teleoperator Glove + +```bash +python -m lerobot.calibrate \ + --teleop.type=homunculus_glove \ + --teleop.port=/dev/tty.usbmodem11201 \ + --teleop.id=red \ + --teleop.side=right +``` + +Move each finger through its full range of motion, starting from the thumb. + +``` +Move thumb through its entire range of motion. +Recording positions. Press ENTER to stop... + +------------------------------------------- +NAME | MIN | POS | MAX +thumb_cmc | 1790 | 1831 | 1853 +thumb_mcp | 1497 | 1514 | 1528 +thumb_pip | 1466 | 1496 | 1515 +thumb_dip | 1463 | 1484 | 1514 +``` + +Continue with each finger: + +``` +Move middle through its entire range of motion. +Recording positions. Press ENTER to stop... + +------------------------------------------- +NAME | MIN | POS | MAX +middle_mcp_abduction | 1598 | 1718 | 1820 +middle_mcp_flexion | 1512 | 1658 | 2136 +middle_dip | 1484 | 1500 | 1547 +``` + +Once calibration is complete, the system will save the calibration to `/Users/your_username/.cache/huggingface/lerobot/calibration/teleoperators/homunculus_glove/red.json` + +### 1.3 Calibrate Robot Arm + +```bash +python -m lerobot.calibrate \ + --robot.type=hope_jr_arm \ + --robot.port=/dev/tty.usbserial-1110 \ + --robot.id=white +``` + +This will open a calibration GUI where you can set the range limits for each motor. The arm motions are organized as follows: + +- **Shoulder**: pitch, yaw, and roll +- **Elbow**: flex +- **Wrist**: pitch, yaw, and roll + +

+ Setting boundaries in the arm calibration GUI +

+ +Use the calibration interface to set the range boundaries for each joint. Move each joint through its full range of motion and adjust the minimum and maximum values accordingly. Once you have set the appropriate boundaries for all joints, save the calibration. + +### 1.4 Calibrate Teleoperator Exoskeleton + +```bash +python -m lerobot.calibrate \ + --teleop.type=homunculus_arm \ + --teleop.port=/dev/tty.usbmodem11201 \ + --teleop.id=black +``` + +The exoskeleton allows one to control the robot arm. During calibration, you'll be prompted to move all joints through their full range of motion: + +``` +Move all joints through their entire range of motion. +Recording positions. Press ENTER to stop... + +------------------------------------------- +------------------------------------------- +NAME | MIN | POS | MAX +shoulder_pitch | 586 | 736 | 895 +shoulder_yaw | 1257 | 1374 | 1390 +shoulder_roll | 449 | 1034 | 2564 +elbow_flex | 3023 | 3117 | 3134 +wrist_roll | 3073 | 3096 | 3147 +wrist_yaw | 2143 | 2171 | 2185 +wrist_pitch | 1975 | 1993 | 2074 +Calibration saved to /Users/your_username/.cache/huggingface/lerobot/calibration/teleoperators/homunculus_arm/black.json +``` + +## Step 2: Teleoperation + +Due to global variable conflicts in the Feetech middleware, teleoperation for arm and hand must run in separate shell sessions: + +### Hand + +```bash +python -m lerobot.teleoperate \ + --robot.type=hope_jr_hand \ + --robot.port=/dev/tty.usbmodem58760432281 \ + --robot.id=blue \ + --robot.side=right \ + --teleop.type=homunculus_glove \ + --teleop.port=/dev/tty.usbmodem11201 \ + --teleop.id=red \ + --teleop.side=right \ + --display_data=true \ + --fps=30 +``` + +### Arm + +```bash +python -m lerobot.teleoperate \ + --robot.type=hope_jr_arm \ + --robot.port=/dev/tty.usbserial-1110 \ + --robot.id=white \ + --teleop.type=homunculus_arm \ + --teleop.port=/dev/tty.usbmodem11201 \ + --teleop.id=black \ + --display_data=true \ + --fps=30 +``` + +## Step 3: Record, Replay, Train + +Record, Replay and Train with Hope-JR is still experimental. + +### Record + +This step records the dataset, which can be seen as an example [here](https://huggingface.co/datasets/nepyope/hand_record_test_with_video_data/settings). + +```bash +python -m lerobot.record \ + --robot.type=hope_jr_hand \ + --robot.port=/dev/tty.usbmodem58760432281 \ + --robot.id=right \ + --robot.side=right \ + --robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \ + --teleop.type=homunculus_glove \ + --teleop.port=/dev/tty.usbmodem1201 \ + --teleop.id=right \ + --teleop.side=right \ + --dataset.repo_id=nepyope/hand_record_test_with_video_data \ + --dataset.single_task="Hand recording test with video data" \ + --dataset.num_episodes=1 \ + --dataset.episode_time_s=5 \ + --dataset.push_to_hub=true \ + --dataset.private=true \ + --display_data=true +``` + +### Replay + +```bash +python -m lerobot.replay \ + --robot.type=hope_jr_hand \ + --robot.port=/dev/tty.usbmodem58760432281 \ + --robot.id=right \ + --robot.side=right \ + --dataset.repo_id=nepyope/hand_record_test_with_camera \ + --dataset.episode=0 +``` + +### Train + +```bash +python -m lerobot.scripts.train \ + --dataset.repo_id=nepyope/hand_record_test_with_video_data \ + --policy.type=act \ + --output_dir=outputs/train/hopejr_hand \ + --job_name=hopejr \ + --policy.device=mps \ + --wandb.enable=true \ + --policy.repo_id=nepyope/hand_test_policy +``` + +### Evaluate + +This training run can be viewed as an example [here](https://wandb.ai/tino/lerobot/runs/rp0k8zvw?nw=nwusertino). + +```bash +python -m lerobot.record \ + --robot.type=hope_jr_hand \ + --robot.port=/dev/tty.usbmodem58760432281 \ + --robot.id=right \ + --robot.side=right \ + --robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \ + --display_data=false \ + --dataset.repo_id=nepyope/eval_hopejr \ + --dataset.single_task="Evaluate hopejr hand policy" \ + --dataset.num_episodes=10 \ + --policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model +``` diff --git a/docs/source/koch.mdx b/docs/source/koch.mdx deleted file mode 120000 index 5383518b3..000000000 --- a/docs/source/koch.mdx +++ /dev/null @@ -1 +0,0 @@ -../../src/lerobot/robots/koch_follower/koch.mdx \ No newline at end of file diff --git a/docs/source/koch.mdx b/docs/source/koch.mdx new file mode 100644 index 000000000..d0b991e74 --- /dev/null +++ b/docs/source/koch.mdx @@ -0,0 +1,283 @@ +# Koch v1.1 + +In the steps below, we explain how to assemble the Koch v1.1 robot. + +## Order and assemble the parts + +Follow the sourcing and assembling instructions provided in this [README](https://github.com/jess-moss/koch-v1-1). This will guide you through setting up both the follower and leader arms, as shown in the image below. + +For a visual walkthrough of the assembly process, you can refer to [this video tutorial](https://youtu.be/8nQIg9BwwTk). + +> [!WARNING] +> Since the production of this video, we simplified the configuration phase. Because of this, two things differ from the instructions in that video: +> +> - Don't plug in all the motor cables right away and wait to be instructed to do so in [Configure the motors](#configure-the-motors). +> - Don't screw in the controller board (PCB) to the base right away and wait for being instructed to do so in [Configure the motors](#configure-the-motors). + +## Install LeRobot 🤗 + +To install LeRobot follow, our [Installation Guide](./installation) + +In addition to these instructions, you need to install the Dynamixel SDK: + +```bash +pip install -e ".[dynamixel]" +``` + +## Configure the motors + +### 1. Find the USB ports associated with each arm + +To find the port for each bus servo adapter, run this script: + +```bash +python -m lerobot.find_port +``` + + + + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] +Remove the USB cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/tty.usbmodem575E0032081 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm. + + + + +On Linux, you might need to give access to the USB ports by running: + +```bash +sudo chmod 666 /dev/ttyACM0 +sudo chmod 666 /dev/ttyACM1 +``` + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/ttyACM0', '/dev/ttyACM1'] +Remove the usb cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/ttyACM1 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/ttyACM1` corresponding to your leader or follower arm. + + + + +### 2. Set the motors ids and baudrates + +Each motor is identified by a unique id on the bus. When brand new, motors usually come with a default id of `1`. For the communication to work properly between the motors and the controller, we first need to set a unique, different id to each motor. Additionally, the speed at which data is transmitted on the bus is determined by the baudrate. In order to talk to each other, the controller and all the motors need to be configured with the same baudrate. + +To that end, we first need to connect to each motor individually with the controller in order to set these. Since we will write these parameters in the non-volatile section of the motors' internal memory (EEPROM), we'll only need to do this once. + +If you are repurposing motors from another robot, you will probably also need to perform this step, as the ids and baudrate likely won't match. + +#### Follower + +Connect the usb cable from your computer and the 5V power supply to the follower arm's controller board. Then, run the following command or run the API example with the port you got from the previous step. You'll also need to give your leader arm a name with the `id` parameter. + +For a visual reference on how to set the motor ids please refer to [this video](https://huggingface.co/docs/lerobot/en/so101#setup-motors-video) where we follow the process for the SO101 arm. + + + + +```bash +python -m lerobot.setup_motors \ + --robot.type=koch_follower \ + --robot.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step +``` + + + + + +```python +from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig + +config = KochFollowerConfig( + port="/dev/tty.usbmodem575E0031751", + id="my_awesome_follower_arm", +) +follower = KochFollower(config) +follower.setup_motors() +``` + + + + + +You should see the following instruction. + +``` +Connect the controller board to the 'gripper' motor only and press enter. +``` + +As instructed, plug the gripper's motor. Make sure it's the only motor connected to the board, and that the motor itself is not yet daisy-chained to any other motor. As you press `[Enter]`, the script will automatically set the id and baudrate for that motor. + +
+Troubleshooting + +If you get an error at that point, check your cables and make sure they are plugged in properly: + +
    +
  • Power supply
  • +
  • USB cable between your computer and the controller board
  • +
  • The 3-pin cable from the controller board to the motor
  • +
+ +If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB). + +
+ +You should then see the following message: + +``` +'gripper' motor id set to 6 +``` + +Followed by the next instruction: + +``` +Connect the controller board to the 'wrist_roll' motor only and press enter. +``` + +You can disconnect the 3-pin cable from the controller board but you can leave it connected to the gripper motor on the other end as it will already be in the right place. Now, plug in another 3-pin cable to the wrist roll motor and connect it to the controller board. As with the previous motor, make sure it is the only motor connected to the board and that the motor itself isn't connected to any other one. + +Repeat the operation for each motor as instructed. + +> [!TIP] +> Check your cabling at each step before pressing Enter. For instance, the power supply cable might disconnect as you manipulate the board. + +When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm. + +#### Leader + +Do the same steps for the leader arm but modify the command or script accordingly. + + + + +```bash +python -m lerobot.setup_motors \ + --teleop.type=koch_leader \ + --teleop.port=/dev/tty.usbmodem575E0031751 \ # <- paste here the port found at previous step +``` + + + + + +```python +from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig + +config = KochLeaderConfig( + port="/dev/tty.usbmodem575E0031751", + id="my_awesome_leader_arm", +) +leader = KochLeader(config) +leader.setup_motors() +``` + + + + + +## Calibrate + +Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. +The calibration process is very important because it allows a neural network trained on one robot to work on another. + +#### Follower + +Run the following command or API example to calibrate the follower arm: + + + + +```bash +python -m lerobot.calibrate \ + --robot.type=koch_follower \ + --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot + --robot.id=my_awesome_follower_arm # <- Give the robot a unique name +``` + + + + + +```python +from lerobot.robots.koch_follower import KochFollowerConfig, KochFollower + +config = KochFollowerConfig( + port="/dev/tty.usbmodem585A0076891", + id="my_awesome_follower_arm", +) + +follower = KochFollower(config) +follower.connect(calibrate=False) +follower.calibrate() +follower.disconnect() +``` + + + + + +We unified the calibration method for most robots. Thus, the calibration steps for this Koch arm are the same as the steps for the SO100 and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video). + +#### Leader + +Do the same steps to calibrate the leader arm, run the following command or API example: + + + + +```bash +python -m lerobot.calibrate \ + --teleop.type=koch_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot + --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name +``` + + + + + +```python +from lerobot.teleoperators.koch_leader import KochLeaderConfig, KochLeader + +config = KochLeaderConfig( + port="/dev/tty.usbmodem575E0031751", + id="my_awesome_leader_arm", +) + +leader = KochLeader(config) +leader.connect(calibrate=False) +leader.calibrate() +leader.disconnect() +``` + + + + + +Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) + +> [!TIP] +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/docs/source/lekiwi.mdx b/docs/source/lekiwi.mdx deleted file mode 120000 index afc43077e..000000000 --- a/docs/source/lekiwi.mdx +++ /dev/null @@ -1 +0,0 @@ -../../src/lerobot/robots/lekiwi/lekiwi.mdx \ No newline at end of file diff --git a/docs/source/lekiwi.mdx b/docs/source/lekiwi.mdx new file mode 100644 index 000000000..bb70fd26b --- /dev/null +++ b/docs/source/lekiwi.mdx @@ -0,0 +1,337 @@ +# LeKiwi + +In the steps below, we explain how to assemble the LeKiwi mobile robot. + +## Source the parts + +Follow this [README](https://github.com/SIGRobotics-UIUC/LeKiwi). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts. +And advise if it's your first time printing or if you don't own a 3D printer. + +### Wired version + +If you have the **wired** LeKiwi version, you can skip the installation of the Raspberry Pi and setting up SSH. You can also run all commands directly on your PC for both the LeKiwi scripts and the leader arm scripts for teleoperating. + +## Install software on Pi + +Now we have to set up the remote PC that will run on the LeKiwi Robot. This is normally a Raspberry Pi, but can be any PC that can run on 5V and has enough usb ports (2 or more) for the cameras and motor control board. + +### Install OS + +For setting up the Raspberry Pi and its SD-card see: [Setup PI](https://www.raspberrypi.com/documentation/computers/getting-started.html). Here is explained how to download the [Imager](https://www.raspberrypi.com/software/) to install Raspberry Pi OS or Ubuntu. + +### Setup SSH + +After setting up your Pi, you should enable and set up [SSH](https://www.raspberrypi.com/news/coding-on-raspberry-pi-remotely-with-visual-studio-code/) (Secure Shell Protocol) so you can log in to the Pi from your laptop without requiring a screen, keyboard, and mouse on the Pi. A great tutorial on how to do this can be found [here](https://www.raspberrypi.com/documentation/computers/remote-access.html#ssh). Logging into your Pi can be done in your Command Prompt (cmd) or, if you use VSCode you can use [this](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-ssh) extension. + +### Install LeRobot on Pi 🤗 + +On your Raspberry Pi install LeRobot using our [Installation Guide](./installation) + +In addition to these instructions, you need to install the Feetech SDK & ZeroMQ on your Pi: + +```bash +pip install -e ".[lekiwi]" +``` + +## Install LeRobot locally + +If you already have installed LeRobot on your laptop/pc you can skip this step; otherwise, please follow along as we do the same steps we did on the Pi. + +Follow our [Installation Guide](./installation) + +In addition to these instructions, you need to install the Feetech SDK & ZeroMQ on your laptop/pc: + +```bash +pip install -e ".[lekiwi]" +``` + +Great :hugs:! You are now done installing LeRobot, and we can begin assembling the SO100/SO101 arms and the mobile base :robot:. +Every time you now want to use LeRobot, you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands. + +# Step-by-Step Assembly Instructions + +First, we will assemble the two SO100/SO101 arms. One to attach to the mobile base and one for teleoperation. Then we will assemble the mobile base. The instructions for assembling can be found on these two pages: + +- [Assemble SO101](./so101#step-by-step-assembly-instructions) +- [Assemble LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi/blob/main/Assembly.md) + +### Find the USB ports associated with motor board + +To find the port for each bus servo adapter, run this script: + +```bash +python -m lerobot.find_port +``` + + + + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/tty.usbmodem575E0032081'] +Remove the USB cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/tty.usbmodem575E0032081 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your board. + + + + +On Linux, you might need to give access to the USB ports by running: + +```bash +sudo chmod 666 /dev/ttyACM0 +sudo chmod 666 /dev/ttyACM1 +``` + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/ttyACM0'] +Remove the usb cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/ttyACM0 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/ttyACM0` corresponding to your board. + + + + +### Configure motors + +The instructions for configuring the motors can be found in the SO101 [docs](./so101#configure-the-motors). Besides the ids for the arm motors, we also need to set the motor ids for the mobile base. These need to be in a specific order to work. Below an image of the motor ids and motor mounting positions for the mobile base. Note that we only use one Motor Control board on LeKiwi. This means the motor ids for the wheels are 7, 8 and 9. + +You can run this command to setup motors for LeKiwi. It will first setup the motors for arm (id 6..1) and then setup motors for wheels (9,8,7) + +```bash +python -m lerobot.setup_motors \ + --robot.type=lekiwi \ + --robot.port=/dev/tty.usbmodem58760431551 # <- paste here the port found at previous step +``` + +Motor ID's for mobile robot + +### Troubleshoot communication + +If you are having trouble connecting to the Mobile SO100, follow these steps to diagnose and resolve the issue. + +#### 1. Verify IP Address Configuration + +Make sure that the correct IP for the Pi is used in the commands or in your code. To check the Raspberry Pi's IP address, run (on the Pi command line): + +```bash +hostname -I +``` + +#### 2. Check if Pi is reachable from laptop/pc + +Try pinging the Raspberry Pi from your laptop: + +```bach +ping +``` + +If the ping fails: + +- Ensure the Pi is powered on and connected to the same network. +- Check if SSH is enabled on the Pi. + +#### 3. Try SSH connection + +If you can't SSH into the Pi, it might not be properly connected. Use: + +```bash +ssh @ +``` + +If you get a connection error: + +- Ensure SSH is enabled on the Pi by running: + ```bash + sudo raspi-config + ``` + Then navigate to: **Interfacing Options -> SSH** and enable it. + +### Calibration + +Now we have to calibrate the leader arm and the follower arm. The wheel motors don't have to be calibrated. +The calibration process is very important because it allows a neural network trained on one robot to work on another. + +### Calibrate follower arm (on mobile base) + +Make sure the arm is connected to the Raspberry Pi and run this script or API example (on the Raspberry Pi via SSH) to launch calibration of the follower arm: + +```bash +python -m lerobot.calibrate \ + --robot.type=lekiwi \ + --robot.id=my_awesome_kiwi # <- Give the robot a unique name +``` + +We unified the calibration method for most robots, thus, the calibration steps for this SO100 arm are the same as the steps for the Koch and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video). + +### Wired version + +If you have the **wired** LeKiwi version, please run all commands on your laptop. + +### Calibrate leader arm + +Then, to calibrate the leader arm (which is attached to the laptop/pc). Run the following command of API example on your laptop: + + + + +```bash +python -m lerobot.calibrate \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot + --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name +``` + + + + + +```python +from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader + +config = SO100LeaderConfig( + port="/dev/tty.usbmodem58760431551", + id="my_awesome_leader_arm", +) + +leader = SO100Leader(config) +leader.connect(calibrate=False) +leader.calibrate() +leader.disconnect() +``` + + + + + +## Teleoperate LeKiwi + +> [!TIP] +> If you're using a Mac, you might need to give Terminal permission to access your keyboard for teleoperation. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal. + +To teleoperate, SSH into your Raspberry Pi, and run `conda activate lerobot` and this command: + +```bash +python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi +``` + +Then on your laptop, also run `conda activate lerobot` and run the API example, make sure you set the correct `remote_ip` and `port` in `examples/lekiwi/teleoperate.py`. + +```bash +python examples/lekiwi/teleoperate.py +``` + +You should see on your laptop something like this: `[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below: + +| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) | +| ---------- | ------------------ | ---------------------- | +| Fast | 0.4 | 90 | +| Medium | 0.25 | 60 | +| Slow | 0.1 | 30 | + +| Key | Action | +| --- | -------------- | +| W | Move forward | +| A | Move left | +| S | Move backward | +| D | Move right | +| Z | Turn left | +| X | Turn right | +| R | Increase speed | +| F | Decrease speed | + +> [!TIP] +> If you use a different keyboard, you can change the keys for each command in the [`LeKiwiConfig`](../src/lerobot/robot_devices/robots/configs.py). + +### Wired version + +If you have the **wired** LeKiwi version, please run all commands on your laptop. + +## Record a dataset + +Once you're familiar with teleoperation, you can record your first dataset. + +We use the Hugging Face hub features for uploading your dataset. If you haven't previously used the Hub, make sure you can login via the cli using a write-access token, this token can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens). + +Add your token to the CLI by running this command: + +```bash +huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential +``` + +Then store your Hugging Face repository name in a variable: + +```bash +HF_USER=$(huggingface-cli whoami | head -n 1) +echo $HF_USER +``` + +Now you can record a dataset. To record episodes and upload your dataset to the hub, execute this API example tailored for LeKiwi. Make sure to first adapt the `remote_ip`, `repo_id`, `port` and `task` in the script. If you would like to run the script for longer you can increase `NB_CYCLES_CLIENT_CONNECTION`. + +```bash +python examples/lekiwi/record.py +``` + +#### Dataset upload + +Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running: + +```bash +echo https://huggingface.co/datasets/${HF_USER}/so101_test +``` + +Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example). + +You can look for other LeRobot datasets on the hub by searching for `LeRobot` [tags](https://huggingface.co/datasets?other=LeRobot). + +#### Tips for gathering data + +Once you're comfortable with data recording, you can create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings. Also make sure the object you are manipulating is visible on the camera's. A good rule of thumb is you should be able to do the task yourself by only looking at the camera images. + +In the following sections, you’ll train your neural network. After achieving reliable grasping performance, you can start introducing more variations during data collection, such as additional grasp locations, different grasping techniques, and altering camera positions. + +Avoid adding too much variation too quickly, as it may hinder your results. + +If you want to dive deeper into this important topic, you can check out the [blog post](https://huggingface.co/blog/lerobot-datasets#what-makes-a-good-dataset) we wrote on what makes a good dataset. + +#### Troubleshooting: + +- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux). + +## Replay an episode + +To replay an episode run the API example below, make sure to change `remote_ip`, `port`, LeRobotDatasetId and episode index. + +```bash +python examples/lekiwi/replay.py +``` + +Congrats 🎉, your robot is all set to learn a task on its own. Start training it by the training part of this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) + +## Evaluate your policy + +To evaluate your policy run the `evaluate.py` API example, make sure to change `remote_ip`, `port`, model.. + +```bash +python examples/lekiwi/evaluate.py +``` + +> [!TIP] +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/docs/source/so100.mdx b/docs/source/so100.mdx deleted file mode 120000 index 0a71dc307..000000000 --- a/docs/source/so100.mdx +++ /dev/null @@ -1 +0,0 @@ -../../src/lerobot/robots/so100_follower/so100.mdx \ No newline at end of file diff --git a/docs/source/so100.mdx b/docs/source/so100.mdx new file mode 100644 index 000000000..d9ff922c5 --- /dev/null +++ b/docs/source/so100.mdx @@ -0,0 +1,640 @@ +# SO-100 + +In the steps below, we explain how to assemble the SO-100 robot. + +## Source the parts + +Follow this [README](https://github.com/TheRobotStudio/SO-ARM100/blob/main/SO100.md). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts. And advise if it's your first time printing or if you don't own a 3D printer. + +## Install LeRobot 🤗 + +To install LeRobot, follow our [Installation Guide](./installation) + +In addition to these instructions, you need to install the Feetech SDK: + +```bash +pip install -e ".[feetech]" +``` + +## Configure the motors + +**Note:** +Unlike the SO-101, the motor connectors are not easily accessible once the arm is assembled, so the configuration step must be done beforehand. + +### 1. Find the USB ports associated with each arm + +To find the port for each bus servo adapter, run this script: + +```bash +python -m lerobot.find_port +``` + + + + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] +Remove the USB cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/tty.usbmodem575E0032081 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm. + + + + +On Linux, you might need to give access to the USB ports by running: + +```bash +sudo chmod 666 /dev/ttyACM0 +sudo chmod 666 /dev/ttyACM1 +``` + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/ttyACM0', '/dev/ttyACM1'] +Remove the usb cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/ttyACM1 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/ttyACM1` corresponding to your leader or follower arm. + + + + +### 2. Set the motors ids and baudrates + +Each motor is identified by a unique id on the bus. When brand new, motors usually come with a default id of `1`. For the communication to work properly between the motors and the controller, we first need to set a unique, different id to each motor. Additionally, the speed at which data is transmitted on the bus is determined by the baudrate. In order to talk to each other, the controller and all the motors need to be configured with the same baudrate. + +To that end, we first need to connect to each motor individually with the controller in order to set these. Since we will write these parameters in the non-volatile section of the motors' internal memory (EEPROM), we'll only need to do this once. + +If you are repurposing motors from another robot, you will probably also need to perform this step as the ids and baudrate likely won't match. + +#### Follower + +Connect the usb cable from your computer and the power supply to the follower arm's controller board. Then, run the following command or run the API example with the port you got from the previous step. You'll also need to give your leader arm a name with the `id` parameter. + +For a visual reference on how to set the motor ids please refer to [this video](https://huggingface.co/docs/lerobot/en/so101#setup-motors-video) where we follow the process for the SO101 arm. + + + + +```bash +python -m lerobot.setup_motors \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step +``` + + + + + +```python +from lerobot.robots.so100_follower import SO100Follower, SO100FollowerConfig + +config = SO100FollowerConfig( + port="/dev/tty.usbmodem585A0076841", + id="my_awesome_follower_arm", +) +follower = SO100Follower(config) +follower.setup_motors() +``` + + + + + +You should see the following instruction + +``` +Connect the controller board to the 'gripper' motor only and press enter. +``` + +As instructed, plug the gripper's motor. Make sure it's the only motor connected to the board, and that the motor itself is not yet daisy-chained to any other motor. As you press `[Enter]`, the script will automatically set the id and baudrate for that motor. + +
+Troubleshooting + +If you get an error at that point, check your cables and make sure they are plugged in properly: + +
    +
  • Power supply
  • +
  • USB cable between your computer and the controller board
  • +
  • The 3-pin cable from the controller board to the motor
  • +
+ +If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB). + +
+ +You should then see the following message: + +``` +'gripper' motor id set to 6 +``` + +Followed by the next instruction: + +``` +Connect the controller board to the 'wrist_roll' motor only and press enter. +``` + +You can disconnect the 3-pin cable from the controller board, but you can leave it connected to the gripper motor on the other end, as it will already be in the right place. Now, plug in another 3-pin cable to the wrist roll motor and connect it to the controller board. As with the previous motor, make sure it is the only motor connected to the board and that the motor itself isn't connected to any other one. + +Repeat the operation for each motor as instructed. + +> [!TIP] +> Check your cabling at each step before pressing Enter. For instance, the power supply cable might disconnect as you manipulate the board. + +When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm. + +#### Leader + +Do the same steps for the leader arm. + + + +```bash +python -m lerobot.setup_motors \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step +``` + + + + +```python +from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig + +config = SO100LeaderConfig( + port="/dev/tty.usbmodem585A0076841", + id="my_awesome_leader_arm", +) +leader = SO100Leader(config) +leader.setup_motors() +``` + + + + + +## Step-by-Step Assembly Instructions + +## Remove the gears of the 6 leader motors + +
+Video removing gears + +
+ +
+ +
+ +Follow the video for removing gears. You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm. + +### Clean Parts + +Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material. + +### Additional Guidance + +
+Video assembling arms + +
+ +
+ +
+ +**Note:** +This video provides visual guidance for assembling the arms, but it doesn't specify when or how to do the wiring. Inserting the cables beforehand is much easier than doing it afterward. The first arm may take a bit more than 1 hour to assemble, but once you get used to it, you can assemble the second arm in under 1 hour. + +--- + +### First Motor + +**Step 2: Insert Wires** + +- Insert two wires into the first motor. + + + +**Step 3: Install in Base** + +- Place the first motor into the base. + + + +**Step 4: Secure Motor** + +- Fasten the motor with 4 screws. Two from the bottom and two from top. + +**Step 5: Attach Motor Holder** + +- Slide over the first motor holder and fasten it using two screws (one on each side). + + + +**Step 6: Attach Motor Horns** + +- Install both motor horns, securing the top horn with a screw. Try not to move the motor position when attaching the motor horn, especially for the leader arms, where we removed the gears. + + + +
+ + Video adding motor horn + + +
+ +**Step 7: Attach Shoulder Part** + +- Route one wire to the back of the robot and the other to the left or towards you (see photo). +- Attach the shoulder part. + + + +**Step 8: Secure Shoulder** + +- Tighten the shoulder part with 4 screws on top and 4 on the bottom + _(access bottom holes by turning the shoulder)._ + +--- + +### Second Motor Assembly + +**Step 9: Install Motor 2** + +- Slide the second motor in from the top and link the wire from motor 1 to motor 2. + + + +**Step 10: Attach Shoulder Holder** + +- Add the shoulder motor holder. +- Ensure the wire from motor 1 to motor 2 goes behind the holder while the other wire is routed upward (see photo). +- This part can be tight to assemble, you can use a workbench like the image or a similar setup to push the part around the motor. + +
+ + + +
+ +**Step 11: Secure Motor 2** + +- Fasten the second motor with 4 screws. + +**Step 12: Attach Motor Horn** + +- Attach both motor horns to motor 2, again use the horn screw. + +**Step 13: Attach Base** + +- Install the base attachment using 2 screws. + + + +**Step 14: Attach Upper Arm** + +- Attach the upper arm with 4 screws on each side. + + + +--- + +### Third Motor Assembly + +**Step 15: Install Motor 3** + +- Route the motor cable from motor 2 through the cable holder to motor 3, then secure motor 3 with 4 screws. + +**Step 16: Attach Motor Horn** + +- Attach both motor horns to motor 3 and secure one again with a horn screw. + + + +**Step 17: Attach Forearm** + +- Connect the forearm to motor 3 using 4 screws on each side. + + + +--- + +### Fourth Motor Assembly + +**Step 18: Install Motor 4** + +- Slide in motor 4, attach the cable from motor 3, and secure the cable in its holder with a screw. + +
+ + +
+ +**Step 19: Attach Motor Holder 4** + +- Install the fourth motor holder (a tight fit). Ensure one wire is routed upward and the wire from motor 3 is routed downward (see photo). + + + +**Step 20: Secure Motor 4 & Attach Horn** + +- Fasten motor 4 with 4 screws and attach its motor horns, use for one a horn screw. + + + +--- + +### Wrist Assembly + +**Step 21: Install Motor 5** + +- Insert motor 5 into the wrist holder and secure it with 2 front screws. + + + +**Step 22: Attach Wrist** + +- Connect the wire from motor 4 to motor 5. And already insert the other wire for the gripper. +- Secure the wrist to motor 4 using 4 screws on both sides. + + + +**Step 23: Attach Wrist Horn** + +- Install only one motor horn on the wrist motor and secure it with a horn screw. + + + +--- + +### Follower Configuration + +**Step 24: Attach Gripper** + +- Attach the gripper to motor 5. + + + +**Step 25: Install Gripper Motor** + +- Insert the gripper motor, connect the motor wire from motor 5 to motor 6, and secure it with 3 screws on each side. + + + +**Step 26: Attach Gripper Horn & Claw** + +- Attach the motor horns and again use a horn screw. +- Install the gripper claw and secure it with 4 screws on both sides. + + + +**Step 27: Mount Controller** + +- Attach the motor controller to the back of the robot. + +
+ + +
+ +_Assembly complete – proceed to Leader arm assembly._ + +--- + +### Leader Configuration + +For the leader configuration, perform **Steps 1–23**. Make sure that you removed the motor gears from the motors. + +**Step 24: Attach Leader Holder** + +- Mount the leader holder onto the wrist and secure it with a screw. + + + +**Step 25: Attach Handle** + +- Attach the handle to motor 5 using 4 screws. + + + +**Step 26: Install Gripper Motor** + +- Insert the gripper motor, secure it with 3 screws on each side, attach a motor horn using a horn screw, and connect the motor wire. + + + +**Step 27: Attach Trigger** + +- Attach the follower trigger with 4 screws. + + + +**Step 28: Mount Controller** + +- Attach the motor controller to the back of the robot. + +
+ + +
+ +## Calibrate + +Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. +The calibration process is very important because it allows a neural network trained on one robot to work on another. + +#### Follower + +Run the following command or API example to calibrate the follower arm: + + + + +```bash +python -m lerobot.calibrate \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot + --robot.id=my_awesome_follower_arm # <- Give the robot a unique name +``` + + + + + +```python +from lerobot.robots.so100_follower import SO100FollowerConfig, SO100Follower + +config = SO100FollowerConfig( + port="/dev/tty.usbmodem585A0076891", + id="my_awesome_follower_arm", +) + +follower = SO100Follower(config) +follower.connect(calibrate=False) +follower.calibrate() +follower.disconnect() +``` + + + + + +We unified the calibration method for most robots. Thus, the calibration steps for this SO100 arm are the same as the steps for the Koch and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video) + +#### Leader + +Do the same steps to calibrate the leader arm, run the following command or API example: + + + + +```bash +python -m lerobot.calibrate \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot + --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name +``` + + + + + +```python +from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader + +config = SO100LeaderConfig( + port="/dev/tty.usbmodem58760431551", + id="my_awesome_leader_arm", +) + +leader = SO100Leader(config) +leader.connect(calibrate=False) +leader.calibrate() +leader.disconnect() +``` + + + + + +Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) + +> [!TIP] +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/docs/source/so101.mdx b/docs/source/so101.mdx deleted file mode 120000 index ab6d0ac61..000000000 --- a/docs/source/so101.mdx +++ /dev/null @@ -1 +0,0 @@ -../../src/lerobot/robots/so101_follower/so101.mdx \ No newline at end of file diff --git a/docs/source/so101.mdx b/docs/source/so101.mdx new file mode 100644 index 000000000..a20a3fa9f --- /dev/null +++ b/docs/source/so101.mdx @@ -0,0 +1,436 @@ +# SO-101 + +In the steps below, we explain how to assemble our flagship robot, the SO-101. + +## Source the parts + +Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts. +And advise if it's your first time printing or if you don't own a 3D printer. + +## Install LeRobot 🤗 + +To install LeRobot, follow our [Installation Guide](./installation) + +In addition to these instructions, you need to install the Feetech SDK: + +```bash +pip install -e ".[feetech]" +``` + +## Step-by-Step Assembly Instructions + +The follower arm uses 6x STS3215 motors with 1/345 gearing. The leader, however, uses three differently geared motors to make sure it can both sustain its own weight and it can be moved without requiring much force. Which motor is needed for which joint is shown in the table below. + +| Leader-Arm Axis | Motor | Gear Ratio | +| ------------------- | :---: | :--------: | +| Base / Shoulder Pan | 1 | 1 / 191 | +| Shoulder Lift | 2 | 1 / 345 | +| Elbow Flex | 3 | 1 / 191 | +| Wrist Flex | 4 | 1 / 147 | +| Wrist Roll | 5 | 1 / 147 | +| Gripper | 6 | 1 / 147 | + +### Clean Parts + +Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material. + +It is advisable to install one 3-pin cable in the motor after placing them before continuing assembly. + +### Joint 1 + +- Place the first motor into the base. +- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom. +- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side). +- Install both motor horns, securing the top horn with a M3x6mm screw. +- Attach the shoulder part. +- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom +- Add the shoulder motor holder. + +
+ +
+ +### Joint 2 + +- Slide the second motor in from the top. +- Fasten the second motor with 4 M2x6mm screws. +- Attach both motor horns to motor 2, again use the M3x6mm horn screw. +- Attach the upper arm with 4 M3x6mm screws on each side. + +
+ +
+ +### Joint 3 + +- Insert motor 3 and fasten using 4 M2x6mm screws +- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw. +- Connect the forearm to motor 3 using 4 M3x6mm screws on each side. + +
+ +
+ +### Joint 4 + +- Slide over motor holder 4. +- Slide in motor 4. +- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw. + +
+ +
+ +### Joint 5 + +- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws. +- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw. +- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides. + +
+ +
+ +### Gripper / Handle + + + + +- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws. +- Insert the gripper motor and secure it with 2 M2x6mm screws on each side. +- Attach the motor horns and again use a M3x6mm horn screw. +- Install the gripper claw and secure it with 4 M3x6mm screws on both sides. + +
+ +
+ +
+ + +- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws. +- Attach the handle to motor 5 using 1 M2x6mm screw. +- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw. +- Attach the follower trigger with 4 M3x6mm screws. + +
+ +
+ +
+
+ +## Configure the motors + +### 1. Find the USB ports associated with each arm + +To find the port for each bus servo adapter, connect MotorBus to your computer via USB and power. Run the following script and disconnect the MotorBus when prompted: + +```bash +python -m lerobot.find_port +``` + + + + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] +Remove the USB cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/tty.usbmodem575E0032081 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm. + + + + +On Linux, you might need to give access to the USB ports by running: + +```bash +sudo chmod 666 /dev/ttyACM0 +sudo chmod 666 /dev/ttyACM1 +``` + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/ttyACM0', '/dev/ttyACM1'] +Remove the usb cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/ttyACM1 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/ttyACM1` corresponding to your leader or follower arm. + + + + +### 2. Set the motors ids and baudrates + +Each motor is identified by a unique id on the bus. When brand new, motors usually come with a default id of `1`. For the communication to work properly between the motors and the controller, we first need to set a unique, different id to each motor. Additionally, the speed at which data is transmitted on the bus is determined by the baudrate. In order to talk to each other, the controller and all the motors need to be configured with the same baudrate. + +To that end, we first need to connect to each motor individually with the controller in order to set these. Since we will write these parameters in the non-volatile section of the motors' internal memory (EEPROM), we'll only need to do this once. + +If you are repurposing motors from another robot, you will probably also need to perform this step as the ids and baudrate likely won't match. + +The video below shows the sequence of steps for setting the motor ids. + +##### Setup motors video + +
+ +
+ +#### Follower + +Connect the usb cable from your computer and the power supply to the follower arm's controller board. Then, run the following command or run the API example with the port you got from the previous step. You'll also need to give your leader arm a name with the `id` parameter. + + + + +```bash +python -m lerobot.setup_motors \ + --robot.type=so101_follower \ + --robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step +``` + + + + + +```python +from lerobot.robots.so101_follower import SO101Follower, SO101FollowerConfig + +config = SO101FollowerConfig( + port="/dev/tty.usbmodem585A0076841", + id="my_awesome_follower_arm", +) +follower = SO101Follower(config) +follower.setup_motors() +``` + + + + + +You should see the following instruction + +```bash +Connect the controller board to the 'gripper' motor only and press enter. +``` + +As instructed, plug the gripper's motor. Make sure it's the only motor connected to the board, and that the motor itself is not yet daisy-chained to any other motor. As you press `[Enter]`, the script will automatically set the id and baudrate for that motor. + +
+Troubleshooting + +If you get an error at that point, check your cables and make sure they are plugged in properly: + +
    +
  • Power supply
  • +
  • USB cable between your computer and the controller board
  • +
  • The 3-pin cable from the controller board to the motor
  • +
+ +If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB). + +
+ +You should then see the following message: + +```bash +'gripper' motor id set to 6 +``` + +Followed by the next instruction: + +```bash +Connect the controller board to the 'wrist_roll' motor only and press enter. +``` + +You can disconnect the 3-pin cable from the controller board, but you can leave it connected to the gripper motor on the other end, as it will already be in the right place. Now, plug in another 3-pin cable to the wrist roll motor and connect it to the controller board. As with the previous motor, make sure it is the only motor connected to the board and that the motor itself isn't connected to any other one. + +Repeat the operation for each motor as instructed. + +> [!TIP] +> Check your cabling at each step before pressing Enter. For instance, the power supply cable might disconnect as you manipulate the board. + +When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm. + +#### Leader + +Do the same steps for the leader arm. + + + + +```bash +python -m lerobot.setup_motors \ + --teleop.type=so101_leader \ + --teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step +``` + + + + + +```python +from lerobot.teleoperators.so101_leader import SO101Leader, SO101LeaderConfig + +config = SO101LeaderConfig( + port="/dev/tty.usbmodem585A0076841", + id="my_awesome_leader_arm", +) +leader = SO101Leader(config) +leader.setup_motors() +``` + + + + + +## Calibrate + +Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. +The calibration process is very important because it allows a neural network trained on one robot to work on another. + +#### Follower + +Run the following command or API example to calibrate the follower arm: + + + + +```bash +python -m lerobot.calibrate \ + --robot.type=so101_follower \ + --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot + --robot.id=my_awesome_follower_arm # <- Give the robot a unique name +``` + + + + + +```python +from lerobot.robots.so101_follower import SO101FollowerConfig, SO101Follower + +config = SO101FollowerConfig( + port="/dev/tty.usbmodem585A0076891", + id="my_awesome_follower_arm", +) + +follower = SO101Follower(config) +follower.connect(calibrate=False) +follower.calibrate() +follower.disconnect() +``` + + + + + +The video below shows how to perform the calibration. First you need to move the robot to the position where all joints are in the middle of their ranges. Then after pressing enter you have to move each joint through its full range of motion. + +##### Calibration video + +
+ +
+ +#### Leader + +Do the same steps to calibrate the leader arm, run the following command or API example: + + + + +```bash +python -m lerobot.calibrate \ + --teleop.type=so101_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot + --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name +``` + + + + + +```python +from lerobot.teleoperators.so101_leader import SO101LeaderConfig, SO101Leader + +config = SO101LeaderConfig( + port="/dev/tty.usbmodem58760431551", + id="my_awesome_leader_arm", +) + +leader = SO101Leader(config) +leader.connect(calibrate=False) +leader.calibrate() +leader.disconnect() +``` + + + + + +Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) + +> [!TIP] +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/src/lerobot/robots/hope_jr/hope_jr.mdx b/src/lerobot/robots/hope_jr/hope_jr.mdx deleted file mode 100644 index 72aa8f923..000000000 --- a/src/lerobot/robots/hope_jr/hope_jr.mdx +++ /dev/null @@ -1,277 +0,0 @@ -# HopeJR - -## Prerequisites - -- [Hardware Setup](https://github.com/TheRobotStudio/HOPEJr) - -## Install LeRobot - -Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot. - -Install LeRobot with HopeJR dependencies: - -```bash -pip install -e ".[hopejr]" -``` - -## Device Configuration - -Before starting calibration and operation, you need to identify the USB ports for each HopeJR component. Run this script to find the USB ports for the arm, hand, glove, and exoskeleton: - -```bash -python -m lerobot.find_port -``` - -This will display the available USB ports and their associated devices. Make note of the port paths (e.g., `/dev/tty.usbmodem58760433331`, `/dev/tty.usbmodem11301`) as you'll need to specify them in the `--robot.port` and `--teleop.port` parameters when recording data, replaying episodes, or running teleoperation scripts. - -## Step 1: Calibration - -Before performing teleoperation, HopeJR's limbs need to be calibrated. Calibration files will be saved in `~/.cache/huggingface/lerobot/calibration` - -### 1.1 Calibrate Robot Hand - -```bash -python -m lerobot.calibrate \ - --robot.type=hope_jr_hand \ - --robot.port=/dev/tty.usbmodem58760432281 \ - --robot.id=blue \ - --robot.side=right -``` - -When running the calibration script, a calibration GUI will pop up. Finger joints are named as follows: - -**Thumb**: - -- **CMC**: base joint connecting thumb to hand -- **MCP**: knuckle joint -- **PIP**: first finger joint -- **DIP** : fingertip joint - -**Index, Middle, Ring, and Pinky fingers**: - -- **Radial flexor**: Moves base of finger towards the thumb -- **Ulnar flexor**: Moves base of finger towards the pinky -- **PIP/DIP**: Flexes the distal and proximal phalanx of the finger - -Each one of these will need to be calibrated individually via the GUI. -Note that ulnar and radial flexors should have ranges of the same size (but with different offsets) in order to get symmetric movement. - -

- Setting boundaries in the hand calibration GUI -

- -Use the calibration interface to set the range boundaries for each joint as shown above. - -

- Saving calibration values -

- -Once you have set the appropriate boundaries for all joints, click "Save" to save the calibration values to the motors. - -### 1.2 Calibrate Teleoperator Glove - -```bash -python -m lerobot.calibrate \ - --teleop.type=homunculus_glove \ - --teleop.port=/dev/tty.usbmodem11201 \ - --teleop.id=red \ - --teleop.side=right -``` - -Move each finger through its full range of motion, starting from the thumb. - -``` -Move thumb through its entire range of motion. -Recording positions. Press ENTER to stop... - -------------------------------------------- -NAME | MIN | POS | MAX -thumb_cmc | 1790 | 1831 | 1853 -thumb_mcp | 1497 | 1514 | 1528 -thumb_pip | 1466 | 1496 | 1515 -thumb_dip | 1463 | 1484 | 1514 -``` - -Continue with each finger: - -``` -Move middle through its entire range of motion. -Recording positions. Press ENTER to stop... - -------------------------------------------- -NAME | MIN | POS | MAX -middle_mcp_abduction | 1598 | 1718 | 1820 -middle_mcp_flexion | 1512 | 1658 | 2136 -middle_dip | 1484 | 1500 | 1547 -``` - -Once calibration is complete, the system will save the calibration to `/Users/your_username/.cache/huggingface/lerobot/calibration/teleoperators/homunculus_glove/red.json` - -### 1.3 Calibrate Robot Arm - -```bash -python -m lerobot.calibrate \ - --robot.type=hope_jr_arm \ - --robot.port=/dev/tty.usbserial-1110 \ - --robot.id=white -``` - -This will open a calibration GUI where you can set the range limits for each motor. The arm motions are organized as follows: - -- **Shoulder**: pitch, yaw, and roll -- **Elbow**: flex -- **Wrist**: pitch, yaw, and roll - -

- Setting boundaries in the arm calibration GUI -

- -Use the calibration interface to set the range boundaries for each joint. Move each joint through its full range of motion and adjust the minimum and maximum values accordingly. Once you have set the appropriate boundaries for all joints, save the calibration. - -### 1.4 Calibrate Teleoperator Exoskeleton - -```bash -python -m lerobot.calibrate \ - --teleop.type=homunculus_arm \ - --teleop.port=/dev/tty.usbmodem11201 \ - --teleop.id=black -``` - -The exoskeleton allows one to control the robot arm. During calibration, you'll be prompted to move all joints through their full range of motion: - -``` -Move all joints through their entire range of motion. -Recording positions. Press ENTER to stop... - -------------------------------------------- -------------------------------------------- -NAME | MIN | POS | MAX -shoulder_pitch | 586 | 736 | 895 -shoulder_yaw | 1257 | 1374 | 1390 -shoulder_roll | 449 | 1034 | 2564 -elbow_flex | 3023 | 3117 | 3134 -wrist_roll | 3073 | 3096 | 3147 -wrist_yaw | 2143 | 2171 | 2185 -wrist_pitch | 1975 | 1993 | 2074 -Calibration saved to /Users/your_username/.cache/huggingface/lerobot/calibration/teleoperators/homunculus_arm/black.json -``` - -## Step 2: Teleoperation - -Due to global variable conflicts in the Feetech middleware, teleoperation for arm and hand must run in separate shell sessions: - -### Hand - -```bash -python -m lerobot.teleoperate \ - --robot.type=hope_jr_hand \ - --robot.port=/dev/tty.usbmodem58760432281 \ - --robot.id=blue \ - --robot.side=right \ - --teleop.type=homunculus_glove \ - --teleop.port=/dev/tty.usbmodem11201 \ - --teleop.id=red \ - --teleop.side=right \ - --display_data=true \ - --fps=30 -``` - -### Arm - -```bash -python -m lerobot.teleoperate \ - --robot.type=hope_jr_arm \ - --robot.port=/dev/tty.usbserial-1110 \ - --robot.id=white \ - --teleop.type=homunculus_arm \ - --teleop.port=/dev/tty.usbmodem11201 \ - --teleop.id=black \ - --display_data=true \ - --fps=30 -``` - -## Step 3: Record, Replay, Train - -Record, Replay and Train with Hope-JR is still experimental. - -### Record - -This step records the dataset, which can be seen as an example [here](https://huggingface.co/datasets/nepyope/hand_record_test_with_video_data/settings). - -```bash -python -m lerobot.record \ - --robot.type=hope_jr_hand \ - --robot.port=/dev/tty.usbmodem58760432281 \ - --robot.id=right \ - --robot.side=right \ - --robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \ - --teleop.type=homunculus_glove \ - --teleop.port=/dev/tty.usbmodem1201 \ - --teleop.id=right \ - --teleop.side=right \ - --dataset.repo_id=nepyope/hand_record_test_with_video_data \ - --dataset.single_task="Hand recording test with video data" \ - --dataset.num_episodes=1 \ - --dataset.episode_time_s=5 \ - --dataset.push_to_hub=true \ - --dataset.private=true \ - --display_data=true -``` - -### Replay - -```bash -python -m lerobot.replay \ - --robot.type=hope_jr_hand \ - --robot.port=/dev/tty.usbmodem58760432281 \ - --robot.id=right \ - --robot.side=right \ - --dataset.repo_id=nepyope/hand_record_test_with_camera \ - --dataset.episode=0 -``` - -### Train - -```bash -python -m lerobot.scripts.train \ - --dataset.repo_id=nepyope/hand_record_test_with_video_data \ - --policy.type=act \ - --output_dir=outputs/train/hopejr_hand \ - --job_name=hopejr \ - --policy.device=mps \ - --wandb.enable=true \ - --policy.repo_id=nepyope/hand_test_policy -``` - -### Evaluate - -This training run can be viewed as an example [here](https://wandb.ai/tino/lerobot/runs/rp0k8zvw?nw=nwusertino). - -```bash -python -m lerobot.record \ - --robot.type=hope_jr_hand \ - --robot.port=/dev/tty.usbmodem58760432281 \ - --robot.id=right \ - --robot.side=right \ - --robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \ - --display_data=false \ - --dataset.repo_id=nepyope/eval_hopejr \ - --dataset.single_task="Evaluate hopejr hand policy" \ - --dataset.num_episodes=10 \ - --policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model -``` diff --git a/src/lerobot/robots/hope_jr/hope_jr.mdx b/src/lerobot/robots/hope_jr/hope_jr.mdx new file mode 120000 index 000000000..a076e4754 --- /dev/null +++ b/src/lerobot/robots/hope_jr/hope_jr.mdx @@ -0,0 +1 @@ +../../../../docs/source/hope_jr.mdx \ No newline at end of file diff --git a/src/lerobot/robots/koch_follower/koch.mdx b/src/lerobot/robots/koch_follower/koch.mdx deleted file mode 100644 index d0b991e74..000000000 --- a/src/lerobot/robots/koch_follower/koch.mdx +++ /dev/null @@ -1,283 +0,0 @@ -# Koch v1.1 - -In the steps below, we explain how to assemble the Koch v1.1 robot. - -## Order and assemble the parts - -Follow the sourcing and assembling instructions provided in this [README](https://github.com/jess-moss/koch-v1-1). This will guide you through setting up both the follower and leader arms, as shown in the image below. - -For a visual walkthrough of the assembly process, you can refer to [this video tutorial](https://youtu.be/8nQIg9BwwTk). - -> [!WARNING] -> Since the production of this video, we simplified the configuration phase. Because of this, two things differ from the instructions in that video: -> -> - Don't plug in all the motor cables right away and wait to be instructed to do so in [Configure the motors](#configure-the-motors). -> - Don't screw in the controller board (PCB) to the base right away and wait for being instructed to do so in [Configure the motors](#configure-the-motors). - -## Install LeRobot 🤗 - -To install LeRobot follow, our [Installation Guide](./installation) - -In addition to these instructions, you need to install the Dynamixel SDK: - -```bash -pip install -e ".[dynamixel]" -``` - -## Configure the motors - -### 1. Find the USB ports associated with each arm - -To find the port for each bus servo adapter, run this script: - -```bash -python -m lerobot.find_port -``` - - - - -Example output: - -``` -Finding all available ports for the MotorBus. -['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] -Remove the USB cable from your MotorsBus and press Enter when done. - -[...Disconnect corresponding leader or follower arm and press Enter...] - -The port of this MotorsBus is /dev/tty.usbmodem575E0032081 -Reconnect the USB cable. -``` - -Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm. - - - - -On Linux, you might need to give access to the USB ports by running: - -```bash -sudo chmod 666 /dev/ttyACM0 -sudo chmod 666 /dev/ttyACM1 -``` - -Example output: - -``` -Finding all available ports for the MotorBus. -['/dev/ttyACM0', '/dev/ttyACM1'] -Remove the usb cable from your MotorsBus and press Enter when done. - -[...Disconnect corresponding leader or follower arm and press Enter...] - -The port of this MotorsBus is /dev/ttyACM1 -Reconnect the USB cable. -``` - -Where the found port is: `/dev/ttyACM1` corresponding to your leader or follower arm. - - - - -### 2. Set the motors ids and baudrates - -Each motor is identified by a unique id on the bus. When brand new, motors usually come with a default id of `1`. For the communication to work properly between the motors and the controller, we first need to set a unique, different id to each motor. Additionally, the speed at which data is transmitted on the bus is determined by the baudrate. In order to talk to each other, the controller and all the motors need to be configured with the same baudrate. - -To that end, we first need to connect to each motor individually with the controller in order to set these. Since we will write these parameters in the non-volatile section of the motors' internal memory (EEPROM), we'll only need to do this once. - -If you are repurposing motors from another robot, you will probably also need to perform this step, as the ids and baudrate likely won't match. - -#### Follower - -Connect the usb cable from your computer and the 5V power supply to the follower arm's controller board. Then, run the following command or run the API example with the port you got from the previous step. You'll also need to give your leader arm a name with the `id` parameter. - -For a visual reference on how to set the motor ids please refer to [this video](https://huggingface.co/docs/lerobot/en/so101#setup-motors-video) where we follow the process for the SO101 arm. - - - - -```bash -python -m lerobot.setup_motors \ - --robot.type=koch_follower \ - --robot.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step -``` - - - - - -```python -from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig - -config = KochFollowerConfig( - port="/dev/tty.usbmodem575E0031751", - id="my_awesome_follower_arm", -) -follower = KochFollower(config) -follower.setup_motors() -``` - - - - - -You should see the following instruction. - -``` -Connect the controller board to the 'gripper' motor only and press enter. -``` - -As instructed, plug the gripper's motor. Make sure it's the only motor connected to the board, and that the motor itself is not yet daisy-chained to any other motor. As you press `[Enter]`, the script will automatically set the id and baudrate for that motor. - -
-Troubleshooting - -If you get an error at that point, check your cables and make sure they are plugged in properly: - -
    -
  • Power supply
  • -
  • USB cable between your computer and the controller board
  • -
  • The 3-pin cable from the controller board to the motor
  • -
- -If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB). - -
- -You should then see the following message: - -``` -'gripper' motor id set to 6 -``` - -Followed by the next instruction: - -``` -Connect the controller board to the 'wrist_roll' motor only and press enter. -``` - -You can disconnect the 3-pin cable from the controller board but you can leave it connected to the gripper motor on the other end as it will already be in the right place. Now, plug in another 3-pin cable to the wrist roll motor and connect it to the controller board. As with the previous motor, make sure it is the only motor connected to the board and that the motor itself isn't connected to any other one. - -Repeat the operation for each motor as instructed. - -> [!TIP] -> Check your cabling at each step before pressing Enter. For instance, the power supply cable might disconnect as you manipulate the board. - -When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm. - -#### Leader - -Do the same steps for the leader arm but modify the command or script accordingly. - - - - -```bash -python -m lerobot.setup_motors \ - --teleop.type=koch_leader \ - --teleop.port=/dev/tty.usbmodem575E0031751 \ # <- paste here the port found at previous step -``` - - - - - -```python -from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig - -config = KochLeaderConfig( - port="/dev/tty.usbmodem575E0031751", - id="my_awesome_leader_arm", -) -leader = KochLeader(config) -leader.setup_motors() -``` - - - - - -## Calibrate - -Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. -The calibration process is very important because it allows a neural network trained on one robot to work on another. - -#### Follower - -Run the following command or API example to calibrate the follower arm: - - - - -```bash -python -m lerobot.calibrate \ - --robot.type=koch_follower \ - --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot - --robot.id=my_awesome_follower_arm # <- Give the robot a unique name -``` - - - - - -```python -from lerobot.robots.koch_follower import KochFollowerConfig, KochFollower - -config = KochFollowerConfig( - port="/dev/tty.usbmodem585A0076891", - id="my_awesome_follower_arm", -) - -follower = KochFollower(config) -follower.connect(calibrate=False) -follower.calibrate() -follower.disconnect() -``` - - - - - -We unified the calibration method for most robots. Thus, the calibration steps for this Koch arm are the same as the steps for the SO100 and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video). - -#### Leader - -Do the same steps to calibrate the leader arm, run the following command or API example: - - - - -```bash -python -m lerobot.calibrate \ - --teleop.type=koch_leader \ - --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot - --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name -``` - - - - - -```python -from lerobot.teleoperators.koch_leader import KochLeaderConfig, KochLeader - -config = KochLeaderConfig( - port="/dev/tty.usbmodem575E0031751", - id="my_awesome_leader_arm", -) - -leader = KochLeader(config) -leader.connect(calibrate=False) -leader.calibrate() -leader.disconnect() -``` - - - - - -Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) - -> [!TIP] -> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/src/lerobot/robots/koch_follower/koch.mdx b/src/lerobot/robots/koch_follower/koch.mdx new file mode 120000 index 000000000..ef43feb06 --- /dev/null +++ b/src/lerobot/robots/koch_follower/koch.mdx @@ -0,0 +1 @@ +../../../../docs/source/koch.mdx \ No newline at end of file diff --git a/src/lerobot/robots/lekiwi/lekiwi.mdx b/src/lerobot/robots/lekiwi/lekiwi.mdx deleted file mode 100644 index bb70fd26b..000000000 --- a/src/lerobot/robots/lekiwi/lekiwi.mdx +++ /dev/null @@ -1,337 +0,0 @@ -# LeKiwi - -In the steps below, we explain how to assemble the LeKiwi mobile robot. - -## Source the parts - -Follow this [README](https://github.com/SIGRobotics-UIUC/LeKiwi). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts. -And advise if it's your first time printing or if you don't own a 3D printer. - -### Wired version - -If you have the **wired** LeKiwi version, you can skip the installation of the Raspberry Pi and setting up SSH. You can also run all commands directly on your PC for both the LeKiwi scripts and the leader arm scripts for teleoperating. - -## Install software on Pi - -Now we have to set up the remote PC that will run on the LeKiwi Robot. This is normally a Raspberry Pi, but can be any PC that can run on 5V and has enough usb ports (2 or more) for the cameras and motor control board. - -### Install OS - -For setting up the Raspberry Pi and its SD-card see: [Setup PI](https://www.raspberrypi.com/documentation/computers/getting-started.html). Here is explained how to download the [Imager](https://www.raspberrypi.com/software/) to install Raspberry Pi OS or Ubuntu. - -### Setup SSH - -After setting up your Pi, you should enable and set up [SSH](https://www.raspberrypi.com/news/coding-on-raspberry-pi-remotely-with-visual-studio-code/) (Secure Shell Protocol) so you can log in to the Pi from your laptop without requiring a screen, keyboard, and mouse on the Pi. A great tutorial on how to do this can be found [here](https://www.raspberrypi.com/documentation/computers/remote-access.html#ssh). Logging into your Pi can be done in your Command Prompt (cmd) or, if you use VSCode you can use [this](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-ssh) extension. - -### Install LeRobot on Pi 🤗 - -On your Raspberry Pi install LeRobot using our [Installation Guide](./installation) - -In addition to these instructions, you need to install the Feetech SDK & ZeroMQ on your Pi: - -```bash -pip install -e ".[lekiwi]" -``` - -## Install LeRobot locally - -If you already have installed LeRobot on your laptop/pc you can skip this step; otherwise, please follow along as we do the same steps we did on the Pi. - -Follow our [Installation Guide](./installation) - -In addition to these instructions, you need to install the Feetech SDK & ZeroMQ on your laptop/pc: - -```bash -pip install -e ".[lekiwi]" -``` - -Great :hugs:! You are now done installing LeRobot, and we can begin assembling the SO100/SO101 arms and the mobile base :robot:. -Every time you now want to use LeRobot, you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands. - -# Step-by-Step Assembly Instructions - -First, we will assemble the two SO100/SO101 arms. One to attach to the mobile base and one for teleoperation. Then we will assemble the mobile base. The instructions for assembling can be found on these two pages: - -- [Assemble SO101](./so101#step-by-step-assembly-instructions) -- [Assemble LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi/blob/main/Assembly.md) - -### Find the USB ports associated with motor board - -To find the port for each bus servo adapter, run this script: - -```bash -python -m lerobot.find_port -``` - - - - -Example output: - -``` -Finding all available ports for the MotorBus. -['/dev/tty.usbmodem575E0032081'] -Remove the USB cable from your MotorsBus and press Enter when done. - -[...Disconnect corresponding leader or follower arm and press Enter...] - -The port of this MotorsBus is /dev/tty.usbmodem575E0032081 -Reconnect the USB cable. -``` - -Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your board. - - - - -On Linux, you might need to give access to the USB ports by running: - -```bash -sudo chmod 666 /dev/ttyACM0 -sudo chmod 666 /dev/ttyACM1 -``` - -Example output: - -``` -Finding all available ports for the MotorBus. -['/dev/ttyACM0'] -Remove the usb cable from your MotorsBus and press Enter when done. - -[...Disconnect corresponding leader or follower arm and press Enter...] - -The port of this MotorsBus is /dev/ttyACM0 -Reconnect the USB cable. -``` - -Where the found port is: `/dev/ttyACM0` corresponding to your board. - - - - -### Configure motors - -The instructions for configuring the motors can be found in the SO101 [docs](./so101#configure-the-motors). Besides the ids for the arm motors, we also need to set the motor ids for the mobile base. These need to be in a specific order to work. Below an image of the motor ids and motor mounting positions for the mobile base. Note that we only use one Motor Control board on LeKiwi. This means the motor ids for the wheels are 7, 8 and 9. - -You can run this command to setup motors for LeKiwi. It will first setup the motors for arm (id 6..1) and then setup motors for wheels (9,8,7) - -```bash -python -m lerobot.setup_motors \ - --robot.type=lekiwi \ - --robot.port=/dev/tty.usbmodem58760431551 # <- paste here the port found at previous step -``` - -Motor ID's for mobile robot - -### Troubleshoot communication - -If you are having trouble connecting to the Mobile SO100, follow these steps to diagnose and resolve the issue. - -#### 1. Verify IP Address Configuration - -Make sure that the correct IP for the Pi is used in the commands or in your code. To check the Raspberry Pi's IP address, run (on the Pi command line): - -```bash -hostname -I -``` - -#### 2. Check if Pi is reachable from laptop/pc - -Try pinging the Raspberry Pi from your laptop: - -```bach -ping -``` - -If the ping fails: - -- Ensure the Pi is powered on and connected to the same network. -- Check if SSH is enabled on the Pi. - -#### 3. Try SSH connection - -If you can't SSH into the Pi, it might not be properly connected. Use: - -```bash -ssh @ -``` - -If you get a connection error: - -- Ensure SSH is enabled on the Pi by running: - ```bash - sudo raspi-config - ``` - Then navigate to: **Interfacing Options -> SSH** and enable it. - -### Calibration - -Now we have to calibrate the leader arm and the follower arm. The wheel motors don't have to be calibrated. -The calibration process is very important because it allows a neural network trained on one robot to work on another. - -### Calibrate follower arm (on mobile base) - -Make sure the arm is connected to the Raspberry Pi and run this script or API example (on the Raspberry Pi via SSH) to launch calibration of the follower arm: - -```bash -python -m lerobot.calibrate \ - --robot.type=lekiwi \ - --robot.id=my_awesome_kiwi # <- Give the robot a unique name -``` - -We unified the calibration method for most robots, thus, the calibration steps for this SO100 arm are the same as the steps for the Koch and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video). - -### Wired version - -If you have the **wired** LeKiwi version, please run all commands on your laptop. - -### Calibrate leader arm - -Then, to calibrate the leader arm (which is attached to the laptop/pc). Run the following command of API example on your laptop: - - - - -```bash -python -m lerobot.calibrate \ - --teleop.type=so100_leader \ - --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot - --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name -``` - - - - - -```python -from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader - -config = SO100LeaderConfig( - port="/dev/tty.usbmodem58760431551", - id="my_awesome_leader_arm", -) - -leader = SO100Leader(config) -leader.connect(calibrate=False) -leader.calibrate() -leader.disconnect() -``` - - - - - -## Teleoperate LeKiwi - -> [!TIP] -> If you're using a Mac, you might need to give Terminal permission to access your keyboard for teleoperation. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal. - -To teleoperate, SSH into your Raspberry Pi, and run `conda activate lerobot` and this command: - -```bash -python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi -``` - -Then on your laptop, also run `conda activate lerobot` and run the API example, make sure you set the correct `remote_ip` and `port` in `examples/lekiwi/teleoperate.py`. - -```bash -python examples/lekiwi/teleoperate.py -``` - -You should see on your laptop something like this: `[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below: - -| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) | -| ---------- | ------------------ | ---------------------- | -| Fast | 0.4 | 90 | -| Medium | 0.25 | 60 | -| Slow | 0.1 | 30 | - -| Key | Action | -| --- | -------------- | -| W | Move forward | -| A | Move left | -| S | Move backward | -| D | Move right | -| Z | Turn left | -| X | Turn right | -| R | Increase speed | -| F | Decrease speed | - -> [!TIP] -> If you use a different keyboard, you can change the keys for each command in the [`LeKiwiConfig`](../src/lerobot/robot_devices/robots/configs.py). - -### Wired version - -If you have the **wired** LeKiwi version, please run all commands on your laptop. - -## Record a dataset - -Once you're familiar with teleoperation, you can record your first dataset. - -We use the Hugging Face hub features for uploading your dataset. If you haven't previously used the Hub, make sure you can login via the cli using a write-access token, this token can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens). - -Add your token to the CLI by running this command: - -```bash -huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential -``` - -Then store your Hugging Face repository name in a variable: - -```bash -HF_USER=$(huggingface-cli whoami | head -n 1) -echo $HF_USER -``` - -Now you can record a dataset. To record episodes and upload your dataset to the hub, execute this API example tailored for LeKiwi. Make sure to first adapt the `remote_ip`, `repo_id`, `port` and `task` in the script. If you would like to run the script for longer you can increase `NB_CYCLES_CLIENT_CONNECTION`. - -```bash -python examples/lekiwi/record.py -``` - -#### Dataset upload - -Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running: - -```bash -echo https://huggingface.co/datasets/${HF_USER}/so101_test -``` - -Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example). - -You can look for other LeRobot datasets on the hub by searching for `LeRobot` [tags](https://huggingface.co/datasets?other=LeRobot). - -#### Tips for gathering data - -Once you're comfortable with data recording, you can create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings. Also make sure the object you are manipulating is visible on the camera's. A good rule of thumb is you should be able to do the task yourself by only looking at the camera images. - -In the following sections, you’ll train your neural network. After achieving reliable grasping performance, you can start introducing more variations during data collection, such as additional grasp locations, different grasping techniques, and altering camera positions. - -Avoid adding too much variation too quickly, as it may hinder your results. - -If you want to dive deeper into this important topic, you can check out the [blog post](https://huggingface.co/blog/lerobot-datasets#what-makes-a-good-dataset) we wrote on what makes a good dataset. - -#### Troubleshooting: - -- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux). - -## Replay an episode - -To replay an episode run the API example below, make sure to change `remote_ip`, `port`, LeRobotDatasetId and episode index. - -```bash -python examples/lekiwi/replay.py -``` - -Congrats 🎉, your robot is all set to learn a task on its own. Start training it by the training part of this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) - -## Evaluate your policy - -To evaluate your policy run the `evaluate.py` API example, make sure to change `remote_ip`, `port`, model.. - -```bash -python examples/lekiwi/evaluate.py -``` - -> [!TIP] -> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/src/lerobot/robots/lekiwi/lekiwi.mdx b/src/lerobot/robots/lekiwi/lekiwi.mdx new file mode 120000 index 000000000..f65158998 --- /dev/null +++ b/src/lerobot/robots/lekiwi/lekiwi.mdx @@ -0,0 +1 @@ +../../../../docs/source/lekiwi.mdx \ No newline at end of file diff --git a/src/lerobot/robots/so100_follower/so100.mdx b/src/lerobot/robots/so100_follower/so100.mdx deleted file mode 100644 index d9ff922c5..000000000 --- a/src/lerobot/robots/so100_follower/so100.mdx +++ /dev/null @@ -1,640 +0,0 @@ -# SO-100 - -In the steps below, we explain how to assemble the SO-100 robot. - -## Source the parts - -Follow this [README](https://github.com/TheRobotStudio/SO-ARM100/blob/main/SO100.md). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts. And advise if it's your first time printing or if you don't own a 3D printer. - -## Install LeRobot 🤗 - -To install LeRobot, follow our [Installation Guide](./installation) - -In addition to these instructions, you need to install the Feetech SDK: - -```bash -pip install -e ".[feetech]" -``` - -## Configure the motors - -**Note:** -Unlike the SO-101, the motor connectors are not easily accessible once the arm is assembled, so the configuration step must be done beforehand. - -### 1. Find the USB ports associated with each arm - -To find the port for each bus servo adapter, run this script: - -```bash -python -m lerobot.find_port -``` - - - - -Example output: - -``` -Finding all available ports for the MotorBus. -['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] -Remove the USB cable from your MotorsBus and press Enter when done. - -[...Disconnect corresponding leader or follower arm and press Enter...] - -The port of this MotorsBus is /dev/tty.usbmodem575E0032081 -Reconnect the USB cable. -``` - -Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm. - - - - -On Linux, you might need to give access to the USB ports by running: - -```bash -sudo chmod 666 /dev/ttyACM0 -sudo chmod 666 /dev/ttyACM1 -``` - -Example output: - -``` -Finding all available ports for the MotorBus. -['/dev/ttyACM0', '/dev/ttyACM1'] -Remove the usb cable from your MotorsBus and press Enter when done. - -[...Disconnect corresponding leader or follower arm and press Enter...] - -The port of this MotorsBus is /dev/ttyACM1 -Reconnect the USB cable. -``` - -Where the found port is: `/dev/ttyACM1` corresponding to your leader or follower arm. - - - - -### 2. Set the motors ids and baudrates - -Each motor is identified by a unique id on the bus. When brand new, motors usually come with a default id of `1`. For the communication to work properly between the motors and the controller, we first need to set a unique, different id to each motor. Additionally, the speed at which data is transmitted on the bus is determined by the baudrate. In order to talk to each other, the controller and all the motors need to be configured with the same baudrate. - -To that end, we first need to connect to each motor individually with the controller in order to set these. Since we will write these parameters in the non-volatile section of the motors' internal memory (EEPROM), we'll only need to do this once. - -If you are repurposing motors from another robot, you will probably also need to perform this step as the ids and baudrate likely won't match. - -#### Follower - -Connect the usb cable from your computer and the power supply to the follower arm's controller board. Then, run the following command or run the API example with the port you got from the previous step. You'll also need to give your leader arm a name with the `id` parameter. - -For a visual reference on how to set the motor ids please refer to [this video](https://huggingface.co/docs/lerobot/en/so101#setup-motors-video) where we follow the process for the SO101 arm. - - - - -```bash -python -m lerobot.setup_motors \ - --robot.type=so100_follower \ - --robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step -``` - - - - - -```python -from lerobot.robots.so100_follower import SO100Follower, SO100FollowerConfig - -config = SO100FollowerConfig( - port="/dev/tty.usbmodem585A0076841", - id="my_awesome_follower_arm", -) -follower = SO100Follower(config) -follower.setup_motors() -``` - - - - - -You should see the following instruction - -``` -Connect the controller board to the 'gripper' motor only and press enter. -``` - -As instructed, plug the gripper's motor. Make sure it's the only motor connected to the board, and that the motor itself is not yet daisy-chained to any other motor. As you press `[Enter]`, the script will automatically set the id and baudrate for that motor. - -
-Troubleshooting - -If you get an error at that point, check your cables and make sure they are plugged in properly: - -
    -
  • Power supply
  • -
  • USB cable between your computer and the controller board
  • -
  • The 3-pin cable from the controller board to the motor
  • -
- -If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB). - -
- -You should then see the following message: - -``` -'gripper' motor id set to 6 -``` - -Followed by the next instruction: - -``` -Connect the controller board to the 'wrist_roll' motor only and press enter. -``` - -You can disconnect the 3-pin cable from the controller board, but you can leave it connected to the gripper motor on the other end, as it will already be in the right place. Now, plug in another 3-pin cable to the wrist roll motor and connect it to the controller board. As with the previous motor, make sure it is the only motor connected to the board and that the motor itself isn't connected to any other one. - -Repeat the operation for each motor as instructed. - -> [!TIP] -> Check your cabling at each step before pressing Enter. For instance, the power supply cable might disconnect as you manipulate the board. - -When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm. - -#### Leader - -Do the same steps for the leader arm. - - - -```bash -python -m lerobot.setup_motors \ - --teleop.type=so100_leader \ - --teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step -``` - - - - -```python -from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig - -config = SO100LeaderConfig( - port="/dev/tty.usbmodem585A0076841", - id="my_awesome_leader_arm", -) -leader = SO100Leader(config) -leader.setup_motors() -``` - - - - - -## Step-by-Step Assembly Instructions - -## Remove the gears of the 6 leader motors - -
-Video removing gears - -
- -
- -
- -Follow the video for removing gears. You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm. - -### Clean Parts - -Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material. - -### Additional Guidance - -
-Video assembling arms - -
- -
- -
- -**Note:** -This video provides visual guidance for assembling the arms, but it doesn't specify when or how to do the wiring. Inserting the cables beforehand is much easier than doing it afterward. The first arm may take a bit more than 1 hour to assemble, but once you get used to it, you can assemble the second arm in under 1 hour. - ---- - -### First Motor - -**Step 2: Insert Wires** - -- Insert two wires into the first motor. - - - -**Step 3: Install in Base** - -- Place the first motor into the base. - - - -**Step 4: Secure Motor** - -- Fasten the motor with 4 screws. Two from the bottom and two from top. - -**Step 5: Attach Motor Holder** - -- Slide over the first motor holder and fasten it using two screws (one on each side). - - - -**Step 6: Attach Motor Horns** - -- Install both motor horns, securing the top horn with a screw. Try not to move the motor position when attaching the motor horn, especially for the leader arms, where we removed the gears. - - - -
- - Video adding motor horn - - -
- -**Step 7: Attach Shoulder Part** - -- Route one wire to the back of the robot and the other to the left or towards you (see photo). -- Attach the shoulder part. - - - -**Step 8: Secure Shoulder** - -- Tighten the shoulder part with 4 screws on top and 4 on the bottom - _(access bottom holes by turning the shoulder)._ - ---- - -### Second Motor Assembly - -**Step 9: Install Motor 2** - -- Slide the second motor in from the top and link the wire from motor 1 to motor 2. - - - -**Step 10: Attach Shoulder Holder** - -- Add the shoulder motor holder. -- Ensure the wire from motor 1 to motor 2 goes behind the holder while the other wire is routed upward (see photo). -- This part can be tight to assemble, you can use a workbench like the image or a similar setup to push the part around the motor. - -
- - - -
- -**Step 11: Secure Motor 2** - -- Fasten the second motor with 4 screws. - -**Step 12: Attach Motor Horn** - -- Attach both motor horns to motor 2, again use the horn screw. - -**Step 13: Attach Base** - -- Install the base attachment using 2 screws. - - - -**Step 14: Attach Upper Arm** - -- Attach the upper arm with 4 screws on each side. - - - ---- - -### Third Motor Assembly - -**Step 15: Install Motor 3** - -- Route the motor cable from motor 2 through the cable holder to motor 3, then secure motor 3 with 4 screws. - -**Step 16: Attach Motor Horn** - -- Attach both motor horns to motor 3 and secure one again with a horn screw. - - - -**Step 17: Attach Forearm** - -- Connect the forearm to motor 3 using 4 screws on each side. - - - ---- - -### Fourth Motor Assembly - -**Step 18: Install Motor 4** - -- Slide in motor 4, attach the cable from motor 3, and secure the cable in its holder with a screw. - -
- - -
- -**Step 19: Attach Motor Holder 4** - -- Install the fourth motor holder (a tight fit). Ensure one wire is routed upward and the wire from motor 3 is routed downward (see photo). - - - -**Step 20: Secure Motor 4 & Attach Horn** - -- Fasten motor 4 with 4 screws and attach its motor horns, use for one a horn screw. - - - ---- - -### Wrist Assembly - -**Step 21: Install Motor 5** - -- Insert motor 5 into the wrist holder and secure it with 2 front screws. - - - -**Step 22: Attach Wrist** - -- Connect the wire from motor 4 to motor 5. And already insert the other wire for the gripper. -- Secure the wrist to motor 4 using 4 screws on both sides. - - - -**Step 23: Attach Wrist Horn** - -- Install only one motor horn on the wrist motor and secure it with a horn screw. - - - ---- - -### Follower Configuration - -**Step 24: Attach Gripper** - -- Attach the gripper to motor 5. - - - -**Step 25: Install Gripper Motor** - -- Insert the gripper motor, connect the motor wire from motor 5 to motor 6, and secure it with 3 screws on each side. - - - -**Step 26: Attach Gripper Horn & Claw** - -- Attach the motor horns and again use a horn screw. -- Install the gripper claw and secure it with 4 screws on both sides. - - - -**Step 27: Mount Controller** - -- Attach the motor controller to the back of the robot. - -
- - -
- -_Assembly complete – proceed to Leader arm assembly._ - ---- - -### Leader Configuration - -For the leader configuration, perform **Steps 1–23**. Make sure that you removed the motor gears from the motors. - -**Step 24: Attach Leader Holder** - -- Mount the leader holder onto the wrist and secure it with a screw. - - - -**Step 25: Attach Handle** - -- Attach the handle to motor 5 using 4 screws. - - - -**Step 26: Install Gripper Motor** - -- Insert the gripper motor, secure it with 3 screws on each side, attach a motor horn using a horn screw, and connect the motor wire. - - - -**Step 27: Attach Trigger** - -- Attach the follower trigger with 4 screws. - - - -**Step 28: Mount Controller** - -- Attach the motor controller to the back of the robot. - -
- - -
- -## Calibrate - -Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. -The calibration process is very important because it allows a neural network trained on one robot to work on another. - -#### Follower - -Run the following command or API example to calibrate the follower arm: - - - - -```bash -python -m lerobot.calibrate \ - --robot.type=so100_follower \ - --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot - --robot.id=my_awesome_follower_arm # <- Give the robot a unique name -``` - - - - - -```python -from lerobot.robots.so100_follower import SO100FollowerConfig, SO100Follower - -config = SO100FollowerConfig( - port="/dev/tty.usbmodem585A0076891", - id="my_awesome_follower_arm", -) - -follower = SO100Follower(config) -follower.connect(calibrate=False) -follower.calibrate() -follower.disconnect() -``` - - - - - -We unified the calibration method for most robots. Thus, the calibration steps for this SO100 arm are the same as the steps for the Koch and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video) - -#### Leader - -Do the same steps to calibrate the leader arm, run the following command or API example: - - - - -```bash -python -m lerobot.calibrate \ - --teleop.type=so100_leader \ - --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot - --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name -``` - - - - - -```python -from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader - -config = SO100LeaderConfig( - port="/dev/tty.usbmodem58760431551", - id="my_awesome_leader_arm", -) - -leader = SO100Leader(config) -leader.connect(calibrate=False) -leader.calibrate() -leader.disconnect() -``` - - - - - -Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) - -> [!TIP] -> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/src/lerobot/robots/so100_follower/so100.mdx b/src/lerobot/robots/so100_follower/so100.mdx new file mode 120000 index 000000000..ad1154e75 --- /dev/null +++ b/src/lerobot/robots/so100_follower/so100.mdx @@ -0,0 +1 @@ +../../../../docs/source/so100.mdx \ No newline at end of file diff --git a/src/lerobot/robots/so101_follower/so101.mdx b/src/lerobot/robots/so101_follower/so101.mdx deleted file mode 100644 index a20a3fa9f..000000000 --- a/src/lerobot/robots/so101_follower/so101.mdx +++ /dev/null @@ -1,436 +0,0 @@ -# SO-101 - -In the steps below, we explain how to assemble our flagship robot, the SO-101. - -## Source the parts - -Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts. -And advise if it's your first time printing or if you don't own a 3D printer. - -## Install LeRobot 🤗 - -To install LeRobot, follow our [Installation Guide](./installation) - -In addition to these instructions, you need to install the Feetech SDK: - -```bash -pip install -e ".[feetech]" -``` - -## Step-by-Step Assembly Instructions - -The follower arm uses 6x STS3215 motors with 1/345 gearing. The leader, however, uses three differently geared motors to make sure it can both sustain its own weight and it can be moved without requiring much force. Which motor is needed for which joint is shown in the table below. - -| Leader-Arm Axis | Motor | Gear Ratio | -| ------------------- | :---: | :--------: | -| Base / Shoulder Pan | 1 | 1 / 191 | -| Shoulder Lift | 2 | 1 / 345 | -| Elbow Flex | 3 | 1 / 191 | -| Wrist Flex | 4 | 1 / 147 | -| Wrist Roll | 5 | 1 / 147 | -| Gripper | 6 | 1 / 147 | - -### Clean Parts - -Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material. - -It is advisable to install one 3-pin cable in the motor after placing them before continuing assembly. - -### Joint 1 - -- Place the first motor into the base. -- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom. -- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side). -- Install both motor horns, securing the top horn with a M3x6mm screw. -- Attach the shoulder part. -- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom -- Add the shoulder motor holder. - -
- -
- -### Joint 2 - -- Slide the second motor in from the top. -- Fasten the second motor with 4 M2x6mm screws. -- Attach both motor horns to motor 2, again use the M3x6mm horn screw. -- Attach the upper arm with 4 M3x6mm screws on each side. - -
- -
- -### Joint 3 - -- Insert motor 3 and fasten using 4 M2x6mm screws -- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw. -- Connect the forearm to motor 3 using 4 M3x6mm screws on each side. - -
- -
- -### Joint 4 - -- Slide over motor holder 4. -- Slide in motor 4. -- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw. - -
- -
- -### Joint 5 - -- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws. -- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw. -- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides. - -
- -
- -### Gripper / Handle - - - - -- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws. -- Insert the gripper motor and secure it with 2 M2x6mm screws on each side. -- Attach the motor horns and again use a M3x6mm horn screw. -- Install the gripper claw and secure it with 4 M3x6mm screws on both sides. - -
- -
- -
- - -- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws. -- Attach the handle to motor 5 using 1 M2x6mm screw. -- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw. -- Attach the follower trigger with 4 M3x6mm screws. - -
- -
- -
-
- -## Configure the motors - -### 1. Find the USB ports associated with each arm - -To find the port for each bus servo adapter, connect MotorBus to your computer via USB and power. Run the following script and disconnect the MotorBus when prompted: - -```bash -python -m lerobot.find_port -``` - - - - -Example output: - -``` -Finding all available ports for the MotorBus. -['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] -Remove the USB cable from your MotorsBus and press Enter when done. - -[...Disconnect corresponding leader or follower arm and press Enter...] - -The port of this MotorsBus is /dev/tty.usbmodem575E0032081 -Reconnect the USB cable. -``` - -Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm. - - - - -On Linux, you might need to give access to the USB ports by running: - -```bash -sudo chmod 666 /dev/ttyACM0 -sudo chmod 666 /dev/ttyACM1 -``` - -Example output: - -``` -Finding all available ports for the MotorBus. -['/dev/ttyACM0', '/dev/ttyACM1'] -Remove the usb cable from your MotorsBus and press Enter when done. - -[...Disconnect corresponding leader or follower arm and press Enter...] - -The port of this MotorsBus is /dev/ttyACM1 -Reconnect the USB cable. -``` - -Where the found port is: `/dev/ttyACM1` corresponding to your leader or follower arm. - - - - -### 2. Set the motors ids and baudrates - -Each motor is identified by a unique id on the bus. When brand new, motors usually come with a default id of `1`. For the communication to work properly between the motors and the controller, we first need to set a unique, different id to each motor. Additionally, the speed at which data is transmitted on the bus is determined by the baudrate. In order to talk to each other, the controller and all the motors need to be configured with the same baudrate. - -To that end, we first need to connect to each motor individually with the controller in order to set these. Since we will write these parameters in the non-volatile section of the motors' internal memory (EEPROM), we'll only need to do this once. - -If you are repurposing motors from another robot, you will probably also need to perform this step as the ids and baudrate likely won't match. - -The video below shows the sequence of steps for setting the motor ids. - -##### Setup motors video - -
- -
- -#### Follower - -Connect the usb cable from your computer and the power supply to the follower arm's controller board. Then, run the following command or run the API example with the port you got from the previous step. You'll also need to give your leader arm a name with the `id` parameter. - - - - -```bash -python -m lerobot.setup_motors \ - --robot.type=so101_follower \ - --robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step -``` - - - - - -```python -from lerobot.robots.so101_follower import SO101Follower, SO101FollowerConfig - -config = SO101FollowerConfig( - port="/dev/tty.usbmodem585A0076841", - id="my_awesome_follower_arm", -) -follower = SO101Follower(config) -follower.setup_motors() -``` - - - - - -You should see the following instruction - -```bash -Connect the controller board to the 'gripper' motor only and press enter. -``` - -As instructed, plug the gripper's motor. Make sure it's the only motor connected to the board, and that the motor itself is not yet daisy-chained to any other motor. As you press `[Enter]`, the script will automatically set the id and baudrate for that motor. - -
-Troubleshooting - -If you get an error at that point, check your cables and make sure they are plugged in properly: - -
    -
  • Power supply
  • -
  • USB cable between your computer and the controller board
  • -
  • The 3-pin cable from the controller board to the motor
  • -
- -If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB). - -
- -You should then see the following message: - -```bash -'gripper' motor id set to 6 -``` - -Followed by the next instruction: - -```bash -Connect the controller board to the 'wrist_roll' motor only and press enter. -``` - -You can disconnect the 3-pin cable from the controller board, but you can leave it connected to the gripper motor on the other end, as it will already be in the right place. Now, plug in another 3-pin cable to the wrist roll motor and connect it to the controller board. As with the previous motor, make sure it is the only motor connected to the board and that the motor itself isn't connected to any other one. - -Repeat the operation for each motor as instructed. - -> [!TIP] -> Check your cabling at each step before pressing Enter. For instance, the power supply cable might disconnect as you manipulate the board. - -When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm. - -#### Leader - -Do the same steps for the leader arm. - - - - -```bash -python -m lerobot.setup_motors \ - --teleop.type=so101_leader \ - --teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step -``` - - - - - -```python -from lerobot.teleoperators.so101_leader import SO101Leader, SO101LeaderConfig - -config = SO101LeaderConfig( - port="/dev/tty.usbmodem585A0076841", - id="my_awesome_leader_arm", -) -leader = SO101Leader(config) -leader.setup_motors() -``` - - - - - -## Calibrate - -Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. -The calibration process is very important because it allows a neural network trained on one robot to work on another. - -#### Follower - -Run the following command or API example to calibrate the follower arm: - - - - -```bash -python -m lerobot.calibrate \ - --robot.type=so101_follower \ - --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot - --robot.id=my_awesome_follower_arm # <- Give the robot a unique name -``` - - - - - -```python -from lerobot.robots.so101_follower import SO101FollowerConfig, SO101Follower - -config = SO101FollowerConfig( - port="/dev/tty.usbmodem585A0076891", - id="my_awesome_follower_arm", -) - -follower = SO101Follower(config) -follower.connect(calibrate=False) -follower.calibrate() -follower.disconnect() -``` - - - - - -The video below shows how to perform the calibration. First you need to move the robot to the position where all joints are in the middle of their ranges. Then after pressing enter you have to move each joint through its full range of motion. - -##### Calibration video - -
- -
- -#### Leader - -Do the same steps to calibrate the leader arm, run the following command or API example: - - - - -```bash -python -m lerobot.calibrate \ - --teleop.type=so101_leader \ - --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot - --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name -``` - - - - - -```python -from lerobot.teleoperators.so101_leader import SO101LeaderConfig, SO101Leader - -config = SO101LeaderConfig( - port="/dev/tty.usbmodem58760431551", - id="my_awesome_leader_arm", -) - -leader = SO101Leader(config) -leader.connect(calibrate=False) -leader.calibrate() -leader.disconnect() -``` - - - - - -Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) - -> [!TIP] -> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/src/lerobot/robots/so101_follower/so101.mdx b/src/lerobot/robots/so101_follower/so101.mdx new file mode 120000 index 000000000..27b892660 --- /dev/null +++ b/src/lerobot/robots/so101_follower/so101.mdx @@ -0,0 +1 @@ +../../../../docs/source/so101.mdx \ No newline at end of file From 26cb4614c961e6da04e4b83b6178331f4150650d Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Sun, 20 Jul 2025 23:41:19 +0200 Subject: [PATCH 026/158] fix: calibration workflow when using robot_id with existing calibration files (#1528) --- src/lerobot/robots/koch_follower/koch_follower.py | 12 ++++++++++++ src/lerobot/robots/lekiwi/lekiwi.py | 12 ++++++++++++ src/lerobot/robots/so100_follower/so100_follower.py | 13 +++++++++++++ src/lerobot/robots/so101_follower/so101_follower.py | 13 +++++++++++++ .../teleoperators/koch_leader/koch_leader.py | 12 ++++++++++++ .../teleoperators/so100_leader/so100_leader.py | 13 +++++++++++++ .../teleoperators/so101_leader/so101_leader.py | 13 +++++++++++++ 7 files changed, 88 insertions(+) diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py index 1cfc6cf08..b09b9b8e2 100644 --- a/src/lerobot/robots/koch_follower/koch_follower.py +++ b/src/lerobot/robots/koch_follower/koch_follower.py @@ -94,6 +94,9 @@ class KochFollower(Robot): self.bus.connect() if not self.is_calibrated and calibrate: + logger.info( + "Mismatch between calibration values in the motor and the calibration file or no calibration file found" + ) self.calibrate() for cam in self.cameras.values(): @@ -107,6 +110,15 @@ class KochFollower(Robot): return self.bus.is_calibrated def calibrate(self) -> None: + if self.calibration: + # Calibration file exists, ask user whether to use it or run new calibration + user_input = input( + f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: " + ) + if user_input.strip().lower() != "c": + logger.info(f"Writing calibration file associated with the id {self.id} to the motors") + self.bus.write_calibration(self.calibration) + return logger.info(f"\nRunning calibration of {self}") self.bus.disable_torque() for motor in self.bus.motors: diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py index ff1465d8b..7004cc0fe 100644 --- a/src/lerobot/robots/lekiwi/lekiwi.py +++ b/src/lerobot/robots/lekiwi/lekiwi.py @@ -114,6 +114,9 @@ class LeKiwi(Robot): self.bus.connect() if not self.is_calibrated and calibrate: + logger.info( + "Mismatch between calibration values in the motor and the calibration file or no calibration file found" + ) self.calibrate() for cam in self.cameras.values(): @@ -127,6 +130,15 @@ class LeKiwi(Robot): return self.bus.is_calibrated def calibrate(self) -> None: + if self.calibration: + # Calibration file exists, ask user whether to use it or run new calibration + user_input = input( + f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: " + ) + if user_input.strip().lower() != "c": + logger.info(f"Writing calibration file associated with the id {self.id} to the motors") + self.bus.write_calibration(self.calibration) + return logger.info(f"\nRunning calibration of {self}") motors = self.arm_motors + self.base_motors diff --git a/src/lerobot/robots/so100_follower/so100_follower.py b/src/lerobot/robots/so100_follower/so100_follower.py index e5da6bc1a..ac52293ff 100644 --- a/src/lerobot/robots/so100_follower/so100_follower.py +++ b/src/lerobot/robots/so100_follower/so100_follower.py @@ -92,6 +92,9 @@ class SO100Follower(Robot): self.bus.connect() if not self.is_calibrated and calibrate: + logger.info( + "Mismatch between calibration values in the motor and the calibration file or no calibration file found" + ) self.calibrate() for cam in self.cameras.values(): @@ -105,6 +108,16 @@ class SO100Follower(Robot): return self.bus.is_calibrated def calibrate(self) -> None: + if self.calibration: + # Calibration file exists, ask user whether to use it or run new calibration + user_input = input( + f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: " + ) + if user_input.strip().lower() != "c": + logger.info(f"Writing calibration file associated with the id {self.id} to the motors") + self.bus.write_calibration(self.calibration) + return + logger.info(f"\nRunning calibration of {self}") self.bus.disable_torque() for motor in self.bus.motors: diff --git a/src/lerobot/robots/so101_follower/so101_follower.py b/src/lerobot/robots/so101_follower/so101_follower.py index 3ae3c3967..3ef66d702 100644 --- a/src/lerobot/robots/so101_follower/so101_follower.py +++ b/src/lerobot/robots/so101_follower/so101_follower.py @@ -92,6 +92,9 @@ class SO101Follower(Robot): self.bus.connect() if not self.is_calibrated and calibrate: + logger.info( + "Mismatch between calibration values in the motor and the calibration file or no calibration file found" + ) self.calibrate() for cam in self.cameras.values(): @@ -105,6 +108,16 @@ class SO101Follower(Robot): return self.bus.is_calibrated def calibrate(self) -> None: + if self.calibration: + # self.calibration is not empty here + user_input = input( + f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: " + ) + if user_input.strip().lower() != "c": + logger.info(f"Writing calibration file associated with the id {self.id} to the motors") + self.bus.write_calibration(self.calibration) + return + logger.info(f"\nRunning calibration of {self}") self.bus.disable_torque() for motor in self.bus.motors: diff --git a/src/lerobot/teleoperators/koch_leader/koch_leader.py b/src/lerobot/teleoperators/koch_leader/koch_leader.py index 8eb076fae..e0318cca5 100644 --- a/src/lerobot/teleoperators/koch_leader/koch_leader.py +++ b/src/lerobot/teleoperators/koch_leader/koch_leader.py @@ -75,6 +75,9 @@ class KochLeader(Teleoperator): self.bus.connect() if not self.is_calibrated and calibrate: + logger.info( + "Mismatch between calibration values in the motor and the calibration file or no calibration file found" + ) self.calibrate() self.configure() @@ -85,6 +88,15 @@ class KochLeader(Teleoperator): return self.bus.is_calibrated def calibrate(self) -> None: + if self.calibration: + # Calibration file exists, ask user whether to use it or run new calibration + user_input = input( + f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: " + ) + if user_input.strip().lower() != "c": + logger.info(f"Writing calibration file associated with the id {self.id} to the motors") + self.bus.write_calibration(self.calibration) + return logger.info(f"\nRunning calibration of {self}") self.bus.disable_torque() for motor in self.bus.motors: diff --git a/src/lerobot/teleoperators/so100_leader/so100_leader.py b/src/lerobot/teleoperators/so100_leader/so100_leader.py index 18dad44d4..a8f6d29b5 100644 --- a/src/lerobot/teleoperators/so100_leader/so100_leader.py +++ b/src/lerobot/teleoperators/so100_leader/so100_leader.py @@ -72,6 +72,9 @@ class SO100Leader(Teleoperator): self.bus.connect() if not self.is_calibrated and calibrate: + logger.info( + "Mismatch between calibration values in the motor and the calibration file or no calibration file found" + ) self.calibrate() self.configure() @@ -82,6 +85,16 @@ class SO100Leader(Teleoperator): return self.bus.is_calibrated def calibrate(self) -> None: + if self.calibration: + # Calibration file exists, ask user whether to use it or run new calibration + user_input = input( + f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: " + ) + if user_input.strip().lower() != "c": + logger.info(f"Writing calibration file associated with the id {self.id} to the motors") + self.bus.write_calibration(self.calibration) + return + logger.info(f"\nRunning calibration of {self}") self.bus.disable_torque() for motor in self.bus.motors: diff --git a/src/lerobot/teleoperators/so101_leader/so101_leader.py b/src/lerobot/teleoperators/so101_leader/so101_leader.py index 2ce28d2e4..15a363e37 100644 --- a/src/lerobot/teleoperators/so101_leader/so101_leader.py +++ b/src/lerobot/teleoperators/so101_leader/so101_leader.py @@ -73,6 +73,9 @@ class SO101Leader(Teleoperator): self.bus.connect() if not self.is_calibrated and calibrate: + logger.info( + "Mismatch between calibration values in the motor and the calibration file or no calibration file found" + ) self.calibrate() self.configure() @@ -83,6 +86,16 @@ class SO101Leader(Teleoperator): return self.bus.is_calibrated def calibrate(self) -> None: + if self.calibration: + # Calibration file exists, ask user whether to use it or run new calibration + user_input = input( + f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: " + ) + if user_input.strip().lower() != "c": + logger.info(f"Writing calibration file associated with the id {self.id} to the motors") + self.bus.write_calibration(self.calibration) + return + logger.info(f"\nRunning calibration of {self}") self.bus.disable_torque() for motor in self.bus.motors: From 17efa2ff8e71d3e9097ac32cb183e46919dea015 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 21 Jul 2025 10:57:35 +0200 Subject: [PATCH 027/158] Add disclaimer to pi0 from_pretrained (#1550) --- src/lerobot/policies/pi0/modeling_pi0.py | 11 +++++++++++ src/lerobot/policies/pi0fast/modeling_pi0fast.py | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index badfb4b8c..11feca964 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -21,6 +21,7 @@ [Jax code](https://github.com/Physical-Intelligence/openpi) Designed by Physical Intelligence. Ported from Jax by Hugging Face. +Disclaimer: It is not expected to perform as well as the original implementation. Install pi0 extra dependencies: ```bash @@ -260,6 +261,16 @@ class PI0Policy(PreTrainedPolicy): def get_optim_params(self) -> dict: return self.parameters() + @classmethod + def from_pretrained(cls, *args, **kwargs): + """Override the from_pretrained method to display important disclaimer.""" + print( + "⚠️ DISCLAIMER: The PI0 model is ported from JAX by the Hugging Face team. \n" + " It is not expected to perform as well as the original implementation. \n" + " Original implementation: https://github.com/Physical-Intelligence/openpi" + ) + return super().from_pretrained(*args, **kwargs) + @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: """Predict a chunk of actions given environment observations.""" diff --git a/src/lerobot/policies/pi0fast/modeling_pi0fast.py b/src/lerobot/policies/pi0fast/modeling_pi0fast.py index 0e53bd349..d3903066c 100644 --- a/src/lerobot/policies/pi0fast/modeling_pi0fast.py +++ b/src/lerobot/policies/pi0fast/modeling_pi0fast.py @@ -21,6 +21,7 @@ [Jax code](https://github.com/Physical-Intelligence/openpi) Designed by Physical Intelligence. Ported from Jax by Hugging Face. +Disclaimer: It is not expected to perform as well as the original implementation. Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`): ```bash @@ -162,6 +163,16 @@ class PI0FASTPolicy(PreTrainedPolicy): """This should be called whenever the environment is reset.""" self._action_queue = deque([], maxlen=self.config.n_action_steps) + @classmethod + def from_pretrained(cls, *args, **kwargs): + """Override the from_pretrained method to display important disclaimer.""" + print( + "⚠️ DISCLAIMER: The PI0FAST model is ported from JAX by the Hugging Face team. \n" + " It is not expected to perform as well as the original implementation. \n" + " Original implementation: https://github.com/Physical-Intelligence/openpi" + ) + return super().from_pretrained(*args, **kwargs) + def get_optim_params(self) -> dict: return self.parameters() From f59baeab45111abdcdb2d23c967812cccf190364 Mon Sep 17 00:00:00 2001 From: Daniel Ritchie Date: Mon, 21 Jul 2025 09:16:50 -0600 Subject: [PATCH 028/158] bump version for breaking changes in 1417 (#1515) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e0d754f53..ec2598973 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb" [project] name = "lerobot" -version = "0.1.0" +version = "0.2.0" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" readme = "README.md" license = { text = "Apache-2.0" } From f6ec1d89a5a3d3f0d28b79bf9b58729d808b3375 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 21 Jul 2025 19:08:32 +0200 Subject: [PATCH 029/158] feat(ci): add release workflow (#1562) --- .github/workflows/full_tests.yml | 10 ++-- .github/workflows/nightly.yml | 16 +++--- .github/workflows/release.yml | 88 ++++++++++++++++++++++++++++++++ Makefile | 8 +-- docker/Dockerfile.internal | 2 +- 5 files changed, 106 insertions(+), 18 deletions(-) create mode 100644 .github/workflows/release.yml diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index f044dc484..55d38883a 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -131,8 +131,8 @@ jobs: - name: Login to Docker Hub uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses] with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_PASSWORD }} + username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }} + password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }} - name: Build and push Docker image uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses] with: @@ -157,8 +157,8 @@ jobs: image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images] options: --gpus all --shm-size "16gb" credentials: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_PASSWORD }} + username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }} + password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }} defaults: run: shell: bash @@ -187,7 +187,7 @@ jobs: TOKEN=$(curl -s -H "Content-Type: application/json" \ -X POST \ - -d '{"username": "${{ secrets.DOCKERHUB_USERNAME }}", "password": "${{ secrets.DOCKERHUB_PASSWORD }}"}' \ + -d '{"username": "${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}", "password": "${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}"}' \ https://hub.docker.com/v2/users/login/ | jq -r .token) if [ "$TOKEN" == "null" ] || [ -z "$TOKEN" ]; then diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 66755d9df..b42b92f6b 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -62,8 +62,8 @@ jobs: - name: Login to Docker Hub uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses] with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_PASSWORD }} + username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }} + password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }} - name: Build and push Docker image CPU uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses] with: @@ -96,8 +96,8 @@ jobs: - name: Login to Docker Hub uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses] with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_PASSWORD }} + username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }} + password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }} - name: Build and push Docker image GPU uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses] with: @@ -120,8 +120,8 @@ jobs: container: image: ${{ needs.build-docker-cpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images] credentials: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_PASSWORD }} + username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }} + password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }} defaults: run: shell: bash @@ -147,8 +147,8 @@ jobs: image: ${{ needs.build-docker-gpu-nightly.outputs.image_tag }} # zizmor: ignore[unpinned-images] options: --gpus all --shm-size "16gb" credentials: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_PASSWORD }} + username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }} + password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }} defaults: run: shell: bash diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 000000000..7d80ac5af --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,88 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: Create Release and Publish to PyPI + +on: + push: + tags: + - 'v*.*.*' # Trigger on tags like v0.1.0, v1.0.0 + +jobs: + # TODO(Steven): Publish draft/pre-release and to test pypi + # TODO(Steven): Tag documentation with the same version as the package + # TODO(Steven): Define entry points for main CLI scripts + build-and-publish: + name: Build and publish Python distributions + runs-on: ubuntu-latest + permissions: + contents: write + id-token: write + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + persist-credentials: false + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Extract Version and Package Name + id: extract_info + # zizmor: ignore[template-injection] + run: | + # Extract version from tag (e.g., v0.1.0 -> 0.1.0) + VERSION=${{ github.ref_name }} + VERSION_NUMBER=${VERSION#v} + echo "tag_version=$VERSION_NUMBER" >> $GITHUB_OUTPUT + + # Extract package name from pyproject.toml + PACKAGE_NAME=$(grep -oP 'name = "\K[^"]+' pyproject.toml) + echo "package_name=$PACKAGE_NAME" >> $GITHUB_OUTPUT + + - name: Check if version exists on PyPI + # zizmor: ignore[template-injection] + run: | + PACKAGE_NAME=${{ steps.extract_info.outputs.package_name }} + NEW_VERSION=${{ steps.extract_info.outputs.tag_version }} + + response=$(curl -s "https://pypi.org/pypi/$PACKAGE_NAME/$NEW_VERSION/json") + if echo "$response" | grep -q "message"; then + echo "Version $NEW_VERSION is available on PyPI. Proceeding with release." + else + echo "Error: Version $NEW_VERSION already exists on PyPI. Aborting." + exit 1 + fi + + - name: Install build dependencies + run: python -m pip install build + + - name: Build package + run: python -m build + + - name: Create GitHub Release + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # zizmor: ignore[template-injection] + run: gh release create ${{ github.ref_name }} --release-name "Release ${{ github.ref_name }}" --generate-notes ./dist/* + + # TODO(Steven): Uncomment when ready to publish to PyPI + # - name: Publish to PyPI + # if: startsWith(github.ref, 'refs/tags/v') + # uses: pypa/gh-action-pypi-publish@v1.12.4 + # with: + # password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/Makefile b/Makefile index ca1495fac..5bfbe76a2 100644 --- a/Makefile +++ b/Makefile @@ -26,11 +26,11 @@ export PATH := $(dir $(PYTHON_PATH)):$(PATH) DEVICE ?= cpu -build-cpu: - docker build -t lerobot:latest -f docker/lerobot-cpu/Dockerfile . +build-user: + docker build -f docker/Dockerfile.user -t lerobot-user . -build-gpu: - docker build -t lerobot:latest -f docker/lerobot-gpu/Dockerfile . +build-internal: + docker build -f docker/Dockerfile.internal -t lerobot-internal . test-end-to-end: ${MAKE} DEVICE=$(DEVICE) test-act-ete-train diff --git a/docker/Dockerfile.internal b/docker/Dockerfile.internal index c799a006d..8c77fe497 100644 --- a/docker/Dockerfile.internal +++ b/docker/Dockerfile.internal @@ -15,7 +15,7 @@ # This Dockerfile is designed for HuggingFace internal CI environments # that require GPU access. It starts from an NVIDIA CUDA base image. -# docker build -f docker/Dockerfile.internal -t lerobot-ci . +# docker build -f docker/Dockerfile.internal -t lerobot-internal . # Configure the base image for CI with GPU access # TODO(Steven): Bump these versions From 9b9f4757fb7952fbe891968dcd0b554907f528d2 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Mon, 21 Jul 2025 19:12:03 +0200 Subject: [PATCH 030/158] style(deprecated method): remove no longer used get_features_from_robot function (replaced by hw_to_dataset_features) (#1560) --- src/lerobot/datasets/utils.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index ac0ab9799..078c5351d 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -41,7 +41,6 @@ from lerobot.datasets.backward_compatibility import ( BackwardCompatibilityError, ForwardCompatibilityError, ) -from lerobot.robots import Robot from lerobot.utils.utils import is_valid_numpy_dtype_string DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk @@ -440,16 +439,6 @@ def build_dataset_frame( return frame -def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict: - camera_ft = {} - if robot.cameras: - camera_ft = { - key: {"dtype": "video" if use_videos else "image", **ft} - for key, ft in robot.camera_features.items() - } - return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES} - - def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: # TODO(aliberts): Implement "type" in dataset features and simplify this policy_features = {} From 5d2aef61b83c22b29994bd348ba4df931a8282ae Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Tue, 22 Jul 2025 11:56:23 +0200 Subject: [PATCH 031/158] Pre-commits fixes (#1568) * Replace typos w/ mirror * Update ruff * Replace prettier mirror --- .pre-commit-config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e509d6d88..f09017991 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,13 +39,13 @@ repos: - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.13 + rev: v0.12.4 hooks: - id: ruff-format - id: ruff args: [--fix, --exit-non-zero-on-fix] - - repo: https://github.com/crate-ci/typos + - repo: https://github.com/adhtruong/mirrors-typos rev: v1.34.0 hooks: - id: typos @@ -58,8 +58,8 @@ repos: args: [--py310-plus] ##### Markdown Quality ##### - - repo: https://github.com/pre-commit/mirrors-prettier - rev: v4.0.0-alpha.8 + - repo: https://github.com/rbubley/mirrors-prettier + rev: v3.6.2 hooks: - id: prettier name: Format Markdown with Prettier From 835f0eddfabc2bbf48b68094c4a2551fb4d273b3 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Tue, 22 Jul 2025 14:31:30 +0200 Subject: [PATCH 032/158] bug(gamepad_utils) inverted axis between x and y (#1572) --- src/lerobot/teleoperators/gamepad/gamepad_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lerobot/teleoperators/gamepad/gamepad_utils.py b/src/lerobot/teleoperators/gamepad/gamepad_utils.py index 9b62dc666..7ebed6b31 100644 --- a/src/lerobot/teleoperators/gamepad/gamepad_utils.py +++ b/src/lerobot/teleoperators/gamepad/gamepad_utils.py @@ -295,8 +295,8 @@ class GamepadController(InputController): try: # Read joystick axes # Left stick X and Y (typically axes 0 and 1) - y_input = self.joystick.get_axis(0) # Left/Right - x_input = self.joystick.get_axis(1) # Up/Down (often inverted) + x_input = self.joystick.get_axis(0) # Left/Right + y_input = self.joystick.get_axis(1) # Up/Down (often inverted) # Right stick Y (typically axis 3 or 4) z_input = self.joystick.get_axis(3) # Up/Down for Z @@ -308,7 +308,7 @@ class GamepadController(InputController): # Calculate deltas (note: may need to invert axes depending on controller) delta_x = -x_input * self.x_step_size # Forward/backward - delta_y = -y_input * self.y_step_size # Left/right + delta_y = y_input * self.y_step_size # Left/right delta_z = -z_input * self.z_step_size # Up/down return delta_x, delta_y, delta_z From f5d6b5b3a790af16c44a52c462f1d28f90659e6b Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 22 Jul 2025 15:14:01 +0200 Subject: [PATCH 033/158] test(cameras): skip depth test in rs camera for latest version (#1574) * test(cameras): increase timeout in depth read for testing * test(cameras): skip test_depth in realsense --------- Co-authored-by: Michel Aractingi --- tests/cameras/test_realsense.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/cameras/test_realsense.py b/tests/cameras/test_realsense.py index 3957baf2d..4b3fbae82 100644 --- a/tests/cameras/test_realsense.py +++ b/tests/cameras/test_realsense.py @@ -104,12 +104,14 @@ def test_read(): assert isinstance(img, np.ndarray) +# TODO(Steven): Fix this test for the latest version of pyrealsense2. +@pytest.mark.skip("Skipping test: pyrealsense2 version > 2.55.1.6486") def test_read_depth(): config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, use_depth=True) camera = RealSenseCamera(config) camera.connect(warmup=False) - img = camera.read_depth(timeout_ms=1000) # NOTE(Steven): Reading depth takes longer + img = camera.read_depth(timeout_ms=2000) # NOTE(Steven): Reading depth takes longer in CI environments. assert isinstance(img, np.ndarray) From 989f3d05ba47f872d75c587e76838e9cc574857a Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Wed, 23 Jul 2025 16:30:01 +0700 Subject: [PATCH 034/158] [Async Inference] Merge Protos & refactoring (#1480) * Merge together proto files and refactor Async inference * Fixup for Async inference * Drop not reuqired changes * Fix tests * Drop old async files * Drop chunk_size param * Fix versions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix wrong fix Co-authored-by: Ben Zhang * Fixup --------- Co-authored-by: Michel Aractingi Co-authored-by: Ben Zhang Co-authored-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --- pyproject.toml | 6 +- src/lerobot/scripts/server/helpers.py | 86 ------ src/lerobot/scripts/server/policy_server.py | 34 +-- src/lerobot/scripts/server/robot_client.py | 25 +- src/lerobot/transport/async_inference.proto | 59 ---- src/lerobot/transport/async_inference_pb2.py | 45 --- .../transport/async_inference_pb2_grpc.py | 277 ------------------ src/lerobot/transport/services.proto | 28 ++ src/lerobot/transport/services_pb2.py | 28 +- src/lerobot/transport/services_pb2_grpc.py | 211 ++++++++++++- src/lerobot/transport/utils.py | 10 +- tests/async_inference/test_e2e.py | 8 +- 12 files changed, 299 insertions(+), 518 deletions(-) delete mode 100644 src/lerobot/transport/async_inference.proto delete mode 100644 src/lerobot/transport/async_inference_pb2.py delete mode 100644 src/lerobot/transport/async_inference_pb2_grpc.py diff --git a/pyproject.toml b/pyproject.toml index ec2598973..7a0ad1480 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,7 @@ dependencies = [ pygame-dep = ["pygame>=2.5.1"] placo-dep = ["placo>=0.9.6"] transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency -grpcio-dep = ["grpcio==1.71.0"] +grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # Motors feetech = ["feetech-servo-sdk>=1.0.0"] @@ -119,14 +119,14 @@ intelrealsense = [ # Policies pi0 = ["lerobot[transformers-dep]"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"] -hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "protobuf>=5.29.3", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] +hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] # Features async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"] # Development docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"] -dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "grpcio-tools==1.71.0"] +dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"] test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"] video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] diff --git a/src/lerobot/scripts/server/helpers.py b/src/lerobot/scripts/server/helpers.py index 7fd56e693..d8051b76e 100644 --- a/src/lerobot/scripts/server/helpers.py +++ b/src/lerobot/scripts/server/helpers.py @@ -12,15 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import io import logging import logging.handlers import os import time from dataclasses import dataclass from pathlib import Path -from threading import Event -from typing import Any import torch @@ -31,8 +28,6 @@ from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features # NOTE: Configs need to be loaded for the client to be able to instantiate the policy config from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401 from lerobot.robots.robot import Robot -from lerobot.transport import async_inference_pb2 -from lerobot.transport.utils import bytes_buffer_size from lerobot.utils.utils import init_logging Action = torch.Tensor @@ -303,84 +298,3 @@ def observations_similar( ) return _compare_observation_states(obs1_state, obs2_state, atol=atol) - - -def send_bytes_in_chunks( - buffer: bytes, - message_class: Any, - log_prefix: str = "", - silent: bool = True, - chunk_size: int = 3 * 1024 * 1024, -): - # NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.send_bytes_in_chunks. Duplication can't be avoided if we - # don't use a unique class for messages sent (due to the different transfer states sent). Also, I'd want more control over the - # chunk size as I am using it to send image observations. - buffer = io.BytesIO(buffer) - size_in_bytes = bytes_buffer_size(buffer) - - sent_bytes = 0 - - logging_method = logging.info if not silent else logging.debug - - logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with") - - while sent_bytes < size_in_bytes: - transfer_state = async_inference_pb2.TransferState.TRANSFER_MIDDLE - - if sent_bytes + chunk_size >= size_in_bytes: - transfer_state = async_inference_pb2.TransferState.TRANSFER_END - elif sent_bytes == 0: - transfer_state = async_inference_pb2.TransferState.TRANSFER_BEGIN - - size_to_read = min(chunk_size, size_in_bytes - sent_bytes) - chunk = buffer.read(size_to_read) - - yield message_class(transfer_state=transfer_state, data=chunk) - sent_bytes += size_to_read - logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}") - - logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") - - -def receive_bytes_in_chunks( - iterator, continue_receiving: Event, logger: logging.Logger, log_prefix: str = "" -): # type: ignore - # NOTE(fracapuano): Partially copied from lerobot.common.transport.utils.receive_bytes_in_chunks. Duplication can't be avoided if we - # don't use a unique class for messages sent (due to the different transfer states sent). Also, on the server side the logic for receiving - # is opposite then the HIL-SERL design (my event showcases keeping on running instead of shutdown) - bytes_buffer = io.BytesIO() - step = 0 - - logger.info(f"{log_prefix} Starting receiver") - for item in iterator: - logger.debug(f"{log_prefix} Received item") - if not continue_receiving.is_set(): - logger.info(f"{log_prefix} Shutting down receiver") - return - - if item.transfer_state == async_inference_pb2.TransferState.TRANSFER_BEGIN: - bytes_buffer.seek(0) - bytes_buffer.truncate(0) - bytes_buffer.write(item.data) - logger.debug(f"{log_prefix} Received data at step 0") - - elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_MIDDLE: - bytes_buffer.write(item.data) - step += 1 - logger.debug(f"{log_prefix} Received data at step {step}") - - elif item.transfer_state == async_inference_pb2.TransferState.TRANSFER_END: - bytes_buffer.write(item.data) - logger.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}") - - complete_bytes = bytes_buffer.getvalue() - - bytes_buffer.seek(0) - bytes_buffer.truncate(0) - - logger.debug(f"{log_prefix} Queue updated") - return complete_bytes - - else: - logger.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}") - raise ValueError(f"Received unknown transfer state {item.transfer_state}") diff --git a/src/lerobot/scripts/server/policy_server.py b/src/lerobot/scripts/server/policy_server.py index 13ba976e2..0ed446d3a 100644 --- a/src/lerobot/scripts/server/policy_server.py +++ b/src/lerobot/scripts/server/policy_server.py @@ -49,21 +49,21 @@ from lerobot.scripts.server.helpers import ( get_logger, observations_similar, raw_observation_to_observation, - receive_bytes_in_chunks, ) from lerobot.transport import ( - async_inference_pb2, # type: ignore - async_inference_pb2_grpc, # type: ignore + services_pb2, # type: ignore + services_pb2_grpc, # type: ignore ) +from lerobot.transport.utils import receive_bytes_in_chunks -class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): +class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): prefix = "policy_server" logger = get_logger(prefix) def __init__(self, config: PolicyServerConfig): self.config = config - self._running_event = threading.Event() + self.shutdown_event = threading.Event() # FPS measurement self.fps_tracker = FPSTracker(target_fps=config.fps) @@ -84,7 +84,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): @property def running(self): - return self._running_event.is_set() + return not self.shutdown_event.is_set() @property def policy_image_features(self): @@ -93,7 +93,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): def _reset_server(self) -> None: """Flushes server state when new client connects.""" # only running inference on the latest observation received by the server - self._running_event.clear() + self.shutdown_event.set() self.observation_queue = Queue(maxsize=1) with self._predicted_timesteps_lock: @@ -103,16 +103,16 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): client_id = context.peer() self.logger.info(f"Client {client_id} connected and ready") self._reset_server() - self._running_event.set() + self.shutdown_event.clear() - return async_inference_pb2.Empty() + return services_pb2.Empty() def SendPolicyInstructions(self, request, context): # noqa: N802 """Receive policy instructions from the robot client""" if not self.running: self.logger.warning("Server is not running. Ignoring policy instructions.") - return async_inference_pb2.Empty() + return services_pb2.Empty() client_id = context.peer() @@ -149,7 +149,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds") - return async_inference_pb2.Empty() + return services_pb2.Empty() def SendObservations(self, request_iterator, context): # noqa: N802 """Receive observations from the robot client""" @@ -159,7 +159,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): receive_time = time.time() # comparing timestamps so need time.time() start_deserialize = time.perf_counter() received_bytes = receive_bytes_in_chunks( - request_iterator, self._running_event, self.logger + request_iterator, None, self.shutdown_event, self.logger ) # blocking call while looping over request_iterator timed_observation = pickle.loads(received_bytes) # nosec deserialize_time = time.perf_counter() - start_deserialize @@ -190,7 +190,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): ): self.logger.info(f"Observation #{obs_timestep} has been filtered out") - return async_inference_pb2.Empty() + return services_pb2.Empty() def GetActions(self, request, context): # noqa: N802 """Returns actions to the robot client. Actions are sent as a single @@ -218,7 +218,7 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): serialize_time = time.perf_counter() - start_time # Create and return the action chunk - actions = async_inference_pb2.Actions(data=actions_bytes) + actions = services_pb2.Actions(data=actions_bytes) self.logger.info( f"Action chunk #{obs.get_timestep()} generated | " @@ -239,12 +239,12 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): return actions except Empty: # no observation added to queue in obs_queue_timeout - return async_inference_pb2.Empty() + return services_pb2.Empty() except Exception as e: self.logger.error(f"Error in StreamActions: {e}") - return async_inference_pb2.Empty() + return services_pb2.Empty() def _obs_sanity_checks(self, obs: TimedObservation, previous_obs: TimedObservation) -> bool: """Check if the observation is valid to be processed by the policy""" @@ -388,7 +388,7 @@ def serve(cfg: PolicyServerConfig): # Setup and start gRPC server server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) - async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) + services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) server.add_insecure_port(f"{cfg.host}:{cfg.port}") policy_server.logger.info(f"PolicyServer started on {cfg.host}:{cfg.port}") diff --git a/src/lerobot/scripts/server/robot_client.py b/src/lerobot/scripts/server/robot_client.py index 68166de6f..0599e068e 100644 --- a/src/lerobot/scripts/server/robot_client.py +++ b/src/lerobot/scripts/server/robot_client.py @@ -69,15 +69,14 @@ from lerobot.scripts.server.helpers import ( TimedObservation, get_logger, map_robot_keys_to_lerobot_features, - send_bytes_in_chunks, validate_robot_cameras_for_policy, visualize_action_queue_size, ) from lerobot.transport import ( - async_inference_pb2, # type: ignore - async_inference_pb2_grpc, # type: ignore + services_pb2, # type: ignore + services_pb2_grpc, # type: ignore ) -from lerobot.transport.utils import grpc_channel_options +from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks class RobotClient: @@ -118,10 +117,10 @@ class RobotClient: self.channel = grpc.insecure_channel( self.server_address, grpc_channel_options(initial_backoff=f"{config.environment_dt:.4f}s") ) - self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel) + self.stub = services_pb2_grpc.AsyncInferenceStub(self.channel) self.logger.info(f"Initializing client to connect to server at {self.server_address}") - self._running_event = threading.Event() + self.shutdown_event = threading.Event() # Initialize client side variables self.latest_action_lock = threading.Lock() @@ -146,20 +145,20 @@ class RobotClient: @property def running(self): - return self._running_event.is_set() + return not self.shutdown_event.is_set() def start(self): """Start the robot client and connect to the policy server""" try: # client-server handshake start_time = time.perf_counter() - self.stub.Ready(async_inference_pb2.Empty()) + self.stub.Ready(services_pb2.Empty()) end_time = time.perf_counter() self.logger.debug(f"Connected to policy server in {end_time - start_time:.4f}s") # send policy instructions policy_config_bytes = pickle.dumps(self.policy_config) - policy_setup = async_inference_pb2.PolicySetup(data=policy_config_bytes) + policy_setup = services_pb2.PolicySetup(data=policy_config_bytes) self.logger.info("Sending policy instructions to policy server") self.logger.debug( @@ -170,7 +169,7 @@ class RobotClient: self.stub.SendPolicyInstructions(policy_setup) - self._running_event.set() + self.shutdown_event.clear() return True @@ -180,7 +179,7 @@ class RobotClient: def stop(self): """Stop the robot client""" - self._running_event.clear() + self.shutdown_event.set() self.robot.disconnect() self.logger.debug("Robot disconnected") @@ -208,7 +207,7 @@ class RobotClient: try: observation_iterator = send_bytes_in_chunks( observation_bytes, - async_inference_pb2.Observation, + services_pb2.Observation, log_prefix="[CLIENT] Observation", silent=True, ) @@ -283,7 +282,7 @@ class RobotClient: while self.running: try: # Use StreamActions to get a stream of actions from the server - actions_chunk = self.stub.GetActions(async_inference_pb2.Empty()) + actions_chunk = self.stub.GetActions(services_pb2.Empty()) if len(actions_chunk.data) == 0: continue # received `Empty` from server, wait for next call diff --git a/src/lerobot/transport/async_inference.proto b/src/lerobot/transport/async_inference.proto deleted file mode 100644 index 434f3142b..000000000 --- a/src/lerobot/transport/async_inference.proto +++ /dev/null @@ -1,59 +0,0 @@ -// fmt: off -// flake8: noqa -// !/usr/bin/env python - -// Copyright 2024 The HuggingFace Inc. team. -// All rights reserved. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at - -// http://www.apache.org/licenses/LICENSE-2.0 - -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -syntax = "proto3"; - -package async_inference; - -// AsyncInference: from Robot perspective -// Robot send observations to & executes action received from a remote Policy server -service AsyncInference { - // Robot -> Policy to share observations with a remote inference server - // Policy -> Robot to share actions predicted for given observations - rpc SendObservations(stream Observation) returns (Empty); - rpc GetActions(Empty) returns (Actions); - rpc SendPolicyInstructions(PolicySetup) returns (Empty); - rpc Ready(Empty) returns (Empty); - rpc Stop(Empty) returns (Empty); -} - -enum TransferState { - TRANSFER_UNKNOWN = 0; - TRANSFER_BEGIN = 1; - TRANSFER_MIDDLE = 2; - TRANSFER_END = 3; -} - -// Messages -message Observation { - // sent by Robot, to remote Policy - TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size - bytes data = 2; -} - -message Actions { - // sent by remote Policy, to Robot - bytes data = 1; -} - -message PolicySetup { - // sent by Robot to remote server, to init Policy - bytes data = 1; -} - -message Empty {} diff --git a/src/lerobot/transport/async_inference_pb2.py b/src/lerobot/transport/async_inference_pb2.py deleted file mode 100644 index 59c8eb488..000000000 --- a/src/lerobot/transport/async_inference_pb2.py +++ /dev/null @@ -1,45 +0,0 @@ -# Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE -# source: async_inference.proto -# Protobuf Python Version: 5.29.0 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 5, - 29, - 0, - '', - 'async_inference.proto' -) -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x61sync_inference.proto\x12\x0f\x61sync_inference\"S\n\x0bObservation\x12\x36\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x1e.async_inference.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xdd\x02\n\x0e\x41syncInference\x12J\n\x10SendObservations\x12\x1c.async_inference.Observation\x1a\x16.async_inference.Empty(\x01\x12>\n\nGetActions\x12\x16.async_inference.Empty\x1a\x18.async_inference.Actions\x12N\n\x16SendPolicyInstructions\x12\x1c.async_inference.PolicySetup\x1a\x16.async_inference.Empty\x12\x37\n\x05Ready\x12\x16.async_inference.Empty\x1a\x16.async_inference.Empty\x12\x36\n\x04Stop\x12\x16.async_inference.Empty\x1a\x16.async_inference.Emptyb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'async_inference_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_TRANSFERSTATE']._serialized_start=190 - _globals['_TRANSFERSTATE']._serialized_end=286 - _globals['_OBSERVATION']._serialized_start=42 - _globals['_OBSERVATION']._serialized_end=125 - _globals['_ACTIONS']._serialized_start=127 - _globals['_ACTIONS']._serialized_end=150 - _globals['_POLICYSETUP']._serialized_start=152 - _globals['_POLICYSETUP']._serialized_end=179 - _globals['_EMPTY']._serialized_start=181 - _globals['_EMPTY']._serialized_end=188 - _globals['_ASYNCINFERENCE']._serialized_start=289 - _globals['_ASYNCINFERENCE']._serialized_end=638 -# @@protoc_insertion_point(module_scope) diff --git a/src/lerobot/transport/async_inference_pb2_grpc.py b/src/lerobot/transport/async_inference_pb2_grpc.py deleted file mode 100644 index 3042db0db..000000000 --- a/src/lerobot/transport/async_inference_pb2_grpc.py +++ /dev/null @@ -1,277 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc -import warnings - -from lerobot.transport import async_inference_pb2 as async__inference__pb2 - -GRPC_GENERATED_VERSION = '1.71.0' -GRPC_VERSION = grpc.__version__ -_version_not_supported = False - -try: - from grpc._utilities import first_version_is_lower - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) -except ImportError: - _version_not_supported = True - -if _version_not_supported: - raise RuntimeError( - f'The grpc package installed is at version {GRPC_VERSION},' - + f' but the generated code in async_inference_pb2_grpc.py depends on' - + f' grpcio>={GRPC_GENERATED_VERSION}.' - + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' - + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' - ) - - -class AsyncInferenceStub: - """AsyncInference: from Robot perspective - Robot send observations to & executes action received from a remote Policy server - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.SendObservations = channel.stream_unary( - '/async_inference.AsyncInference/SendObservations', - request_serializer=async__inference__pb2.Observation.SerializeToString, - response_deserializer=async__inference__pb2.Empty.FromString, - _registered_method=True) - self.GetActions = channel.unary_unary( - '/async_inference.AsyncInference/GetActions', - request_serializer=async__inference__pb2.Empty.SerializeToString, - response_deserializer=async__inference__pb2.Actions.FromString, - _registered_method=True) - self.SendPolicyInstructions = channel.unary_unary( - '/async_inference.AsyncInference/SendPolicyInstructions', - request_serializer=async__inference__pb2.PolicySetup.SerializeToString, - response_deserializer=async__inference__pb2.Empty.FromString, - _registered_method=True) - self.Ready = channel.unary_unary( - '/async_inference.AsyncInference/Ready', - request_serializer=async__inference__pb2.Empty.SerializeToString, - response_deserializer=async__inference__pb2.Empty.FromString, - _registered_method=True) - self.Stop = channel.unary_unary( - '/async_inference.AsyncInference/Stop', - request_serializer=async__inference__pb2.Empty.SerializeToString, - response_deserializer=async__inference__pb2.Empty.FromString, - _registered_method=True) - - -class AsyncInferenceServicer: - """AsyncInference: from Robot perspective - Robot send observations to & executes action received from a remote Policy server - """ - - def SendObservations(self, request_iterator, context): - """Robot -> Policy to share observations with a remote inference server - Policy -> Robot to share actions predicted for given observations - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetActions(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def SendPolicyInstructions(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Ready(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Stop(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_AsyncInferenceServicer_to_server(servicer, server): - rpc_method_handlers = { - 'SendObservations': grpc.stream_unary_rpc_method_handler( - servicer.SendObservations, - request_deserializer=async__inference__pb2.Observation.FromString, - response_serializer=async__inference__pb2.Empty.SerializeToString, - ), - 'GetActions': grpc.unary_unary_rpc_method_handler( - servicer.GetActions, - request_deserializer=async__inference__pb2.Empty.FromString, - response_serializer=async__inference__pb2.Actions.SerializeToString, - ), - 'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler( - servicer.SendPolicyInstructions, - request_deserializer=async__inference__pb2.PolicySetup.FromString, - response_serializer=async__inference__pb2.Empty.SerializeToString, - ), - 'Ready': grpc.unary_unary_rpc_method_handler( - servicer.Ready, - request_deserializer=async__inference__pb2.Empty.FromString, - response_serializer=async__inference__pb2.Empty.SerializeToString, - ), - 'Stop': grpc.unary_unary_rpc_method_handler( - servicer.Stop, - request_deserializer=async__inference__pb2.Empty.FromString, - response_serializer=async__inference__pb2.Empty.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'async_inference.AsyncInference', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - server.add_registered_method_handlers('async_inference.AsyncInference', rpc_method_handlers) - - - # This class is part of an EXPERIMENTAL API. -class AsyncInference: - """AsyncInference: from Robot perspective - Robot send observations to & executes action received from a remote Policy server - """ - - @staticmethod - def SendObservations(request_iterator, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.stream_unary( - request_iterator, - target, - '/async_inference.AsyncInference/SendObservations', - async__inference__pb2.Observation.SerializeToString, - async__inference__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def GetActions(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/async_inference.AsyncInference/GetActions', - async__inference__pb2.Empty.SerializeToString, - async__inference__pb2.Actions.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def SendPolicyInstructions(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/async_inference.AsyncInference/SendPolicyInstructions', - async__inference__pb2.PolicySetup.SerializeToString, - async__inference__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def Ready(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/async_inference.AsyncInference/Ready', - async__inference__pb2.Empty.SerializeToString, - async__inference__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def Stop(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/async_inference.AsyncInference/Stop', - async__inference__pb2.Empty.SerializeToString, - async__inference__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) diff --git a/src/lerobot/transport/services.proto b/src/lerobot/transport/services.proto index 70f39741f..ea0c12de6 100644 --- a/src/lerobot/transport/services.proto +++ b/src/lerobot/transport/services.proto @@ -33,6 +33,17 @@ service LearnerService { rpc Ready(Empty) returns (Empty); } +// AsyncInference: from Robot perspective +// Robot send observations to & executes action received from a remote Policy server +service AsyncInference { + // Robot -> Policy to share observations with a remote inference server + // Policy -> Robot to share actions predicted for given observations + rpc SendObservations(stream Observation) returns (Empty); + rpc GetActions(Empty) returns (Actions); + rpc SendPolicyInstructions(PolicySetup) returns (Empty); + rpc Ready(Empty) returns (Empty); +} + enum TransferState { TRANSFER_UNKNOWN = 0; TRANSFER_BEGIN = 1; @@ -56,4 +67,21 @@ message InteractionMessage { bytes data = 2; } +// Messages +message Observation { + // sent by Robot, to remote Policy + TransferState transfer_state = 1; // Observations can be streamed exceeding 4MB of size + bytes data = 2; +} + +message Actions { + // sent by remote Policy, to Robot + bytes data = 1; +} + +message PolicySetup { + // sent by Robot to remote server, to init Policy + bytes data = 1; +} + message Empty {} diff --git a/src/lerobot/transport/services_pb2.py b/src/lerobot/transport/services_pb2.py index 9e66ae1e3..05f2d174f 100644 --- a/src/lerobot/transport/services_pb2.py +++ b/src/lerobot/transport/services_pb2.py @@ -1,7 +1,7 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE # source: lerobot/transport/services.proto -# Protobuf Python Version: 5.29.0 +# Protobuf Python Version: 6.31.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -10,8 +10,8 @@ from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder _runtime_version.ValidateProtobufRuntimeVersion( _runtime_version.Domain.PUBLIC, - 5, - 29, + 6, + 31, 0, '', 'lerobot/transport/services.proto' @@ -23,23 +23,31 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n lerobot/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"M\n\x0bObservation\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x17\n\x07\x41\x63tions\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x1b\n\x0bPolicySetup\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Empty2\xf5\x01\n\x0e\x41syncInference\x12>\n\x10SendObservations\x12\x16.transport.Observation\x1a\x10.transport.Empty(\x01\x12\x32\n\nGetActions\x12\x10.transport.Empty\x1a\x12.transport.Actions\x12\x42\n\x16SendPolicyInstructions\x12\x16.transport.PolicySetup\x1a\x10.transport.Empty\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.transport.services_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_TRANSFERSTATE']._serialized_start=298 - _globals['_TRANSFERSTATE']._serialized_end=394 + _globals['_TRANSFERSTATE']._serialized_start=431 + _globals['_TRANSFERSTATE']._serialized_end=527 _globals['_TRANSITION']._serialized_start=47 _globals['_TRANSITION']._serialized_end=123 _globals['_PARAMETERS']._serialized_start=125 _globals['_PARAMETERS']._serialized_end=201 _globals['_INTERACTIONMESSAGE']._serialized_start=203 _globals['_INTERACTIONMESSAGE']._serialized_end=287 - _globals['_EMPTY']._serialized_start=289 - _globals['_EMPTY']._serialized_end=296 - _globals['_LEARNERSERVICE']._serialized_start=397 - _globals['_LEARNERSERVICE']._serialized_end=654 + _globals['_OBSERVATION']._serialized_start=289 + _globals['_OBSERVATION']._serialized_end=366 + _globals['_ACTIONS']._serialized_start=368 + _globals['_ACTIONS']._serialized_end=391 + _globals['_POLICYSETUP']._serialized_start=393 + _globals['_POLICYSETUP']._serialized_end=420 + _globals['_EMPTY']._serialized_start=422 + _globals['_EMPTY']._serialized_end=429 + _globals['_LEARNERSERVICE']._serialized_start=530 + _globals['_LEARNERSERVICE']._serialized_end=787 + _globals['_ASYNCINFERENCE']._serialized_start=790 + _globals['_ASYNCINFERENCE']._serialized_end=1035 # @@protoc_insertion_point(module_scope) diff --git a/src/lerobot/transport/services_pb2_grpc.py b/src/lerobot/transport/services_pb2_grpc.py index 77801a340..35a01b675 100644 --- a/src/lerobot/transport/services_pb2_grpc.py +++ b/src/lerobot/transport/services_pb2_grpc.py @@ -5,7 +5,7 @@ import warnings from lerobot.transport import services_pb2 as lerobot_dot_transport_dot_services__pb2 -GRPC_GENERATED_VERSION = '1.71.0' +GRPC_GENERATED_VERSION = '1.73.1' GRPC_VERSION = grpc.__version__ _version_not_supported = False @@ -231,3 +231,212 @@ class LearnerService: timeout, metadata, _registered_method=True) + + +class AsyncInferenceStub: + """AsyncInference: from Robot perspective + Robot send observations to & executes action received from a remote Policy server + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SendObservations = channel.stream_unary( + '/transport.AsyncInference/SendObservations', + request_serializer=lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, + _registered_method=True) + self.GetActions = channel.unary_unary( + '/transport.AsyncInference/GetActions', + request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Actions.FromString, + _registered_method=True) + self.SendPolicyInstructions = channel.unary_unary( + '/transport.AsyncInference/SendPolicyInstructions', + request_serializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, + _registered_method=True) + self.Ready = channel.unary_unary( + '/transport.AsyncInference/Ready', + request_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, + _registered_method=True) + + +class AsyncInferenceServicer: + """AsyncInference: from Robot perspective + Robot send observations to & executes action received from a remote Policy server + """ + + def SendObservations(self, request_iterator, context): + """Robot -> Policy to share observations with a remote inference server + Policy -> Robot to share actions predicted for given observations + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetActions(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendPolicyInstructions(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Ready(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_AsyncInferenceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'SendObservations': grpc.stream_unary_rpc_method_handler( + servicer.SendObservations, + request_deserializer=lerobot_dot_transport_dot_services__pb2.Observation.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + ), + 'GetActions': grpc.unary_unary_rpc_method_handler( + servicer.GetActions, + request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Actions.SerializeToString, + ), + 'SendPolicyInstructions': grpc.unary_unary_rpc_method_handler( + servicer.SendPolicyInstructions, + request_deserializer=lerobot_dot_transport_dot_services__pb2.PolicySetup.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + ), + 'Ready': grpc.unary_unary_rpc_method_handler( + servicer.Ready, + request_deserializer=lerobot_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'transport.AsyncInference', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('transport.AsyncInference', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class AsyncInference: + """AsyncInference: from Robot perspective + Robot send observations to & executes action received from a remote Policy server + """ + + @staticmethod + def SendObservations(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/transport.AsyncInference/SendObservations', + lerobot_dot_transport_dot_services__pb2.Observation.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetActions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/transport.AsyncInference/GetActions', + lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Actions.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendPolicyInstructions(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/transport.AsyncInference/SendPolicyInstructions', + lerobot_dot_transport_dot_services__pb2.PolicySetup.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Ready(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/transport.AsyncInference/Ready', + lerobot_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_transport_dot_services__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/lerobot/transport/utils.py b/src/lerobot/transport/utils.py index bf1aab755..5c9f702fc 100644 --- a/src/lerobot/transport/utils.py +++ b/src/lerobot/transport/utils.py @@ -19,7 +19,8 @@ import io import json import logging import pickle # nosec B403: Safe usage for internal serialization only -from multiprocessing import Event, Queue +from multiprocessing import Event +from queue import Queue from typing import Any import torch @@ -66,7 +67,7 @@ def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "" logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") -def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore +def receive_bytes_in_chunks(iterator, queue: Queue | None, shutdown_event: Event, log_prefix: str = ""): bytes_buffer = io.BytesIO() step = 0 @@ -91,7 +92,10 @@ def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_p bytes_buffer.write(item.data) logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}") - queue.put(bytes_buffer.getvalue()) + if queue is not None: + queue.put(bytes_buffer.getvalue()) + else: + return bytes_buffer.getvalue() bytes_buffer.seek(0) bytes_buffer.truncate(0) diff --git a/tests/async_inference/test_e2e.py b/tests/async_inference/test_e2e.py index d7b68e66b..1c0400e66 100644 --- a/tests/async_inference/test_e2e.py +++ b/tests/async_inference/test_e2e.py @@ -54,8 +54,8 @@ def test_async_inference_e2e(monkeypatch): from lerobot.scripts.server.policy_server import PolicyServer from lerobot.scripts.server.robot_client import RobotClient from lerobot.transport import ( - async_inference_pb2, # type: ignore - async_inference_pb2_grpc, # type: ignore + services_pb2, # type: ignore + services_pb2_grpc, # type: ignore ) from tests.mocks.mock_robot import MockRobotConfig @@ -113,13 +113,13 @@ def test_async_inference_e2e(monkeypatch): # Bypass potentially heavy model loading inside SendPolicyInstructions def _fake_send_policy_instructions(self, request, context): # noqa: N802 - return async_inference_pb2.Empty() + return services_pb2.Empty() monkeypatch.setattr(PolicyServer, "SendPolicyInstructions", _fake_send_policy_instructions, raising=True) # Build gRPC server running a PolicyServer server = grpc.server(futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="policy_server")) - async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) + services_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) # Use the host/port specified in the fixture's config server_address = f"{policy_server.config.host}:{policy_server.config.port}" From 4c8f0020551bc6ba30ac2d7f54906aeed55ab85d Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Thu, 24 Jul 2025 17:09:12 +0200 Subject: [PATCH 035/158] fix(act): disable VAE during offline inference (#1588) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Prevent VAE inference when running in offline mode. In the lerobot dataset, the presence of the 'action' field incorrectly triggers the VAE inference block. This leads to a RuntimeError due to mismatched tensor dimensions (3 vs 2) when concatenating cls_embed, robot_state_embed, and action_embed—since action_embed lacks the chunk_size dimension. Additionally, this aligns with the original paper, where variational inference is skipped during inference. --- src/lerobot/policies/act/modeling_act.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index 4a048e63d..cfd549b25 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -420,7 +420,7 @@ class ACT(nn.Module): batch_size = batch["observation.environment_state"].shape[0] # Prepare the latent for input to the transformer encoder. - if self.config.use_vae and "action" in batch: + if self.config.use_vae and "action" in batch and self.training: # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. cls_embed = einops.repeat( self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size From d4f962fb34ae3bb2265fb241b50c9e3f6e85a798 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 25 Jul 2025 12:06:46 +0200 Subject: [PATCH 036/158] feat(ci): add entrypoints + add version checks + add minimal release testing + uncomment publishing to pypi (#1589) --- .github/workflows/full_tests.yml | 2 + .github/workflows/release.yml | 77 +++++++++++++++++++++++++------- pyproject.toml | 11 +++++ 3 files changed, 74 insertions(+), 16 deletions(-) diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index 55d38883a..d16fe5e72 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -206,3 +206,5 @@ jobs: echo "::error::Failed to delete Docker image. HTTP status: $HTTP_RESPONSE" exit 1 fi + +# TODO(Steven): Check dockerimages pull in ubuntu diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7d80ac5af..32c1c605a 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,12 +20,12 @@ on: - 'v*.*.*' # Trigger on tags like v0.1.0, v1.0.0 jobs: - # TODO(Steven): Publish draft/pre-release and to test pypi - # TODO(Steven): Tag documentation with the same version as the package - # TODO(Steven): Define entry points for main CLI scripts + # This job builds the Python package and publishes it to PyPI build-and-publish: name: Build and publish Python distributions runs-on: ubuntu-latest + outputs: + version: ${{ steps.extract_info.outputs.tag_version }} permissions: contents: write id-token: write @@ -41,26 +41,34 @@ jobs: with: python-version: '3.10' - - name: Extract Version and Package Name + - name: Extract Version id: extract_info + # Extract version from tag (e.g., v0.1.0 -> 0.1.0) # zizmor: ignore[template-injection] run: | - # Extract version from tag (e.g., v0.1.0 -> 0.1.0) VERSION=${{ github.ref_name }} VERSION_NUMBER=${VERSION#v} echo "tag_version=$VERSION_NUMBER" >> $GITHUB_OUTPUT + - name: Check if version matches pyproject.toml + # zizmor: ignore[template-injection] + run: | + TAG_VERSION=${{ steps.extract_info.outputs.tag_version }} - # Extract package name from pyproject.toml - PACKAGE_NAME=$(grep -oP 'name = "\K[^"]+' pyproject.toml) - echo "package_name=$PACKAGE_NAME" >> $GITHUB_OUTPUT + PYPROJECT_VERSION=$(grep '^version = ' pyproject.toml | awk -F' = ' '{print $2}' | tr -d '"') + + if [[ "$TAG_VERSION" != "$PYPROJECT_VERSION" ]]; then + echo "Error: Tag version ($TAG_VERSION) does not match pyproject.toml version ($PYPROJECT_VERSION)." >&2 + exit 1 + else + echo "Tag version matches pyproject.toml version: $TAG_VERSION. Proceeding with release." + fi - name: Check if version exists on PyPI # zizmor: ignore[template-injection] run: | - PACKAGE_NAME=${{ steps.extract_info.outputs.package_name }} NEW_VERSION=${{ steps.extract_info.outputs.tag_version }} - response=$(curl -s "https://pypi.org/pypi/$PACKAGE_NAME/$NEW_VERSION/json") + response=$(curl -s "https://pypi.org/pypi/lerobot/$NEW_VERSION/json") if echo "$response" | grep -q "message"; then echo "Version $NEW_VERSION is available on PyPI. Proceeding with release." else @@ -80,9 +88,46 @@ jobs: # zizmor: ignore[template-injection] run: gh release create ${{ github.ref_name }} --release-name "Release ${{ github.ref_name }}" --generate-notes ./dist/* - # TODO(Steven): Uncomment when ready to publish to PyPI - # - name: Publish to PyPI - # if: startsWith(github.ref, 'refs/tags/v') - # uses: pypa/gh-action-pypi-publish@v1.12.4 - # with: - # password: ${{ secrets.PYPI_API_TOKEN }} + - name: Publish to PyPI + if: startsWith(github.ref, 'refs/tags/v') + uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing] + with: + password: ${{ secrets.PYPI_API_TOKEN }} + + # This job runs end-to-end tests on the release + test-release: + name: Test Release + needs: [build-and-publish] + runs-on: ubuntu-latest + permissions: + contents: read + env: + MUJOCO_GL: egl + steps: + - uses: actions/checkout@v4 + with: + lfs: true + persist-credentials: false + - name: Install apt dependencies + run: | + sudo apt-get update && sudo apt-get install -y build-essential \ + git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \ + speech-dispatcher libgeos-dev portaudio19-dev + - name: Setup uv and Python + uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses] + with: + enable-cache: true + version: ${{ env.UV_VERSION }} + python-version: ${{ env.PYTHON_VERSION }} + - name: Install lerobot release + run: uv run pip install lerobot==${{ needs.build-and-publish.outputs.version }} # zizmor: ignore[template-injection] + + - name: Check lerobot version + run: uv run lerobot --version + + - name: Run end-to-end tests + run: uv run make test-end-to-end + + +# TODO(Steven): Publish draft/pre-release and to test pypi +# TODO(Steven): Tag documentation with the same version as the package diff --git a/pyproject.toml b/pyproject.toml index 7a0ad1480..a05d5c24d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,6 +156,17 @@ all = [ "lerobot[xarm]" ] +[project.scripts] +lerobot-calibrate="lerobot.calibrate:main" +lerobot-find-cameras="lerobot.find_cameras:main" +lerobot-find-port="lerobot.find_port:main" +lerobot-record="lerobot.record:main" +lerobot-replay="lerobot.replay:main" +lerobot-setup-motors="lerobot.setup_motors:main" +lerobot-teleoperate="lerobot.teleoperate:main" +lerobot-eval="lerobot.scripts.eval:main" +lerobot-train="lerobot.scripts.train:main" + # ---------------- Tool Configurations ---------------- [tool.setuptools.packages.find] where = ["src"] From b2a71c6fe4e04aacf1d4767b067085ca81747949 Mon Sep 17 00:00:00 2001 From: HUANG TZU-CHUN Date: Fri, 25 Jul 2025 21:08:00 +0800 Subject: [PATCH 037/158] fix: Rename sync_cache_first to force_cache_sync in LeRobotDataset docstring (#1310) --- src/lerobot/datasets/lerobot_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 72d1a722c..617ac297f 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -433,7 +433,7 @@ class LeRobotDataset(torch.utils.data.Dataset): multiples of 1/fps. Defaults to 1e-4. revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a commit hash. Defaults to current codebase version tag. - sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files + force_cache_sync (bool, optional): Flag to sync and refresh local files first. If True and files are already present in the local cache, this will be faster. However, files loaded might not be in sync with the version on the hub, especially if you specified 'revision'. Defaults to False. From dacd1d7f5c719c3e56d7b7154a751bef6d5bd23c Mon Sep 17 00:00:00 2001 From: arulloomba1 <145633197+arulloomba1@users.noreply.github.com> Date: Fri, 25 Jul 2025 07:44:43 -0700 Subject: [PATCH 038/158] Fixing all broken links in integrate_hardware document (#1445) Signed-off-by: arulloomba1 <145633197+arulloomba1@users.noreply.github.com> --- docs/source/hilserl.mdx | 2 +- docs/source/il_robots.mdx | 2 +- docs/source/integrate_hardware.mdx | 18 +++++++++--------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx index c647a58d5..f66d8cab7 100644 --- a/docs/source/hilserl.mdx +++ b/docs/source/hilserl.mdx @@ -477,7 +477,7 @@ Create a training configuration file (example available [here](https://huggingfa 1. Configure the policy settings (`type="sac"`, `device`, etc.) 2. Set `dataset` to your cropped dataset 3. Configure environment settings with crop parameters -4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/19bb621a7d0a31c20cd3cc08b1dbab68d3031454/lerobot/policies/sac/configuration_sac.py#L79). +4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/19bb621a7d0a31c20cd3cc08b1dbab68d3031454/lerobot/common/policies/sac/configuration_sac.py#L79). 5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task. **Starting the Learner** diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index b18adb8f4..ccca6d508 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -485,7 +485,7 @@ huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \ ## Run inference and evaluate your policy -You can use the `record` script from [`lerobot/record.py`](https://github.com/huggingface/lerobot/blob/main/lerobot/record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes: +You can use the `record` script from [`lerobot/record.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes: diff --git a/docs/source/integrate_hardware.mdx b/docs/source/integrate_hardware.mdx index 089126fcb..7b2e3833f 100644 --- a/docs/source/integrate_hardware.mdx +++ b/docs/source/integrate_hardware.mdx @@ -2,23 +2,23 @@ This tutorial will explain how to integrate your own robot design into the LeRobot ecosystem and have it access all of our tools (data collection, control pipelines, policy training and inference). -To that end, we provide the [`Robot`](https://github.com/huggingface/lerobot/blob/main/lerobot/robots/robot.py) base class in the LeRobot which specifies a standard interface for physical robot integration. Let's see how to implement it. +To that end, we provide the [`Robot`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/robots/robot.py) base class in the LeRobot which specifies a standard interface for physical robot integration. Let's see how to implement it. ## Prerequisites - Your own robot which exposes a communication interface (e.g. serial, CAN, TCP) - A way to read sensor data and send motor commands programmatically, e.g. manufacturer's SDK or API, or your own protocol implementation. -- LeRobot installed in your environment. Follow our [Installation Guide](./installation). +- LeRobot installed in your environment. Follow our [Installation Guide](./installation.mdx). ## Choose your motors If you're using Feetech or Dynamixel motors, LeRobot provides built-in bus interfaces: -- [`FeetechMotorsBus`](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/feetech/feetech.py) – for controlling Feetech servos -- [`DynamixelMotorsBus`](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/dynamixel/dynamixel.py) – for controlling Dynamixel servos +- [`FeetechMotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/feetech/feetech.py) – for controlling Feetech servos +- [`DynamixelMotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/dynamixel/dynamixel.py) – for controlling Dynamixel servos -Please refer to the [`MotorsBus`](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/motors_bus.py) abstract class to learn about its API. -For a good example of how it can be used, you can have a look at our own [SO101 follower implementation](https://github.com/huggingface/lerobot/blob/main/lerobot/robots/so101_follower/so101_follower.py) +Please refer to the [`MotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/motors_bus.py) abstract class to learn about its API. +For a good example of how it can be used, you can have a look at our own [SO101 follower implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/robots/so101_follower/so101_follower.py) Use these if compatible. Otherwise, you'll need to find or write a Python interface (not covered in this tutorial): @@ -29,7 +29,7 @@ You're not alone—many community contributions use custom boards or firmware! For Feetech and Dynamixel, we currently support these servos: - Feetech: - STS & SMS series (protocol 0): `sts3215`, `sts3250`, `sm8512bl` - SCS series (protocol 1): `scs0009` - Dynamixel (protocol 2.0 only): `xl330-m077`, `xl330-m288`, `xl430-w250`, `xm430-w350`, `xm540-w270`, `xc430-w150` -If you are using Feetech or Dynamixel servos that are not in this list, you can add those in the [Feetech table](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/feetech/tables.py) or [Dynamixel table](https://github.com/huggingface/lerobot/blob/main/lerobot/motors/dynamixel/tables.py). Depending on the model, this will require you to add model-specific information. In most cases though, there shouldn't be a lot of additions to do. +If you are using Feetech or Dynamixel servos that are not in this list, you can add those in the [Feetech table](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/feetech/tables.py) or [Dynamixel table](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/dynamixel/tables.py). Depending on the model, this will require you to add model-specific information. In most cases though, there shouldn't be a lot of additions to do. In the next sections, we'll use a `FeetechMotorsBus` as the motors interface for the examples. Replace it and adapt to your motors if necessary. @@ -65,7 +65,7 @@ class MyCoolRobotConfig(RobotConfig): ``` -Have a look at our [Cameras tutorial](./cameras) to understand how to detect and add your camera. +[Cameras tutorial](./cameras.mdx) to understand how to detect and add your camera. Next, we'll create our actual robot class which inherits from `Robot`. This abstract class defines a contract you must follow for your robot to be usable with the rest of the LeRobot tools. @@ -331,7 +331,7 @@ def send_action(self, action: dict[str, Any]) -> dict[str, Any]: ## Adding a Teleoperator -For implementing teleoperation devices, we also provide a [`Teleoperator`](https://github.com/huggingface/lerobot/blob/main/lerobot/teleoperators/teleoperator.py) base class. This class is very similar to the `Robot` base class and also doesn't assume anything on form factor. +For implementing teleoperation devices, we also provide a [`Teleoperator`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/teleoperators/teleoperator.py) base class. This class is very similar to the `Robot` base class and also doesn't assume anything on form factor. The main differences are in the I/O functions: a teleoperator allows you to produce action via `get_action` and can receive feedback actions via `send_feedback`. Feedback could be anything controllable on the teleoperation device that could help the person controlling it understand the consequences of the actions sent. Think motion/force feedback on a leader arm, vibrations on a gamepad controller for example. To implement a teleoperator, you can follow this same tutorial and adapt it for these two methods. From f089ab3628dec08bf10dd954996c310700d20e5e Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Mon, 28 Jul 2025 11:09:18 +0200 Subject: [PATCH 039/158] fix(hf hub dependency): adding ceiling version on huggingface_hub (#1608) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a05d5c24d..5080cd890 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ dependencies = [ # Hugging Face dependencies "datasets>=2.19.0,<=3.6.0", # TODO: Bumb dependency "diffusers>=0.27.2", - "huggingface-hub[hf-transfer,cli]>=0.27.1", + "huggingface-hub[hf-transfer,cli]>=0.27.1,<0.34.0", # Core dependencies "cmake>=3.29.0.1", From 615adfc48d60a8ecb9e1891c773405268770e414 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 28 Jul 2025 11:44:22 +0200 Subject: [PATCH 040/158] smolfix(vla): typing and fix offline inference when action in the batch (#1597) --- src/lerobot/policies/smolvla/modeling_smolvla.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index a31e1b078..d2f78068c 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -384,8 +384,13 @@ class SmolVLAPolicy(PreTrainedPolicy): return self.parameters() def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + # TODO: Check if this for loop is needed. + # Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch + # In the case of offline inference, we have the action in the batch + # that why without the k != ACTION check, it will raise an error because we are trying to stack + # on an empty container. for k in batch: - if k in self._queues: + if k in self._queues and k != ACTION: batch[k] = torch.stack(list(self._queues[k]), dim=1) images, img_masks = self.prepare_images(batch) @@ -631,7 +636,7 @@ class VLAFlowMatching(nn.Module): └──────────────────────────────┘ """ - def __init__(self, config): + def __init__(self, config: SmolVLAConfig): super().__init__() self.config = config From 98746c7cf9dd6fbab7434ea07ceeadcba989eb4b Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 28 Jul 2025 11:45:30 +0200 Subject: [PATCH 041/158] bump wandb version to be compatible with ne grpcio-deps (#1604) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5080cd890..7cd516920 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ dependencies = [ "packaging>=24.2", "pynput>=1.7.7", "pyserial>=3.5", - "wandb>=0.16.3", + "wandb>=0.20.0", "draccus==0.10.0", # TODO: Remove == "gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency From b61a4ded9aaca054164451d72b20e7d8a6528dbf Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 28 Jul 2025 11:49:05 +0200 Subject: [PATCH 042/158] chore(pi0fast): TODO comment to warn the need for removal ignore_index (#1593) --- src/lerobot/policies/pi0fast/modeling_pi0fast.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lerobot/policies/pi0fast/modeling_pi0fast.py b/src/lerobot/policies/pi0fast/modeling_pi0fast.py index d3903066c..80e10bc02 100644 --- a/src/lerobot/policies/pi0fast/modeling_pi0fast.py +++ b/src/lerobot/policies/pi0fast/modeling_pi0fast.py @@ -488,6 +488,8 @@ class PI0FAST(nn.Module): param.data = param.data.to(dtype=torch_precision) self.set_requires_grad() self.image_keys = self.config.image_features.keys() + # TODO: Remove this once we bump transformers to >4.52.0 because the attribute will be removed + # AttributeError: 'PaliGemmaConfig' object has no attribute 'ignore_index' self.ignore_index = self.pi0_paligemma.config.ignore_index self.padding_side = self.config.padding_side From 664e069c3f03fbc142a1ef16106c9b929bd6e5c5 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Mon, 28 Jul 2025 12:55:47 +0200 Subject: [PATCH 043/158] docs/style: updating docs and deprecated links (#1584) --- docs/source/hilserl.mdx | 4 ++-- docs/source/il_robots.mdx | 2 +- docs/source/lekiwi.mdx | 2 +- src/lerobot/__init__.py | 6 +++--- src/lerobot/robots/viperx/README.md | 2 +- tests/async_inference/test_robot_client.py | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx index f66d8cab7..2f73d0964 100644 --- a/docs/source/hilserl.mdx +++ b/docs/source/hilserl.mdx @@ -28,7 +28,7 @@ This guide provides step-by-step instructions for training a robot policy using - A gamepad (recommended) or keyboard to control the robot - A Nvidia GPU - A real robot with a follower and leader arm (optional if you use the keyboard or the gamepad) -- A URDF file for the robot for the kinematics package (check `lerobot/common/model/kinematics.py`) +- A URDF file for the robot for the kinematics package (check `lerobot/model/kinematics.py`) ## What kind of tasks can I train? @@ -477,7 +477,7 @@ Create a training configuration file (example available [here](https://huggingfa 1. Configure the policy settings (`type="sac"`, `device`, etc.) 2. Set `dataset` to your cropped dataset 3. Configure environment settings with crop parameters -4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/19bb621a7d0a31c20cd3cc08b1dbab68d3031454/lerobot/common/policies/sac/configuration_sac.py#L79). +4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79). 5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task. **Starting the Learner** diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index ccca6d508..de80b1fcd 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -323,7 +323,7 @@ The `record` function provides a suite of tools for capturing and managing data ##### 2. Checkpointing and Resuming - Checkpoints are automatically created during recording. -- If an issue occurs, you can resume by re-running the same command with `--resume=true`. +- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset ! - To start recording from scratch, **manually delete** the dataset directory. ##### 3. Recording Parameters diff --git a/docs/source/lekiwi.mdx b/docs/source/lekiwi.mdx index bb70fd26b..a5bdb19cf 100644 --- a/docs/source/lekiwi.mdx +++ b/docs/source/lekiwi.mdx @@ -258,7 +258,7 @@ You should see on your laptop something like this: `[INFO] Connected to remote r | F | Decrease speed | > [!TIP] -> If you use a different keyboard, you can change the keys for each command in the [`LeKiwiConfig`](../src/lerobot/robot_devices/robots/configs.py). +> If you use a different keyboard, you can change the keys for each command in the [`LeKiwiClientConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/robots/lekiwi/config_lekiwi.py). ### Wired version diff --git a/src/lerobot/__init__.py b/src/lerobot/__init__.py index 38d4e8644..9d3ed1893 100644 --- a/src/lerobot/__init__.py +++ b/src/lerobot/__init__.py @@ -170,7 +170,7 @@ available_datasets = sorted( # lists all available policies from `lerobot/policies` available_policies = ["act", "diffusion", "tdmpc", "vqbet"] -# lists all available robots from `lerobot/robot_devices/robots` +# lists all available robots from `lerobot/robots` available_robots = [ "koch", "koch_bimanual", @@ -179,13 +179,13 @@ available_robots = [ "so101", ] -# lists all available cameras from `lerobot/robot_devices/cameras` +# lists all available cameras from `lerobot/cameras` available_cameras = [ "opencv", "intelrealsense", ] -# lists all available motors from `lerobot/robot_devices/motors` +# lists all available motors from `lerobot/motors` available_motors = [ "dynamixel", "feetech", diff --git a/src/lerobot/robots/viperx/README.md b/src/lerobot/robots/viperx/README.md index 4e90c99c7..5cdb152a2 100644 --- a/src/lerobot/robots/viperx/README.md +++ b/src/lerobot/robots/viperx/README.md @@ -63,7 +63,7 @@ python lerobot/scripts/control_robot.py \ --control.type=teleoperate ``` -By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`AlohaRobotConfig`](lerobot/robot_devices/robots/configs.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line: +By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`ViperXConfig`](./config_viperx.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line: ```bash python lerobot/scripts/control_robot.py \ diff --git a/tests/async_inference/test_robot_client.py b/tests/async_inference/test_robot_client.py index d1273ae63..51db2c3a7 100644 --- a/tests/async_inference/test_robot_client.py +++ b/tests/async_inference/test_robot_client.py @@ -13,7 +13,7 @@ # limitations under the License. """Unit-tests for the `RobotClient` action-queue logic (pure Python, no gRPC). -We monkey-patch `lerobot.common.robot_devices.robots.utils.make_robot` so that +We monkey-patch `lerobot.robots.utils.make_robot_from_config` so that no real hardware is accessed. Only the queue-update mechanism is verified. """ From c3d5e494c0b530368332c0b0eb114c32bc3b8f2c Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 28 Jul 2025 13:10:34 +0200 Subject: [PATCH 044/158] fix(policies): remove action from batch for offline evaluation (#1609) * fix(policies): remove action from batch for offline evaluation in diffusion, tdmpc, and vqbet policies * style(diffusion): correct comment capitalization for clarity in modeling_diffusion.py --- src/lerobot/policies/diffusion/modeling_diffusion.py | 6 +++++- src/lerobot/policies/tdmpc/modeling_tdmpc.py | 5 +++++ src/lerobot/policies/vqbet/modeling_vqbet.py | 7 +++++-- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 24b273967..941a3acb5 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -133,11 +133,15 @@ class DiffusionPolicy(PreTrainedPolicy): "horizon" may not the best name to describe what the variable actually means, because this period is actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. """ + # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out + if ACTION in batch: + batch.pop(ACTION) + batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) - # Note: It's important that this happens after stacking the images into a single key. + # NOTE: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) if len(self._queues[ACTION]) == 0: diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index 664fe863d..7ba88e5e6 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -143,7 +143,12 @@ class TDMPCPolicy(PreTrainedPolicy): @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" + # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out + if ACTION in batch: + batch.pop(ACTION) + batch = self.normalize_inputs(batch) + if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index b271298a3..feb65bb4c 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -139,11 +139,14 @@ class VQBeTPolicy(PreTrainedPolicy): environment. It works by managing the actions in a queue and only calling `select_actions` when the queue is empty. """ - + # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out + if ACTION in batch: + batch.pop(ACTION) batch = self.normalize_inputs(batch) batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + # NOTE: It's important that this happens after stacking the images into a single key. batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) - # Note: It's important that this happens after stacking the images into a single key. + self._queues = populate_queues(self._queues, batch) if not self.vqbet.action_head.vqvae_model.discretized.item(): From 4b88842d20c3872674a77a1cc06ca023b443bb9f Mon Sep 17 00:00:00 2001 From: Kleist Bond <61907235+KleistvonLiu@users.noreply.github.com> Date: Mon, 28 Jul 2025 21:17:30 +0800 Subject: [PATCH 045/158] fix bug about sampling time from beta distribution (#1605) * fix bug about sampling t from beta distribution * fix: address review comments --------- --- src/lerobot/policies/pi0/modeling_pi0.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 11feca964..e9e6014f8 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -515,9 +515,10 @@ class PI0FlowMatching(nn.Module): return noise def sample_time(self, bsize, device): - time_beta = sample_beta(1.5, 1.0, bsize, device) + beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0) + time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32) time = time_beta * 0.999 + 0.001 - return time.to(dtype=torch.float32, device=device) + return time def embed_prefix( self, images, img_masks, lang_tokens, lang_masks From 7fe6adaf617c41bbb8a3a65d42a37f6e038184f4 Mon Sep 17 00:00:00 2001 From: Lumen Yang <45258158+LumenYoung@users.noreply.github.com> Date: Mon, 28 Jul 2025 15:22:37 +0200 Subject: [PATCH 046/158] fix(config): typing correction on config.py (#1320) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michel Aractingi --- src/lerobot/envs/configs.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index ef381e9e7..35797c6ed 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -44,7 +44,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC): @EnvConfig.register_subclass("aloha") @dataclass class AlohaEnv(EnvConfig): - task: str = "AlohaInsertion-v0" + task: str | None = "AlohaInsertion-v0" fps: int = 50 episode_length: int = 400 obs_type: str = "pixels_agent_pos" @@ -82,7 +82,7 @@ class AlohaEnv(EnvConfig): @EnvConfig.register_subclass("pusht") @dataclass class PushtEnv(EnvConfig): - task: str = "PushT-v0" + task: str | None = "PushT-v0" fps: int = 10 episode_length: int = 300 obs_type: str = "pixels_agent_pos" @@ -124,7 +124,7 @@ class PushtEnv(EnvConfig): @EnvConfig.register_subclass("xarm") @dataclass class XarmEnv(EnvConfig): - task: str = "XarmLift-v0" + task: str | None = "XarmLift-v0" fps: int = 15 episode_length: int = 200 obs_type: str = "pixels_agent_pos" @@ -200,10 +200,10 @@ class HILSerlRobotEnvConfig(EnvConfig): wrapper: EnvTransformConfig | None = None fps: int = 10 name: str = "real_robot" - mode: str = None # Either "record", "replay", None + mode: str | None = None # Either "record", "replay", None repo_id: str | None = None dataset_root: str | None = None - task: str = "" + task: str | None = "" num_episodes: int = 10 # only for record mode episode: int = 0 device: str = "cuda" @@ -213,6 +213,7 @@ class HILSerlRobotEnvConfig(EnvConfig): # For the reward classifier, to record more positive examples after a success number_of_steps_after_success: int = 0 + @property def gym_kwargs(self) -> dict: return {} @@ -222,9 +223,8 @@ class HILSerlRobotEnvConfig(EnvConfig): class HILEnvConfig(EnvConfig): """Configuration for the HIL environment.""" - type: str = "hil" name: str = "PandaPickCube" - task: str = "PandaPickCubeKeyboard-v0" + task: str | None = "PandaPickCubeKeyboard-v0" use_viewer: bool = True gripper_penalty: float = 0.0 use_gamepad: bool = True @@ -252,7 +252,7 @@ class HILEnvConfig(EnvConfig): robot_config: RobotConfig | None = None teleop_config: TeleoperatorConfig | None = None wrapper: EnvTransformConfig | None = None - mode: str = None # Either "record", "replay", None + mode: str | None = None # Either "record", "replay", None repo_id: str | None = None dataset_root: str | None = None num_episodes: int = 10 # only for record mode From b267cd40f7fba70c35c841cc9faa7e9788d4a4f6 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Mon, 28 Jul 2025 17:05:44 +0200 Subject: [PATCH 047/158] fix(tokenizers dependency): adding ceiling version on tokenizers (#1612) --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7cd516920..2bce3ecbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ dependencies = [ # Hugging Face dependencies "datasets>=2.19.0,<=3.6.0", # TODO: Bumb dependency "diffusers>=0.27.2", - "huggingface-hub[hf-transfer,cli]>=0.27.1,<0.34.0", + "huggingface-hub[hf-transfer,cli]>=0.27.1,<0.34.0", # TODO: Bumb dependency # Core dependencies "cmake>=3.29.0.1", @@ -94,7 +94,7 @@ dependencies = [ # Common pygame-dep = ["pygame>=2.5.1"] placo-dep = ["placo>=0.9.6"] -transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency +transformers-dep = ["transformers>=4.50.3,<4.52.0", "tokenizers<0.21.4"] # TODO: Bumb dependency, remove tokenizers dependency grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # Motors From c7c3b477d6d39cf7046a7225eecc4e5debe67065 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 28 Jul 2025 17:28:55 +0200 Subject: [PATCH 048/158] Fix sample beta for smolvla as done for pi0, remove sample_beta func (#1611) --- src/lerobot/policies/pi0/modeling_pi0.py | 6 ------ src/lerobot/policies/smolvla/modeling_smolvla.py | 11 +++-------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index e9e6014f8..a34aa34f9 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -90,12 +90,6 @@ def create_sinusoidal_pos_embedding( return pos_emb -def sample_beta(alpha, beta, bsize, device): - gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha) - gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta) - return gamma1 / (gamma1 + gamma2) - - def make_att_2d_masks(pad_masks, att_masks): """Copied from big_vision. diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index d2f78068c..469645e84 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -194,12 +194,6 @@ def create_sinusoidal_pos_embedding( return pos_emb -def sample_beta(alpha, beta, bsize, device): - gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha) - gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta) - return gamma1 / (gamma1 + gamma2) - - def make_att_2d_masks(pad_masks, att_masks): """Copied from big_vision. @@ -690,9 +684,10 @@ class VLAFlowMatching(nn.Module): return noise def sample_time(self, bsize, device): - time_beta = sample_beta(1.5, 1.0, bsize, device) + beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0) + time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32) time = time_beta * 0.999 + 0.001 - return time.to(dtype=torch.float32, device=device) + return time def embed_prefix( self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None From c14ab9e97be6a25ac8751c8a01c53d6296278405 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Tue, 29 Jul 2025 10:59:23 +0200 Subject: [PATCH 049/158] fix(dependencies): removing versions ceilings on tokenizers and huggingface_hub dependencies (#1618) --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2bce3ecbe..a8680e39f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ dependencies = [ # Hugging Face dependencies "datasets>=2.19.0,<=3.6.0", # TODO: Bumb dependency "diffusers>=0.27.2", - "huggingface-hub[hf-transfer,cli]>=0.27.1,<0.34.0", # TODO: Bumb dependency + "huggingface-hub[hf-transfer,cli]>=0.34.2", # Core dependencies "cmake>=3.29.0.1", @@ -94,7 +94,7 @@ dependencies = [ # Common pygame-dep = ["pygame>=2.5.1"] placo-dep = ["placo>=0.9.6"] -transformers-dep = ["transformers>=4.50.3,<4.52.0", "tokenizers<0.21.4"] # TODO: Bumb dependency, remove tokenizers dependency +transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # Motors From 5695432142c44f787ab6432f44faa8126932bda5 Mon Sep 17 00:00:00 2001 From: Abhay Deshpande Date: Tue, 29 Jul 2025 04:40:16 -0700 Subject: [PATCH 050/158] fix(DiffusionPolicy): Fix bug where training without image features would crash with exception, fix environment state docs (#1617) * Fix bug in diffusion config validation when not using image features * Fix DiffusionPolicy docstring about shape of env state --- .../policies/diffusion/configuration_diffusion.py | 13 +++++++------ .../policies/diffusion/modeling_diffusion.py | 4 ++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index ce2de7052..54569434a 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -217,12 +217,13 @@ class DiffusionConfig(PreTrainedConfig): ) # Check that all input images have the same shape. - first_image_key, first_image_ft = next(iter(self.image_features.items())) - for key, image_ft in self.image_features.items(): - if image_ft.shape != first_image_ft.shape: - raise ValueError( - f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." - ) + if len(self.image_features) > 0: + first_image_key, first_image_ft = next(iter(self.image_features.items())) + for key, image_ft in self.image_features.items(): + if image_ft.shape != first_image_ft.shape: + raise ValueError( + f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." + ) @property def observation_delta_indices(self) -> list: diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 941a3acb5..85d4d5981 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -288,7 +288,7 @@ class DiffusionModel(nn.Module): "observation.images": (B, n_obs_steps, num_cameras, C, H, W) AND/OR - "observation.environment_state": (B, environment_dim) + "observation.environment_state": (B, n_obs_steps, environment_dim) } """ batch_size, n_obs_steps = batch["observation.state"].shape[:2] @@ -315,7 +315,7 @@ class DiffusionModel(nn.Module): "observation.images": (B, n_obs_steps, num_cameras, C, H, W) AND/OR - "observation.environment_state": (B, environment_dim) + "observation.environment_state": (B, n_obs_steps, environment_dim) "action": (B, horizon, action_dim) "action_is_pad": (B, horizon) From 67196c9d5344cd932612cef79229f9d04134c91e Mon Sep 17 00:00:00 2001 From: Rayen Ghali Date: Tue, 29 Jul 2025 08:54:43 -0300 Subject: [PATCH 051/158] fix(180-degree rotation): Add `cv2.ROTATE_180` to rotation checks in both OpenCV and RealSense camera implementations --- src/lerobot/cameras/opencv/camera_opencv.py | 2 +- src/lerobot/cameras/realsense/camera_realsense.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index 7ad9988cc..aad19819a 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -368,7 +368,7 @@ class OpenCVCamera(Camera): if requested_color_mode == ColorMode.RGB: processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]: processed_image = cv2.rotate(processed_image, self.rotation) return processed_image diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index 74b055fa4..918c5592e 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -434,7 +434,7 @@ class RealSenseCamera(Camera): if self.color_mode == ColorMode.BGR: processed_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]: processed_image = cv2.rotate(processed_image, self.rotation) return processed_image From 71eff183ff1a286c9d3ad0962de3750a7ca00cb4 Mon Sep 17 00:00:00 2001 From: Yushun Xiang <73413365+YushunXiang@users.noreply.github.com> Date: Wed, 30 Jul 2025 23:38:32 +0800 Subject: [PATCH 052/158] Fix pi0 checkpoint state map (#1415) Co-authored-by: Michel Aractingi --- src/lerobot/policies/pi0/modeling_pi0.py | 96 +++++++++++++++++++++++- src/lerobot/policies/pretrained.py | 33 +++++--- src/lerobot/policies/utils.py | 14 ++++ 3 files changed, 130 insertions(+), 13 deletions(-) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index a34aa34f9..e56946ac8 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -66,7 +66,8 @@ from lerobot.policies.pi0.paligemma_with_expert import ( PaliGemmaWithExpertModel, ) from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.utils import get_safe_dtype +from lerobot.policies.utils import log_model_loading_keys +from lerobot.utils.utils import get_safe_dtype, init_logging def create_sinusoidal_pos_embedding( @@ -252,6 +253,99 @@ class PI0Policy(PreTrainedPolicy): """This should be called whenever the environment is reset.""" self._action_queue = deque([], maxlen=self.config.n_action_steps) + @classmethod + def _transform_state_dict_keys(cls, state_dict: dict) -> dict: + """ + Transform state dict keys to match expected model structure. + + Transformations: + - model.paligemma_with_expert.paligemma.language_model.lm_head -> + model.paligemma_with_expert.paligemma.lm_head + - model.paligemma_with_expert.paligemma.language_model.model -> + model.paligemma_with_expert.paligemma.model.language_model + - model.paligemma_with_expert.paligemma.vision_tower -> + model.paligemma_with_expert.paligemma.model.vision_tower + - model.paligemma_with_expert.paligemma.multi_modal_projector -> + model.paligemma_with_expert.paligemma.model.multi_modal_projector + + Also handles tied weights between lm_head.weight and + embed_tokens.weight. + """ + import re + + transformed_dict = {} + + transformations = [ + ( + re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"), + ".paligemma_with_expert.paligemma.lm_head", + ), + ( + re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"), + ".paligemma_with_expert.paligemma.model.language_model", + ), + ( + re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"), + ".paligemma_with_expert.paligemma.model.vision_tower", + ), + ( + re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"), + ".paligemma_with_expert.paligemma.model.multi_modal_projector", + ), + ] + + for key, value in state_dict.items(): + new_key = key + for pattern, replacement in transformations: + new_key = pattern.sub(replacement, new_key) + transformed_dict[new_key] = value + + # Handle tied weights: lm_head.weight and embed_tokens.weight share memory + lm_head_key = None + embed_tokens_key = None + + for key in transformed_dict: + if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"): + lm_head_key = key + elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"): + embed_tokens_key = key + if lm_head_key and embed_tokens_key: + break + + if lm_head_key and not embed_tokens_key: + embed_tokens_key = lm_head_key.replace( + ".lm_head.weight", ".model.language_model.embed_tokens.weight" + ) + transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key] + elif embed_tokens_key and not lm_head_key: + lm_head_key = embed_tokens_key.replace( + ".model.language_model.embed_tokens.weight", ".lm_head.weight" + ) + transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key] + + return transformed_dict + + @classmethod + def _load_as_safetensor( + cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool + ) -> "PI0Policy": + """Override to apply key transformations before loading.""" + from safetensors.torch import load_file + + init_logging() + # Load the state dict from file safely + state_dict = load_file(model_file, device=map_location) + + # Apply key transformations + transformed_state_dict = cls._transform_state_dict_keys(state_dict) + + # Load the transformed state dict + msg = model.load_state_dict(transformed_state_dict, strict=strict) + + # Log message + log_model_loading_keys(msg.missing_keys, msg.unexpected_keys) + return model + def get_optim_params(self) -> dict: return self.parameters() diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index d745c901c..2f69309c1 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -30,6 +30,7 @@ from torch import Tensor, nn from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.train import TrainPipelineConfig +from lerobot.policies.utils import log_model_loading_keys from lerobot.utils.hub import HubMixin T = TypeVar("T", bound="PreTrainedPolicy") @@ -128,18 +129,26 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): @classmethod def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: - if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): - load_model_as_safetensor(model, model_file, strict=strict) - if map_location != "cpu": - logging.warning( - "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." - " This means that the model is loaded on 'cpu' first and then copied to the device." - " This leads to a slower loading time." - " Please update safetensors to version 0.4.3 or above for improved performance." - ) - model.to(map_location) - else: - safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) + # Create base kwargs + kwargs = {"strict": strict} + + # Add device parameter for newer versions that support it + if packaging.version.parse(safetensors.__version__) >= packaging.version.parse("0.4.3"): + kwargs["device"] = map_location + + # Load the model with appropriate kwargs + missing_keys, unexpected_keys = load_model_as_safetensor(model, model_file, **kwargs) + log_model_loading_keys(missing_keys, unexpected_keys) + + # For older versions, manually move to device if needed + if "device" not in kwargs and map_location != "cpu": + logging.warning( + "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." + " This means that the model is loaded on 'cpu' first and then copied to the device." + " This leads to a slower loading time." + " Please update safetensors to version 0.4.3 or above for improved performance." + ) + model.to(map_location) return model @abc.abstractmethod diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index 5659e8727..5a3994cdf 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from collections import deque import torch @@ -71,3 +72,16 @@ def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple: with torch.inference_mode(): output = module(dummy_input) return tuple(output.shape) + + +def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str]) -> None: + """Log missing and unexpected keys when loading a model. + + Args: + missing_keys (list[str]): Keys that were expected but not found. + unexpected_keys (list[str]): Keys that were found but not expected. + """ + if missing_keys: + logging.warning(f"Missing key(s) when loading model: {missing_keys}") + if unexpected_keys: + logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}") From 945e1ff2669bb7b31cb7fe6033fe9679767c2442 Mon Sep 17 00:00:00 2001 From: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Date: Thu, 31 Jul 2025 11:08:12 +0200 Subject: [PATCH 053/158] fix colab typo (#1629) Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --- docs/source/il_robots.mdx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index de80b1fcd..8c075e5b2 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -462,9 +462,9 @@ If you do not want to push your model to the hub after training use `--policy.pu Additionally you can provide extra `tags` or specify a `license` for your model or make the model repo `private` by adding this: `--policy.private=true --policy.tags=\[ppo,rl\] --policy.license=mit` -#### Train using Collab +#### Train using Google Colab -If your local computer doesn't have a powerful GPU you could utilize Google Collab to train your model by following the [ACT training notebook](./notebooks#training-act). +If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act). #### Upload policy checkpoints From 91ed6097bc0b9e86668c1cfd4dbd1cc05348d2cf Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 1 Aug 2025 12:04:34 +0200 Subject: [PATCH 054/158] fix(ci): declare entrypoints + fix testing release (#1642) --- .github/workflows/release.yml | 9 +++++++-- src/lerobot/calibrate.py | 6 +++++- src/lerobot/find_cameras.py | 6 +++++- src/lerobot/find_port.py | 6 +++++- src/lerobot/record.py | 6 +++++- src/lerobot/replay.py | 6 +++++- src/lerobot/scripts/eval.py | 6 +++++- src/lerobot/scripts/train.py | 6 +++++- src/lerobot/setup_motors.py | 6 +++++- src/lerobot/teleoperate.py | 6 +++++- 10 files changed, 52 insertions(+), 11 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 32c1c605a..63d60f5d8 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -19,6 +19,11 @@ on: tags: - 'v*.*.*' # Trigger on tags like v0.1.0, v1.0.0 +# Sets up the environment variables +env: + UV_VERSION: "0.8.0" + PYTHON_VERSION: "3.10" + jobs: # This job builds the Python package and publishes it to PyPI build-and-publish: @@ -120,10 +125,10 @@ jobs: version: ${{ env.UV_VERSION }} python-version: ${{ env.PYTHON_VERSION }} - name: Install lerobot release - run: uv run pip install lerobot==${{ needs.build-and-publish.outputs.version }} # zizmor: ignore[template-injection] + run: uv run pip install "lerobot[all]==${{ needs.build-and-publish.outputs.version }}" # zizmor: ignore[template-injection] - name: Check lerobot version - run: uv run lerobot --version + run: uv run python -c "import lerobot; print(lerobot.__version__)" - name: Run end-to-end tests run: uv run make test-end-to-end diff --git a/src/lerobot/calibrate.py b/src/lerobot/calibrate.py index 1e8bf4751..0dda80ba2 100644 --- a/src/lerobot/calibrate.py +++ b/src/lerobot/calibrate.py @@ -82,5 +82,9 @@ def calibrate(cfg: CalibrateConfig): device.disconnect() -if __name__ == "__main__": +def main(): calibrate() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/find_cameras.py b/src/lerobot/find_cameras.py index be8f272ee..8f88d3107 100644 --- a/src/lerobot/find_cameras.py +++ b/src/lerobot/find_cameras.py @@ -286,7 +286,7 @@ def save_images_from_all_cameras( print(f"Image capture finished. Images saved to {output_dir}") -if __name__ == "__main__": +def main(): parser = argparse.ArgumentParser( description="Unified camera utility script for listing cameras and capturing images." ) @@ -313,3 +313,7 @@ if __name__ == "__main__": ) args = parser.parse_args() save_images_from_all_cameras(**vars(args)) + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/find_port.py b/src/lerobot/find_port.py index cf0282507..babe0288e 100644 --- a/src/lerobot/find_port.py +++ b/src/lerobot/find_port.py @@ -61,5 +61,9 @@ def find_port(): raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).") -if __name__ == "__main__": +def main(): find_port() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/record.py b/src/lerobot/record.py index d662efcab..575fcb94d 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -393,5 +393,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset: return dataset -if __name__ == "__main__": +def main(): record() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/replay.py b/src/lerobot/replay.py index afe54d90f..a9dceb741 100644 --- a/src/lerobot/replay.py +++ b/src/lerobot/replay.py @@ -112,5 +112,9 @@ def replay(cfg: ReplayConfig): robot.disconnect() -if __name__ == "__main__": +def main(): replay() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index 7c5aec48a..6a6c02a24 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -501,6 +501,10 @@ def eval_main(cfg: EvalPipelineConfig): logging.info("End of eval") -if __name__ == "__main__": +def main(): init_logging() eval_main() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index f09d231a8..235352cd8 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -286,6 +286,10 @@ def train(cfg: TrainPipelineConfig): policy.push_model_to_hub(cfg) -if __name__ == "__main__": +def main(): init_logging() train() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/setup_motors.py b/src/lerobot/setup_motors.py index c54582a1d..76cdca56d 100644 --- a/src/lerobot/setup_motors.py +++ b/src/lerobot/setup_motors.py @@ -80,5 +80,9 @@ def setup_motors(cfg: SetupConfig): device.setup_motors() -if __name__ == "__main__": +def main(): setup_motors() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/teleoperate.py b/src/lerobot/teleoperate.py index 9836f1393..3c72caf79 100644 --- a/src/lerobot/teleoperate.py +++ b/src/lerobot/teleoperate.py @@ -153,5 +153,9 @@ def teleoperate(cfg: TeleoperateConfig): robot.disconnect() -if __name__ == "__main__": +def main(): teleoperate() + + +if __name__ == "__main__": + main() From 1baaa77a86eee9f5a51b2295c67b93f89bbc11a1 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 1 Aug 2025 17:14:15 +0200 Subject: [PATCH 055/158] feat(ci): release workflow publish to pypi test + lock files (#1643) * chore(ci): add some release stuff * chore(ci): add requirements-macos * chore(ci): added lockfiles for future reference * feat(ci): add draft & prerelease option to release workflow tag --- .github/workflows/release.yml | 27 +- docs-requirements.txt | 3 + docs/README.md | 2 +- pyproject.toml | 2 - requirements-macos.txt | 625 ++++++++++++++++++++++++++++++++ requirements-ubuntu.txt | 650 ++++++++++++++++++++++++++++++++++ requirements.in | 9 + 7 files changed, 1310 insertions(+), 8 deletions(-) create mode 100644 docs-requirements.txt create mode 100644 requirements-macos.txt create mode 100644 requirements-ubuntu.txt create mode 100644 requirements.in diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 63d60f5d8..b4eff4589 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -91,13 +91,29 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # zizmor: ignore[template-injection] - run: gh release create ${{ github.ref_name }} --release-name "Release ${{ github.ref_name }}" --generate-notes ./dist/* + run: | + gh release create ${{ github.ref_name }} \ + --release-name "Release ${{ github.ref_name }}" \ + --generate-notes \ + --draft=$([[ "${{ github.ref_name }}" == *-* ]] && echo true || echo false) \ + --prerelease=$([[ "${{ github.ref_name }}" == *-* ]] && echo true || echo false) \ + ./dist/* - - name: Publish to PyPI - if: startsWith(github.ref, 'refs/tags/v') + - name: Publish to TestPyPI for pre-releases + # True for tags like 'v0.2.0-rc1' + if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-') uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing] with: - password: ${{ secrets.PYPI_API_TOKEN }} + repository-url: https://test.pypi.org/legacy/ + verbose: true + print-hash: true + + - name: Publish to PyPI + if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-') + uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing] + with: + verbose: true + print-hash: true # This job runs end-to-end tests on the release test-release: @@ -134,5 +150,6 @@ jobs: run: uv run make test-end-to-end -# TODO(Steven): Publish draft/pre-release and to test pypi +# TODO(Steven): Publish draft/pre-release and to test pypi weekly +# TODO(Steven): Separate build and publish job # TODO(Steven): Tag documentation with the same version as the package diff --git a/docs-requirements.txt b/docs-requirements.txt new file mode 100644 index 000000000..e286ad2bb --- /dev/null +++ b/docs-requirements.txt @@ -0,0 +1,3 @@ +# docs-requirements.txt +hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main +watchdog>=6.0.0 diff --git a/docs/README.md b/docs/README.md index 967de7b84..476eb8dce 100644 --- a/docs/README.md +++ b/docs/README.md @@ -20,7 +20,7 @@ To generate the documentation, you first have to build it. Several packages are you can install them with the following command, at the root of the code repository: ```bash -pip install -e ".[docs]" +pip install -e . -r docs-requirements.txt ``` You will also need `nodejs`. Please refer to their [installation page](https://nodejs.org/en/download) diff --git a/pyproject.toml b/pyproject.toml index a8680e39f..c369b612c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,7 +125,6 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "lerobot[grpcio-dep]", async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"] # Development -docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"] dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"] test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"] video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] @@ -147,7 +146,6 @@ all = [ "lerobot[smolvla]", "lerobot[hilserl]", "lerobot[async]", - "lerobot[docs]", "lerobot[dev]", "lerobot[test]", "lerobot[video_benchmark]", diff --git a/requirements-macos.txt b/requirements-macos.txt new file mode 100644 index 000000000..07e263da5 --- /dev/null +++ b/requirements-macos.txt @@ -0,0 +1,625 @@ +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile --output-file=requirements-macos.txt requirements.in +# +-e .[all] + # via -[all] +absl-py==2.3.1 + # via + # dm-control + # dm-env + # dm-tree + # labmaze + # mujoco +accelerate==1.9.0 + # via lerobot +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.15 + # via fsspec +aiosignal==1.4.0 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +asttokens==3.0.0 + # via stack-data +async-timeout==5.0.1 + # via aiohttp +attrs==25.3.0 + # via + # aiohttp + # dm-tree + # jsonlines + # rerun-sdk +av==15.0.0 + # via lerobot +blinker==1.9.0 + # via flask +certifi==2025.7.14 + # via + # requests + # sentry-sdk +cffi==1.17.1 + # via pymunk +cfgv==3.4.0 + # via pre-commit +charset-normalizer==3.4.2 + # via requests +click==8.2.1 + # via + # flask + # wandb +cloudpickle==3.1.1 + # via gymnasium +cmake==4.0.3 + # via lerobot +cmeel==0.57.3 + # via + # cmeel-assimp + # cmeel-boost + # cmeel-console-bridge + # cmeel-octomap + # cmeel-qhull + # cmeel-tinyxml2 + # cmeel-urdfdom + # cmeel-zlib + # coal-library + # eigenpy + # eiquadprog + # pin + # placo + # rhoban-cmeel-jsoncpp +cmeel-assimp==5.4.3.1 + # via coal-library +cmeel-boost==1.87.0.1 + # via + # coal-library + # eigenpy + # eiquadprog + # pin +cmeel-console-bridge==1.0.2.3 + # via cmeel-urdfdom +cmeel-octomap==1.10.0 + # via coal-library +cmeel-qhull==8.0.2.1 + # via coal-library +cmeel-tinyxml2==10.0.0 + # via cmeel-urdfdom +cmeel-urdfdom==4.0.1 + # via pin +cmeel-zlib==1.3.1 + # via cmeel-assimp +coal-library==3.0.1 + # via pin +contourpy==1.3.2 + # via matplotlib +coverage[toml]==7.10.1 + # via pytest-cov +cycler==0.12.1 + # via matplotlib +datasets==3.6.0 + # via lerobot +debugpy==1.8.15 + # via lerobot +decorator==5.2.1 + # via ipython +deepdiff==8.5.0 + # via lerobot +diffusers==0.34.0 + # via lerobot +dill==0.3.8 + # via + # datasets + # multiprocess +distlib==0.4.0 + # via virtualenv +dm-control==1.0.14 + # via gym-aloha +dm-env==1.6 + # via dm-control +dm-tree==0.1.9 + # via + # dm-control + # dm-env +docopt==0.6.2 + # via num2words +draccus==0.10.0 + # via lerobot +dynamixel-sdk==3.7.31 + # via lerobot +eigenpy==3.10.3 + # via coal-library +einops==0.8.1 + # via lerobot +eiquadprog==1.2.9 + # via placo +exceptiongroup==1.3.0 + # via + # ipython + # pytest +executing==2.2.0 + # via stack-data +farama-notifications==0.0.4 + # via gymnasium +feetech-servo-sdk==1.0.0 + # via lerobot +filelock==3.18.0 + # via + # datasets + # diffusers + # huggingface-hub + # torch + # transformers + # virtualenv +flask==3.1.1 + # via lerobot +fonttools==4.59.0 + # via matplotlib +frozenlist==1.7.0 + # via + # aiohttp + # aiosignal +fsspec[http]==2025.3.0 + # via + # datasets + # huggingface-hub + # torch +gitdb==4.0.12 + # via gitpython +gitpython==3.1.45 + # via wandb +glfw==2.9.0 + # via + # dm-control + # mujoco +grpcio==1.73.1 + # via + # grpcio-tools + # lerobot +grpcio-tools==1.73.1 + # via lerobot +gym-aloha==0.1.1 + # via lerobot +gym-hil==0.1.10 + # via lerobot +gym-pusht==0.1.5 + # via lerobot +gym-xarm==0.1.1 + # via lerobot +gymnasium==0.29.1 + # via + # gym-aloha + # gym-hil + # gym-pusht + # gym-xarm + # gymnasium-robotics + # lerobot + # pettingzoo +gymnasium-robotics==1.2.4 + # via gym-xarm +hf-transfer==0.1.9 + # via huggingface-hub +hf-xet==1.1.5 + # via huggingface-hub +hidapi==0.14.0.post4 + # via + # gym-hil + # lerobot +huggingface-hub[cli,hf-transfer]==0.34.3 + # via + # accelerate + # datasets + # diffusers + # lerobot + # tokenizers + # transformers +identify==2.6.12 + # via pre-commit +idna==3.10 + # via + # requests + # yarl +imageio[ffmpeg]==2.37.0 + # via + # gym-aloha + # gym-hil + # gymnasium-robotics + # lerobot + # scikit-image +imageio-ffmpeg==0.6.0 + # via imageio +importlib-metadata==8.7.0 + # via diffusers +iniconfig==2.1.0 + # via pytest +inquirerpy==0.3.4 + # via huggingface-hub +ipython==8.37.0 + # via meshcat +ischedule==1.2.7 + # via placo +itsdangerous==2.2.0 + # via flask +jedi==0.19.2 + # via ipython +jinja2==3.1.6 + # via + # flask + # gymnasium-robotics + # torch +jsonlines==4.0.0 + # via lerobot +kiwisolver==1.4.8 + # via matplotlib +labmaze==1.0.6 + # via dm-control +lazy-loader==0.4 + # via scikit-image +lxml==6.0.0 + # via dm-control +markupsafe==3.0.2 + # via + # flask + # jinja2 + # werkzeug +matplotlib==3.10.5 + # via lerobot +matplotlib-inline==0.1.7 + # via ipython +mergedeep==1.3.4 + # via draccus +meshcat==0.3.2 + # via placo +mock-serial==0.0.1 + # via lerobot +mpmath==1.3.0 + # via sympy +mujoco==2.3.7 + # via + # dm-control + # gym-aloha + # gym-hil + # gym-xarm + # gymnasium-robotics +multidict==6.6.3 + # via + # aiohttp + # yarl +multiprocess==0.70.16 + # via datasets +mypy-extensions==1.1.0 + # via typing-inspect +networkx==3.4.2 + # via + # scikit-image + # torch +nodeenv==1.9.1 + # via pre-commit +num2words==0.5.14 + # via lerobot +numpy==2.2.6 + # via + # accelerate + # cmeel-boost + # contourpy + # datasets + # diffusers + # dm-control + # dm-env + # dm-tree + # gymnasium + # gymnasium-robotics + # imageio + # labmaze + # matplotlib + # meshcat + # mujoco + # opencv-python + # opencv-python-headless + # pandas + # pettingzoo + # rerun-sdk + # scikit-image + # scipy + # shapely + # tifffile + # torchvision + # transformers +opencv-python==4.12.0.88 + # via gym-pusht +opencv-python-headless==4.12.0.88 + # via lerobot +orderly-set==5.5.0 + # via deepdiff +packaging==25.0 + # via + # accelerate + # datasets + # huggingface-hub + # lazy-loader + # lerobot + # matplotlib + # pytest + # scikit-image + # transformers + # wandb +pandas==2.3.1 + # via + # datasets + # lerobot +parso==0.8.4 + # via jedi +pettingzoo==1.24.3 + # via gymnasium-robotics +pexpect==4.9.0 + # via ipython +pfzy==0.3.4 + # via inquirerpy +pillow==11.3.0 + # via + # diffusers + # imageio + # matplotlib + # meshcat + # rerun-sdk + # scikit-image + # torchvision +pin==3.4.0 + # via placo +placo==0.9.14 + # via lerobot +platformdirs==4.3.8 + # via + # virtualenv + # wandb +pluggy==1.6.0 + # via + # pytest + # pytest-cov +pre-commit==4.2.0 + # via lerobot +prompt-toolkit==3.0.51 + # via + # inquirerpy + # ipython +propcache==0.3.2 + # via + # aiohttp + # yarl +protobuf==6.31.0 + # via + # dm-control + # grpcio-tools + # lerobot + # wandb +psutil==7.0.0 + # via + # accelerate + # imageio +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.3 + # via stack-data +pyarrow==21.0.0 + # via + # datasets + # rerun-sdk +pycparser==2.22 + # via cffi +pydantic==2.11.7 + # via wandb +pydantic-core==2.33.2 + # via pydantic +pygame==2.6.1 + # via + # gym-hil + # gym-pusht + # lerobot +pygments==2.19.2 + # via + # ipython + # pytest +pymunk==6.11.1 + # via + # gym-pusht + # lerobot +pyngrok==7.2.12 + # via meshcat +pynput==1.8.1 + # via + # gym-hil + # lerobot +pyobjc-core==11.1 + # via + # pyobjc-framework-applicationservices + # pyobjc-framework-cocoa + # pyobjc-framework-coretext + # pyobjc-framework-quartz +pyobjc-framework-applicationservices==11.1 + # via pynput +pyobjc-framework-cocoa==11.1 + # via + # pyobjc-framework-applicationservices + # pyobjc-framework-coretext + # pyobjc-framework-quartz +pyobjc-framework-coretext==11.1 + # via pyobjc-framework-applicationservices +pyobjc-framework-quartz==11.1 + # via + # pynput + # pyobjc-framework-applicationservices + # pyobjc-framework-coretext +pyopengl==3.1.9 + # via + # dm-control + # mujoco +pyparsing==3.2.3 + # via + # dm-control + # matplotlib +pyrealsense2-macosx==2.54.2 + # via lerobot +pyserial==3.5 + # via + # dynamixel-sdk + # feetech-servo-sdk + # lerobot +pytest==8.4.1 + # via + # lerobot + # pytest-cov + # pytest-timeout +pytest-cov==6.2.1 + # via lerobot +pytest-timeout==2.4.0 + # via lerobot +python-dateutil==2.9.0.post0 + # via + # matplotlib + # pandas +pytz==2025.2 + # via pandas +pyyaml==6.0.2 + # via + # accelerate + # datasets + # draccus + # huggingface-hub + # pre-commit + # pyngrok + # pyyaml-include + # transformers + # wandb +pyyaml-include==1.4.1 + # via draccus +pyzmq==27.0.0 + # via + # lerobot + # meshcat +regex==2025.7.34 + # via + # diffusers + # transformers +requests==2.32.4 + # via + # datasets + # diffusers + # dm-control + # huggingface-hub + # transformers + # wandb +rerun-sdk==0.22.1 + # via lerobot +rhoban-cmeel-jsoncpp==1.9.4.9 + # via placo +safetensors==0.5.3 + # via + # accelerate + # diffusers + # lerobot + # transformers +scikit-image==0.25.2 + # via + # gym-pusht + # lerobot +scipy==1.15.3 + # via + # dm-control + # scikit-image +sentry-sdk==2.34.1 + # via wandb +shapely==2.1.1 + # via gym-pusht +six==1.17.0 + # via + # pynput + # python-dateutil +smmap==5.0.2 + # via gitdb +stack-data==0.6.3 + # via ipython +sympy==1.14.0 + # via torch +termcolor==3.1.0 + # via lerobot +tifffile==2025.5.10 + # via scikit-image +tokenizers==0.21.4 + # via transformers +toml==0.10.2 + # via draccus +tomli==2.2.1 + # via + # cmeel + # coverage + # pytest +torch==2.7.1 + # via + # accelerate + # lerobot + # torchvision +torchcodec==0.5 + # via lerobot +torchvision==0.22.1 + # via lerobot +tornado==6.5.1 + # via meshcat +tqdm==4.67.1 + # via + # datasets + # dm-control + # huggingface-hub + # transformers +traitlets==5.14.3 + # via + # ipython + # matplotlib-inline +transformers==4.51.3 + # via lerobot +typing-extensions==4.14.1 + # via + # aiosignal + # exceptiongroup + # gymnasium + # huggingface-hub + # ipython + # multidict + # pydantic + # pydantic-core + # rerun-sdk + # torch + # typing-inspect + # typing-inspection + # wandb +typing-inspect==0.9.0 + # via draccus +typing-inspection==0.4.1 + # via pydantic +tzdata==2025.2 + # via pandas +u-msgpack-python==2.8.0 + # via meshcat +urllib3==2.5.0 + # via + # requests + # sentry-sdk +virtualenv==20.32.0 + # via pre-commit +wandb==0.21.0 + # via lerobot +wcwidth==0.2.13 + # via prompt-toolkit +werkzeug==3.1.3 + # via flask +wrapt==1.17.2 + # via dm-tree +xxhash==3.5.0 + # via datasets +yarl==1.20.1 + # via aiohttp +zipp==3.23.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements-ubuntu.txt b/requirements-ubuntu.txt new file mode 100644 index 000000000..af7258d67 --- /dev/null +++ b/requirements-ubuntu.txt @@ -0,0 +1,650 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile --output-file=requirements-ubuntu.txt requirements.in +# +-e .[all] + # via -[all] +absl-py==2.3.1 + # via + # dm-control + # dm-env + # dm-tree + # labmaze + # mujoco +accelerate==1.9.0 + # via lerobot +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.15 + # via fsspec +aiosignal==1.4.0 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +asttokens==3.0.0 + # via stack-data +async-timeout==5.0.1 + # via aiohttp +attrs==25.3.0 + # via + # aiohttp + # dm-tree + # jsonlines + # rerun-sdk +av==15.0.0 + # via lerobot +blinker==1.9.0 + # via flask +certifi==2025.7.14 + # via + # requests + # sentry-sdk +cffi==1.17.1 + # via pymunk +cfgv==3.4.0 + # via pre-commit +charset-normalizer==3.4.2 + # via requests +click==8.2.1 + # via + # flask + # wandb +cloudpickle==3.1.1 + # via gymnasium +cmake==4.0.3 + # via lerobot +cmeel==0.57.3 + # via + # cmeel-assimp + # cmeel-boost + # cmeel-console-bridge + # cmeel-octomap + # cmeel-qhull + # cmeel-tinyxml2 + # cmeel-urdfdom + # cmeel-zlib + # coal-library + # eigenpy + # eiquadprog + # pin + # placo + # rhoban-cmeel-jsoncpp +cmeel-assimp==5.4.3.1 + # via coal-library +cmeel-boost==1.87.0.1 + # via + # coal-library + # eigenpy + # eiquadprog + # pin +cmeel-console-bridge==1.0.2.3 + # via cmeel-urdfdom +cmeel-octomap==1.10.0 + # via coal-library +cmeel-qhull==8.0.2.1 + # via coal-library +cmeel-tinyxml2==10.0.0 + # via cmeel-urdfdom +cmeel-urdfdom==4.0.1 + # via pin +cmeel-zlib==1.3.1 + # via cmeel-assimp +coal-library==3.0.1 + # via pin +contourpy==1.3.2 + # via matplotlib +coverage[toml]==7.10.1 + # via pytest-cov +cycler==0.12.1 + # via matplotlib +datasets==3.6.0 + # via lerobot +debugpy==1.8.15 + # via lerobot +decorator==5.2.1 + # via ipython +deepdiff==8.5.0 + # via lerobot +diffusers==0.34.0 + # via lerobot +dill==0.3.8 + # via + # datasets + # multiprocess +distlib==0.4.0 + # via virtualenv +dm-control==1.0.14 + # via gym-aloha +dm-env==1.6 + # via dm-control +dm-tree==0.1.9 + # via + # dm-control + # dm-env +docopt==0.6.2 + # via num2words +draccus==0.10.0 + # via lerobot +dynamixel-sdk==3.7.31 + # via lerobot +eigenpy==3.10.3 + # via coal-library +einops==0.8.1 + # via lerobot +eiquadprog==1.2.9 + # via placo +evdev==1.9.2 + # via pynput +exceptiongroup==1.3.0 + # via + # ipython + # pytest +executing==2.2.0 + # via stack-data +farama-notifications==0.0.4 + # via gymnasium +feetech-servo-sdk==1.0.0 + # via lerobot +filelock==3.18.0 + # via + # datasets + # diffusers + # huggingface-hub + # torch + # transformers + # virtualenv +flask==3.1.1 + # via lerobot +fonttools==4.59.0 + # via matplotlib +frozenlist==1.7.0 + # via + # aiohttp + # aiosignal +fsspec[http]==2025.3.0 + # via + # datasets + # huggingface-hub + # torch +gitdb==4.0.12 + # via gitpython +gitpython==3.1.45 + # via wandb +glfw==2.9.0 + # via + # dm-control + # mujoco +grpcio==1.73.1 + # via + # grpcio-tools + # lerobot +grpcio-tools==1.73.1 + # via lerobot +gym-aloha==0.1.1 + # via lerobot +gym-hil==0.1.10 + # via lerobot +gym-pusht==0.1.5 + # via lerobot +gym-xarm==0.1.1 + # via lerobot +gymnasium==0.29.1 + # via + # gym-aloha + # gym-hil + # gym-pusht + # gym-xarm + # gymnasium-robotics + # lerobot + # pettingzoo +gymnasium-robotics==1.2.4 + # via gym-xarm +hf-transfer==0.1.9 + # via huggingface-hub +hf-xet==1.1.5 + # via huggingface-hub +hidapi==0.14.0.post4 + # via + # gym-hil + # lerobot +huggingface-hub[cli,hf-transfer]==0.34.3 + # via + # accelerate + # datasets + # diffusers + # lerobot + # tokenizers + # transformers +identify==2.6.12 + # via pre-commit +idna==3.10 + # via + # requests + # yarl +imageio[ffmpeg]==2.37.0 + # via + # gym-aloha + # gym-hil + # gymnasium-robotics + # lerobot + # scikit-image +imageio-ffmpeg==0.6.0 + # via imageio +importlib-metadata==8.7.0 + # via diffusers +iniconfig==2.1.0 + # via pytest +inquirerpy==0.3.4 + # via huggingface-hub +ipython==8.37.0 + # via meshcat +ischedule==1.2.7 + # via placo +itsdangerous==2.2.0 + # via flask +jedi==0.19.2 + # via ipython +jinja2==3.1.6 + # via + # flask + # gymnasium-robotics + # torch +jsonlines==4.0.0 + # via lerobot +kiwisolver==1.4.8 + # via matplotlib +labmaze==1.0.6 + # via dm-control +lazy-loader==0.4 + # via scikit-image +lxml==6.0.0 + # via dm-control +markupsafe==3.0.2 + # via + # flask + # jinja2 + # werkzeug +matplotlib==3.10.5 + # via lerobot +matplotlib-inline==0.1.7 + # via ipython +mergedeep==1.3.4 + # via draccus +meshcat==0.3.2 + # via placo +mock-serial==0.0.1 + # via lerobot +mpmath==1.3.0 + # via sympy +mujoco==2.3.7 + # via + # dm-control + # gym-aloha + # gym-hil + # gym-xarm + # gymnasium-robotics +multidict==6.6.3 + # via + # aiohttp + # yarl +multiprocess==0.70.16 + # via datasets +mypy-extensions==1.1.0 + # via typing-inspect +networkx==3.4.2 + # via + # scikit-image + # torch +nodeenv==1.9.1 + # via pre-commit +num2words==0.5.14 + # via lerobot +numpy==2.2.6 + # via + # accelerate + # cmeel-boost + # contourpy + # datasets + # diffusers + # dm-control + # dm-env + # dm-tree + # gymnasium + # gymnasium-robotics + # imageio + # labmaze + # matplotlib + # meshcat + # mujoco + # opencv-python + # opencv-python-headless + # pandas + # pettingzoo + # rerun-sdk + # scikit-image + # scipy + # shapely + # tifffile + # torchvision + # transformers +nvidia-cublas-cu12==12.6.4.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.6.80 + # via torch +nvidia-cuda-nvrtc-cu12==12.6.77 + # via torch +nvidia-cuda-runtime-cu12==12.6.77 + # via torch +nvidia-cudnn-cu12==9.5.1.17 + # via torch +nvidia-cufft-cu12==11.3.0.4 + # via torch +nvidia-cufile-cu12==1.11.1.6 + # via torch +nvidia-curand-cu12==10.3.7.77 + # via torch +nvidia-cusolver-cu12==11.7.1.2 + # via torch +nvidia-cusparse-cu12==12.5.4.2 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-cusparselt-cu12==0.6.3 + # via torch +nvidia-nccl-cu12==2.26.2 + # via torch +nvidia-nvjitlink-cu12==12.6.85 + # via + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + # torch +nvidia-nvtx-cu12==12.6.77 + # via torch +opencv-python==4.12.0.88 + # via gym-pusht +opencv-python-headless==4.12.0.88 + # via lerobot +orderly-set==5.5.0 + # via deepdiff +packaging==25.0 + # via + # accelerate + # datasets + # huggingface-hub + # lazy-loader + # lerobot + # matplotlib + # pytest + # scikit-image + # transformers + # wandb +pandas==2.3.1 + # via + # datasets + # lerobot +parso==0.8.4 + # via jedi +pettingzoo==1.24.3 + # via gymnasium-robotics +pexpect==4.9.0 + # via ipython +pfzy==0.3.4 + # via inquirerpy +pillow==11.3.0 + # via + # diffusers + # imageio + # matplotlib + # meshcat + # rerun-sdk + # scikit-image + # torchvision +pin==3.4.0 + # via placo +placo==0.9.14 + # via lerobot +platformdirs==4.3.8 + # via + # virtualenv + # wandb +pluggy==1.6.0 + # via + # pytest + # pytest-cov +pre-commit==4.2.0 + # via lerobot +prompt-toolkit==3.0.51 + # via + # inquirerpy + # ipython +propcache==0.3.2 + # via + # aiohttp + # yarl +protobuf==6.31.0 + # via + # dm-control + # grpcio-tools + # lerobot + # wandb +psutil==7.0.0 + # via + # accelerate + # imageio +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.3 + # via stack-data +pyarrow==21.0.0 + # via + # datasets + # rerun-sdk +pycparser==2.22 + # via cffi +pydantic==2.11.7 + # via wandb +pydantic-core==2.33.2 + # via pydantic +pygame==2.6.1 + # via + # gym-hil + # gym-pusht + # lerobot +pygments==2.19.2 + # via + # ipython + # pytest +pymunk==6.11.1 + # via + # gym-pusht + # lerobot +pyngrok==7.2.12 + # via meshcat +pynput==1.8.1 + # via + # gym-hil + # lerobot +pyopengl==3.1.9 + # via + # dm-control + # mujoco +pyparsing==3.2.3 + # via + # dm-control + # matplotlib +pyrealsense2==2.56.5.9235 + # via lerobot +pyserial==3.5 + # via + # dynamixel-sdk + # feetech-servo-sdk + # lerobot +pytest==8.4.1 + # via + # lerobot + # pytest-cov + # pytest-timeout +pytest-cov==6.2.1 + # via lerobot +pytest-timeout==2.4.0 + # via lerobot +python-dateutil==2.9.0.post0 + # via + # matplotlib + # pandas +python-xlib==0.33 + # via pynput +pytz==2025.2 + # via pandas +pyyaml==6.0.2 + # via + # accelerate + # datasets + # draccus + # huggingface-hub + # pre-commit + # pyngrok + # pyyaml-include + # transformers + # wandb +pyyaml-include==1.4.1 + # via draccus +pyzmq==27.0.0 + # via + # lerobot + # meshcat +regex==2025.7.34 + # via + # diffusers + # transformers +requests==2.32.4 + # via + # datasets + # diffusers + # dm-control + # huggingface-hub + # transformers + # wandb +rerun-sdk==0.22.1 + # via lerobot +rhoban-cmeel-jsoncpp==1.9.4.9 + # via placo +safetensors==0.5.3 + # via + # accelerate + # diffusers + # lerobot + # transformers +scikit-image==0.25.2 + # via + # gym-pusht + # lerobot +scipy==1.15.3 + # via + # dm-control + # scikit-image +sentry-sdk==2.34.1 + # via wandb +shapely==2.1.1 + # via gym-pusht +six==1.17.0 + # via + # pynput + # python-dateutil + # python-xlib +smmap==5.0.2 + # via gitdb +stack-data==0.6.3 + # via ipython +sympy==1.14.0 + # via torch +termcolor==3.1.0 + # via lerobot +tifffile==2025.5.10 + # via scikit-image +tokenizers==0.21.4 + # via transformers +toml==0.10.2 + # via draccus +tomli==2.2.1 + # via + # cmeel + # coverage + # pytest +torch==2.7.1 + # via + # accelerate + # lerobot + # torchvision +torchcodec==0.5 + # via lerobot +torchvision==0.22.1 + # via lerobot +tornado==6.5.1 + # via meshcat +tqdm==4.67.1 + # via + # datasets + # dm-control + # huggingface-hub + # transformers +traitlets==5.14.3 + # via + # ipython + # matplotlib-inline +transformers==4.51.3 + # via lerobot +triton==3.3.1 + # via torch +typing-extensions==4.14.1 + # via + # aiosignal + # exceptiongroup + # gymnasium + # huggingface-hub + # ipython + # multidict + # pydantic + # pydantic-core + # rerun-sdk + # torch + # typing-inspect + # typing-inspection + # wandb +typing-inspect==0.9.0 + # via draccus +typing-inspection==0.4.1 + # via pydantic +tzdata==2025.2 + # via pandas +u-msgpack-python==2.8.0 + # via meshcat +urllib3==2.5.0 + # via + # requests + # sentry-sdk +virtualenv==20.32.0 + # via pre-commit +wandb==0.21.0 + # via lerobot +wcwidth==0.2.13 + # via prompt-toolkit +werkzeug==3.1.3 + # via flask +wrapt==1.17.2 + # via dm-tree +xxhash==3.5.0 + # via datasets +yarl==1.20.1 + # via aiohttp +zipp==3.23.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements.in b/requirements.in new file mode 100644 index 000000000..272f7f540 --- /dev/null +++ b/requirements.in @@ -0,0 +1,9 @@ +# requirements.in + +# requirements-macos.txt was generated on macOS and is platform-specific (macOS 15.5 24F74 arm64). +# Darwin MacBook-Pro.local 24.5.0 Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:43 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8132 arm64 + +# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.2 LTS x86_64). +# Linux mlerobot-linux 6.14.0-27-generic #27~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 22 17:38:49 UTC 2 x86_64 x86_64 x86_64 GNU/Linux + +-e .[all] From 2f8d98b05e622c2b8dbbf931b601c44127f4124c Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Fri, 1 Aug 2025 17:39:39 +0200 Subject: [PATCH 056/158] Update readme (#1570) * Cleanup badges * Remove comment * Remove profiling section * Move acknowledgment * Move citations * Fix badge display * Move build your robot section * Fix nightly badge * Revert be13b3f * Update README.md Co-authored-by: HUANG TZU-CHUN Signed-off-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> * chore(docs): optimize readme for PyPI rendering * chore(docs): move policy readme to docs folder + symlink in policy dirs * fix(docs): max width og lerobot logo + url in citation block --------- Signed-off-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: HUANG TZU-CHUN Co-authored-by: Steven Palma --- README.md | 201 ++++------------------- docs/source/policy_act_README.md | 14 ++ docs/source/policy_diffusion_README.md | 14 ++ docs/source/policy_smolvla_README.md | 14 ++ docs/source/policy_tdmpc_README.md | 14 ++ docs/source/policy_vqbet_README.md | 14 ++ src/lerobot/policies/act/README.md | 1 + src/lerobot/policies/diffusion/README.md | 1 + src/lerobot/policies/smolvla/README.md | 1 + src/lerobot/policies/tdmpc/README.md | 1 + src/lerobot/policies/vqbet/README.md | 1 + 11 files changed, 108 insertions(+), 168 deletions(-) create mode 100644 docs/source/policy_act_README.md create mode 100644 docs/source/policy_diffusion_README.md create mode 100644 docs/source/policy_smolvla_README.md create mode 100644 docs/source/policy_tdmpc_README.md create mode 100644 docs/source/policy_vqbet_README.md create mode 120000 src/lerobot/policies/act/README.md create mode 120000 src/lerobot/policies/diffusion/README.md create mode 120000 src/lerobot/policies/smolvla/README.md create mode 120000 src/lerobot/policies/tdmpc/README.md create mode 120000 src/lerobot/policies/vqbet/README.md diff --git a/README.md b/README.md index 1d7cbcad4..13cc95f90 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,21 @@

- - - - LeRobot, Hugging Face Robotics Library - + LeRobot, Hugging Face Robotics Library

-[![Tests](https://github.com/huggingface/lerobot/actions/workflows/nightly-tests.yml/badge.svg?branch=main)](https://github.com/huggingface/lerobot/actions/workflows/nightly-tests.yml?query=branch%3Amain) -[![Coverage](https://codecov.io/gh/huggingface/lerobot/branch/main/graph/badge.svg?token=TODO)](https://codecov.io/gh/huggingface/lerobot) +[![Tests](https://github.com/huggingface/lerobot/actions/workflows/nightly.yml/badge.svg?branch=main)](https://github.com/huggingface/lerobot/actions/workflows/nighty.yml?query=branch%3Amain) [![Python versions](https://img.shields.io/pypi/pyversions/lerobot)](https://www.python.org/downloads/) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/huggingface/lerobot/blob/main/LICENSE) [![Status](https://img.shields.io/pypi/status/lerobot)](https://pypi.org/project/lerobot/) [![Version](https://img.shields.io/pypi/v/lerobot)](https://pypi.org/project/lerobot/) -[![Examples](https://img.shields.io/badge/Examples-green.svg)](https://github.com/huggingface/lerobot/tree/main/examples) -[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v2.1%20adopted-ff69b4.svg)](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md) +[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v2.1-ff69b4.svg)](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md) [![Discord](https://dcbadge.vercel.app/api/server/C5P34WJ68S?style=flat)](https://discord.gg/s3KuuzsPFb) + +

@@ -29,10 +25,10 @@
HopeJR robot

Meet HopeJR – A humanoid robot arm and hand for dexterous manipulation!

@@ -51,20 +47,12 @@

-
- SO-101 follower arm - SO-101 leader arm -
+ + + + + +
SO-101 follower armSO-101 leader arm

Meet the updated SO100, the SO-101 – Just €114 per arm!

Train it in minutes with a few simple moves on your laptop.

@@ -76,7 +64,7 @@

Want to take it to the next level? Make your SO-101 mobile by building LeKiwi!

Check out the LeKiwi tutorial and bring your robot to life on wheels.

- LeKiwi mobile robot + LeKiwi mobile robot

@@ -99,9 +87,9 @@ - - - + + + @@ -110,24 +98,9 @@
ACT policy on ALOHA envTDMPC policy on SimXArm envDiffusion policy on PushT envACT policy on ALOHA envTDMPC policy on SimXArm envDiffusion policy on PushT env
ACT policy on ALOHA env
-### Acknowledgment - -- The LeRobot team 🤗 for building SmolVLA [Paper](https://arxiv.org/abs/2506.01844), [Blog](https://huggingface.co/blog/smolvla). -- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io). -- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io). -- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM). -- Thanks to Antonio Loquercio and Ashish Kumar for their early support. -- Thanks to [Seungjae (Jay) Lee](https://sjlee.cc/), [Mahi Shafiullah](https://mahis.life/) and colleagues for open sourcing [VQ-BeT](https://sjlee.cc/vq-bet/) policy and helping us adapt the codebase to our repository. The policy is adapted from [VQ-BeT repo](https://github.com/jayLEE0301/vq_bet_official). - ## Installation -Download our source code: - -```bash -git clone https://github.com/huggingface/lerobot.git -cd lerobot -``` - +LeRobot works with Python 3.10+ and PyTorch 2.2+. Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html): ```bash @@ -154,7 +127,7 @@ conda install ffmpeg -c conda-forge Install 🤗 LeRobot: ```bash -pip install -e . +pip install lerobot ``` > **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run: @@ -182,7 +155,7 @@ wandb login ### Visualize datasets -Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub. +Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub. You can also locally visualize episodes from a dataset on the hub by executing our script from the command line: @@ -212,7 +185,7 @@ Our script can also visualize datasets stored on a distant server. See `python - A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and PyTorch dataset. For instance `dataset[0]` will retrieve a single temporal frame from the dataset containing observation(s) and an action as PyTorch tensors ready to be fed to a model. -A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`. +A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](https://github.com/huggingface/lerobot/blob/main/examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`. Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states but easily extended to other types of sensory inputs as long as they can be represented by a tensor. @@ -256,7 +229,7 @@ Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work ### Evaluate a pretrained policy -Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment. +Check out [example 2](https://github.com/huggingface/lerobot/blob/main/examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment. We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht): @@ -280,13 +253,13 @@ See `python -m lerobot.scripts.eval --help` for more instructions. ### Train your own policy -Check out [example 3](./examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line. +Check out [example 3](https://github.com/huggingface/lerobot/blob/main/examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md) that shows how to use our training script from command line. To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding `--wandb.enable=true`. -A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](./examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs. +A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs. -![](media/wandb.png) +\WandB logs example Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python -m lerobot.scripts.eval --help` for more instructions. @@ -305,26 +278,6 @@ reproduces SOTA results for Diffusion Policy on the PushT task. If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md). - - ### Add a pretrained policy Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)). @@ -341,34 +294,16 @@ To upload these to the hub, run the following: huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model ``` -See [eval.py](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/eval.py) for an example of how other people may use your policy. +See [eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/eval.py) for an example of how other people may use your policy. -### Improve your code with profiling +### Acknowledgment -An example of a code snippet to profile the evaluation of a policy: - - -```python -from torch.profiler import profile, record_function, ProfilerActivity - -def trace_handler(prof): - prof.export_chrome_trace(f"tmp/trace_schedule_{prof.step_num}.json") - -with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - schedule=torch.profiler.schedule( - wait=2, - warmup=2, - active=3, - ), - on_trace_ready=trace_handler -) as prof: - with record_function("eval_policy"): - for i in range(num_episodes): - prof.step() - # insert code to profile, potentially whole body of eval_policy function -``` - +- The LeRobot team 🤗 for building SmolVLA [Paper](https://arxiv.org/abs/2506.01844), [Blog](https://huggingface.co/blog/smolvla). +- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io). +- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io). +- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM). +- Thanks to Antonio Loquercio and Ashish Kumar for their early support. +- Thanks to [Seungjae (Jay) Lee](https://sjlee.cc/), [Mahi Shafiullah](https://mahis.life/) and colleagues for open sourcing [VQ-BeT](https://sjlee.cc/vq-bet/) policy and helping us adapt the codebase to our repository. The policy is adapted from [VQ-BeT repo](https://github.com/jayLEE0301/vq_bet_official). ## Citation @@ -383,76 +318,6 @@ If you want, you can cite this work with: } ``` -Additionally, if you are using any of the particular policy architecture, pretrained models, or datasets, it is recommended to cite the original authors of the work as they appear below: - -- [SmolVLA](https://arxiv.org/abs/2506.01844) - -```bibtex -@article{shukor2025smolvla, - title={SmolVLA: A Vision-Language-Action Model for Affordable and Efficient Robotics}, - author={Shukor, Mustafa and Aubakirova, Dana and Capuano, Francesco and Kooijmans, Pepijn and Palma, Steven and Zouitine, Adil and Aractingi, Michel and Pascal, Caroline and Russi, Martino and Marafioti, Andres and Alibert, Simon and Cord, Matthieu and Wolf, Thomas and Cadene, Remi}, - journal={arXiv preprint arXiv:2506.01844}, - year={2025} -} -``` - -- [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) - -```bibtex -@article{chi2024diffusionpolicy, - author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song}, - title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion}, - journal = {The International Journal of Robotics Research}, - year = {2024}, -} -``` - -- [ACT or ALOHA](https://tonyzhaozh.github.io/aloha) - -```bibtex -@article{zhao2023learning, - title={Learning fine-grained bimanual manipulation with low-cost hardware}, - author={Zhao, Tony Z and Kumar, Vikash and Levine, Sergey and Finn, Chelsea}, - journal={arXiv preprint arXiv:2304.13705}, - year={2023} -} -``` - -- [TDMPC](https://www.nicklashansen.com/td-mpc/) - -```bibtex -@inproceedings{Hansen2022tdmpc, - title={Temporal Difference Learning for Model Predictive Control}, - author={Nicklas Hansen and Xiaolong Wang and Hao Su}, - booktitle={ICML}, - year={2022} -} -``` - -- [VQ-BeT](https://sjlee.cc/vq-bet/) - -```bibtex -@article{lee2024behavior, - title={Behavior generation with latent actions}, - author={Lee, Seungjae and Wang, Yibin and Etukuru, Haritheja and Kim, H Jin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel}, - journal={arXiv preprint arXiv:2403.03181}, - year={2024} -} -``` - -- [HIL-SERL](https://hil-serl.github.io/) - -```bibtex -@Article{luo2024hilserl, -title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning}, -author={Jianlan Luo and Charles Xu and Jeffrey Wu and Sergey Levine}, -year={2024}, -eprint={2410.21845}, -archivePrefix={arXiv}, -primaryClass={cs.RO} -} -``` - ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=huggingface/lerobot&type=Timeline)](https://star-history.com/#huggingface/lerobot&Timeline) diff --git a/docs/source/policy_act_README.md b/docs/source/policy_act_README.md new file mode 100644 index 000000000..371a9136f --- /dev/null +++ b/docs/source/policy_act_README.md @@ -0,0 +1,14 @@ +## Paper + +https://tonyzhaozh.github.io/aloha + +## Citation + +```bibtex +@article{zhao2023learning, + title={Learning fine-grained bimanual manipulation with low-cost hardware}, + author={Zhao, Tony Z and Kumar, Vikash and Levine, Sergey and Finn, Chelsea}, + journal={arXiv preprint arXiv:2304.13705}, + year={2023} +} +``` diff --git a/docs/source/policy_diffusion_README.md b/docs/source/policy_diffusion_README.md new file mode 100644 index 000000000..9ec934add --- /dev/null +++ b/docs/source/policy_diffusion_README.md @@ -0,0 +1,14 @@ +## Paper + +https://diffusion-policy.cs.columbia.edu + +## Citation + +```bibtex +@article{chi2024diffusionpolicy, + author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song}, + title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion}, + journal = {The International Journal of Robotics Research}, + year = {2024}, +} +``` diff --git a/docs/source/policy_smolvla_README.md b/docs/source/policy_smolvla_README.md new file mode 100644 index 000000000..ee567ee83 --- /dev/null +++ b/docs/source/policy_smolvla_README.md @@ -0,0 +1,14 @@ +## Paper + +https://arxiv.org/abs/2506.01844 + +## Citation + +```bibtex +@article{shukor2025smolvla, + title={SmolVLA: A Vision-Language-Action Model for Affordable and Efficient Robotics}, + author={Shukor, Mustafa and Aubakirova, Dana and Capuano, Francesco and Kooijmans, Pepijn and Palma, Steven and Zouitine, Adil and Aractingi, Michel and Pascal, Caroline and Russi, Martino and Marafioti, Andres and Alibert, Simon and Cord, Matthieu and Wolf, Thomas and Cadene, Remi}, + journal={arXiv preprint arXiv:2506.01844}, + year={2025} +} +``` diff --git a/docs/source/policy_tdmpc_README.md b/docs/source/policy_tdmpc_README.md new file mode 100644 index 000000000..804f166c8 --- /dev/null +++ b/docs/source/policy_tdmpc_README.md @@ -0,0 +1,14 @@ +## Paper + +https://www.nicklashansen.com/td-mpc/ + +## Citation + +```bibtex +@inproceedings{Hansen2022tdmpc, + title={Temporal Difference Learning for Model Predictive Control}, + author={Nicklas Hansen and Xiaolong Wang and Hao Su}, + booktitle={ICML}, + year={2022} +} +``` diff --git a/docs/source/policy_vqbet_README.md b/docs/source/policy_vqbet_README.md new file mode 100644 index 000000000..02f95b7c2 --- /dev/null +++ b/docs/source/policy_vqbet_README.md @@ -0,0 +1,14 @@ +## Paper + +https://sjlee.cc/vq-bet/ + +## Citation + +```bibtex +@article{lee2024behavior, + title={Behavior generation with latent actions}, + author={Lee, Seungjae and Wang, Yibin and Etukuru, Haritheja and Kim, H Jin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel}, + journal={arXiv preprint arXiv:2403.03181}, + year={2024} +} +``` diff --git a/src/lerobot/policies/act/README.md b/src/lerobot/policies/act/README.md new file mode 120000 index 000000000..046020098 --- /dev/null +++ b/src/lerobot/policies/act/README.md @@ -0,0 +1 @@ +../../../../docs/source/policy_act_README.md \ No newline at end of file diff --git a/src/lerobot/policies/diffusion/README.md b/src/lerobot/policies/diffusion/README.md new file mode 120000 index 000000000..d332d79c8 --- /dev/null +++ b/src/lerobot/policies/diffusion/README.md @@ -0,0 +1 @@ +../../../../docs/source/policy_diffusion_README.md \ No newline at end of file diff --git a/src/lerobot/policies/smolvla/README.md b/src/lerobot/policies/smolvla/README.md new file mode 120000 index 000000000..f8de40269 --- /dev/null +++ b/src/lerobot/policies/smolvla/README.md @@ -0,0 +1 @@ +../../../../docs/source/policy_smolvla_README.md \ No newline at end of file diff --git a/src/lerobot/policies/tdmpc/README.md b/src/lerobot/policies/tdmpc/README.md new file mode 120000 index 000000000..413ea87b8 --- /dev/null +++ b/src/lerobot/policies/tdmpc/README.md @@ -0,0 +1 @@ +../../../../docs/source/policy_tdmpc_README.md \ No newline at end of file diff --git a/src/lerobot/policies/vqbet/README.md b/src/lerobot/policies/vqbet/README.md new file mode 120000 index 000000000..a4ae9291a --- /dev/null +++ b/src/lerobot/policies/vqbet/README.md @@ -0,0 +1 @@ +../../../../docs/source/policy_vqbet_README.md \ No newline at end of file From 11525cedeb5d2e906b81d2d57b64bf16a2f2a351 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 1 Aug 2025 18:05:20 +0200 Subject: [PATCH 057/158] fix(ci): change steps based on wheter it is a -rc tag (#1646) --- .github/workflows/release.yml | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b4eff4589..b0e16e0f7 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -55,6 +55,7 @@ jobs: VERSION_NUMBER=${VERSION#v} echo "tag_version=$VERSION_NUMBER" >> $GITHUB_OUTPUT - name: Check if version matches pyproject.toml + if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-') # zizmor: ignore[template-injection] run: | TAG_VERSION=${{ steps.extract_info.outputs.tag_version }} @@ -141,8 +142,19 @@ jobs: version: ${{ env.UV_VERSION }} python-version: ${{ env.PYTHON_VERSION }} - name: Install lerobot release - run: uv run pip install "lerobot[all]==${{ needs.build-and-publish.outputs.version }}" # zizmor: ignore[template-injection] - + # zizmor: ignore[template-injection] + run: | + VERSION="${{ needs.build-and-publish.outputs.version }}" + if [[ "$VERSION" == *-* ]]; then + echo "Installing pre-release version $VERSION from TestPyPI..." + uv run pip install \ + --index-url https://test.pypi.org/simple/ \ + --extra-index-url https://pypi.org/simple \ + "lerobot[all]==$VERSION" + else + echo "Installing release version $VERSION from PyPI..." + uv run pip install "lerobot[all]==$VERSION" + fi - name: Check lerobot version run: uv run python -c "import lerobot; print(lerobot.__version__)" From dcb305ffb2e8ddada10dfcd66b646fc18f1a76d7 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 1 Aug 2025 18:11:08 +0200 Subject: [PATCH 058/158] fix(ci): change release-name to title (#1647) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b0e16e0f7..d287d2ef6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -94,7 +94,7 @@ jobs: # zizmor: ignore[template-injection] run: | gh release create ${{ github.ref_name }} \ - --release-name "Release ${{ github.ref_name }}" \ + --title "Release ${{ github.ref_name }}" \ --generate-notes \ --draft=$([[ "${{ github.ref_name }}" == *-* ]] && echo true || echo false) \ --prerelease=$([[ "${{ github.ref_name }}" == *-* ]] && echo true || echo false) \ From 60dc8e3a5db6fdb7e2471315559aecea375ac603 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 1 Aug 2025 18:21:37 +0200 Subject: [PATCH 059/158] fix(ci): use base tag for testpy to mimic the pyproject.toml version (#1648) --- .github/workflows/release.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d287d2ef6..9e63a2e2a 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -146,11 +146,12 @@ jobs: run: | VERSION="${{ needs.build-and-publish.outputs.version }}" if [[ "$VERSION" == *-* ]]; then - echo "Installing pre-release version $VERSION from TestPyPI..." + BASE_VERSION="${VERSION%%-*}" + echo "Installing pre-release version $BASE_VERSION from TestPyPI..." uv run pip install \ --index-url https://test.pypi.org/simple/ \ --extra-index-url https://pypi.org/simple \ - "lerobot[all]==$VERSION" + "lerobot[all]==$BASE_VERSION" else echo "Installing release version $VERSION from PyPI..." uv run pip install "lerobot[all]==$VERSION" From 3e24ecaf54f2d2bfe6fd446433f7896bc2496028 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 1 Aug 2025 18:30:33 +0200 Subject: [PATCH 060/158] chore(ci): Bump to v0.3.0 (#1649) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c369b612c..d913245b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb" [project] name = "lerobot" -version = "0.2.0" +version = "0.3.0" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" readme = "README.md" license = { text = "Apache-2.0" } From 240a3892ae8ea9f71baf6962385420d31d844743 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 1 Aug 2025 20:52:10 +0200 Subject: [PATCH 061/158] fix(ci): remove uv run + bump minor (#1651) --- .github/workflows/release.yml | 5 +++-- pyproject.toml | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9e63a2e2a..b6ac0fcb9 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -148,13 +148,14 @@ jobs: if [[ "$VERSION" == *-* ]]; then BASE_VERSION="${VERSION%%-*}" echo "Installing pre-release version $BASE_VERSION from TestPyPI..." - uv run pip install \ + uv pip install \ --index-url https://test.pypi.org/simple/ \ --extra-index-url https://pypi.org/simple \ + --index-strategy unsafe-best-match \ "lerobot[all]==$BASE_VERSION" else echo "Installing release version $VERSION from PyPI..." - uv run pip install "lerobot[all]==$VERSION" + uv pip install "lerobot[all]==$VERSION" fi - name: Check lerobot version run: uv run python -c "import lerobot; print(lerobot.__version__)" diff --git a/pyproject.toml b/pyproject.toml index d913245b1..8984f67bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb" [project] name = "lerobot" -version = "0.3.0" +version = "0.3.1" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" readme = "README.md" license = { text = "Apache-2.0" } From f771e3eaf1b30aae65768513b2e67d351f5f6c2f Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 1 Aug 2025 21:04:47 +0200 Subject: [PATCH 062/158] fix(ci): create venv for release testing (#1652) --- .github/workflows/release.yml | 2 ++ pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b6ac0fcb9..67aa5186b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -141,6 +141,8 @@ jobs: enable-cache: true version: ${{ env.UV_VERSION }} python-version: ${{ env.PYTHON_VERSION }} + - name: Create uv virtual environment + run: uv venv - name: Install lerobot release # zizmor: ignore[template-injection] run: | diff --git a/pyproject.toml b/pyproject.toml index 8984f67bb..7e737ddc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb" [project] name = "lerobot" -version = "0.3.1" +version = "0.3.2" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" readme = "README.md" license = { text = "Apache-2.0" } From 8c577525c199642047062225f490eca71906a7c5 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 4 Aug 2025 11:00:22 +0200 Subject: [PATCH 063/158] chore: Bump to 4.0.0 (#1653) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7e737ddc9..a1db99c24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb" [project] name = "lerobot" -version = "0.3.2" +version = "0.4.0" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" readme = "README.md" license = { text = "Apache-2.0" } From 90d3a99aa130da938d6a17cc1f7c5d47831cef3b Mon Sep 17 00:00:00 2001 From: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Date: Mon, 4 Aug 2025 21:49:51 +0200 Subject: [PATCH 064/158] Fix policy construction (#1665) * add: test to check proper construction with multiple features with STATE/ACTION type * fix: robot and action state should match policy's expectations * fix minor Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --------- Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --- src/lerobot/configs/policies.py | 9 +++--- tests/policies/test_policies.py | 50 +++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index c5b2fa09e..f5fa727cf 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -27,6 +27,7 @@ from huggingface_hub.constants import CONFIG_NAME from huggingface_hub.errors import HfHubHTTPError from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_STATE from lerobot.optim.optimizers import OptimizerConfig from lerobot.optim.schedulers import LRSchedulerConfig from lerobot.utils.hub import HubMixin @@ -119,8 +120,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): @property def robot_state_feature(self) -> PolicyFeature | None: - for _, ft in self.input_features.items(): - if ft.type is FeatureType.STATE: + for ft_name, ft in self.input_features.items(): + if ft.type is FeatureType.STATE and ft_name == OBS_STATE: return ft return None @@ -137,8 +138,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): @property def action_feature(self) -> PolicyFeature | None: - for _, ft in self.output_features.items(): - if ft.type is FeatureType.ACTION: + for ft_name, ft in self.output_features.items(): + if ft.type is FeatureType.ACTION and ft_name == ACTION: return ft return None diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index ed37fedd6..da7573d7c 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -27,11 +27,13 @@ from lerobot import available_policies from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_STATE from lerobot.datasets.factory import make_dataset from lerobot.datasets.utils import cycle, dataset_to_policy_features from lerobot.envs.factory import make_env, make_env_config from lerobot.envs.utils import preprocess_observation from lerobot.optim.factory import make_optimizer_and_scheduler +from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.act.modeling_act import ACTTemporalEnsembler from lerobot.policies.factory import ( get_policy_class, @@ -363,6 +365,54 @@ def test_normalize(insert_temporal_dim): unnormalize(output_batch) +@pytest.mark.parametrize("multikey", [True, False]) +def test_multikey_construction(multikey: bool): + """ + Asserts that multiple keys with type State/Action are correctly processed by the policy constructor, + preventing erroneous creation of the policy object. + """ + input_features = { + "observation.state": PolicyFeature( + type=FeatureType.STATE, + shape=(10,), + ), + } + output_features = { + "action": PolicyFeature( + type=FeatureType.ACTION, + shape=(5,), + ), + } + + if multikey: + """Simulates the complete state/action is constructed from more granular multiple + keys, of the same type as the overall state/action""" + input_features = {} + input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) + input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) + input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,)) + + output_features = {} + output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,)) + output_features["action.last_two_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(2,)) + output_features["action"] = PolicyFeature( + type=FeatureType.ACTION, + shape=(5,), + ) + + config = ACTConfig(input_features=input_features, output_features=output_features) + + state_condition = config.robot_state_feature == input_features[OBS_STATE] + action_condition = config.action_feature == output_features[ACTION] + + assert state_condition, ( + f"Discrepancy detected. Robot state feature is {config.robot_state_feature} but policy expects {input_features[OBS_STATE]}" + ) + assert action_condition, ( + f"Discrepancy detected. Action feature is {config.action_feature} but policy expects {output_features[ACTION]}" + ) + + @pytest.mark.parametrize( "ds_repo_id, policy_name, policy_kwargs, file_name_extra", [ From e0096feb6ac4b67b13b94e58dc3d9c5d41b55f4a Mon Sep 17 00:00:00 2001 From: HUANG TZU-CHUN Date: Tue, 5 Aug 2025 18:33:55 +0800 Subject: [PATCH 065/158] fix(docs): Update links in il_robots.mdx and il_sim.mdx to use absolute URLs (#1313) * Update links to use absolute URLs. * Update dataset upload example link to use HF_USER variable and match the correct syntax. --- docs/source/il_robots.mdx | 6 +++--- docs/source/il_sim.mdx | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 8c075e5b2..ec5491b2a 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -294,7 +294,7 @@ dataset.push_to_hub() #### Dataset upload -Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running: +Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. `https://huggingface.co/datasets/${HF_USER}/so101_test`) that you can obtain by running: ```bash echo https://huggingface.co/datasets/${HF_USER}/so101_test @@ -428,7 +428,7 @@ Your robot should replicate movements similar to those you recorded. For example ## Train a policy -To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: +To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: ```bash python -m lerobot.scripts.train \ @@ -444,7 +444,7 @@ python -m lerobot.scripts.train \ Let's explain the command: 1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so101_test`. -2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. +2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. 3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. 4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. diff --git a/docs/source/il_sim.mdx b/docs/source/il_sim.mdx index 193b09b1b..761e24e0f 100644 --- a/docs/source/il_sim.mdx +++ b/docs/source/il_sim.mdx @@ -96,7 +96,7 @@ If you uploaded your dataset to the hub you can [visualize your dataset online]( ## Train a policy -To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: +To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: ```bash python -m lerobot.scripts.train \ @@ -111,7 +111,7 @@ python -m lerobot.scripts.train \ Let's explain the command: 1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/il_gym`. -2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. +2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. 3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. 4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. From 06bebd97b37ee4ce595f1d70c8fc752eb59117a3 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Tue, 5 Aug 2025 23:47:49 +0200 Subject: [PATCH 066/158] fix(typo): fixing typo in LeRobot authors names (#1673) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 13cc95f90..265dec5c8 100644 --- a/README.md +++ b/README.md @@ -311,7 +311,7 @@ If you want, you can cite this work with: ```bibtex @misc{cadene2024lerobot, - author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascale, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas}, + author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas}, title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch}, howpublished = "\url{https://github.com/huggingface/lerobot}", year = {2024} From 6daa579ce1b3ff0bd58a2d1c3638d1de8b3191e9 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 6 Aug 2025 15:06:36 +0200 Subject: [PATCH 067/158] docs: update installation instructions (#1686) --- README.md | 46 +++++++++++++++++++++++++++++++-- docs/source/installation.mdx | 50 ++++++++++++++++++++++++++++-------- 2 files changed, 83 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 265dec5c8..7255ed3ef 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,9 @@ ## Installation LeRobot works with Python 3.10+ and PyTorch 2.2+. + +### Environment Setup + Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html): ```bash @@ -124,10 +127,21 @@ conda install ffmpeg -c conda-forge > > - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. -Install 🤗 LeRobot: +### Install LeRobot 🤗 + +#### From Source + +First, clone the repository and navigate into the directory: ```bash -pip install lerobot +git clone https://github.com/huggingface/lerobot.git +cd lerobot +``` + +Then, install the library in editable mode. This is useful if you plan to contribute to the code. + +```bash +pip install -e . ``` > **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run: @@ -145,6 +159,34 @@ For instance, to install 🤗 LeRobot with aloha and pusht, use: pip install -e ".[aloha, pusht]" ``` +### Installation from PyPI + +**Core Library:** +Install the base package with: + +```bash +pip install lerobot +``` + +_This installs only the default dependencies._ + +**Extra Features:** +To install additional functionality, use one of the following: + +```bash +pip install 'lerobot[all]' # All available features +pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht) +pip install 'lerobot[feetech]' # Feetech motor support +``` + +_Replace `[...]` with your desired features._ + +**Available Tags:** +For a full list of optional dependencies, see: +https://pypi.org/project/lerobot/ + +### Weights & Biases + To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with ```bash diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 13c3600b4..93354c2ee 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -1,15 +1,6 @@ # Installation -## Install LeRobot - -Currently only available from source. - -Download our source code: - -```bash -git clone https://github.com/huggingface/lerobot.git -cd lerobot -``` +## Environment Setup Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install) @@ -40,12 +31,49 @@ conda install ffmpeg -c conda-forge > > - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. -Install 🤗 LeRobot: +## Install LeRobot 🤗 + +### From Source + +First, clone the repository and navigate into the directory: + +```bash +git clone https://github.com/huggingface/lerobot.git +cd lerobot +``` + +Then, install the library in editable mode. This is useful if you plan to contribute to the code. ```bash pip install -e . ``` +### Installation from PyPI + +**Core Library:** +Install the base package with: + +```bash +pip install lerobot +``` + +_This installs only the default dependencies._ + +**Extra Features:** +To install additional functionality, use one of the following: + +```bash +pip install 'lerobot[all]' # All available features +pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht) +pip install 'lerobot[feetech]' # Feetech motor support +``` + +_Replace `[...]` with your desired features._ + +**Available Tags:** +For a full list of optional dependencies, see: +https://pypi.org/project/lerobot/ + ### Troubleshooting If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`. From 88f7bf01c1a09cada3db6932e14052ed759f5c48 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Wed, 6 Aug 2025 16:11:04 +0200 Subject: [PATCH 068/158] feat(pipeline): universal processor for LeRobot (#1431) * Refactor observation preprocessing to use a modular pipeline system - Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations. - Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline. - Added tests for the new processing components and ensured they match the original functionality. - Removed hardcoded logic in favor of a more flexible, composable architecture. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor observation processing and improve modularity - Updated `ObservationProcessor` to enhance the modular design for processing observations. - Cleaned up imports and improved code readability by removing unnecessary lines and comments. - Ensured backward compatibility while integrating new processing components. - Added tests to validate the functionality of the updated processing architecture. * Remove redundant tests for None observation and serialization methods in `test_observation_processor.py` to streamline the test suite and improve maintainability. * Refactor processing architecture to use RobotProcessor - Replaced instances of RobotPipeline with RobotProcessor across the codebase for improved modularity and clarity. - Introduced ProcessorStepRegistry for better management of processing steps. - Updated relevant documentation and tests to reflect the new processing structure. - Enhanced the save/load functionality to support the new processor design. - Added a model card template for RobotProcessor to facilitate sharing and documentation. * Add RobotProcessor tutorial to documentation - Introduced a new tutorial on using RobotProcessor for preprocessing robot data. - Added a section in the table of contents for easy navigation to the new tutorial. - The tutorial covers key concepts, real-world scenarios, and practical examples for effective use of the RobotProcessor pipeline. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add normalization processor and related components - Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization. - Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks. - Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports. - Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity. - Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. * chore (docs): add docstring for processor * fix (test): test factory * fix(test): policies * Update tests/processor/test_observation_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine * chore(test): add suggestion made by copilot regarding numpy test * fix(test): import issue * Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. * chore (docstrin):Improve docstring for NormalizerProcessor * feat (device processor): Implement device processor * chore (batch handling): Enhance processing components with batch conversion utilities * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(test): linting issue * chore (output format): improves output format * chore (type): add typing for multiprocess envs * feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. * chore(normalization): addressing comments from copilot * chore(learner): nit comment from copilot * feat(pipeline): Enhance step_through method to support both tuple and dict inputs * refactor(pipeline): Simplify observation and padding data handling in batch transitions * Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Transition from tuple to dictionary format for EnvTransition - Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase. * refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling - Introduced constants for observation keys to enhance readability. - Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly. - Updated the environment state and agent position assignments to use the new constants, improving maintainability. * feat(pipeline): Add hook unregistration functionality and enhance documentation - Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management. - Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks. - Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks. * refactor(pipeline): Clarify hook behavior and improve documentation - Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability. - Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions. - Enhanced documentation to clearly outline the purpose of hooks and their execution semantics. - Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method. * feat(pipeline): Add __repr__ method to RobotProcessor for improved readability - Implemented a __repr__ method in the RobotProcessor class to provide a clear string representation of the processor, including step names and optional parameters like name and seed. - Added comprehensive tests to validate the __repr__ output for various scenarios, including empty processors, single and multiple steps, custom names, and seed values. - Ensured that the representation handles long lists of steps with truncation for better readability. * chore(pipeline): Move _CFG_NAME along other class member * refactor(pipeline): Utilize get_safe_torch_device for device assignment - Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling. - This change enhances code readability and maintains consistency in device management across the RobotProcessor class. * refactor(pipeline): Enhance state filename generation and profiling method - Updated state filename generation to use the registry name when available, improving clarity in saved files. - Modified the profile_steps method to include a warmup_runs parameter, allowing for more controlled performance profiling. - Ensured consistent conditions during profiling by deep copying transitions for each run, enhancing accuracy in timing results. * chore(doc): address pip install commant lerobot that not exist yet * feat(pipeline): Enhance configuration filename handling and state file naming - Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default. - Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved. - Added automatic detection for configuration files when loading from a directory, with error handling for multiple files. - Updated tests to validate new features, including custom filenames and automatic config detection. * refactor(pipeline): Improve state file naming conventions for clarity and uniqueness - Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory. - Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts. - Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files. * docs(pipeline): Add clarification for repo name sanitization process * Feat/pipeline add feature contract (#1637) * Add feature contract to pipelinestep and pipeline * Add tests * Add processor tests * PR feedback * encorperate pr feedback * type in doc * oops * docs(pipeline): Clarify transition handling and hook behavior - Updated documentation to specify that hooks always receive transitions in EnvTransition format, ensuring consistent behavior across input formats. - Refactored the step_through method to yield only EnvTransition objects, regardless of the input format, and updated related tests to reflect this change. - Enhanced test assertions to verify the structure of results and the correctness of processing steps. * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * refactor(pipeline): Remove model card generation and streamline processor methods - Eliminated the _generate_model_card method from RobotProcessor, which was responsible for generating README.md files from a template. - Updated save_pretrained method to remove model card generation, focusing on serialization of processor definitions and parameters. - Added default implementations for get_config, state_dict, load_state_dict, reset, and feature_contract methods in various processor classes to enhance consistency and usability. * refactor(observation): Streamline observation preprocessing and remove unused processor methods - Updated the `preprocess_observation` function to enhance image handling and ensure proper tensor formatting. - Removed the `RobotProcessor` and associated transition handling from the `rollout` function, simplifying the observation processing flow. - Integrated direct calls to `preprocess_observation` for improved clarity and efficiency in the evaluation script. * refactor(pipeline): Rename parameters for clarity and enhance save/load functionality - Updated parameter names in the save_pretrained and from_pretrained methods for improved readability, changing destination_path to save_directory and source to pretrained_model_name_or_path. - Enhanced the save_pretrained method to ensure directory creation and file handling is consistent with the new parameter names. - Streamlined the loading process in from_pretrained to utilize loaded_config for better clarity and maintainability. * refactor(pipeline): minor improvements (#1684) * chore(pipeline): remove unused features + device torch + envtransition keys * refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor * refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code * test(pipeline): fix broken test after refactors * docs(pipeline): update docstrings VanillaObservationProcessor * chore(pipeline): move None check to base pipeline classes --------- Signed-off-by: Adil Zouitine Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Steven Palma --- src/lerobot/processor/__init__.py | 54 + src/lerobot/processor/device_processor.py | 82 + src/lerobot/processor/normalize_processor.py | 331 +++ .../processor/observation_processor.py | 157 ++ src/lerobot/processor/pipeline.py | 1264 +++++++++++ src/lerobot/processor/rename_processor.py | 51 + tests/conftest.py | 17 + tests/processor/test_batch_conversion.py | 282 +++ tests/processor/test_normalize_processor.py | 628 ++++++ tests/processor/test_observation_processor.py | 486 +++++ tests/processor/test_pipeline.py | 1919 +++++++++++++++++ tests/processor/test_rename_processor.py | 467 ++++ 12 files changed, 5738 insertions(+) create mode 100644 src/lerobot/processor/__init__.py create mode 100644 src/lerobot/processor/device_processor.py create mode 100644 src/lerobot/processor/normalize_processor.py create mode 100644 src/lerobot/processor/observation_processor.py create mode 100644 src/lerobot/processor/pipeline.py create mode 100644 src/lerobot/processor/rename_processor.py create mode 100644 tests/processor/test_batch_conversion.py create mode 100644 tests/processor/test_normalize_processor.py create mode 100644 tests/processor/test_observation_processor.py create mode 100644 tests/processor/test_pipeline.py create mode 100644 tests/processor/test_rename_processor.py diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py new file mode 100644 index 000000000..8dd244c27 --- /dev/null +++ b/src/lerobot/processor/__init__.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .device_processor import DeviceProcessor +from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor +from .observation_processor import VanillaObservationProcessor +from .pipeline import ( + ActionProcessor, + DoneProcessor, + EnvTransition, + IdentityProcessor, + InfoProcessor, + ObservationProcessor, + ProcessorStep, + ProcessorStepRegistry, + RewardProcessor, + RobotProcessor, + TransitionKey, + TruncatedProcessor, +) +from .rename_processor import RenameProcessor + +__all__ = [ + "ActionProcessor", + "DeviceProcessor", + "DoneProcessor", + "EnvTransition", + "IdentityProcessor", + "InfoProcessor", + "NormalizerProcessor", + "UnnormalizerProcessor", + "ObservationProcessor", + "ProcessorStep", + "ProcessorStepRegistry", + "RenameProcessor", + "RewardProcessor", + "RobotProcessor", + "TransitionKey", + "TruncatedProcessor", + "VanillaObservationProcessor", +] diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py new file mode 100644 index 000000000..0f00bb470 --- /dev/null +++ b/src/lerobot/processor/device_processor.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any + +import torch + +from lerobot.configs.types import PolicyFeature +from lerobot.processor.pipeline import EnvTransition, TransitionKey +from lerobot.utils.utils import get_safe_torch_device + + +@dataclass +class DeviceProcessor: + """Processes transitions by moving tensors to the specified device. + + This processor ensures that all tensors in the transition are moved to the + specified device (CPU or GPU) before they are returned. + """ + + device: torch.device = "cpu" + + def __post_init__(self): + self.device = get_safe_torch_device(self.device) + self.non_blocking = "cuda" in str(self.device) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + # Create a copy of the transition + new_transition = transition.copy() + + # Process observation tensors + observation = transition.get(TransitionKey.OBSERVATION) + if observation is not None: + new_observation = { + k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v + for k, v in observation.items() + } + new_transition[TransitionKey.OBSERVATION] = new_observation + + # Process action tensor + action = transition.get(TransitionKey.ACTION) + if action is not None and isinstance(action, torch.Tensor): + new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking) + + # Process reward tensor + reward = transition.get(TransitionKey.REWARD) + if reward is not None and isinstance(reward, torch.Tensor): + new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking) + + # Process done tensor + done = transition.get(TransitionKey.DONE) + if done is not None and isinstance(done, torch.Tensor): + new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking) + + # Process truncated tensor + truncated = transition.get(TransitionKey.TRUNCATED) + if truncated is not None and isinstance(truncated, torch.Tensor): + new_transition[TransitionKey.TRUNCATED] = truncated.to( + self.device, non_blocking=self.non_blocking + ) + + return new_transition + + def get_config(self) -> dict[str, Any]: + """Return configuration for serialization.""" + return {"device": self.device} + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py new file mode 100644 index 000000000..14628727f --- /dev/null +++ b/src/lerobot/processor/normalize_processor.py @@ -0,0 +1,331 @@ +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +import torch +from torch import Tensor + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey + + +def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]: + """Convert numpy arrays and other types to torch tensors.""" + tensor_stats: dict[str, dict[str, Tensor]] = {} + for key, sub in stats.items(): + tensor_stats[key] = {} + for stat_name, value in sub.items(): + if isinstance(value, np.ndarray): + tensor_val = torch.from_numpy(value.astype(np.float32)) + elif isinstance(value, torch.Tensor): + tensor_val = value.to(dtype=torch.float32) + elif isinstance(value, (int, float, list, tuple)): + tensor_val = torch.tensor(value, dtype=torch.float32) + else: + raise TypeError(f"Unsupported type for stats['{key}']['{stat_name}']: {type(value)}") + tensor_stats[key][stat_name] = tensor_val + return tensor_stats + + +@dataclass +@ProcessorStepRegistry.register(name="normalizer_processor") +class NormalizerProcessor: + """Normalizes observations and actions in a single processor step. + + This processor handles normalization of both observation and action tensors + using either mean/std normalization or min/max scaling to a [-1, 1] range. + + For each tensor key in the stats dictionary, the processor will: + - Use mean/std normalization if those statistics are provided: (x - mean) / std + - Use min/max scaling if those statistics are provided: 2 * (x - min) / (max - min) - 1 + + The processor can be configured to normalize only specific keys by setting + the normalize_keys parameter. + """ + + # Features and normalisation map are mandatory to match the design of normalize.py + features: dict[str, PolicyFeature] + norm_map: dict[FeatureType, NormalizationMode] + + # Pre-computed statistics coming from dataset.meta.stats for instance. + stats: dict[str, dict[str, Any]] | None = None + + # Explicit subset of keys to normalise. If ``None`` every key (except + # "action") found in ``stats`` will be normalised. Using a ``set`` makes + # membership checks O(1). + normalize_keys: set[str] | None = None + + eps: float = 1e-8 + + _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) + + @classmethod + def from_lerobot_dataset( + cls, + dataset: LeRobotDataset, + features: dict[str, PolicyFeature], + norm_map: dict[FeatureType, NormalizationMode], + *, + normalize_keys: set[str] | None = None, + eps: float = 1e-8, + ) -> NormalizerProcessor: + """Factory helper that pulls statistics from a :class:`LeRobotDataset`. + + The features and norm_map parameters are mandatory to match the design + pattern used in normalize.py. + """ + + return cls( + features=features, + norm_map=norm_map, + stats=dataset.meta.stats, + normalize_keys=normalize_keys, + eps=eps, + ) + + def __post_init__(self): + # Handle deserialization from JSON config + if self.features and isinstance(list(self.features.values())[0], dict): + # Features came from JSON - need to reconstruct PolicyFeature objects + reconstructed_features = {} + for key, ft_dict in self.features.items(): + reconstructed_features[key] = PolicyFeature( + type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) + ) + self.features = reconstructed_features + + if self.norm_map and isinstance(list(self.norm_map.keys())[0], str): + # norm_map came from JSON - need to reconstruct enum keys and values + reconstructed_norm_map = {} + for ft_type_str, norm_mode_str in self.norm_map.items(): + reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) + self.norm_map = reconstructed_norm_map + + # Convert statistics once so we avoid repeated numpy→Tensor conversions + # during runtime. + self.stats = self.stats or {} + self._tensor_stats = _convert_stats_to_tensors(self.stats) + + # Ensure *normalize_keys* is a set for fast look-ups and compare by + # value later when returning the configuration. + if self.normalize_keys is not None and not isinstance(self.normalize_keys, set): + self.normalize_keys = set(self.normalize_keys) + + def _normalize_obs(self, observation): + if observation is None: + return None + + # Decide which keys should be normalised for this call. + if self.normalize_keys is not None: + keys_to_norm = self.normalize_keys + else: + # Use feature map to skip action keys. + keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION} + + processed = dict(observation) + for key in keys_to_norm: + if key not in processed or key not in self._tensor_stats: + continue + + orig_val = processed[key] + tensor = ( + orig_val.to(dtype=torch.float32) + if isinstance(orig_val, torch.Tensor) + else torch.as_tensor(orig_val, dtype=torch.float32) + ) + stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} + + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + processed[key] = (tensor - mean) / (std + self.eps) + elif "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 + return processed + + def _normalize_action(self, action): + if action is None or "action" not in self._tensor_stats: + return action + + tensor = ( + action.to(dtype=torch.float32) + if isinstance(action, torch.Tensor) + else torch.as_tensor(action, dtype=torch.float32) + ) + stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + return (tensor - mean) / (std + self.eps) + if "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 + raise ValueError("Action stats must contain either ('mean','std') or ('min','max')") + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION)) + action = self._normalize_action(transition.get(TransitionKey.ACTION)) + + # Create a new transition with normalized values + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = observation + new_transition[TransitionKey.ACTION] = action + return new_transition + + def get_config(self) -> dict[str, Any]: + config = { + "eps": self.eps, + "features": { + key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() + }, + "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, + } + if self.normalize_keys is not None: + # Serialise as a list for YAML / JSON friendliness + config["normalize_keys"] = sorted(self.normalize_keys) + return config + + def state_dict(self) -> dict[str, Tensor]: + flat = {} + for key, sub in self._tensor_stats.items(): + for stat_name, tensor in sub.items(): + flat[f"{key}.{stat_name}"] = tensor + return flat + + def load_state_dict(self, state: Mapping[str, Tensor]) -> None: + self._tensor_stats.clear() + for flat_key, tensor in state.items(): + key, stat_name = flat_key.rsplit(".", 1) + self._tensor_stats.setdefault(key, {})[stat_name] = tensor + + def reset(self): + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +@dataclass +@ProcessorStepRegistry.register(name="unnormalizer_processor") +class UnnormalizerProcessor: + """Inverse normalisation for observations and actions. + + Exactly mirrors :class:`NormalizerProcessor` but applies the inverse + transform. + """ + + features: dict[str, PolicyFeature] + norm_map: dict[FeatureType, NormalizationMode] + stats: dict[str, dict[str, Any]] | None = None + + _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) + + @classmethod + def from_lerobot_dataset( + cls, + dataset: LeRobotDataset, + features: dict[str, PolicyFeature], + norm_map: dict[FeatureType, NormalizationMode], + ) -> UnnormalizerProcessor: + return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats) + + def __post_init__(self): + # Handle deserialization from JSON config + if self.features and isinstance(list(self.features.values())[0], dict): + # Features came from JSON - need to reconstruct PolicyFeature objects + reconstructed_features = {} + for key, ft_dict in self.features.items(): + reconstructed_features[key] = PolicyFeature( + type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) + ) + self.features = reconstructed_features + + if self.norm_map and isinstance(list(self.norm_map.keys())[0], str): + # norm_map came from JSON - need to reconstruct enum keys and values + reconstructed_norm_map = {} + for ft_type_str, norm_mode_str in self.norm_map.items(): + reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) + self.norm_map = reconstructed_norm_map + + self.stats = self.stats or {} + self._tensor_stats = _convert_stats_to_tensors(self.stats) + + def _unnormalize_obs(self, observation): + if observation is None: + return None + keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION] + processed = dict(observation) + for key in keys: + if key not in processed or key not in self._tensor_stats: + continue + orig_val = processed[key] + tensor = ( + orig_val.to(dtype=torch.float32) + if isinstance(orig_val, torch.Tensor) + else torch.as_tensor(orig_val, dtype=torch.float32) + ) + stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + processed[key] = tensor * std + mean + elif "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val + return processed + + def _unnormalize_action(self, action): + if action is None or "action" not in self._tensor_stats: + return action + tensor = ( + action.to(dtype=torch.float32) + if isinstance(action, torch.Tensor) + else torch.as_tensor(action, dtype=torch.float32) + ) + stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + return tensor * std + mean + if "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + return (tensor + 1) / 2 * (max_val - min_val) + min_val + raise ValueError("Action stats must contain either ('mean','std') or ('min','max')") + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION)) + action = self._unnormalize_action(transition.get(TransitionKey.ACTION)) + + # Create a new transition with unnormalized values + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = observation + new_transition[TransitionKey.ACTION] = action + return new_transition + + def get_config(self) -> dict[str, Any]: + return { + "features": { + key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() + }, + "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, + } + + def state_dict(self) -> dict[str, Tensor]: + flat = {} + for key, sub in self._tensor_stats.items(): + for stat_name, tensor in sub.items(): + flat[f"{key}.{stat_name}"] = tensor + return flat + + def load_state_dict(self, state: Mapping[str, Tensor]) -> None: + self._tensor_stats.clear() + for flat_key, tensor in state.items(): + key, stat_name = flat_key.rsplit(".", 1) + self._tensor_stats.setdefault(key, {})[stat_name] = tensor + + def reset(self): + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py new file mode 100644 index 000000000..7d63db238 --- /dev/null +++ b/src/lerobot/processor/observation_processor.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + +import einops +import numpy as np +import torch +from torch import Tensor + +from lerobot.configs.types import PolicyFeature +from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry + + +@dataclass +@ProcessorStepRegistry.register(name="observation_processor") +class VanillaObservationProcessor(ObservationProcessor): + """ + Processes environment observations into the LeRobot format by handling both images and states. + + Image processing: + - Converts channel-last (H, W, C) images to channel-first (C, H, W) + - Normalizes uint8 images ([0, 255]) to float32 ([0, 1]) + - Adds a batch dimension if missing + - Supports single images and image dictionaries + + State processing: + - Maps 'environment_state' to observation.environment_state + - Maps 'agent_pos' to observation.state + - Converts numpy arrays to tensors + - Adds a batch dimension if missing + """ + + def _process_single_image(self, img: np.ndarray) -> Tensor: + """Process a single image array.""" + # Convert to tensor + img_tensor = torch.from_numpy(img) + + # Add batch dimension if needed + if img_tensor.ndim == 3: + img_tensor = img_tensor.unsqueeze(0) + + # Validate image format + _, h, w, c = img_tensor.shape + if not (c < h and c < w): + raise ValueError(f"Expected channel-last images, but got shape {img_tensor.shape}") + + if img_tensor.dtype != torch.uint8: + raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}") + + # Convert to channel-first format + img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous() + + # Convert to float32 and normalize to [0, 1] + img_tensor = img_tensor.type(torch.float32) / 255.0 + + return img_tensor + + def _process_observation(self, observation): + """ + Processes both image and state observations. + """ + + processed_obs = observation.copy() + + if "pixels" in processed_obs: + pixels = processed_obs.pop("pixels") + + if isinstance(pixels, dict): + imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()} + else: + imgs = {OBS_IMAGE: pixels} + + for imgkey, img in imgs.items(): + processed_obs[imgkey] = self._process_single_image(img) + + if "environment_state" in processed_obs: + env_state_np = processed_obs.pop("environment_state") + env_state = torch.from_numpy(env_state_np).float() + if env_state.dim() == 1: + env_state = env_state.unsqueeze(0) + processed_obs[OBS_ENV_STATE] = env_state + + if "agent_pos" in processed_obs: + agent_pos_np = processed_obs.pop("agent_pos") + agent_pos = torch.from_numpy(agent_pos_np).float() + if agent_pos.dim() == 1: + agent_pos = agent_pos.unsqueeze(0) + processed_obs[OBS_STATE] = agent_pos + + return processed_obs + + def observation(self, observation): + return self._process_observation(observation) + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """Transforms feature keys to a standardized contract. + + This method handles several renaming patterns: + - Exact matches (e.g., 'pixels' -> 'OBS_IMAGE'). + - Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE'). + - Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1'). + - Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1'). + - environment_state -> OBS_ENV_STATE, + - agent_pos -> OBS_STATE, + - observation.environment_state -> OBS_ENV_STATE, + - observation.agent_pos -> OBS_STATE + """ + exact_pairs = { + "pixels": OBS_IMAGE, + "environment_state": OBS_ENV_STATE, + "agent_pos": OBS_STATE, + } + + prefix_pairs = { + "pixels.": f"{OBS_IMAGES}.", + } + + for key in list(features.keys()): + matched_prefix = False + for old_prefix, new_prefix in prefix_pairs.items(): + prefixed_old = f"observation.{old_prefix}" + if key.startswith(prefixed_old): + suffix = key[len(prefixed_old) :] + features[f"{new_prefix}{suffix}"] = features.pop(key) + matched_prefix = True + break + + if key.startswith(old_prefix): + suffix = key[len(old_prefix) :] + features[f"{new_prefix}{suffix}"] = features.pop(key) + matched_prefix = True + break + + if matched_prefix: + continue + + for old, new in exact_pairs.items(): + if key == old or key == f"observation.{old}": + if key in features: + features[new] = features.pop(key) + break + + return features diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py new file mode 100644 index 000000000..6e1b2a2cb --- /dev/null +++ b/src/lerobot/processor/pipeline.py @@ -0,0 +1,1264 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import importlib +import json +import os +from collections.abc import Callable, Iterable, Sequence +from copy import deepcopy +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Protocol, TypedDict + +import torch +from huggingface_hub import ModelHubMixin, hf_hub_download +from huggingface_hub.errors import HfHubHTTPError +from safetensors.torch import load_file, save_file + +from lerobot.configs.types import PolicyFeature + + +class TransitionKey(str, Enum): + """Keys for accessing EnvTransition dictionary components.""" + + # TODO(Steven): Use consts + OBSERVATION = "observation" + ACTION = "action" + REWARD = "reward" + DONE = "done" + TRUNCATED = "truncated" + INFO = "info" + COMPLEMENTARY_DATA = "complementary_data" + + +EnvTransition = TypedDict( + "EnvTransition", + { + TransitionKey.OBSERVATION.value: dict[str, Any] | None, + TransitionKey.ACTION.value: Any | torch.Tensor | None, + TransitionKey.REWARD.value: float | torch.Tensor | None, + TransitionKey.DONE.value: bool | torch.Tensor | None, + TransitionKey.TRUNCATED.value: bool | torch.Tensor | None, + TransitionKey.INFO.value: dict[str, Any] | None, + TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None, + }, +) + + +class ProcessorStepRegistry: + """Registry for processor steps that enables saving/loading by name instead of module path.""" + + _registry: dict[str, type] = {} + + @classmethod + def register(cls, name: str = None): + """Decorator to register a processor step class. + + Args: + name: Optional registration name. If not provided, uses class name. + + Example: + @ProcessorStepRegistry.register("adaptive_normalizer") + class AdaptiveObservationNormalizer: + ... + """ + + def decorator(step_class: type) -> type: + registration_name = name if name is not None else step_class.__name__ + + if registration_name in cls._registry: + raise ValueError( + f"Processor step '{registration_name}' is already registered. " + f"Use a different name or unregister the existing one first." + ) + + cls._registry[registration_name] = step_class + # Store the registration name on the class for later reference + step_class._registry_name = registration_name + return step_class + + return decorator + + @classmethod + def get(cls, name: str) -> type: + """Get a registered processor step class by name. + + Args: + name: The registration name of the step. + + Returns: + The registered step class. + + Raises: + KeyError: If the step is not registered. + """ + if name not in cls._registry: + available = list(cls._registry.keys()) + raise KeyError( + f"Processor step '{name}' not found in registry. " + f"Available steps: {available}. " + f"Make sure the step is registered using @ProcessorStepRegistry.register()" + ) + return cls._registry[name] + + @classmethod + def unregister(cls, name: str) -> None: + """Remove a step from the registry.""" + cls._registry.pop(name, None) + + @classmethod + def list(cls) -> list[str]: + """List all registered step names.""" + return list(cls._registry.keys()) + + @classmethod + def clear(cls) -> None: + """Clear all registrations.""" + cls._registry.clear() + + +class ProcessorStep(Protocol): + """Structural typing interface for a single processor step. + + A step is any callable accepting a full `EnvTransition` dict and + returning a (possibly modified) dict of the same structure. Implementers + are encouraged—but not required—to expose the optional helper methods + listed below. When present, these hooks let `RobotProcessor` + automatically serialise the step's configuration and learnable state using + a safe-to-share JSON + SafeTensors format. + + + **Required**: + - ``__call__(transition: EnvTransition) -> EnvTransition`` + - ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]`` + + Optional helper protocol: + * ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable + configuration and state. YOU decide what to save here. This is where all + non-tensor state goes (e.g., name, counter, threshold, window_size). + The config dict will be passed to your class constructor when loading. + * ``state_dict() -> dict[str, torch.Tensor]`` – PyTorch tensor state ONLY. + This is exclusively for torch.Tensor objects (e.g., learned weights, + running statistics as tensors). Never put simple Python types here. + * ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict + containing torch tensors only. + * ``reset()`` – Clear internal buffers at episode boundaries. + + Example separation: + - get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10} + - state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)} + """ + + def __call__(self, transition: EnvTransition) -> EnvTransition: ... + + def get_config(self) -> dict[str, Any]: ... + + def state_dict(self) -> dict[str, torch.Tensor]: ... + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ... + + def reset(self) -> None: ... + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ... + + +def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401 + """Convert a *batch* dict coming from Learobot replay/dataset code into an + ``EnvTransition`` dictionary. + + The function maps well known keys to the EnvTransition structure. Missing keys are + filled with sane defaults (``None`` or ``0.0``/``False``). + + Keys recognised (case-sensitive): + + * "observation.*" (keys starting with "observation." are grouped into observation dict) + * "action" + * "next.reward" + * "next.done" + * "next.truncated" + * "info" + + Additional keys are ignored so that existing dataloaders can carry extra + metadata without breaking the processor. + """ + + # Extract observation keys + observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} + observation = observation_keys if observation_keys else None + + # Extract padding and task keys for complementary data + pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} + task_key = {"task": batch["task"]} if "task" in batch else {} + complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {} + + transition: EnvTransition = { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: batch.get("action"), + TransitionKey.REWARD: batch.get("next.reward", 0.0), + TransitionKey.DONE: batch.get("next.done", False), + TransitionKey.TRUNCATED: batch.get("next.truncated", False), + TransitionKey.INFO: batch.get("info", {}), + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } + return transition + + +def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401 + """Inverse of :pyfunc:`_default_batch_to_transition`. Returns a dict with + the canonical field names used throughout *LeRobot*. + """ + + batch = { + "action": transition.get(TransitionKey.ACTION), + "next.reward": transition.get(TransitionKey.REWARD, 0.0), + "next.done": transition.get(TransitionKey.DONE, False), + "next.truncated": transition.get(TransitionKey.TRUNCATED, False), + "info": transition.get(TransitionKey.INFO, {}), + } + + # Add padding and task data from complementary_data + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data: + pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k} + batch.update(pad_data) + + if "task" in complementary_data: + batch["task"] = complementary_data["task"] + + # Handle observation - flatten dict to observation.* keys if it's a dict + observation = transition.get(TransitionKey.OBSERVATION) + if isinstance(observation, dict): + batch.update(observation) + + return batch + + +@dataclass +class RobotProcessor(ModelHubMixin): + """ + Composable, debuggable post-processing processor for robot transitions. + + The class orchestrates an ordered collection of small, functional transforms—steps—executed + left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` dicts + and batch dictionaries, automatically converting between formats as needed. + + Args: + steps: Ordered list of processing steps executed on every call. Defaults to empty list. + name: Human-readable identifier that is persisted inside the JSON config. + Defaults to "RobotProcessor". + to_transition: Function to convert batch dict to EnvTransition dict. + Defaults to _default_batch_to_transition. + to_output: Function to convert EnvTransition dict to the desired output format. + Usually it is a batch dict or EnvTransition dict. + Defaults to _default_transition_to_batch. + before_step_hooks: List of hooks called before each step. Each hook receives the step + index and transition, and can optionally return a modified transition. + after_step_hooks: List of hooks called after each step. Each hook receives the step + index and transition, and can optionally return a modified transition. + + Hook Semantics: + - Hooks are executed sequentially in the order they were registered. There is no way to + reorder hooks after registration without creating a new pipeline. + - Hooks are for observation/monitoring only and DO NOT modify transitions. They are called + with the step index and current transition for logging, debugging, or monitoring purposes. + - All hooks for a given type (before/after) are executed for every step, or none at all if + an error occurs. There is no partial execution of hooks. + - Hooks should generally be stateless to maintain predictable behavior. If you need stateful + processing, consider implementing a proper ProcessorStep instead. + - To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline. + - Hooks ALWAYS receive transitions in EnvTransition format, regardless of the input format + passed to __call__. This ensures consistent hook behavior whether processing batch dicts + or EnvTransition objects. + """ + + steps: Sequence[ProcessorStep] = field(default_factory=list) + name: str = "RobotProcessor" + + to_transition: Callable[[dict[str, Any]], EnvTransition] = field( + default_factory=lambda: _default_batch_to_transition, repr=False + ) + to_output: Callable[[EnvTransition], dict[str, Any] | EnvTransition] = field( + default_factory=lambda: _default_transition_to_batch, repr=False + ) + + # Processor-level hooks for observation/monitoring + # Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes + before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) + after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) + + def __call__(self, data: EnvTransition | dict[str, Any]): + """Process data through all steps. + + The method accepts either the classic EnvTransition dict or a batch dictionary + (like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied + it is first converted to the internal dict format using to_transition; after all + steps are executed the dict is transformed back into a batch dict with to_batch and the + result is returned – thereby preserving the caller's original data type. + + Args: + data: Either an EnvTransition dict or a batch dictionary to process. + + Returns: + The processed data in the same format as the input (EnvTransition or batch dict). + + Raises: + ValueError: If the transition is not a valid EnvTransition format. + """ + # Check if we need to convert back to batch format at the end + _, called_with_batch = self._prepare_transition(data) + + # Use step_through to get the iterator + step_iterator = self.step_through(data) + + # Get initial state (before any steps) + current_transition = next(step_iterator) + + # Process each step with hooks + for idx, next_transition in enumerate(step_iterator): + # Apply before hooks with current state (before step execution) + for hook in self.before_step_hooks: + hook(idx, current_transition) + + # Move to next state (after step execution) + current_transition = next_transition + + # Apply after hooks with updated state + for hook in self.after_step_hooks: + hook(idx, current_transition) + + # Convert back to original format if needed + return self.to_output(current_transition) if called_with_batch else current_transition + + def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]: + """Prepare and validate transition data for processing. + + Args: + data: Either an EnvTransition dict or a batch dictionary to process. + + Returns: + A tuple of (prepared_transition, called_with_batch_flag) + + Raises: + ValueError: If the transition is not a valid EnvTransition format. + """ + # Check if data is already an EnvTransition or needs conversion + if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()): + # It's a batch dict, convert it + called_with_batch = True + transition = self.to_transition(data) + else: + # It's already an EnvTransition + called_with_batch = False + transition = data + + # Basic validation + if not isinstance(transition, dict): + raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}") + + return transition, called_with_batch + + def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition]: + """Yield the intermediate results after each processor step. + + This is a low-level method that does NOT apply hooks. It simply executes each step + and yields the intermediate results. This allows users to debug the pipeline or + apply custom logic between steps if needed. + + Note: This method always yields EnvTransition objects regardless of input format. + If you need the results in the original input format, you'll need to convert them + using `to_output()`. + + Args: + data: Either an EnvTransition dict or a batch dictionary to process. + + Yields: + The intermediate EnvTransition results after each step. + """ + transition, _ = self._prepare_transition(data) + + # Yield initial state + yield transition + + # Process each step WITHOUT hooks (low-level method) + for processor_step in self.steps: + transition = processor_step(transition) + yield transition + + def _save_pretrained(self, save_directory: Path, **kwargs): + """Internal save method for ModelHubMixin compatibility.""" + # Extract config_filename from kwargs if provided + config_filename = kwargs.pop("config_filename", None) + self.save_pretrained(save_directory, config_filename=config_filename) + + def save_pretrained(self, save_directory: str | Path, config_filename: str | None = None, **kwargs): + """Serialize the processor definition and parameters to *save_directory*. + + Args: + save_directory: Directory where the processor will be saved. + config_filename: Optional custom config filename. If not provided, defaults to + "{self.name}.json" where self.name is sanitized for filesystem compatibility. + """ + os.makedirs(str(save_directory), exist_ok=True) + + # Sanitize processor name for use in filenames + import re + + # The huggingface hub does not allow special characters in the repo name, so we sanitize the name + sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) + + # Use sanitized name for config if not provided + if config_filename is None: + config_filename = f"{sanitized_name}.json" + + config: dict[str, Any] = { + "name": self.name, + "steps": [], + } + + for step_index, processor_step in enumerate(self.steps): + # Check if step was registered + registry_name = getattr(processor_step.__class__, "_registry_name", None) + + step_entry: dict[str, Any] = {} + if registry_name: + # Use registry name for registered steps + step_entry["registry_name"] = registry_name + else: + # Fall back to full module path for unregistered steps + step_entry["class"] = ( + f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}" + ) + + if hasattr(processor_step, "get_config"): + step_entry["config"] = processor_step.get_config() + + if hasattr(processor_step, "state_dict"): + state = processor_step.state_dict() + if state: + # Clone tensors to avoid shared memory issues + # This ensures each tensor has its own memory allocation + # The reason is to avoid the following error: + # RuntimeError: Some tensors share memory, this will lead to duplicate memory on disk + # and potential differences when loading them again + # ------------------------------------------------------------------------------ + # Since the state_dict of processor will be light, we can just clone the tensors + # and save them to the disk. + cloned_state = {} + for key, tensor in state.items(): + cloned_state[key] = tensor.clone() + + # Include pipeline name and step index to ensure unique filenames + # This prevents conflicts when multiple processors are saved in the same directory + if registry_name: + state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors" + else: + state_filename = f"{sanitized_name}_step_{step_index}.safetensors" + + save_file(cloned_state, os.path.join(str(save_directory), state_filename)) + step_entry["state_file"] = state_filename + + config["steps"].append(step_entry) + + with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer: + json.dump(config, file_pointer, indent=2) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str | Path, + *, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict[str, str] | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + config_filename: str | None = None, + overrides: dict[str, Any] | None = None, + **kwargs, + ) -> RobotProcessor: + """Load a serialized processor from source (local path or Hugging Face Hub identifier). + + Args: + pretrained_model_name_or_path: Local path to a saved processor directory or Hugging Face Hub identifier + (e.g., "username/processor-name"). + config_filename: Optional specific config filename to load. If not provided, will: + - For local paths: look for any .json file in the directory (error if multiple found) + - For HF Hub: try common names ("processor.json", "preprocessor.json", "postprocessor.json") + overrides: Optional dictionary mapping step names to configuration overrides. + Keys must match exact step class names (for unregistered steps) or registry names + (for registered steps). Values are dictionaries containing parameter overrides + that will be merged with the saved configuration. This is useful for providing + non-serializable objects like environment instances. + + Returns: + A RobotProcessor instance loaded from the saved configuration. + + Raises: + ImportError: If a processor step class cannot be loaded or imported. + ValueError: If a step cannot be instantiated with the provided configuration. + KeyError: If an override key doesn't match any step in the saved configuration. + + Examples: + Basic loading: + ```python + processor = RobotProcessor.from_pretrained("path/to/processor") + ``` + + Loading specific config file: + ```python + processor = RobotProcessor.from_pretrained( + "username/multi-processor-repo", config_filename="preprocessor.json" + ) + ``` + + Loading with overrides for non-serializable objects: + ```python + import gym + + env = gym.make("CartPole-v1") + processor = RobotProcessor.from_pretrained( + "username/cartpole-processor", overrides={"ActionRepeatStep": {"env": env}} + ) + ``` + + Multiple overrides: + ```python + processor = RobotProcessor.from_pretrained( + "path/to/processor", + overrides={ + "CustomStep": {"param1": "new_value"}, + "device_processor": {"device": "cuda:1"}, # For registered steps + }, + ) + ``` + """ + # Use the local variable name 'source' for clarity + source = str(pretrained_model_name_or_path) + + if Path(source).is_dir(): + # Local path - use it directly + base_path = Path(source) + + if config_filename is None: + # Look for any .json file in the directory + json_files = list(base_path.glob("*.json")) + if len(json_files) == 0: + raise FileNotFoundError(f"No .json configuration files found in {source}") + elif len(json_files) > 1: + raise ValueError( + f"Multiple .json files found in {source}: {[f.name for f in json_files]}. " + f"Please specify which one to load using the config_filename parameter." + ) + config_filename = json_files[0].name + + with open(base_path / config_filename) as file_pointer: + loaded_config: dict[str, Any] = json.load(file_pointer) + else: + # Hugging Face Hub - download all required files + if config_filename is None: + # Try common config names + common_names = [ + "processor.json", + "preprocessor.json", + "postprocessor.json", + "robotprocessor.json", + ] + config_path = None + for name in common_names: + try: + config_path = hf_hub_download( + source, + name, + repo_type="model", + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + config_filename = name + break + except (FileNotFoundError, OSError, HfHubHTTPError): + # FileNotFoundError: local file issues + # OSError: network/system errors + # HfHubHTTPError: file not found on Hub (404) or other HTTP errors + continue + + if config_path is None: + raise FileNotFoundError( + f"No processor configuration file found in {source}. " + f"Tried: {common_names}. Please specify the config_filename parameter." + ) + else: + # Download specific config file + config_path = hf_hub_download( + source, + config_filename, + repo_type="model", + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + + with open(config_path) as file_pointer: + loaded_config = json.load(file_pointer) + + # Store downloaded files in the same directory as the config + base_path = Path(config_path).parent + + # Handle None overrides + if overrides is None: + overrides = {} + + # Validate that all override keys will be matched + override_keys = set(overrides.keys()) + + steps: list[ProcessorStep] = [] + for step_entry in loaded_config["steps"]: + # Check if step uses registry name or module path + if "registry_name" in step_entry: + # Load from registry + try: + step_class = ProcessorStepRegistry.get(step_entry["registry_name"]) + step_key = step_entry["registry_name"] + except KeyError as e: + raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e + else: + # Fall back to module path loading for backward compatibility + full_class_path = step_entry["class"] + module_path, class_name = full_class_path.rsplit(".", 1) + + # Import the module containing the step class + try: + module = importlib.import_module(module_path) + step_class = getattr(module, class_name) + step_key = class_name + except (ImportError, AttributeError) as e: + raise ImportError( + f"Failed to load processor step '{full_class_path}'. " + f"Make sure the module '{module_path}' is installed and contains class '{class_name}'. " + f"Consider registering the step using @ProcessorStepRegistry.register() for better portability. " + f"Error: {str(e)}" + ) from e + + # Instantiate the step with its config + try: + saved_cfg = step_entry.get("config", {}) + step_overrides = overrides.get(step_key, {}) + merged_cfg = {**saved_cfg, **step_overrides} + step_instance: ProcessorStep = step_class(**merged_cfg) + + # Track which override keys were used + if step_key in override_keys: + override_keys.discard(step_key) + + except Exception as e: + step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown")) + raise ValueError( + f"Failed to instantiate processor step '{step_name}' with config: {step_entry.get('config', {})}. " + f"Error: {str(e)}" + ) from e + + # Load state if available + if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"): + if Path(source).is_dir(): + # Local path - read directly + state_path = str(base_path / step_entry["state_file"]) + else: + # Hugging Face Hub - download the state file + state_path = hf_hub_download( + source, + step_entry["state_file"], + repo_type="model", + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + + step_instance.load_state_dict(load_file(state_path)) + + steps.append(step_instance) + + # Check for unused override keys + if override_keys: + available_keys = [] + for step_entry in loaded_config["steps"]: + if "registry_name" in step_entry: + available_keys.append(step_entry["registry_name"]) + else: + full_class_path = step_entry["class"] + class_name = full_class_path.rsplit(".", 1)[1] + available_keys.append(class_name) + + raise KeyError( + f"Override keys {list(override_keys)} do not match any step in the saved configuration. " + f"Available step keys: {available_keys}. " + f"Make sure override keys match exact step class names or registry names." + ) + + return cls(steps, loaded_config.get("name", "RobotProcessor")) + + def __len__(self) -> int: + """Return the number of steps in the processor.""" + return len(self.steps) + + def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor: + """Indexing helper exposing underlying steps. + * ``int`` – returns the idx-th ProcessorStep. + * ``slice`` – returns a new RobotProcessor with the sliced steps. + """ + if isinstance(idx, slice): + return RobotProcessor(self.steps[idx], self.name) + return self.steps[idx] + + def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]): + """Attach fn to be executed before every processor step.""" + self.before_step_hooks.append(fn) + + def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], None]): + """Remove a previously registered before_step hook. + + Args: + fn: The exact function reference that was registered. Must be the same object. + + Raises: + ValueError: If the hook is not found in the registered hooks. + """ + try: + self.before_step_hooks.remove(fn) + except ValueError: + raise ValueError( + f"Hook {fn} not found in before_step_hooks. Make sure to pass the exact same function reference." + ) from None + + def register_after_step_hook(self, fn: Callable[[int, EnvTransition], None]): + """Attach fn to be executed after every processor step.""" + self.after_step_hooks.append(fn) + + def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], None]): + """Remove a previously registered after_step hook. + + Args: + fn: The exact function reference that was registered. Must be the same object. + + Raises: + ValueError: If the hook is not found in the registered hooks. + """ + try: + self.after_step_hooks.remove(fn) + except ValueError: + raise ValueError( + f"Hook {fn} not found in after_step_hooks. Make sure to pass the exact same function reference." + ) from None + + def reset(self): + """Clear state in every step that implements ``reset()`` and fire registered hooks.""" + for step in self.steps: + if hasattr(step, "reset"): + step.reset() # type: ignore[attr-defined] + + def __repr__(self) -> str: + """Return a readable string representation of the processor.""" + step_names = [step.__class__.__name__ for step in self.steps] + + if not step_names: + steps_repr = "steps=0: []" + elif len(step_names) <= 3: + steps_repr = f"steps={len(step_names)}: [{', '.join(step_names)}]" + else: + # Show first 2 and last 1 with ellipsis for long lists + displayed = f"{step_names[0]}, {step_names[1]}, ..., {step_names[-1]}" + steps_repr = f"steps={len(step_names)}: [{displayed}]" + + parts = [f"name='{self.name}'", steps_repr] + + return f"RobotProcessor({', '.join(parts)})" + + def __post_init__(self): + for i, step in enumerate(self.steps): + if not callable(step): + raise TypeError( + f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition" + ) + + fc = getattr(step, "feature_contract", None) + if not callable(fc): + raise TypeError( + f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]" + ) + + def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """ + Apply ALL steps in order. Each step must implement + feature_contract(features) and return a dict (full or incremental schema). + """ + features: dict[str, PolicyFeature] = deepcopy(initial_features) + + for _, step in enumerate(self.steps): + out = step.feature_contract(features) + if not isinstance(out, dict): + raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]") + features = out + return features + + +class ObservationProcessor: + """Base class for processors that modify only the observation component of a transition. + + Subclasses should override the `observation` method to implement custom observation processing. + This class handles the boilerplate of extracting and reinserting the processed observation + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class MyObservationScaler(ObservationProcessor): + def __init__(self, scale_factor): + self.scale_factor = scale_factor + + def observation(self, observation): + return observation * self.scale_factor + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific observation processing logic. + """ + + def observation(self, observation): + """Process the observation component. + + Args: + observation: The observation to process + + Returns: + The processed observation + """ + return observation + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = transition.get(TransitionKey.OBSERVATION) + if observation is None: + return transition + + processed_observation = self.observation(observation) + # Create a new transition dict with the processed observation + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = processed_observation + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class ActionProcessor: + """Base class for processors that modify only the action component of a transition. + + Subclasses should override the `action` method to implement custom action processing. + This class handles the boilerplate of extracting and reinserting the processed action + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class ActionClipping(ActionProcessor): + def __init__(self, min_val, max_val): + self.min_val = min_val + self.max_val = max_val + + def action(self, action): + return np.clip(action, self.min_val, self.max_val) + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific action processing logic. + """ + + def action(self, action): + """Process the action component. + + Args: + action: The action to process + + Returns: + The processed action + """ + return action + + def __call__(self, transition: EnvTransition) -> EnvTransition: + action = transition.get(TransitionKey.ACTION) + if action is None: + return transition + + processed_action = self.action(action) + # Create a new transition dict with the processed action + new_transition = transition.copy() + new_transition[TransitionKey.ACTION] = processed_action + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class RewardProcessor: + """Base class for processors that modify only the reward component of a transition. + + Subclasses should override the `reward` method to implement custom reward processing. + This class handles the boilerplate of extracting and reinserting the processed reward + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class RewardScaler(RewardProcessor): + def __init__(self, scale_factor): + self.scale_factor = scale_factor + + def reward(self, reward): + return reward * self.scale_factor + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific reward processing logic. + """ + + def reward(self, reward): + """Process the reward component. + + Args: + reward: The reward to process + + Returns: + The processed reward + """ + return reward + + def __call__(self, transition: EnvTransition) -> EnvTransition: + reward = transition.get(TransitionKey.REWARD) + if reward is None: + return transition + + processed_reward = self.reward(reward) + # Create a new transition dict with the processed reward + new_transition = transition.copy() + new_transition[TransitionKey.REWARD] = processed_reward + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class DoneProcessor: + """Base class for processors that modify only the done flag of a transition. + + Subclasses should override the `done` method to implement custom done flag processing. + This class handles the boilerplate of extracting and reinserting the processed done flag + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class TimeoutDone(DoneProcessor): + def __init__(self, max_steps): + self.steps = 0 + self.max_steps = max_steps + + def done(self, done): + self.steps += 1 + return done or self.steps >= self.max_steps + + def reset(self): + self.steps = 0 + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific done flag processing logic. + """ + + def done(self, done): + """Process the done flag. + + Args: + done: The done flag to process + + Returns: + The processed done flag + """ + return done + + def __call__(self, transition: EnvTransition) -> EnvTransition: + done = transition.get(TransitionKey.DONE) + if done is None: + return transition + + processed_done = self.done(done) + # Create a new transition dict with the processed done flag + new_transition = transition.copy() + new_transition[TransitionKey.DONE] = processed_done + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class TruncatedProcessor: + """Base class for processors that modify only the truncated flag of a transition. + + Subclasses should override the `truncated` method to implement custom truncated flag processing. + This class handles the boilerplate of extracting and reinserting the processed truncated flag + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class EarlyTruncation(TruncatedProcessor): + def __init__(self, threshold): + self.threshold = threshold + + def truncated(self, truncated): + # Additional truncation condition + return truncated or some_condition > self.threshold + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific truncated flag processing logic. + """ + + def truncated(self, truncated): + """Process the truncated flag. + + Args: + truncated: The truncated flag to process + + Returns: + The processed truncated flag + """ + return truncated + + def __call__(self, transition: EnvTransition) -> EnvTransition: + truncated = transition.get(TransitionKey.TRUNCATED) + if truncated is None: + return transition + + processed_truncated = self.truncated(truncated) + # Create a new transition dict with the processed truncated flag + new_transition = transition.copy() + new_transition[TransitionKey.TRUNCATED] = processed_truncated + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class InfoProcessor: + """Base class for processors that modify only the info dictionary of a transition. + + Subclasses should override the `info` method to implement custom info processing. + This class handles the boilerplate of extracting and reinserting the processed info + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class InfoAugmenter(InfoProcessor): + def __init__(self): + self.step_count = 0 + + def info(self, info): + info = info.copy() # Create a copy to avoid modifying the original + info["steps"] = self.step_count + self.step_count += 1 + return info + + def reset(self): + self.step_count = 0 + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific info dictionary processing logic. + """ + + def info(self, info): + """Process the info dictionary. + + Args: + info: The info dictionary to process + + Returns: + The processed info dictionary + """ + return info + + def __call__(self, transition: EnvTransition) -> EnvTransition: + info = transition.get(TransitionKey.INFO) + if info is None: + return transition + + processed_info = self.info(info) + # Create a new transition dict with the processed info + new_transition = transition.copy() + new_transition[TransitionKey.INFO] = processed_info + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class ComplementaryDataProcessor: + """Base class for processors that modify only the complementary data of a transition. + + Subclasses should override the `complementary_data` method to implement custom complementary data processing. + This class handles the boilerplate of extracting and reinserting the processed complementary data + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + """ + + def complementary_data(self, complementary_data): + """Process the complementary data. + + Args: + complementary_data: The complementary data to process + + Returns: + The processed complementary data + """ + return complementary_data + + def __call__(self, transition: EnvTransition) -> EnvTransition: + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None: + return transition + + processed_complementary_data = self.complementary_data(complementary_data) + # Create a new transition dict with the processed complementary data + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class IdentityProcessor: + """Identity processor that does nothing.""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py new file mode 100644 index 000000000..4fe4105a5 --- /dev/null +++ b/src/lerobot/processor/rename_processor.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import Any + +from lerobot.configs.types import PolicyFeature +from lerobot.processor.pipeline import ( + ObservationProcessor, + ProcessorStepRegistry, +) + + +@dataclass +@ProcessorStepRegistry.register(name="rename_processor") +class RenameProcessor(ObservationProcessor): + """Rename processor that renames keys in the observation.""" + + rename_map: dict[str, str] = field(default_factory=dict) + + def observation(self, observation): + processed_obs = {} + for key, value in observation.items(): + if key in self.rename_map: + processed_obs[self.rename_map[key]] = value + else: + processed_obs[key] = value + + return processed_obs + + def get_config(self) -> dict[str, Any]: + return {"rename_map": self.rename_map} + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """Transforms: + - Each key in the observation that appears in `rename_map` is renamed to its value. + - Keys not in `rename_map` remain unchanged. + """ + return {self.rename_map.get(k, k): v for k, v in features.items()} diff --git a/tests/conftest.py b/tests/conftest.py index 69dd3049b..7940cc5ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,7 @@ import traceback import pytest from serial import SerialException +from lerobot.configs.types import FeatureType, PolicyFeature from tests.utils import DEVICE # Import fixture modules as plugins @@ -69,3 +70,19 @@ def patch_builtins_input(monkeypatch): print(text) monkeypatch.setattr("builtins.input", print_text) + + +@pytest.fixture +def policy_feature_factory(): + """PolicyFeature factory""" + + def _pf(ft: FeatureType, shape: tuple[int, ...]) -> PolicyFeature: + return PolicyFeature(type=ft, shape=shape) + + return _pf + + +def assert_contract_is_typed(features: dict[str, PolicyFeature]) -> None: + assert isinstance(features, dict) + assert all(isinstance(k, str) for k in features.keys()) + assert all(isinstance(v, PolicyFeature) for v in features.values()) diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py new file mode 100644 index 000000000..63894025d --- /dev/null +++ b/tests/processor/test_batch_conversion.py @@ -0,0 +1,282 @@ +import torch + +from lerobot.processor.pipeline import ( + RobotProcessor, + TransitionKey, + _default_batch_to_transition, + _default_transition_to_batch, +) + + +def _dummy_batch(): + """Create a dummy batch using the new format with observation.* and next.* keys.""" + return { + "observation.image.left": torch.randn(1, 3, 128, 128), + "observation.image.right": torch.randn(1, 3, 128, 128), + "observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]), + "action": torch.tensor([[0.5]]), + "next.reward": 1.0, + "next.done": False, + "next.truncated": False, + "info": {"key": "value"}, + } + + +def test_observation_grouping_roundtrip(): + """Test that observation.* keys are properly grouped and ungrouped.""" + proc = RobotProcessor([]) + batch_in = _dummy_batch() + batch_out = proc(batch_in) + + # Check that all observation.* keys are preserved + original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")} + reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")} + + assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys()) + + # Check tensor values + assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"]) + assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"]) + assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"]) + + # Check other fields + assert torch.allclose(batch_out["action"], batch_in["action"]) + assert batch_out["next.reward"] == batch_in["next.reward"] + assert batch_out["next.done"] == batch_in["next.done"] + assert batch_out["next.truncated"] == batch_in["next.truncated"] + assert batch_out["info"] == batch_in["info"] + + +def test_batch_to_transition_observation_grouping(): + """Test that _default_batch_to_transition correctly groups observation.* keys.""" + batch = { + "observation.image.top": torch.randn(1, 3, 128, 128), + "observation.image.left": torch.randn(1, 3, 128, 128), + "observation.state": [1, 2, 3, 4], + "action": "action_data", + "next.reward": 1.5, + "next.done": True, + "next.truncated": False, + "info": {"episode": 42}, + } + + transition = _default_batch_to_transition(batch) + + # Check observation is a dict with all observation.* keys + assert isinstance(transition[TransitionKey.OBSERVATION], dict) + assert "observation.image.top" in transition[TransitionKey.OBSERVATION] + assert "observation.image.left" in transition[TransitionKey.OBSERVATION] + assert "observation.state" in transition[TransitionKey.OBSERVATION] + + # Check values are preserved + assert torch.allclose( + transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"] + ) + assert torch.allclose( + transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"] + ) + assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4] + + # Check other fields + assert transition[TransitionKey.ACTION] == "action_data" + assert transition[TransitionKey.REWARD] == 1.5 + assert transition[TransitionKey.DONE] + assert not transition[TransitionKey.TRUNCATED] + assert transition[TransitionKey.INFO] == {"episode": 42} + assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} + + +def test_transition_to_batch_observation_flattening(): + """Test that _default_transition_to_batch correctly flattens observation dict.""" + observation_dict = { + "observation.image.top": torch.randn(1, 3, 128, 128), + "observation.image.left": torch.randn(1, 3, 128, 128), + "observation.state": [1, 2, 3, 4], + } + + transition = { + TransitionKey.OBSERVATION: observation_dict, + TransitionKey.ACTION: "action_data", + TransitionKey.REWARD: 1.5, + TransitionKey.DONE: True, + TransitionKey.TRUNCATED: False, + TransitionKey.INFO: {"episode": 42}, + TransitionKey.COMPLEMENTARY_DATA: {}, + } + + batch = _default_transition_to_batch(transition) + + # Check that observation.* keys are flattened back to batch + assert "observation.image.top" in batch + assert "observation.image.left" in batch + assert "observation.state" in batch + + # Check values are preserved + assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"]) + assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"]) + assert batch["observation.state"] == [1, 2, 3, 4] + + # Check other fields are mapped to next.* format + assert batch["action"] == "action_data" + assert batch["next.reward"] == 1.5 + assert batch["next.done"] + assert not batch["next.truncated"] + assert batch["info"] == {"episode": 42} + + +def test_no_observation_keys(): + """Test behavior when there are no observation.* keys.""" + batch = { + "action": "action_data", + "next.reward": 2.0, + "next.done": False, + "next.truncated": True, + "info": {"test": "no_obs"}, + } + + transition = _default_batch_to_transition(batch) + + # Observation should be None when no observation.* keys + assert transition[TransitionKey.OBSERVATION] is None + + # Check other fields + assert transition[TransitionKey.ACTION] == "action_data" + assert transition[TransitionKey.REWARD] == 2.0 + assert not transition[TransitionKey.DONE] + assert transition[TransitionKey.TRUNCATED] + assert transition[TransitionKey.INFO] == {"test": "no_obs"} + + # Round trip should work + reconstructed_batch = _default_transition_to_batch(transition) + assert reconstructed_batch["action"] == "action_data" + assert reconstructed_batch["next.reward"] == 2.0 + assert not reconstructed_batch["next.done"] + assert reconstructed_batch["next.truncated"] + assert reconstructed_batch["info"] == {"test": "no_obs"} + + +def test_minimal_batch(): + """Test with minimal batch containing only observation.* and action.""" + batch = {"observation.state": "minimal_state", "action": "minimal_action"} + + transition = _default_batch_to_transition(batch) + + # Check observation + assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"} + assert transition[TransitionKey.ACTION] == "minimal_action" + + # Check defaults + assert transition[TransitionKey.REWARD] == 0.0 + assert not transition[TransitionKey.DONE] + assert not transition[TransitionKey.TRUNCATED] + assert transition[TransitionKey.INFO] == {} + assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} + + # Round trip + reconstructed_batch = _default_transition_to_batch(transition) + assert reconstructed_batch["observation.state"] == "minimal_state" + assert reconstructed_batch["action"] == "minimal_action" + assert reconstructed_batch["next.reward"] == 0.0 + assert not reconstructed_batch["next.done"] + assert not reconstructed_batch["next.truncated"] + assert reconstructed_batch["info"] == {} + + +def test_empty_batch(): + """Test behavior with empty batch.""" + batch = {} + + transition = _default_batch_to_transition(batch) + + # All fields should have defaults + assert transition[TransitionKey.OBSERVATION] is None + assert transition[TransitionKey.ACTION] is None + assert transition[TransitionKey.REWARD] == 0.0 + assert not transition[TransitionKey.DONE] + assert not transition[TransitionKey.TRUNCATED] + assert transition[TransitionKey.INFO] == {} + assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} + + # Round trip + reconstructed_batch = _default_transition_to_batch(transition) + assert reconstructed_batch["action"] is None + assert reconstructed_batch["next.reward"] == 0.0 + assert not reconstructed_batch["next.done"] + assert not reconstructed_batch["next.truncated"] + assert reconstructed_batch["info"] == {} + + +def test_complex_nested_observation(): + """Test with complex nested observation data.""" + batch = { + "observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890}, + "observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, + "observation.state": torch.randn(7), + "action": torch.randn(8), + "next.reward": 3.14, + "next.done": False, + "next.truncated": True, + "info": {"episode_length": 200, "success": True}, + } + + transition = _default_batch_to_transition(batch) + reconstructed_batch = _default_transition_to_batch(transition) + + # Check that all observation keys are preserved + original_obs_keys = {k for k in batch if k.startswith("observation.")} + reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")} + + assert original_obs_keys == reconstructed_obs_keys + + # Check tensor values + assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"]) + + # Check nested dict with tensors + assert torch.allclose( + batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"] + ) + assert torch.allclose( + batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"] + ) + + # Check action tensor + assert torch.allclose(batch["action"], reconstructed_batch["action"]) + + # Check other fields + assert batch["next.reward"] == reconstructed_batch["next.reward"] + assert batch["next.done"] == reconstructed_batch["next.done"] + assert batch["next.truncated"] == reconstructed_batch["next.truncated"] + assert batch["info"] == reconstructed_batch["info"] + + +def test_custom_converter(): + """Test that custom converters can still be used.""" + + def to_tr(batch): + # Custom converter that modifies the reward + tr = _default_batch_to_transition(batch) + # Double the reward + reward = tr.get(TransitionKey.REWARD, 0.0) + new_tr = tr.copy() + new_tr[TransitionKey.REWARD] = reward * 2 if reward is not None else 0.0 + return new_tr + + def to_batch(tr): + batch = _default_transition_to_batch(tr) + return batch + + processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch) + + batch = { + "observation.state": torch.randn(1, 4), + "action": torch.randn(1, 2), + "next.reward": 1.0, + "next.done": False, + } + + result = processor(batch) + + # Check the reward was doubled by our custom converter + assert result["next.reward"] == 2.0 + assert torch.allclose(result["observation.state"], batch["observation.state"]) + assert torch.allclose(result["action"], batch["action"]) diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py new file mode 100644 index 000000000..26aea56c7 --- /dev/null +++ b/tests/processor/test_normalize_processor.py @@ -0,0 +1,628 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock + +import numpy as np +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.processor.normalize_processor import ( + NormalizerProcessor, + UnnormalizerProcessor, + _convert_stats_to_tensors, +) +from lerobot.processor.pipeline import RobotProcessor, TransitionKey + + +def create_transition( + observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } + + +def test_numpy_conversion(): + stats = { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + } + } + tensor_stats = _convert_stats_to_tensors(stats) + + assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor) + assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor) + assert torch.allclose(tensor_stats["observation.image"]["mean"], torch.tensor([0.5, 0.5, 0.5])) + assert torch.allclose(tensor_stats["observation.image"]["std"], torch.tensor([0.2, 0.2, 0.2])) + + +def test_tensor_conversion(): + stats = { + "action": { + "mean": torch.tensor([0.0, 0.0]), + "std": torch.tensor([1.0, 1.0]), + } + } + tensor_stats = _convert_stats_to_tensors(stats) + + assert tensor_stats["action"]["mean"].dtype == torch.float32 + assert tensor_stats["action"]["std"].dtype == torch.float32 + + +def test_scalar_conversion(): + stats = { + "reward": { + "mean": 0.5, + "std": 0.1, + } + } + tensor_stats = _convert_stats_to_tensors(stats) + + assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5)) + assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1)) + + +def test_list_conversion(): + stats = { + "observation.state": { + "min": [0.0, -1.0, -2.0], + "max": [1.0, 1.0, 2.0], + } + } + tensor_stats = _convert_stats_to_tensors(stats) + + assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0])) + assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0])) + + +def test_unsupported_type(): + stats = { + "bad_key": { + "mean": "string_value", + } + } + with pytest.raises(TypeError, match="Unsupported type"): + _convert_stats_to_tensors(stats) + + +# Helper functions to create feature maps and norm maps +def _create_observation_features(): + return { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + + +def _create_observation_norm_map(): + return { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.STATE: NormalizationMode.MIN_MAX, + } + + +# Fixtures for observation normalisation tests using NormalizerProcessor +@pytest.fixture +def observation_stats(): + return { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + }, + "observation.state": { + "min": np.array([0.0, -1.0]), + "max": np.array([1.0, 1.0]), + }, + } + + +@pytest.fixture +def observation_normalizer(observation_stats): + """Return a NormalizerProcessor that only has observation stats (no action).""" + features = _create_observation_features() + norm_map = _create_observation_norm_map() + return NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats) + + +def test_mean_std_normalization(observation_normalizer): + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + normalized_transition = observation_normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Check mean/std normalization + expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 + assert torch.allclose(normalized_obs["observation.image"], expected_image) + + +def test_min_max_normalization(observation_normalizer): + observation = { + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + normalized_transition = observation_normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Check min/max normalization to [-1, 1] + # For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0 + # For state[1]: 2 * (0.0 - (-1.0)) / (1.0 - (-1.0)) - 1 = 0.0 + expected_state = torch.tensor([0.0, 0.0]) + assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6) + + +def test_selective_normalization(observation_stats): + features = _create_observation_features() + norm_map = _create_observation_norm_map() + normalizer = NormalizerProcessor( + features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"} + ) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Only image should be normalized + assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2) + # State should remain unchanged + assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"]) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_device_compatibility(observation_stats): + features = _create_observation_features() + norm_map = _create_observation_norm_map() + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats) + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + assert normalized_obs["observation.image"].device.type == "cuda" + + +def test_from_lerobot_dataset(): + # Mock dataset + mock_dataset = Mock() + mock_dataset.meta.stats = { + "observation.image": {"mean": [0.5], "std": [0.2]}, + "action": {"mean": [0.0], "std": [1.0]}, + } + + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "action": PolicyFeature(FeatureType.ACTION, (1,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) + + # Both observation and action statistics should be present in tensor stats + assert "observation.image" in normalizer._tensor_stats + assert "action" in normalizer._tensor_stats + + +def test_state_dict_save_load(observation_normalizer): + # Save state + state_dict = observation_normalizer.state_dict() + + # Create new normalizer and load state + features = _create_observation_features() + norm_map = _create_observation_norm_map() + new_normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={}) + new_normalizer.load_state_dict(state_dict) + + # Test that it works the same + observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + transition = create_transition(observation=observation) + + result1 = observation_normalizer(transition)[TransitionKey.OBSERVATION] + result2 = new_normalizer(transition)[TransitionKey.OBSERVATION] + + assert torch.allclose(result1["observation.image"], result2["observation.image"]) + + +# Fixtures for ActionUnnormalizer tests +@pytest.fixture +def action_stats_mean_std(): + return { + "mean": np.array([0.0, 0.0, 0.0]), + "std": np.array([1.0, 2.0, 0.5]), + } + + +@pytest.fixture +def action_stats_min_max(): + return { + "min": np.array([-1.0, -2.0, 0.0]), + "max": np.array([1.0, 2.0, 1.0]), + } + + +def _create_action_features(): + return { + "action": PolicyFeature(FeatureType.ACTION, (3,)), + } + + +def _create_action_norm_map_mean_std(): + return { + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + +def _create_action_norm_map_min_max(): + return { + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + + +def test_mean_std_unnormalization(action_stats_mean_std): + features = _create_action_features() + norm_map = _create_action_norm_map_mean_std() + unnormalizer = UnnormalizerProcessor( + features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + ) + + normalized_action = torch.tensor([1.0, -0.5, 2.0]) + transition = create_transition(action=normalized_action) + + unnormalized_transition = unnormalizer(transition) + unnormalized_action = unnormalized_transition[TransitionKey.ACTION] + + # action * std + mean + expected = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0, 2.0 * 0.5 + 0.0]) + assert torch.allclose(unnormalized_action, expected) + + +def test_min_max_unnormalization(action_stats_min_max): + features = _create_action_features() + norm_map = _create_action_norm_map_min_max() + unnormalizer = UnnormalizerProcessor( + features=features, norm_map=norm_map, stats={"action": action_stats_min_max} + ) + + # Actions in [-1, 1] + normalized_action = torch.tensor([0.0, -1.0, 1.0]) + transition = create_transition(action=normalized_action) + + unnormalized_transition = unnormalizer(transition) + unnormalized_action = unnormalized_transition[TransitionKey.ACTION] + + # Map from [-1, 1] to [min, max] + # (action + 1) / 2 * (max - min) + min + expected = torch.tensor( + [ + (0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0), # 0.0 + (-1.0 + 1) / 2 * (2.0 - (-2.0)) + (-2.0), # -2.0 + (1.0 + 1) / 2 * (1.0 - 0.0) + 0.0, # 1.0 + ] + ) + assert torch.allclose(unnormalized_action, expected) + + +def test_numpy_action_input(action_stats_mean_std): + features = _create_action_features() + norm_map = _create_action_norm_map_mean_std() + unnormalizer = UnnormalizerProcessor( + features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + ) + + normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32) + transition = create_transition(action=normalized_action) + + unnormalized_transition = unnormalizer(transition) + unnormalized_action = unnormalized_transition[TransitionKey.ACTION] + + assert isinstance(unnormalized_action, torch.Tensor) + expected = torch.tensor([1.0, -1.0, 1.0]) + assert torch.allclose(unnormalized_action, expected) + + +def test_none_action(action_stats_mean_std): + features = _create_action_features() + norm_map = _create_action_norm_map_mean_std() + unnormalizer = UnnormalizerProcessor( + features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + ) + + transition = create_transition() + result = unnormalizer(transition) + + # Should return transition unchanged + assert result == transition + + +def test_action_from_lerobot_dataset(): + mock_dataset = Mock() + mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}} + features = {"action": PolicyFeature(FeatureType.ACTION, (1,))} + norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD} + unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) + assert "mean" in unnormalizer._tensor_stats["action"] + + +# Fixtures for NormalizerProcessor tests +@pytest.fixture +def full_stats(): + return { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + }, + "observation.state": { + "min": np.array([0.0, -1.0]), + "max": np.array([1.0, 1.0]), + }, + "action": { + "mean": np.array([0.0, 0.0]), + "std": np.array([1.0, 2.0]), + }, + } + + +def _create_full_features(): + return { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + + +def _create_full_norm_map(): + return { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.STATE: NormalizationMode.MIN_MAX, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + +@pytest.fixture +def normalizer_processor(full_stats): + features = _create_full_features() + norm_map = _create_full_norm_map() + return NormalizerProcessor(features=features, norm_map=norm_map, stats=full_stats) + + +def test_combined_normalization(normalizer_processor): + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + action = torch.tensor([1.0, -0.5]) + transition = create_transition( + observation=observation, + action=action, + reward=1.0, + done=False, + truncated=False, + info={}, + complementary_data={}, + ) + + processed_transition = normalizer_processor(transition) + + # Check normalized observations + processed_obs = processed_transition[TransitionKey.OBSERVATION] + expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 + assert torch.allclose(processed_obs["observation.image"], expected_image) + + # Check normalized action + processed_action = processed_transition[TransitionKey.ACTION] + expected_action = torch.tensor([(1.0 - 0.0) / 1.0, (-0.5 - 0.0) / 2.0]) + assert torch.allclose(processed_action, expected_action) + + # Check other fields remain unchanged + assert processed_transition[TransitionKey.REWARD] == 1.0 + assert not processed_transition[TransitionKey.DONE] + + +def test_processor_from_lerobot_dataset(full_stats): + # Mock dataset + mock_dataset = Mock() + mock_dataset.meta.stats = full_stats + + features = _create_full_features() + norm_map = _create_full_norm_map() + + processor = NormalizerProcessor.from_lerobot_dataset( + mock_dataset, features, norm_map, normalize_keys={"observation.image"} + ) + + assert processor.normalize_keys == {"observation.image"} + assert "observation.image" in processor._tensor_stats + assert "action" in processor._tensor_stats + + +def test_get_config(full_stats): + features = _create_full_features() + norm_map = _create_full_norm_map() + processor = NormalizerProcessor( + features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 + ) + + config = processor.get_config() + expected_config = { + "normalize_keys": ["observation.image"], + "eps": 1e-6, + "features": { + "observation.image": {"type": "VISUAL", "shape": (3, 96, 96)}, + "observation.state": {"type": "STATE", "shape": (2,)}, + "action": {"type": "ACTION", "shape": (2,)}, + }, + "norm_map": { + "VISUAL": "MEAN_STD", + "STATE": "MIN_MAX", + "ACTION": "MEAN_STD", + }, + } + assert config == expected_config + + +def test_integration_with_robot_processor(normalizer_processor): + """Test integration with RobotProcessor pipeline""" + robot_processor = RobotProcessor([normalizer_processor]) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + action = torch.tensor([1.0, -0.5]) + transition = create_transition( + observation=observation, + action=action, + reward=1.0, + done=False, + truncated=False, + info={}, + complementary_data={}, + ) + + processed_transition = robot_processor(transition) + + # Verify the processing worked + assert isinstance(processed_transition[TransitionKey.OBSERVATION], dict) + assert isinstance(processed_transition[TransitionKey.ACTION], torch.Tensor) + + +# Edge case tests +def test_empty_observation(): + stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + transition = create_transition() + result = normalizer(transition) + + assert result == transition + + +def test_empty_stats(): + features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={}) + observation = {"observation.image": torch.tensor([0.5])} + transition = create_transition(observation=observation) + + result = normalizer(transition) + # Should return observation unchanged since no stats are available + assert torch.allclose( + result[TransitionKey.OBSERVATION]["observation.image"], observation["observation.image"] + ) + + +def test_partial_stats(): + """If statistics are incomplete, the value should pass through unchanged.""" + stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max) + features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + observation = {"observation.image": torch.tensor([0.7])} + transition = create_transition(observation=observation) + + processed = normalizer(transition)[TransitionKey.OBSERVATION] + assert torch.allclose(processed["observation.image"], observation["observation.image"]) + + +def test_missing_action_stats_no_error(): + mock_dataset = Mock() + mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + + features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + + processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) + # The tensor stats should not contain the 'action' key + assert "action" not in processor._tensor_stats + + +def test_serialization_roundtrip(full_stats): + """Test that features and norm_map can be serialized and deserialized correctly.""" + features = _create_full_features() + norm_map = _create_full_norm_map() + original_processor = NormalizerProcessor( + features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 + ) + + # Get config (serialization) + config = original_processor.get_config() + + # Create a new processor from the config (deserialization) + new_processor = NormalizerProcessor( + features=config["features"], + norm_map=config["norm_map"], + stats=full_stats, + normalize_keys=set(config["normalize_keys"]), + eps=config["eps"], + ) + + # Test that both processors work the same way + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + action = torch.tensor([1.0, -0.5]) + transition = create_transition( + observation=observation, + action=action, + reward=1.0, + done=False, + truncated=False, + info={}, + complementary_data={}, + ) + + result1 = original_processor(transition) + result2 = new_processor(transition) + + # Compare results + assert torch.allclose( + result1[TransitionKey.OBSERVATION]["observation.image"], + result2[TransitionKey.OBSERVATION]["observation.image"], + ) + assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) + + # Verify features and norm_map are correctly reconstructed + assert new_processor.features.keys() == original_processor.features.keys() + for key in new_processor.features: + assert new_processor.features[key].type == original_processor.features[key].type + assert new_processor.features[key].shape == original_processor.features[key].shape + + assert new_processor.norm_map == original_processor.norm_map diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py new file mode 100644 index 000000000..e48b6bc08 --- /dev/null +++ b/tests/processor/test_observation_processor.py @@ -0,0 +1,486 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch + +from lerobot.configs.types import FeatureType +from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.processor import VanillaObservationProcessor +from lerobot.processor.pipeline import TransitionKey +from tests.conftest import assert_contract_is_typed + + +def create_transition( + observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } + + +def test_process_single_image(): + """Test processing a single image.""" + processor = VanillaObservationProcessor() + + # Create a mock image (H, W, C) format, uint8 + image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) + + observation = {"pixels": image} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that the image was processed correctly + assert "observation.image" in processed_obs + processed_img = processed_obs["observation.image"] + + # Check shape: should be (1, 3, 64, 64) - batch, channels, height, width + assert processed_img.shape == (1, 3, 64, 64) + + # Check dtype and range + assert processed_img.dtype == torch.float32 + assert processed_img.min() >= 0.0 + assert processed_img.max() <= 1.0 + + +def test_process_image_dict(): + """Test processing multiple images in a dictionary.""" + processor = VanillaObservationProcessor() + + # Create mock images + image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) + image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8) + + observation = {"pixels": {"camera1": image1, "camera2": image2}} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that both images were processed + assert "observation.images.camera1" in processed_obs + assert "observation.images.camera2" in processed_obs + + # Check shapes + assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32) + assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48) + + +def test_process_batched_image(): + """Test processing already batched images.""" + processor = VanillaObservationProcessor() + + # Create a batched image (B, H, W, C) + image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8) + + observation = {"pixels": image} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that batch dimension is preserved + assert processed_obs["observation.image"].shape == (2, 3, 64, 64) + + +def test_invalid_image_format(): + """Test error handling for invalid image formats.""" + processor = VanillaObservationProcessor() + + # Test wrong channel order (channels first) + image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8) + observation = {"pixels": image} + transition = create_transition(observation=observation) + + with pytest.raises(ValueError, match="Expected channel-last images"): + processor(transition) + + +def test_invalid_image_dtype(): + """Test error handling for invalid image dtype.""" + processor = VanillaObservationProcessor() + + # Test wrong dtype + image = np.random.rand(64, 64, 3).astype(np.float32) + observation = {"pixels": image} + transition = create_transition(observation=observation) + + with pytest.raises(ValueError, match="Expected torch.uint8 images"): + processor(transition) + + +def test_no_pixels_in_observation(): + """Test processor when no pixels are in observation.""" + processor = VanillaObservationProcessor() + + observation = {"other_data": np.array([1, 2, 3])} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Should preserve other data unchanged + assert "other_data" in processed_obs + np.testing.assert_array_equal(processed_obs["other_data"], np.array([1, 2, 3])) + + +def test_none_observation(): + """Test processor with None observation.""" + processor = VanillaObservationProcessor() + + transition = create_transition() + result = processor(transition) + + assert result == transition + + +def test_serialization_methods(): + """Test serialization methods.""" + processor = VanillaObservationProcessor() + + # Test get_config + config = processor.get_config() + assert isinstance(config, dict) + + # Test state_dict + state = processor.state_dict() + assert isinstance(state, dict) + + # Test load_state_dict (should not raise) + processor.load_state_dict(state) + + # Test reset (should not raise) + processor.reset() + + +def test_process_environment_state(): + """Test processing environment_state.""" + processor = VanillaObservationProcessor() + + env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) + observation = {"environment_state": env_state} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that environment_state was renamed and processed + assert "observation.environment_state" in processed_obs + assert "environment_state" not in processed_obs + + processed_state = processed_obs["observation.environment_state"] + assert processed_state.shape == (1, 3) # Batch dimension added + assert processed_state.dtype == torch.float32 + torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]])) + + +def test_process_agent_pos(): + """Test processing agent_pos.""" + processor = VanillaObservationProcessor() + + agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) + observation = {"agent_pos": agent_pos} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that agent_pos was renamed and processed + assert "observation.state" in processed_obs + assert "agent_pos" not in processed_obs + + processed_state = processed_obs["observation.state"] + assert processed_state.shape == (1, 3) # Batch dimension added + assert processed_state.dtype == torch.float32 + torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]])) + + +def test_process_batched_states(): + """Test processing already batched states.""" + processor = VanillaObservationProcessor() + + env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32) + + observation = {"environment_state": env_state, "agent_pos": agent_pos} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that batch dimensions are preserved + assert processed_obs["observation.environment_state"].shape == (2, 2) + assert processed_obs["observation.state"].shape == (2, 2) + + +def test_process_both_states(): + """Test processing both environment_state and agent_pos.""" + processor = VanillaObservationProcessor() + + env_state = np.array([1.0, 2.0], dtype=np.float32) + agent_pos = np.array([0.5, -0.5], dtype=np.float32) + + observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that both states were processed + assert "observation.environment_state" in processed_obs + assert "observation.state" in processed_obs + + # Check that original keys were removed + assert "environment_state" not in processed_obs + assert "agent_pos" not in processed_obs + + # Check that other data was preserved + assert processed_obs["other_data"] == "keep_me" + + +def test_no_states_in_observation(): + """Test processor when no states are in observation.""" + processor = VanillaObservationProcessor() + + observation = {"other_data": np.array([1, 2, 3])} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Should preserve data unchanged + np.testing.assert_array_equal(processed_obs, observation) + + +def test_complete_observation_processing(): + """Test processing a complete observation with both images and states.""" + processor = VanillaObservationProcessor() + + # Create mock data + image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) + env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) + agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) + + observation = { + "pixels": image, + "environment_state": env_state, + "agent_pos": agent_pos, + "other_data": "preserve_me", + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that image was processed + assert "observation.image" in processed_obs + assert processed_obs["observation.image"].shape == (1, 3, 32, 32) + + # Check that states were processed + assert "observation.environment_state" in processed_obs + assert "observation.state" in processed_obs + + # Check that original keys were removed + assert "pixels" not in processed_obs + assert "environment_state" not in processed_obs + assert "agent_pos" not in processed_obs + + # Check that other data was preserved + assert processed_obs["other_data"] == "preserve_me" + + +def test_image_only_processing(): + """Test processing observation with only images.""" + processor = VanillaObservationProcessor() + + image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) + observation = {"pixels": image} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert "observation.image" in processed_obs + assert len(processed_obs) == 1 + + +def test_state_only_processing(): + """Test processing observation with only states.""" + processor = VanillaObservationProcessor() + + agent_pos = np.array([1.0, 2.0], dtype=np.float32) + observation = {"agent_pos": agent_pos} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert "observation.state" in processed_obs + assert "agent_pos" not in processed_obs + + +def test_empty_observation(): + """Test processing empty observation.""" + processor = VanillaObservationProcessor() + + observation = {} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert processed_obs == {} + + +def test_equivalent_to_original_function(): + """Test that ObservationProcessor produces equivalent results to preprocess_observation.""" + # Import the original function for comparison + from lerobot.envs.utils import preprocess_observation + + processor = VanillaObservationProcessor() + + # Create test data similar to what the original function expects + image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) + env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) + agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) + + observation = {"pixels": image, "environment_state": env_state, "agent_pos": agent_pos} + + # Process with original function + original_result = preprocess_observation(observation) + + # Process with new processor + transition = create_transition(observation=observation) + processor_result = processor(transition)[TransitionKey.OBSERVATION] + + # Compare results + assert set(original_result.keys()) == set(processor_result.keys()) + + for key in original_result: + torch.testing.assert_close(original_result[key], processor_result[key]) + + +def test_equivalent_with_image_dict(): + """Test equivalence with dictionary of images.""" + from lerobot.envs.utils import preprocess_observation + + processor = VanillaObservationProcessor() + + # Create test data with multiple cameras + image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) + image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8) + agent_pos = np.array([1.0, 2.0], dtype=np.float32) + + observation = {"pixels": {"cam1": image1, "cam2": image2}, "agent_pos": agent_pos} + + # Process with original function + original_result = preprocess_observation(observation) + + # Process with new processor + transition = create_transition(observation=observation) + processor_result = processor(transition)[TransitionKey.OBSERVATION] + + # Compare results + assert set(original_result.keys()) == set(processor_result.keys()) + + for key in original_result: + torch.testing.assert_close(original_result[key], processor_result[key]) + + +def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory): + processor = VanillaObservationProcessor() + features = { + "pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + } + out = processor.feature_contract(features.copy()) + + assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"] + assert "pixels" not in out + assert out["keep"] == features["keep"] + assert_contract_is_typed(out) + + +def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory): + processor = VanillaObservationProcessor() + features = { + "observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + } + out = processor.feature_contract(features.copy()) + + assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"] + assert "observation.pixels" not in out + assert out["keep"] == features["keep"] + assert_contract_is_typed(out) + + +def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory): + processor = VanillaObservationProcessor() + features = { + "pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (7,)), + } + out = processor.feature_contract(features.copy()) + + assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"] + assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"] + assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"] + assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out + assert out["keep"] == features["keep"] + assert_contract_is_typed(out) + + +def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory): + processor = VanillaObservationProcessor() + features = { + "environment_state": policy_feature_factory(FeatureType.STATE, (3,)), + "agent_pos": policy_feature_factory(FeatureType.STATE, (7,)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + } + out = processor.feature_contract(features.copy()) + + assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"] + assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"] + assert "environment_state" not in out and "agent_pos" not in out + assert out["keep"] == features["keep"] + assert_contract_is_typed(out) + + +def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory): + proc = VanillaObservationProcessor() + features = { + "observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), + "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), + } + out = proc.feature_contract(features.copy()) + + assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"] + assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"] + assert "environment_state" not in out and "agent_pos" not in out + assert_contract_is_typed(out) diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py new file mode 100644 index 000000000..5665d5a7d --- /dev/null +++ b/tests/processor/test_pipeline.py @@ -0,0 +1,1919 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import tempfile +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import pytest +import torch +import torch.nn as nn + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor +from lerobot.processor.pipeline import TransitionKey +from tests.conftest import assert_contract_is_typed + + +def create_transition( + observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info if info is not None else {}, + TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {}, + } + + +@dataclass +class MockStep: + """Mock pipeline step for testing - demonstrates best practices. + + This example shows the proper separation: + - JSON-serializable attributes (name, counter) go in get_config() + - Only torch tensors go in state_dict() + + Note: The counter is part of the configuration, so it will be restored + when the step is recreated from config during loading. + """ + + name: str = "mock_step" + counter: int = 0 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Add a counter to the complementary_data.""" + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + comp_data = {} if comp_data is None else dict(comp_data) # Make a copy + + comp_data[f"{self.name}_counter"] = self.counter + self.counter += 1 + + # Create a new transition with updated complementary_data + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition + + def get_config(self) -> dict[str, Any]: + # Return all JSON-serializable attributes that should be persisted + # These will be passed to __init__ when loading + return {"name": self.name, "counter": self.counter} + + def state_dict(self) -> dict[str, torch.Tensor]: + # Only return torch tensors (empty in this case since we have no tensor state) + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + # No tensor state to load + pass + + def reset(self) -> None: + self.counter = 0 + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +@dataclass +class MockStepWithoutOptionalMethods: + """Mock step that only implements the required __call__ method.""" + + multiplier: float = 2.0 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Multiply reward by multiplier.""" + reward = transition.get(TransitionKey.REWARD) + + if reward is not None: + new_transition = transition.copy() + new_transition[TransitionKey.REWARD] = reward * self.multiplier + return new_transition + + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +@dataclass +class MockStepWithTensorState: + """Mock step demonstrating mixed JSON attributes and tensor state.""" + + name: str = "tensor_step" + learning_rate: float = 0.01 + window_size: int = 10 + + def __init__(self, name: str = "tensor_step", learning_rate: float = 0.01, window_size: int = 10): + self.name = name + self.learning_rate = learning_rate + self.window_size = window_size + # Tensor state + self.running_mean = torch.zeros(window_size) + self.running_count = torch.tensor(0) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Update running statistics.""" + reward = transition.get(TransitionKey.REWARD) + + if reward is not None: + # Update running mean + idx = self.running_count % self.window_size + self.running_mean[idx] = reward + self.running_count += 1 + + return transition + + def get_config(self) -> dict[str, Any]: + # Only JSON-serializable attributes + return { + "name": self.name, + "learning_rate": self.learning_rate, + "window_size": self.window_size, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + # Only tensor state + return { + "running_mean": self.running_mean, + "running_count": self.running_count, + } + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + self.running_mean = state["running_mean"] + self.running_count = state["running_count"] + + def reset(self) -> None: + self.running_mean.zero_() + self.running_count.zero_() + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +def test_empty_pipeline(): + """Test pipeline with no steps.""" + pipeline = RobotProcessor() + + transition = create_transition() + result = pipeline(transition) + + assert result == transition + assert len(pipeline) == 0 + + +def test_single_step_pipeline(): + """Test pipeline with a single step.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + + transition = create_transition() + result = pipeline(transition) + + assert len(pipeline) == 1 + assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0 + + # Call again to test counter increment + result = pipeline(transition) + assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 1 + + +def test_multiple_steps_pipeline(): + """Test pipeline with multiple steps.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotProcessor([step1, step2]) + + transition = create_transition() + result = pipeline(transition) + + assert len(pipeline) == 2 + assert result[TransitionKey.COMPLEMENTARY_DATA]["step1_counter"] == 0 + assert result[TransitionKey.COMPLEMENTARY_DATA]["step2_counter"] == 0 + + +def test_invalid_transition_format(): + """Test pipeline with invalid transition format.""" + pipeline = RobotProcessor([MockStep()]) + + # Test with wrong type (tuple instead of dict) + with pytest.raises(ValueError, match="EnvTransition must be a dictionary"): + pipeline((None, None, 0.0, False, False, {}, {})) # Tuple instead of dict + + # Test with wrong type (string) + with pytest.raises(ValueError, match="EnvTransition must be a dictionary"): + pipeline("not a dict") + + +def test_step_through(): + """Test step_through method with dict input.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotProcessor([step1, step2]) + + transition = create_transition() + + results = list(pipeline.step_through(transition)) + + assert len(results) == 3 # Original + 2 steps + assert results[0] == transition # Original + assert "step1_counter" in results[1][TransitionKey.COMPLEMENTARY_DATA] # After step1 + assert "step2_counter" in results[2][TransitionKey.COMPLEMENTARY_DATA] # After step2 + + # Ensure all results are dicts (same format as input) + for result in results: + assert isinstance(result, dict) + assert all(isinstance(k, TransitionKey) for k in result.keys()) + + +def test_step_through_with_dict(): + """Test step_through method with dict input.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotProcessor([step1, step2]) + + batch = { + "observation.image": None, + "action": None, + "next.reward": 0.0, + "next.done": False, + "next.truncated": False, + "info": {}, + } + + results = list(pipeline.step_through(batch)) + + assert len(results) == 3 # Original + 2 steps + + # Ensure all results are EnvTransition dicts (regardless of input format) + for result in results: + assert isinstance(result, dict) + # Check that keys are TransitionKey enums or at least valid transition keys + for key in result: + assert key in [ + TransitionKey.OBSERVATION, + TransitionKey.ACTION, + TransitionKey.REWARD, + TransitionKey.DONE, + TransitionKey.TRUNCATED, + TransitionKey.INFO, + TransitionKey.COMPLEMENTARY_DATA, + ] + + # Check that the processing worked - verify step counters in complementary_data + assert results[1].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0 + assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0 + assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step2_counter") == 0 + + +def test_step_through_no_hooks(): + """Test that step_through doesn't execute hooks.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + + hook_calls = [] + + def tracking_hook(idx: int, transition: EnvTransition): + hook_calls.append(f"hook_called_step_{idx}") + + # Register hooks + pipeline.register_before_step_hook(tracking_hook) + pipeline.register_after_step_hook(tracking_hook) + + # Use step_through + transition = create_transition() + results = list(pipeline.step_through(transition)) + + # Verify step was executed (counter should increment) + assert len(results) == 2 # Initial + 1 step + assert results[1][TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0 + + # Verify hooks were NOT called + assert len(hook_calls) == 0 + + # Now use __call__ to verify hooks ARE called there + hook_calls.clear() + pipeline(transition) + + # Verify hooks were called (before and after for 1 step = 2 calls) + assert len(hook_calls) == 2 + assert hook_calls == ["hook_called_step_0", "hook_called_step_0"] + + +def test_indexing(): + """Test pipeline indexing.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotProcessor([step1, step2]) + + # Test integer indexing + assert pipeline[0] is step1 + assert pipeline[1] is step2 + + # Test slice indexing + sub_pipeline = pipeline[0:1] + assert isinstance(sub_pipeline, RobotProcessor) + assert len(sub_pipeline) == 1 + assert sub_pipeline[0] is step1 + + +def test_hooks(): + """Test before/after step hooks.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + + before_calls = [] + after_calls = [] + + def before_hook(idx: int, transition: EnvTransition): + before_calls.append(idx) + + def after_hook(idx: int, transition: EnvTransition): + after_calls.append(idx) + + pipeline.register_before_step_hook(before_hook) + pipeline.register_after_step_hook(after_hook) + + transition = create_transition() + pipeline(transition) + + assert before_calls == [0] + assert after_calls == [0] + + +def test_unregister_hooks(): + """Test unregistering hooks from the pipeline.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + + # Test before_step_hook + before_calls = [] + + def before_hook(idx: int, transition: EnvTransition): + before_calls.append(idx) + + pipeline.register_before_step_hook(before_hook) + + # Verify hook is registered + transition = create_transition() + pipeline(transition) + assert len(before_calls) == 1 + + # Unregister and verify it's no longer called + pipeline.unregister_before_step_hook(before_hook) + before_calls.clear() + pipeline(transition) + assert len(before_calls) == 0 + + # Test after_step_hook + after_calls = [] + + def after_hook(idx: int, transition: EnvTransition): + after_calls.append(idx) + + pipeline.register_after_step_hook(after_hook) + pipeline(transition) + assert len(after_calls) == 1 + + pipeline.unregister_after_step_hook(after_hook) + after_calls.clear() + pipeline(transition) + assert len(after_calls) == 0 + + +def test_unregister_nonexistent_hook(): + """Test error handling when unregistering hooks that don't exist.""" + pipeline = RobotProcessor([MockStep()]) + + def some_hook(idx: int, transition: EnvTransition): + pass + + def reset_hook(): + pass + + # Test unregistering hooks that were never registered + with pytest.raises(ValueError, match="not found in before_step_hooks"): + pipeline.unregister_before_step_hook(some_hook) + + with pytest.raises(ValueError, match="not found in after_step_hooks"): + pipeline.unregister_after_step_hook(some_hook) + + +def test_multiple_hooks_and_selective_unregister(): + """Test registering multiple hooks and selectively unregistering them.""" + pipeline = RobotProcessor([MockStep("step1"), MockStep("step2")]) + + calls_1 = [] + calls_2 = [] + calls_3 = [] + + def hook1(idx: int, transition: EnvTransition): + calls_1.append(f"hook1_step{idx}") + + def hook2(idx: int, transition: EnvTransition): + calls_2.append(f"hook2_step{idx}") + + def hook3(idx: int, transition: EnvTransition): + calls_3.append(f"hook3_step{idx}") + + # Register multiple hooks + pipeline.register_before_step_hook(hook1) + pipeline.register_before_step_hook(hook2) + pipeline.register_before_step_hook(hook3) + + # Run pipeline - all hooks should be called for both steps + transition = create_transition() + pipeline(transition) + + assert calls_1 == ["hook1_step0", "hook1_step1"] + assert calls_2 == ["hook2_step0", "hook2_step1"] + assert calls_3 == ["hook3_step0", "hook3_step1"] + + # Clear calls + calls_1.clear() + calls_2.clear() + calls_3.clear() + + # Unregister middle hook + pipeline.unregister_before_step_hook(hook2) + + # Run again - only hook1 and hook3 should be called + pipeline(transition) + + assert calls_1 == ["hook1_step0", "hook1_step1"] + assert calls_2 == [] # hook2 was unregistered + assert calls_3 == ["hook3_step0", "hook3_step1"] + + +def test_hook_execution_order_documentation(): + """Test and document that hooks are executed sequentially in registration order.""" + pipeline = RobotProcessor([MockStep("step")]) + + execution_order = [] + + def hook_a(idx: int, transition: EnvTransition): + execution_order.append("A") + + def hook_b(idx: int, transition: EnvTransition): + execution_order.append("B") + + def hook_c(idx: int, transition: EnvTransition): + execution_order.append("C") + + # Register in specific order: A, B, C + pipeline.register_before_step_hook(hook_a) + pipeline.register_before_step_hook(hook_b) + pipeline.register_before_step_hook(hook_c) + + transition = create_transition() + pipeline(transition) + + # Verify execution order matches registration order + assert execution_order == ["A", "B", "C"] + + # Test that after unregistering B and re-registering it, it goes to the end + pipeline.unregister_before_step_hook(hook_b) + execution_order.clear() + + pipeline(transition) + assert execution_order == ["A", "C"] # B is gone + + # Re-register B - it should now be at the end + pipeline.register_before_step_hook(hook_b) + execution_order.clear() + + pipeline(transition) + assert execution_order == ["A", "C", "B"] # B is now last + + +def test_save_and_load_pretrained(): + """Test saving and loading pipeline. + + This test demonstrates that JSON-serializable attributes (like counter) + are saved in the config and restored when the step is recreated. + """ + step1 = MockStep("step1") + step2 = MockStep("step2") + + # Increment counters to have some state + step1.counter = 5 + step2.counter = 10 + + pipeline = RobotProcessor([step1, step2], name="TestPipeline") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save pipeline + pipeline.save_pretrained(tmp_dir) + + # Check files were created + config_path = Path(tmp_dir) / "testpipeline.json" # Based on name="TestPipeline" + assert config_path.exists() + + # Check config content + with open(config_path) as f: + config = json.load(f) + + assert config["name"] == "TestPipeline" + assert len(config["steps"]) == 2 + + # Verify counters are saved in config, not in separate state files + assert config["steps"][0]["config"]["counter"] == 5 + assert config["steps"][1]["config"]["counter"] == 10 + + # Load pipeline + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + + assert loaded_pipeline.name == "TestPipeline" + assert len(loaded_pipeline) == 2 + + # Check that counter was restored from config + assert loaded_pipeline.steps[0].counter == 5 + assert loaded_pipeline.steps[1].counter == 10 + + +def test_step_without_optional_methods(): + """Test pipeline with steps that don't implement optional methods.""" + step = MockStepWithoutOptionalMethods(multiplier=3.0) + pipeline = RobotProcessor([step]) + + transition = create_transition(reward=2.0) + result = pipeline(transition) + + assert result[TransitionKey.REWARD] == 6.0 # 2.0 * 3.0 + + # Reset should work even if step doesn't implement reset + pipeline.reset() + + # Save/load should work even without optional methods + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + assert len(loaded_pipeline) == 1 + + +def test_mixed_json_and_tensor_state(): + """Test step with both JSON attributes and tensor state.""" + step = MockStepWithTensorState(name="stats", learning_rate=0.05, window_size=5) + pipeline = RobotProcessor([step]) + + # Process some transitions with rewards + for i in range(10): + transition = create_transition(reward=float(i)) + pipeline(transition) + + # Check state + assert step.running_count.item() == 10 + assert step.learning_rate == 0.05 + + # Save and load + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Check that both config and state files were created + config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor" + state_path = Path(tmp_dir) / "robotprocessor_step_0.safetensors" + assert config_path.exists() + assert state_path.exists() + + # Load and verify + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_step = loaded_pipeline.steps[0] + + # Check JSON attributes were restored + assert loaded_step.name == "stats" + assert loaded_step.learning_rate == 0.05 + assert loaded_step.window_size == 5 + + # Check tensor state was restored + assert loaded_step.running_count.item() == 10 + assert torch.allclose(loaded_step.running_mean, step.running_mean) + + +class MockModuleStep(nn.Module): + """Mock step that inherits from nn.Module to test state_dict handling of module parameters.""" + + def __init__(self, input_dim: int = 10, hidden_dim: int = 5): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.linear = nn.Linear(input_dim, hidden_dim) + self.running_mean = nn.Parameter(torch.zeros(hidden_dim), requires_grad=False) + self.counter = 0 # Non-tensor state + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Process transition and update running mean.""" + obs = transition.get(TransitionKey.OBSERVATION) + + if obs is not None and isinstance(obs, torch.Tensor): + # Process observation through linear layer + processed = self.forward(obs[:, : self.input_dim]) + + # Update running mean in-place (don't reassign the parameter) + with torch.no_grad(): + self.running_mean.mul_(0.9).add_(processed.mean(dim=0), alpha=0.1) + + self.counter += 1 + + return transition + + def get_config(self) -> dict[str, Any]: + return { + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "counter": self.counter, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + """Override to return all module parameters and buffers.""" + # Get the module's state dict (includes all parameters and buffers) + return super().state_dict() + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Override to load all module parameters and buffers.""" + # Use the module's load_state_dict + super().load_state_dict(state) + + def reset(self) -> None: + self.running_mean.zero_() + self.counter = 0 + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +class MockNonModuleStepWithState: + """Mock step that explicitly does NOT inherit from nn.Module but has tensor state. + + This tests the state_dict/load_state_dict path for regular classes. + """ + + def __init__(self, name: str = "non_module_step", feature_dim: int = 10): + self.name = name + self.feature_dim = feature_dim + + # Initialize tensor state - these are regular tensors, not nn.Parameters + self.weights = torch.randn(feature_dim, feature_dim) + self.bias = torch.zeros(feature_dim) + self.running_stats = torch.zeros(feature_dim) + self.step_count = torch.tensor(0) + + # Non-tensor state + self.config_value = 42 + self.history = [] + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Process transition using tensor operations.""" + obs = transition.get(TransitionKey.OBSERVATION) + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + if obs is not None and isinstance(obs, torch.Tensor) and obs.numel() >= self.feature_dim: + # Perform some tensor operations + flat_obs = obs.flatten()[: self.feature_dim] + + # Simple linear transformation (ensure dimensions match for matmul) + output = torch.matmul(self.weights.T, flat_obs) + self.bias + + # Update running stats + self.running_stats = 0.9 * self.running_stats + 0.1 * output + self.step_count += 1 + + # Add to complementary data + comp_data = {} if comp_data is None else dict(comp_data) + comp_data[f"{self.name}_mean_output"] = output.mean().item() + comp_data[f"{self.name}_steps"] = self.step_count.item() + + # Return updated transition + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition + + return transition + + def get_config(self) -> dict[str, Any]: + return { + "name": self.name, + "feature_dim": self.feature_dim, + "config_value": self.config_value, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + """Return only tensor state.""" + return { + "weights": self.weights, + "bias": self.bias, + "running_stats": self.running_stats, + "step_count": self.step_count, + } + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Load tensor state.""" + self.weights = state["weights"] + self.bias = state["bias"] + self.running_stats = state["running_stats"] + self.step_count = state["step_count"] + + def reset(self) -> None: + """Reset statistics but keep learned parameters.""" + self.running_stats.zero_() + self.step_count.zero_() + self.history.clear() + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +# Tests for overrides functionality +@dataclass +class MockStepWithNonSerializableParam: + """Mock step that requires a non-serializable parameter.""" + + def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None): + self.name = name + # Add type validation for multiplier + if isinstance(multiplier, str): + raise ValueError(f"multiplier must be a number, got string '{multiplier}'") + if not isinstance(multiplier, (int, float)): + raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}") + self.multiplier = float(multiplier) + self.env = env # Non-serializable parameter (like gym.Env) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + reward = transition.get(TransitionKey.REWARD) + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + # Use the env parameter if provided + if self.env is not None: + comp_data = {} if comp_data is None else dict(comp_data) + comp_data[f"{self.name}_env_info"] = str(self.env) + + # Apply multiplier to reward + new_transition = transition.copy() + if reward is not None: + new_transition[TransitionKey.REWARD] = reward * self.multiplier + + if comp_data: + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + + return new_transition + + def get_config(self) -> dict[str, Any]: + # Note: env is intentionally NOT included here as it's not serializable + return { + "name": self.name, + "multiplier": self.multiplier, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +@ProcessorStepRegistry.register("registered_mock_step") +@dataclass +class RegisteredMockStep: + """Mock step registered in the registry.""" + + value: int = 42 + device: str = "cpu" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + comp_data = {} if comp_data is None else dict(comp_data) + comp_data["registered_step_value"] = self.value + comp_data["registered_step_device"] = self.device + + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition + + def get_config(self) -> dict[str, Any]: + return { + "value": self.value, + "device": self.device, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +class MockEnvironment: + """Mock environment for testing non-serializable parameters.""" + + def __init__(self, name: str): + self.name = name + + def __str__(self): + return f"MockEnvironment({self.name})" + + +def test_from_pretrained_with_overrides(): + """Test loading processor with parameter overrides.""" + # Create a processor with steps that need overrides + env_step = MockStepWithNonSerializableParam(name="env_step", multiplier=2.0) + registered_step = RegisteredMockStep(value=100, device="cpu") + + pipeline = RobotProcessor([env_step, registered_step], name="TestOverrides") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save the pipeline + pipeline.save_pretrained(tmp_dir) + + # Create a mock environment for override + mock_env = MockEnvironment("test_env") + + # Load with overrides + overrides = { + "MockStepWithNonSerializableParam": { + "env": mock_env, + "multiplier": 3.0, # Override the multiplier too + }, + "registered_mock_step": {"device": "cuda", "value": 200}, + } + + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + # Verify the pipeline was loaded correctly + assert len(loaded_pipeline) == 2 + assert loaded_pipeline.name == "TestOverrides" + + # Test the loaded steps + transition = create_transition(reward=1.0) + result = loaded_pipeline(transition) + + # Check that overrides were applied + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert "env_step_env_info" in comp_data + assert comp_data["env_step_env_info"] == "MockEnvironment(test_env)" + assert comp_data["registered_step_value"] == 200 + assert comp_data["registered_step_device"] == "cuda" + + # Check that multiplier override was applied + assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0 (overridden multiplier) + + +def test_from_pretrained_with_partial_overrides(): + """Test loading processor with overrides for only some steps.""" + step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0) + step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0) + + pipeline = RobotProcessor([step1, step2]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override only one step + overrides = {"MockStepWithNonSerializableParam": {"multiplier": 5.0}} + + # The current implementation applies overrides to ALL steps with the same class name + # Both steps will get the override + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + transition = create_transition(reward=1.0) + result = loaded_pipeline(transition) + + # The reward should be affected by both steps, both getting the override + # First step: 1.0 * 5.0 = 5.0 (overridden) + # Second step: 5.0 * 5.0 = 25.0 (also overridden) + assert result[TransitionKey.REWARD] == 25.0 + + +def test_from_pretrained_invalid_override_key(): + """Test that invalid override keys raise KeyError.""" + step = MockStepWithNonSerializableParam() + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Try to override a non-existent step + overrides = {"NonExistentStep": {"param": "value"}} + + with pytest.raises(KeyError, match="Override keys.*do not match any step"): + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + +def test_from_pretrained_multiple_invalid_override_keys(): + """Test that multiple invalid override keys are reported.""" + step = MockStepWithNonSerializableParam() + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Try to override multiple non-existent steps + overrides = {"NonExistentStep1": {"param": "value1"}, "NonExistentStep2": {"param": "value2"}} + + with pytest.raises(KeyError) as exc_info: + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + error_msg = str(exc_info.value) + assert "NonExistentStep1" in error_msg + assert "NonExistentStep2" in error_msg + assert "Available step keys" in error_msg + + +def test_from_pretrained_registered_step_override(): + """Test overriding registered steps using registry names.""" + registered_step = RegisteredMockStep(value=50, device="cpu") + pipeline = RobotProcessor([registered_step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override using registry name + overrides = {"registered_mock_step": {"value": 999, "device": "cuda"}} + + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + # Test that overrides were applied + transition = create_transition() + result = loaded_pipeline(transition) + + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert comp_data["registered_step_value"] == 999 + assert comp_data["registered_step_device"] == "cuda" + + +def test_from_pretrained_mixed_registered_and_unregistered(): + """Test overriding both registered and unregistered steps.""" + unregistered_step = MockStepWithNonSerializableParam(name="unregistered", multiplier=1.0) + registered_step = RegisteredMockStep(value=10, device="cpu") + + pipeline = RobotProcessor([unregistered_step, registered_step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + mock_env = MockEnvironment("mixed_test") + + overrides = { + "MockStepWithNonSerializableParam": {"env": mock_env, "multiplier": 4.0}, + "registered_mock_step": {"value": 777}, + } + + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + # Test both steps + transition = create_transition(reward=2.0) + result = loaded_pipeline(transition) + + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert comp_data["unregistered_env_info"] == "MockEnvironment(mixed_test)" + assert comp_data["registered_step_value"] == 777 + assert result[TransitionKey.REWARD] == 8.0 # 2.0 * 4.0 + + +def test_from_pretrained_no_overrides(): + """Test that from_pretrained works without overrides (backward compatibility).""" + step = MockStepWithNonSerializableParam(name="no_override", multiplier=3.0) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load without overrides + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + + assert len(loaded_pipeline) == 1 + + # Test that the step works (env will be None) + transition = create_transition(reward=1.0) + result = loaded_pipeline(transition) + + assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0 + + +def test_from_pretrained_empty_overrides(): + """Test that from_pretrained works with empty overrides dict.""" + step = MockStepWithNonSerializableParam(multiplier=2.0) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load with empty overrides + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides={}) + + assert len(loaded_pipeline) == 1 + + # Test that the step works normally + transition = create_transition(reward=1.0) + result = loaded_pipeline(transition) + + assert result[TransitionKey.REWARD] == 2.0 + + +def test_from_pretrained_override_instantiation_error(): + """Test that instantiation errors with overrides are properly reported.""" + step = MockStepWithNonSerializableParam(multiplier=1.0) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Try to override with invalid parameter type + overrides = { + "MockStepWithNonSerializableParam": { + "multiplier": "invalid_type" # Should be float, not string + } + } + + with pytest.raises(ValueError, match="Failed to instantiate processor step"): + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + +def test_from_pretrained_with_state_and_overrides(): + """Test that overrides work correctly with steps that have tensor state.""" + step = MockStepWithTensorState(name="tensor_step", learning_rate=0.01, window_size=5) + pipeline = RobotProcessor([step]) + + # Process some data to create state + for i in range(10): + transition = create_transition(reward=float(i)) + pipeline(transition) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load with overrides + overrides = { + "MockStepWithTensorState": { + "learning_rate": 0.05, # Override learning rate + "window_size": 3, # Override window size + } + } + + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + loaded_step = loaded_pipeline.steps[0] + + # Check that config overrides were applied + assert loaded_step.learning_rate == 0.05 + assert loaded_step.window_size == 3 + + # Check that tensor state was preserved + assert loaded_step.running_count.item() == 10 + + # The running_mean should still have the original window_size (5) from saved state + # but the new step will use window_size=3 for future operations + assert loaded_step.running_mean.shape[0] == 5 # From saved state + + +def test_from_pretrained_override_error_messages(): + """Test that error messages for override failures are helpful.""" + step1 = MockStepWithNonSerializableParam(name="step1") + step2 = RegisteredMockStep() + pipeline = RobotProcessor([step1, step2]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Test with invalid override key + overrides = {"WrongStepName": {"param": "value"}} + + with pytest.raises(KeyError) as exc_info: + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + error_msg = str(exc_info.value) + assert "WrongStepName" in error_msg + assert "Available step keys" in error_msg + assert "MockStepWithNonSerializableParam" in error_msg + assert "registered_mock_step" in error_msg + + +def test_repr_empty_processor(): + """Test __repr__ with empty processor.""" + pipeline = RobotProcessor() + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=0: [])" + assert repr_str == expected + + +def test_repr_single_step(): + """Test __repr__ with single step.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])" + assert repr_str == expected + + +def test_repr_multiple_steps_under_limit(): + """Test __repr__ with 2-3 steps (all shown).""" + step1 = MockStep("step1") + step2 = MockStepWithoutOptionalMethods() + pipeline = RobotProcessor([step1, step2]) + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])" + assert repr_str == expected + + # Test with 3 steps (boundary case) + step3 = MockStepWithTensorState() + pipeline = RobotProcessor([step1, step2, step3]) + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=3: [MockStep, MockStepWithoutOptionalMethods, MockStepWithTensorState])" + assert repr_str == expected + + +def test_repr_many_steps_truncated(): + """Test __repr__ with more than 3 steps (truncated with ellipsis).""" + step1 = MockStep("step1") + step2 = MockStepWithoutOptionalMethods() + step3 = MockStepWithTensorState() + step4 = MockModuleStep() + step5 = MockNonModuleStepWithState() + + pipeline = RobotProcessor([step1, step2, step3, step4, step5]) + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=5: [MockStep, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" + assert repr_str == expected + + +def test_repr_with_custom_name(): + """Test __repr__ with custom processor name.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step], name="CustomProcessor") + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='CustomProcessor', steps=1: [MockStep])" + assert repr_str == expected + + +def test_repr_with_seed(): + """Test __repr__ with seed parameter.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])" + assert repr_str == expected + + +def test_repr_with_custom_name_and_seed(): + """Test __repr__ with both custom name and seed.""" + step1 = MockStep("step1") + step2 = MockStepWithoutOptionalMethods() + pipeline = RobotProcessor([step1, step2], name="MyProcessor") + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])" + assert repr_str == expected + + +def test_repr_without_seed(): + """Test __repr__ when seed is explicitly None (should not show seed).""" + step = MockStep("test_step") + pipeline = RobotProcessor([step], name="TestProcessor") + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='TestProcessor', steps=1: [MockStep])" + assert repr_str == expected + + +def test_repr_various_step_types(): + """Test __repr__ with different types of steps to verify class name extraction.""" + step1 = MockStep() + step2 = MockStepWithTensorState() + step3 = MockModuleStep() + step4 = MockNonModuleStepWithState() + + pipeline = RobotProcessor([step1, step2, step3, step4], name="MixedSteps") + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='MixedSteps', steps=4: [MockStep, MockStepWithTensorState, ..., MockNonModuleStepWithState])" + assert repr_str == expected + + +def test_repr_edge_case_long_names(): + """Test __repr__ handles steps with long class names properly.""" + step1 = MockStepWithNonSerializableParam() + step2 = MockStepWithoutOptionalMethods() + step3 = MockStepWithTensorState() + step4 = MockNonModuleStepWithState() + + pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames") + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" + assert repr_str == expected + + +# Tests for config filename features and multiple processors +def test_save_with_custom_config_filename(): + """Test saving processor with custom config filename.""" + step = MockStep("test") + pipeline = RobotProcessor([step], name="TestProcessor") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save with custom filename + pipeline.save_pretrained(tmp_dir, config_filename="my_custom_config.json") + + # Check file exists + config_path = Path(tmp_dir) / "my_custom_config.json" + assert config_path.exists() + + # Check content + with open(config_path) as f: + config = json.load(f) + assert config["name"] == "TestProcessor" + + # Load with specific filename + loaded = RobotProcessor.from_pretrained(tmp_dir, config_filename="my_custom_config.json") + assert loaded.name == "TestProcessor" + + +def test_multiple_processors_same_directory(): + """Test saving multiple processors to the same directory with different config files.""" + # Create different processors + preprocessor = RobotProcessor([MockStep("pre1"), MockStep("pre2")], name="preprocessor") + + postprocessor = RobotProcessor([MockStepWithoutOptionalMethods(multiplier=0.5)], name="postprocessor") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save both to same directory + preprocessor.save_pretrained(tmp_dir) + postprocessor.save_pretrained(tmp_dir) + + # Check both config files exist + assert (Path(tmp_dir) / "preprocessor.json").exists() + assert (Path(tmp_dir) / "postprocessor.json").exists() + + # Load them back + loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json") + loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json") + + assert loaded_pre.name == "preprocessor" + assert loaded_post.name == "postprocessor" + assert len(loaded_pre) == 2 + assert len(loaded_post) == 1 + + +def test_auto_detect_single_config(): + """Test automatic config detection when there's only one JSON file.""" + step = MockStepWithTensorState() + pipeline = RobotProcessor([step], name="SingleConfig") + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load without specifying config_filename + loaded = RobotProcessor.from_pretrained(tmp_dir) + assert loaded.name == "SingleConfig" + + +def test_error_multiple_configs_no_filename(): + """Test error when multiple configs exist and no filename specified.""" + proc1 = RobotProcessor([MockStep()], name="processor1") + proc2 = RobotProcessor([MockStep()], name="processor2") + + with tempfile.TemporaryDirectory() as tmp_dir: + proc1.save_pretrained(tmp_dir) + proc2.save_pretrained(tmp_dir) + + # Should raise error + with pytest.raises(ValueError, match="Multiple .json files found"): + RobotProcessor.from_pretrained(tmp_dir) + + +def test_state_file_naming_with_indices(): + """Test that state files include pipeline name and step indices to avoid conflicts.""" + # Create multiple steps of same type with state + step1 = MockStepWithTensorState(name="norm1", window_size=5) + step2 = MockStepWithTensorState(name="norm2", window_size=10) + step3 = MockModuleStep(input_dim=5) + + pipeline = RobotProcessor([step1, step2, step3]) + + # Process some data to create state + for i in range(5): + transition = create_transition(observation=torch.randn(2, 5), reward=float(i)) + pipeline(transition) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Check state files have indices + state_files = sorted(Path(tmp_dir).glob("*.safetensors")) + assert len(state_files) == 3 + + # Files should be named with pipeline name prefix and indices + expected_names = [ + "robotprocessor_step_0.safetensors", + "robotprocessor_step_1.safetensors", + "robotprocessor_step_2.safetensors", + ] + actual_names = [f.name for f in state_files] + assert actual_names == expected_names + + +def test_state_file_naming_with_registry(): + """Test state file naming for registered steps includes pipeline name, index and registry name.""" + + # Register a test step + @ProcessorStepRegistry.register("test_stateful_step") + @dataclass + class TestStatefulStep: + value: int = 0 + + def __init__(self, value: int = 0): + self.value = value + self.state_tensor = torch.randn(3, 3) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def get_config(self): + return {"value": self.value} + + def state_dict(self): + return {"state_tensor": self.state_tensor} + + def load_state_dict(self, state): + self.state_tensor = state["state_tensor"] + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + try: + # Create pipeline with registered steps + step1 = TestStatefulStep(1) + step2 = TestStatefulStep(2) + pipeline = RobotProcessor([step1, step2]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Check state files + state_files = sorted(Path(tmp_dir).glob("*.safetensors")) + assert len(state_files) == 2 + + # Should include pipeline name, index and registry name + expected_names = [ + "robotprocessor_step_0_test_stateful_step.safetensors", + "robotprocessor_step_1_test_stateful_step.safetensors", + ] + actual_names = [f.name for f in state_files] + assert actual_names == expected_names + + finally: + # Cleanup registry + ProcessorStepRegistry.unregister("test_stateful_step") + + +# More comprehensive override tests +def test_override_with_nested_config(): + """Test overrides with nested configuration dictionaries.""" + + @ProcessorStepRegistry.register("complex_config_step") + @dataclass + class ComplexConfigStep: + name: str = "complex" + simple_param: int = 42 + nested_config: dict = None + + def __post_init__(self): + if self.nested_config is None: + self.nested_config = {"level1": {"level2": "default"}} + + def __call__(self, transition: EnvTransition) -> EnvTransition: + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + comp_data = dict(comp_data) + comp_data["config_value"] = self.nested_config.get("level1", {}).get("level2", "missing") + + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition + + def get_config(self): + return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config} + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + try: + step = ComplexConfigStep() + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load with nested override + loaded = RobotProcessor.from_pretrained( + tmp_dir, + overrides={"complex_config_step": {"nested_config": {"level1": {"level2": "overridden"}}}}, + ) + + # Test that override worked + transition = create_transition() + result = loaded(transition) + assert result[TransitionKey.COMPLEMENTARY_DATA]["config_value"] == "overridden" + finally: + ProcessorStepRegistry.unregister("complex_config_step") + + +def test_override_preserves_defaults(): + """Test that overrides only affect specified parameters.""" + step = MockStepWithNonSerializableParam(name="test", multiplier=2.0) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override only one parameter + loaded = RobotProcessor.from_pretrained( + tmp_dir, + overrides={ + "MockStepWithNonSerializableParam": { + "multiplier": 5.0 # Only override multiplier + } + }, + ) + + # Check that name was preserved from saved config + loaded_step = loaded.steps[0] + assert loaded_step.name == "test" # Original value + assert loaded_step.multiplier == 5.0 # Overridden value + + +def test_override_type_validation(): + """Test that type errors in overrides are caught properly.""" + step = MockStepWithTensorState(learning_rate=0.01) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Try to override with wrong type + overrides = { + "MockStepWithTensorState": { + "window_size": "not_an_int" # Should be int + } + } + + with pytest.raises(ValueError, match="Failed to instantiate"): + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + +def test_override_with_callables(): + """Test overriding with callable objects.""" + + @ProcessorStepRegistry.register("callable_step") + @dataclass + class CallableStep: + name: str = "callable_step" + transform_fn: Any = None + + def __call__(self, transition: EnvTransition) -> EnvTransition: + obs = transition.get(TransitionKey.OBSERVATION) + if obs is not None and self.transform_fn is not None: + processed_obs = {} + for k, v in obs.items(): + processed_obs[k] = self.transform_fn(v) + + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = processed_obs + return new_transition + return transition + + def get_config(self): + return {"name": self.name} + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + try: + step = CallableStep() + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Define a transform function + def double_values(x): + if isinstance(x, (int, float)): + return x * 2 + elif isinstance(x, torch.Tensor): + return x * 2 + return x + + # Load with callable override + loaded = RobotProcessor.from_pretrained( + tmp_dir, overrides={"callable_step": {"transform_fn": double_values}} + ) + + # Test it works + transition = create_transition(observation={"value": torch.tensor(5.0)}) + result = loaded(transition) + assert result[TransitionKey.OBSERVATION]["value"].item() == 10.0 + finally: + ProcessorStepRegistry.unregister("callable_step") + + +def test_override_multiple_same_class_warning(): + """Test behavior when multiple steps of same class exist.""" + step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0) + step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0) + pipeline = RobotProcessor([step1, step2]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override affects all instances of the class + loaded = RobotProcessor.from_pretrained( + tmp_dir, overrides={"MockStepWithNonSerializableParam": {"multiplier": 10.0}} + ) + + # Both steps get the same override + assert loaded.steps[0].multiplier == 10.0 + assert loaded.steps[1].multiplier == 10.0 + + # But original names are preserved + assert loaded.steps[0].name == "step1" + assert loaded.steps[1].name == "step2" + + +def test_config_filename_special_characters(): + """Test config filenames with special characters are sanitized.""" + # Processor name with special characters + pipeline = RobotProcessor([MockStep()], name="My/Processor\\With:Special*Chars") + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Check that filename was sanitized + json_files = list(Path(tmp_dir).glob("*.json")) + assert len(json_files) == 1 + + # Should have replaced special chars with underscores + expected_name = "my_processor_with_special_chars.json" + assert json_files[0].name == expected_name + + +def test_state_file_naming_with_multiple_processors(): + """Test that state files are properly prefixed with pipeline names to avoid conflicts.""" + # Create two processors with state + step1 = MockStepWithTensorState(name="norm", window_size=5) + preprocessor = RobotProcessor([step1], name="PreProcessor") + + step2 = MockStepWithTensorState(name="norm", window_size=10) + postprocessor = RobotProcessor([step2], name="PostProcessor") + + # Process some data to create state + for i in range(3): + transition = create_transition(reward=float(i)) + preprocessor(transition) + postprocessor(transition) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save both processors to the same directory + preprocessor.save_pretrained(tmp_dir) + postprocessor.save_pretrained(tmp_dir) + + # Check that all files exist and are distinct + assert (Path(tmp_dir) / "preprocessor.json").exists() + assert (Path(tmp_dir) / "postprocessor.json").exists() + assert (Path(tmp_dir) / "preprocessor_step_0.safetensors").exists() + assert (Path(tmp_dir) / "postprocessor_step_0.safetensors").exists() + + # Load both back and verify they work correctly + loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json") + loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json") + + assert loaded_pre.name == "PreProcessor" + assert loaded_post.name == "PostProcessor" + assert loaded_pre.steps[0].window_size == 5 + assert loaded_post.steps[0].window_size == 10 + + +def test_override_with_device_strings(): + """Test overriding device parameters with string values.""" + + @ProcessorStepRegistry.register("device_aware_step") + @dataclass + class DeviceAwareStep: + device: str = "cpu" + + def __init__(self, device: str = "cpu"): + self.device = device + self.buffer = torch.zeros(10, device=device) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def get_config(self): + return {"device": str(self.device)} + + def state_dict(self): + return {"buffer": self.buffer} + + def load_state_dict(self, state): + self.buffer = state["buffer"] + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + try: + step = DeviceAwareStep(device="cpu") + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override device + if torch.cuda.is_available(): + loaded = RobotProcessor.from_pretrained( + tmp_dir, overrides={"device_aware_step": {"device": "cuda:0"}} + ) + + loaded_step = loaded.steps[0] + assert loaded_step.device == "cuda:0" + # Note: buffer will still be on CPU from saved state + # until .to() is called on the processor + + finally: + ProcessorStepRegistry.unregister("device_aware_step") + + +def test_from_pretrained_nonexistent_path(): + """Test error handling when loading from non-existent sources.""" + from huggingface_hub.errors import HfHubHTTPError, HFValidationError + + # Test with an invalid repo ID (too many slashes) - caught by HF validation + with pytest.raises(HFValidationError): + RobotProcessor.from_pretrained("/path/that/does/not/exist") + + # Test with a non-existent but valid Hub repo format + with pytest.raises((FileNotFoundError, HfHubHTTPError)): + RobotProcessor.from_pretrained("nonexistent-user/nonexistent-repo") + + # Test with a local directory that exists but has no config files + with tempfile.TemporaryDirectory() as tmp_dir: + with pytest.raises(FileNotFoundError, match="No .json configuration files found"): + RobotProcessor.from_pretrained(tmp_dir) + + +def test_save_load_with_custom_converter_functions(): + """Test that custom to_transition and to_output functions are NOT saved.""" + + def custom_to_transition(batch): + # Custom conversion logic + return { + TransitionKey.OBSERVATION: batch.get("obs"), + TransitionKey.ACTION: batch.get("act"), + TransitionKey.REWARD: batch.get("rew", 0.0), + TransitionKey.DONE: batch.get("done", False), + TransitionKey.TRUNCATED: batch.get("truncated", False), + TransitionKey.INFO: {}, + TransitionKey.COMPLEMENTARY_DATA: {}, + } + + def custom_to_output(transition): + # Custom output format + return { + "obs": transition.get(TransitionKey.OBSERVATION), + "act": transition.get(TransitionKey.ACTION), + "rew": transition.get(TransitionKey.REWARD), + "done": transition.get(TransitionKey.DONE), + "truncated": transition.get(TransitionKey.TRUNCATED), + } + + # Create processor with custom converters + pipeline = RobotProcessor([MockStep()], to_transition=custom_to_transition, to_output=custom_to_output) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load - should use default converters + loaded = RobotProcessor.from_pretrained(tmp_dir) + + # Verify it uses default converters by checking with standard batch format + batch = { + "observation.image": torch.randn(1, 3, 32, 32), + "action": torch.randn(1, 7), + "next.reward": torch.tensor([1.0]), + "next.done": torch.tensor([False]), + "next.truncated": torch.tensor([False]), + "info": {}, + } + + # Should work with standard format (wouldn't work with custom converter) + result = loaded(batch) + assert "observation.image" in result # Standard format preserved + + +class NonCompliantStep: + """Intentionally non-compliant: missing feature_contract.""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + +def test_construction_rejects_step_without_feature_contract(): + with pytest.raises(TypeError, match=r"must define feature_contract\(features\) -> dict\[str, Any\]"): + RobotProcessor([NonCompliantStep()]) + + +class NonCallableStep: + """Intentionally non-compliant: missing __call__.""" + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +def test_construction_rejects_step_without_call(): + with pytest.raises(TypeError, match=r"must define __call__"): + RobotProcessor([NonCallableStep()]) + + +@dataclass +class FeatureContractAddStep: + """Adds a PolicyFeature""" + + key: str = "a" + value: PolicyFeature = PolicyFeature(type=FeatureType.STATE, shape=(1,)) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + features[self.key] = self.value + return features + + +@dataclass +class FeatureContractMutateStep: + """Mutates a PolicyFeature""" + + key: str = "a" + fn: Callable[[PolicyFeature | None], PolicyFeature] = lambda x: x # noqa: E731 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + features[self.key] = self.fn(features.get(self.key)) + return features + + +@dataclass +class FeatureContractBadReturnStep: + """Returns a non-dict""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return ["not-a-dict"] + + +@dataclass +class FeatureContractRemoveStep: + """Removes a PolicyFeature""" + + key: str + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + features.pop(self.key, None) + return features + + +def test_feature_contract_orders_and_merges(policy_feature_factory): + p = RobotProcessor( + [ + FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), + FeatureContractMutateStep("a", lambda v: PolicyFeature(type=v.type, shape=(3,))), + FeatureContractAddStep("b", policy_feature_factory(FeatureType.ENV, (2,))), + ] + ) + out = p.feature_contract({}) + + assert out["a"].type == FeatureType.STATE and out["a"].shape == (3,) + assert out["b"].type == FeatureType.ENV and out["b"].shape == (2,) + assert_contract_is_typed(out) + + +def test_feature_contract_respects_initial_without_mutation(policy_feature_factory): + initial = { + "seed": policy_feature_factory(FeatureType.STATE, (7,)), + "nested": policy_feature_factory(FeatureType.ENV, (0,)), + } + p = RobotProcessor( + [ + FeatureContractMutateStep("seed", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 1,))), + FeatureContractMutateStep( + "nested", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 5,)) + ), + ] + ) + out = p.feature_contract(initial_features=initial) + + assert out["seed"].shape == (8,) + assert out["nested"].shape == (5,) + # Initial dict must be preserved + assert initial["seed"].shape == (7,) + assert initial["nested"].shape == (0,) + + assert_contract_is_typed(out) + + +def test_feature_contract_type_error_on_bad_step(): + p = RobotProcessor([FeatureContractAddStep(), FeatureContractBadReturnStep()]) + with pytest.raises(TypeError, match=r"\w+\.feature_contract must return dict\[str, Any\]"): + _ = p.feature_contract({}) + + +def test_feature_contract_execution_order_tracking(): + class Track: + def __init__(self, label): + self.label = label + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + code = {"A": 1, "B": 2, "C": 3}[self.label] + pf = features.get("order", PolicyFeature(type=FeatureType.ENV, shape=())) + features["order"] = PolicyFeature(type=pf.type, shape=pf.shape + (code,)) + return features + + out = RobotProcessor([Track("A"), Track("B"), Track("C")]).feature_contract({}) + assert out["order"].shape == (1, 2, 3) + + +def test_feature_contract_remove_key(policy_feature_factory): + p = RobotProcessor( + [ + FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), + FeatureContractRemoveStep("a"), + ] + ) + out = p.feature_contract({}) + assert "a" not in out + + +def test_feature_contract_remove_from_initial(policy_feature_factory): + initial = { + "keep": policy_feature_factory(FeatureType.STATE, (1,)), + "drop": policy_feature_factory(FeatureType.STATE, (1,)), + } + p = RobotProcessor([FeatureContractRemoveStep("drop")]) + out = p.feature_contract(initial_features=initial) + assert "drop" not in out and out["keep"] == initial["keep"] diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py new file mode 100644 index 000000000..229d57f9f --- /dev/null +++ b/tests/processor/test_rename_processor.py @@ -0,0 +1,467 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +from pathlib import Path + +import numpy as np +import torch + +from lerobot.configs.types import FeatureType +from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey +from tests.conftest import assert_contract_is_typed + + +def create_transition( + observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } + + +def test_basic_renaming(): + """Test basic key renaming functionality.""" + rename_map = { + "old_key1": "new_key1", + "old_key2": "new_key2", + } + processor = RenameProcessor(rename_map=rename_map) + + observation = { + "old_key1": torch.tensor([1.0, 2.0]), + "old_key2": np.array([3.0, 4.0]), + "unchanged_key": "keep_me", + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check renamed keys + assert "new_key1" in processed_obs + assert "new_key2" in processed_obs + assert "old_key1" not in processed_obs + assert "old_key2" not in processed_obs + + # Check values are preserved + torch.testing.assert_close(processed_obs["new_key1"], torch.tensor([1.0, 2.0])) + np.testing.assert_array_equal(processed_obs["new_key2"], np.array([3.0, 4.0])) + + # Check unchanged key is preserved + assert processed_obs["unchanged_key"] == "keep_me" + + +def test_empty_rename_map(): + """Test processor with empty rename map (should pass through unchanged).""" + processor = RenameProcessor(rename_map={}) + + observation = { + "key1": torch.tensor([1.0]), + "key2": "value2", + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # All keys should be unchanged + assert processed_obs.keys() == observation.keys() + torch.testing.assert_close(processed_obs["key1"], observation["key1"]) + assert processed_obs["key2"] == observation["key2"] + + +def test_none_observation(): + """Test processor with None observation.""" + processor = RenameProcessor(rename_map={"old": "new"}) + + transition = create_transition() + result = processor(transition) + + # Should return transition unchanged + assert result == transition + + +def test_overlapping_rename(): + """Test renaming when new names might conflict.""" + rename_map = { + "a": "b", + "b": "c", # This creates a potential conflict + } + processor = RenameProcessor(rename_map=rename_map) + + observation = { + "a": 1, + "b": 2, + "x": 3, + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that renaming happens correctly + assert "a" not in processed_obs + assert processed_obs["b"] == 1 # 'a' renamed to 'b' + assert processed_obs["c"] == 2 # original 'b' renamed to 'c' + assert processed_obs["x"] == 3 + + +def test_partial_rename(): + """Test renaming only some keys.""" + rename_map = { + "observation.state": "observation.proprio_state", + "pixels": "observation.image", + } + processor = RenameProcessor(rename_map=rename_map) + + observation = { + "observation.state": torch.randn(10), + "pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8), + "reward": 1.0, + "info": {"episode": 1}, + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check renamed keys + assert "observation.proprio_state" in processed_obs + assert "observation.image" in processed_obs + assert "observation.state" not in processed_obs + assert "pixels" not in processed_obs + + # Check unchanged keys + assert processed_obs["reward"] == 1.0 + assert processed_obs["info"] == {"episode": 1} + + +def test_get_config(): + """Test configuration serialization.""" + rename_map = { + "old1": "new1", + "old2": "new2", + } + processor = RenameProcessor(rename_map=rename_map) + + config = processor.get_config() + assert config == {"rename_map": rename_map} + + +def test_state_dict(): + """Test state dict (should be empty for RenameProcessor).""" + processor = RenameProcessor(rename_map={"old": "new"}) + + state = processor.state_dict() + assert state == {} + + # Load state dict should work even with empty dict + processor.load_state_dict({}) + + +def test_integration_with_robot_processor(): + """Test integration with RobotProcessor pipeline.""" + rename_map = { + "agent_pos": "observation.state", + "pixels": "observation.image", + } + rename_processor = RenameProcessor(rename_map=rename_map) + + pipeline = RobotProcessor([rename_processor]) + + observation = { + "agent_pos": np.array([1.0, 2.0, 3.0]), + "pixels": np.zeros((32, 32, 3), dtype=np.uint8), + "other_data": "preserve_me", + } + transition = create_transition( + observation=observation, reward=0.5, done=False, truncated=False, info={}, complementary_data={} + ) + + result = pipeline(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check renaming worked through pipeline + assert "observation.state" in processed_obs + assert "observation.image" in processed_obs + assert "agent_pos" not in processed_obs + assert "pixels" not in processed_obs + assert processed_obs["other_data"] == "preserve_me" + + # Check other transition elements unchanged + assert result[TransitionKey.REWARD] == 0.5 + assert result[TransitionKey.DONE] is False + + +def test_save_and_load_pretrained(): + """Test saving and loading processor with RobotProcessor.""" + rename_map = { + "old_state": "observation.state", + "old_image": "observation.image", + } + processor = RenameProcessor(rename_map=rename_map) + pipeline = RobotProcessor([processor], name="TestRenameProcessor") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save pipeline + pipeline.save_pretrained(tmp_dir) + + # Check files were created + config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor" + assert config_path.exists() + + # No state files should be created for RenameProcessor + state_files = list(Path(tmp_dir).glob("*.safetensors")) + assert len(state_files) == 0 + + # Load pipeline + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + + assert loaded_pipeline.name == "TestRenameProcessor" + assert len(loaded_pipeline) == 1 + + # Check that loaded processor works correctly + loaded_processor = loaded_pipeline.steps[0] + assert isinstance(loaded_processor, RenameProcessor) + assert loaded_processor.rename_map == rename_map + + # Test functionality after loading + observation = {"old_state": [1, 2, 3], "old_image": "image_data"} + transition = create_transition(observation=observation) + + result = loaded_pipeline(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert "observation.state" in processed_obs + assert "observation.image" in processed_obs + assert processed_obs["observation.state"] == [1, 2, 3] + assert processed_obs["observation.image"] == "image_data" + + +def test_registry_functionality(): + """Test that RenameProcessor is properly registered.""" + # Check that it's registered + assert "rename_processor" in ProcessorStepRegistry.list() + + # Get from registry + retrieved_class = ProcessorStepRegistry.get("rename_processor") + assert retrieved_class is RenameProcessor + + # Create instance from registry + instance = retrieved_class(rename_map={"old": "new"}) + assert isinstance(instance, RenameProcessor) + assert instance.rename_map == {"old": "new"} + + +def test_registry_based_save_load(): + """Test save/load using registry name instead of module path.""" + processor = RenameProcessor(rename_map={"key1": "renamed_key1"}) + pipeline = RobotProcessor([processor]) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save and load + pipeline.save_pretrained(tmp_dir) + + # Verify config uses registry name + import json + + with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor" + config = json.load(f) + + assert "registry_name" in config["steps"][0] + assert config["steps"][0]["registry_name"] == "rename_processor" + assert "class" not in config["steps"][0] # Should use registry, not module path + + # Load should work + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_processor = loaded_pipeline.steps[0] + assert isinstance(loaded_processor, RenameProcessor) + assert loaded_processor.rename_map == {"key1": "renamed_key1"} + + +def test_chained_rename_processors(): + """Test multiple RenameProcessors in a pipeline.""" + # First processor: rename raw keys to intermediate format + processor1 = RenameProcessor( + rename_map={ + "pos": "agent_position", + "img": "camera_image", + } + ) + + # Second processor: rename to final format + processor2 = RenameProcessor( + rename_map={ + "agent_position": "observation.state", + "camera_image": "observation.image", + } + ) + + pipeline = RobotProcessor([processor1, processor2]) + + observation = { + "pos": np.array([1.0, 2.0]), + "img": "image_data", + "extra": "keep_me", + } + transition = create_transition(observation=observation) + + # Step through to see intermediate results + results = list(pipeline.step_through(transition)) + + # After first processor + assert "agent_position" in results[1][TransitionKey.OBSERVATION] + assert "camera_image" in results[1][TransitionKey.OBSERVATION] + + # After second processor + final_obs = results[2][TransitionKey.OBSERVATION] + assert "observation.state" in final_obs + assert "observation.image" in final_obs + assert final_obs["extra"] == "keep_me" + + # Original keys should be gone + assert "pos" not in final_obs + assert "img" not in final_obs + assert "agent_position" not in final_obs + assert "camera_image" not in final_obs + + +def test_nested_observation_rename(): + """Test renaming with nested observation structures.""" + rename_map = { + "observation.images.left": "observation.camera.left_view", + "observation.images.right": "observation.camera.right_view", + "observation.proprio": "observation.proprioception", + } + processor = RenameProcessor(rename_map=rename_map) + + observation = { + "observation.images.left": torch.randn(3, 64, 64), + "observation.images.right": torch.randn(3, 64, 64), + "observation.proprio": torch.randn(7), + "observation.gripper": torch.tensor([0.0]), # Not renamed + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check renames + assert "observation.camera.left_view" in processed_obs + assert "observation.camera.right_view" in processed_obs + assert "observation.proprioception" in processed_obs + + # Check unchanged key + assert "observation.gripper" in processed_obs + + # Check old keys removed + assert "observation.images.left" not in processed_obs + assert "observation.images.right" not in processed_obs + assert "observation.proprio" not in processed_obs + + +def test_value_types_preserved(): + """Test that various value types are preserved during renaming.""" + rename_map = {"old_tensor": "new_tensor", "old_array": "new_array", "old_scalar": "new_scalar"} + processor = RenameProcessor(rename_map=rename_map) + + tensor_value = torch.randn(3, 3) + array_value = np.random.rand(2, 2) + + observation = { + "old_tensor": tensor_value, + "old_array": array_value, + "old_scalar": 42, + "old_string": "hello", + "old_dict": {"nested": "value"}, + "old_list": [1, 2, 3], + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that values and types are preserved + assert torch.equal(processed_obs["new_tensor"], tensor_value) + assert np.array_equal(processed_obs["new_array"], array_value) + assert processed_obs["new_scalar"] == 42 + assert processed_obs["old_string"] == "hello" + assert processed_obs["old_dict"] == {"nested": "value"} + assert processed_obs["old_list"] == [1, 2, 3] + + +def test_feature_contract_basic_renaming(policy_feature_factory): + processor = RenameProcessor(rename_map={"a": "x", "b": "y"}) + features = { + "a": policy_feature_factory(FeatureType.STATE, (2,)), + "b": policy_feature_factory(FeatureType.ACTION, (3,)), + "c": policy_feature_factory(FeatureType.ENV, (1,)), + } + + out = processor.feature_contract(features.copy()) + + # Values preserved and typed + assert out["x"] == features["a"] + assert out["y"] == features["b"] + assert out["c"] == features["c"] + + assert_contract_is_typed(out) + # Input not mutated + assert set(features) == {"a", "b", "c"} + + +def test_feature_contract_overlapping_keys(policy_feature_factory): + # Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c' + processor = RenameProcessor(rename_map={"a": "b", "b": "c"}) + features = { + "a": policy_feature_factory(FeatureType.STATE, (1,)), + "b": policy_feature_factory(FeatureType.STATE, (2,)), + } + out = processor.feature_contract(features) + + assert set(out) == {"b", "c"} + assert out["b"] == features["a"] # 'a' renamed to'b' + assert out["c"] == features["b"] # 'b' renamed to 'c' + assert_contract_is_typed(out) + + +def test_feature_contract_chained_processors(policy_feature_factory): + # Chain two rename processors at the contract level + processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"}) + processor2 = RenameProcessor( + rename_map={"agent_position": "observation.state", "camera_image": "observation.image"} + ) + pipeline = RobotProcessor([processor1, processor2]) + + spec = { + "pos": policy_feature_factory(FeatureType.STATE, (7,)), + "img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "extra": policy_feature_factory(FeatureType.ENV, (1,)), + } + out = pipeline.feature_contract(initial_features=spec) + + assert set(out) == {"observation.state", "observation.image", "extra"} + assert out["observation.state"] == spec["pos"] + assert out["observation.image"] == spec["img"] + assert out["extra"] == spec["extra"] + assert_contract_is_typed(out) From 49ecbeb33f4148c5731eba41e8d827dab41b9dd2 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 6 Aug 2025 20:10:47 +0200 Subject: [PATCH 069/158] fix(deps): ceil torch pkg versions (#1689) * fix(deps): ceil torch pkg versions * chore(Docs): add todo comment --- pyproject.toml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a1db99c24..d26513404 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,15 +68,16 @@ dependencies = [ "einops>=0.8.0", "opencv-python-headless>=4.9.0", "av>=14.2.0", - "torch>=2.2.1", - "torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", - "torchvision>=0.21.0", "jsonlines>=4.0.0", "packaging>=24.2", "pynput>=1.7.7", "pyserial>=3.5", "wandb>=0.20.0", + "torch>=2.2.1,<2.8.0", # TODO: Bumb dependency + "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency + "torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency + "draccus==0.10.0", # TODO: Remove == "gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency "rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency From b883328e6c95681ca90a18b102e4ae5e1f91e2bf Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 6 Aug 2025 20:29:48 +0200 Subject: [PATCH 070/158] chore: Bump to 0.3.3 (#1690) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d26513404..4696a2ae6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb" [project] name = "lerobot" -version = "0.4.0" +version = "0.3.3" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" readme = "README.md" license = { text = "Apache-2.0" } From c66cd401767e60baece16e1cf68da2824227e076 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 6 Aug 2025 21:07:54 +0200 Subject: [PATCH 071/158] chore: Bump to 0.3.4 (#1691) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4696a2ae6..2bc57c076 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb" [project] name = "lerobot" -version = "0.3.3" +version = "0.3.4" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" readme = "README.md" license = { text = "Apache-2.0" } From ce3b9f627e55223d6d1c449d348c6b351b35d082 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 7 Aug 2025 14:25:44 +0200 Subject: [PATCH 072/158] chore(docs): prioritize use of entry points in docs + fix nightly badge (#1692) * chore(docs): fix typo in nightly badge * chore(docs): prioritize the use of entrypoints for consistency --- .github/PULL_REQUEST_TEMPLATE.md | 2 +- Makefile | 18 ++++++------ README.md | 12 ++++---- docs/source/cameras.mdx | 2 +- docs/source/hilserl.mdx | 4 +-- docs/source/hope_jr.mdx | 22 +++++++-------- docs/source/il_robots.mdx | 16 +++++------ docs/source/il_sim.mdx | 4 +-- docs/source/koch.mdx | 10 +++---- docs/source/lekiwi.mdx | 8 +++--- docs/source/smolvla.mdx | 6 ++-- docs/source/so100.mdx | 10 +++---- docs/source/so101.mdx | 10 +++---- examples/4_train_policy_with_script.md | 28 +++++++++---------- examples/backward_compatibility/replay.py | 2 +- src/lerobot/calibrate.py | 2 +- src/lerobot/cameras/opencv/camera_opencv.py | 5 ++-- .../cameras/realsense/camera_realsense.py | 5 ++-- src/lerobot/find_cameras.py | 2 +- src/lerobot/find_port.py | 2 +- src/lerobot/motors/motors_bus.py | 4 +-- src/lerobot/policies/pi0/modeling_pi0.py | 4 +-- .../policies/pi0fast/modeling_pi0fast.py | 4 +-- .../policies/smolvla/modeling_smolvla.py | 4 +-- src/lerobot/record.py | 4 +-- src/lerobot/replay.py | 4 +-- src/lerobot/robots/viperx/README.md | 4 +-- src/lerobot/scripts/eval.py | 4 +-- src/lerobot/setup_motors.py | 2 +- src/lerobot/teleoperate.py | 4 +-- .../templates/lerobot_modelcard_template.md | 4 +-- 31 files changed, 105 insertions(+), 107 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 22f1ee3d7..d37b1a92f 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -30,7 +30,7 @@ pytest -sx tests/test_stuff.py::test_something ``` ```bash -python -m lerobot.scripts.train --some.option=true +lerobot-train --some.option=true ``` ## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR diff --git a/Makefile b/Makefile index 5bfbe76a2..fbe8a5bae 100644 --- a/Makefile +++ b/Makefile @@ -44,7 +44,7 @@ test-end-to-end: ${MAKE} DEVICE=$(DEVICE) test-smolvla-ete-eval test-act-ete-train: - python -m lerobot.scripts.train \ + lerobot-train \ --policy.type=act \ --policy.dim_model=64 \ --policy.n_action_steps=20 \ @@ -68,12 +68,12 @@ test-act-ete-train: --output_dir=tests/outputs/act/ test-act-ete-train-resume: - python -m lerobot.scripts.train \ + lerobot-train \ --config_path=tests/outputs/act/checkpoints/000002/pretrained_model/train_config.json \ --resume=true test-act-ete-eval: - python -m lerobot.scripts.eval \ + lerobot-eval \ --policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \ --policy.device=$(DEVICE) \ --env.type=aloha \ @@ -82,7 +82,7 @@ test-act-ete-eval: --eval.batch_size=1 test-diffusion-ete-train: - python -m lerobot.scripts.train \ + lerobot-train \ --policy.type=diffusion \ --policy.down_dims='[64,128,256]' \ --policy.diffusion_step_embed_dim=32 \ @@ -106,7 +106,7 @@ test-diffusion-ete-train: --output_dir=tests/outputs/diffusion/ test-diffusion-ete-eval: - python -m lerobot.scripts.eval \ + lerobot-eval \ --policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \ --policy.device=$(DEVICE) \ --env.type=pusht \ @@ -115,7 +115,7 @@ test-diffusion-ete-eval: --eval.batch_size=1 test-tdmpc-ete-train: - python -m lerobot.scripts.train \ + lerobot-train \ --policy.type=tdmpc \ --policy.device=$(DEVICE) \ --policy.push_to_hub=false \ @@ -137,7 +137,7 @@ test-tdmpc-ete-train: --output_dir=tests/outputs/tdmpc/ test-tdmpc-ete-eval: - python -m lerobot.scripts.eval \ + lerobot-eval \ --policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \ --policy.device=$(DEVICE) \ --env.type=xarm \ @@ -148,7 +148,7 @@ test-tdmpc-ete-eval: test-smolvla-ete-train: - python -m lerobot.scripts.train \ + lerobot-train \ --policy.type=smolvla \ --policy.n_action_steps=20 \ --policy.chunk_size=20 \ @@ -171,7 +171,7 @@ test-smolvla-ete-train: --output_dir=tests/outputs/smolvla/ test-smolvla-ete-eval: - python -m lerobot.scripts.eval \ + lerobot-eval \ --policy.path=tests/outputs/smolvla/checkpoints/000004/pretrained_model \ --policy.device=$(DEVICE) \ --env.type=aloha \ diff --git a/README.md b/README.md index 7255ed3ef..b5e666aa8 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@
-[![Tests](https://github.com/huggingface/lerobot/actions/workflows/nightly.yml/badge.svg?branch=main)](https://github.com/huggingface/lerobot/actions/workflows/nighty.yml?query=branch%3Amain) +[![Tests](https://github.com/huggingface/lerobot/actions/workflows/nightly.yml/badge.svg?branch=main)](https://github.com/huggingface/lerobot/actions/workflows/nightly.yml?query=branch%3Amain) [![Python versions](https://img.shields.io/pypi/pyversions/lerobot)](https://www.python.org/downloads/) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/huggingface/lerobot/blob/main/LICENSE) [![Status](https://img.shields.io/pypi/status/lerobot)](https://pypi.org/project/lerobot/) @@ -276,7 +276,7 @@ Check out [example 2](https://github.com/huggingface/lerobot/blob/main/examples/ We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht): ```bash -python -m lerobot.scripts.eval \ +lerobot-eval \ --policy.path=lerobot/diffusion_pusht \ --env.type=pusht \ --eval.batch_size=10 \ @@ -288,10 +288,10 @@ python -m lerobot.scripts.eval \ Note: After training your own policy, you can re-evaluate the checkpoints with: ```bash -python -m lerobot.scripts.eval --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model +lerobot-eval --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model ``` -See `python -m lerobot.scripts.eval --help` for more instructions. +See `lerobot-eval --help` for more instructions. ### Train your own policy @@ -303,7 +303,7 @@ A link to the wandb logs for the run will also show up in yellow in your termina \WandB logs example -Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python -m lerobot.scripts.eval --help` for more instructions. +Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `lerobot-eval --help` for more instructions. #### Reproduce state-of-the-art (SOTA) @@ -311,7 +311,7 @@ We provide some pretrained policies on our [hub page](https://huggingface.co/ler You can reproduce their training by loading the config from their run. Simply running: ```bash -python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht +lerobot-train --config_path=lerobot/diffusion_pusht ``` reproduces SOTA results for Diffusion Policy on the PushT task. diff --git a/docs/source/cameras.mdx b/docs/source/cameras.mdx index 604863d74..5c35be0ba 100644 --- a/docs/source/cameras.mdx +++ b/docs/source/cameras.mdx @@ -9,7 +9,7 @@ To instantiate a camera, you need a camera identifier. This identifier might cha To find the camera indices of the cameras plugged into your system, run the following script: ```bash -python -m lerobot.find_cameras opencv # or realsense for Intel Realsense cameras +lerobot-find-cameras opencv # or realsense for Intel Realsense cameras ``` The output will look something like this if you have two cameras connected: diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx index 2f73d0964..f8a5c69b2 100644 --- a/docs/source/hilserl.mdx +++ b/docs/source/hilserl.mdx @@ -412,7 +412,7 @@ Example configuration for training the [reward classifier](https://huggingface.c To train the classifier, use the `train.py` script with your configuration: ```bash -python -m lerobot.scripts.train --config_path path/to/reward_classifier_train_config.json +lerobot-train --config_path path/to/reward_classifier_train_config.json ``` **Deploying and Testing the Model** @@ -458,7 +458,7 @@ The reward classifier will automatically provide rewards based on the visual inp 3. **Train the classifier**: ```bash - python -m lerobot.scripts.train --config_path src/lerobot/configs/reward_classifier_train_config.json + lerobot-train --config_path src/lerobot/configs/reward_classifier_train_config.json ``` 4. **Test the classifier**: diff --git a/docs/source/hope_jr.mdx b/docs/source/hope_jr.mdx index 72aa8f923..856febb95 100644 --- a/docs/source/hope_jr.mdx +++ b/docs/source/hope_jr.mdx @@ -19,7 +19,7 @@ pip install -e ".[hopejr]" Before starting calibration and operation, you need to identify the USB ports for each HopeJR component. Run this script to find the USB ports for the arm, hand, glove, and exoskeleton: ```bash -python -m lerobot.find_port +lerobot-find-port ``` This will display the available USB ports and their associated devices. Make note of the port paths (e.g., `/dev/tty.usbmodem58760433331`, `/dev/tty.usbmodem11301`) as you'll need to specify them in the `--robot.port` and `--teleop.port` parameters when recording data, replaying episodes, or running teleoperation scripts. @@ -31,7 +31,7 @@ Before performing teleoperation, HopeJR's limbs need to be calibrated. Calibrati ### 1.1 Calibrate Robot Hand ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --robot.type=hope_jr_hand \ --robot.port=/dev/tty.usbmodem58760432281 \ --robot.id=blue \ @@ -81,7 +81,7 @@ Once you have set the appropriate boundaries for all joints, click "Save" to sav ### 1.2 Calibrate Teleoperator Glove ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --teleop.type=homunculus_glove \ --teleop.port=/dev/tty.usbmodem11201 \ --teleop.id=red \ @@ -120,7 +120,7 @@ Once calibration is complete, the system will save the calibration to `/Users/yo ### 1.3 Calibrate Robot Arm ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --robot.type=hope_jr_arm \ --robot.port=/dev/tty.usbserial-1110 \ --robot.id=white @@ -146,7 +146,7 @@ Use the calibration interface to set the range boundaries for each joint. Move e ### 1.4 Calibrate Teleoperator Exoskeleton ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --teleop.type=homunculus_arm \ --teleop.port=/dev/tty.usbmodem11201 \ --teleop.id=black @@ -178,7 +178,7 @@ Due to global variable conflicts in the Feetech middleware, teleoperation for ar ### Hand ```bash -python -m lerobot.teleoperate \ +lerobot-teleoperate \ --robot.type=hope_jr_hand \ --robot.port=/dev/tty.usbmodem58760432281 \ --robot.id=blue \ @@ -194,7 +194,7 @@ python -m lerobot.teleoperate \ ### Arm ```bash -python -m lerobot.teleoperate \ +lerobot-teleoperate \ --robot.type=hope_jr_arm \ --robot.port=/dev/tty.usbserial-1110 \ --robot.id=white \ @@ -214,7 +214,7 @@ Record, Replay and Train with Hope-JR is still experimental. This step records the dataset, which can be seen as an example [here](https://huggingface.co/datasets/nepyope/hand_record_test_with_video_data/settings). ```bash -python -m lerobot.record \ +lerobot-record \ --robot.type=hope_jr_hand \ --robot.port=/dev/tty.usbmodem58760432281 \ --robot.id=right \ @@ -236,7 +236,7 @@ python -m lerobot.record \ ### Replay ```bash -python -m lerobot.replay \ +lerobot-replay \ --robot.type=hope_jr_hand \ --robot.port=/dev/tty.usbmodem58760432281 \ --robot.id=right \ @@ -248,7 +248,7 @@ python -m lerobot.replay \ ### Train ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --dataset.repo_id=nepyope/hand_record_test_with_video_data \ --policy.type=act \ --output_dir=outputs/train/hopejr_hand \ @@ -263,7 +263,7 @@ python -m lerobot.scripts.train \ This training run can be viewed as an example [here](https://wandb.ai/tino/lerobot/runs/rp0k8zvw?nw=nwusertino). ```bash -python -m lerobot.record \ +lerobot-record \ --robot.type=hope_jr_hand \ --robot.port=/dev/tty.usbmodem58760432281 \ --robot.id=right \ diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index ec5491b2a..905046bef 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -45,7 +45,7 @@ Note that the `id` associated with a robot is used to store the calibration file ```bash -python -m lerobot.teleoperate \ +lerobot-teleoperate \ --robot.type=so101_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=my_awesome_follower_arm \ @@ -101,7 +101,7 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam ```bash -python -m lerobot.teleoperate \ +lerobot-teleoperate \ --robot.type=koch_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=my_awesome_follower_arm \ @@ -174,7 +174,7 @@ Now you can record a dataset. To record 5 episodes and upload your dataset to th ```bash -python -m lerobot.record \ +lerobot-record \ --robot.type=so101_follower \ --robot.port=/dev/tty.usbmodem585A0076841 \ --robot.id=my_awesome_follower_arm \ @@ -376,7 +376,7 @@ You can replay the first episode on your robot with either the command below or ```bash -python -m lerobot.replay \ +lerobot-replay \ --robot.type=so101_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=my_awesome_follower_arm \ @@ -428,10 +428,10 @@ Your robot should replicate movements similar to those you recorded. For example ## Train a policy -To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: +To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --dataset.repo_id=${HF_USER}/so101_test \ --policy.type=act \ --output_dir=outputs/train/act_so101_test \ @@ -453,7 +453,7 @@ Training should take several hours. You will find checkpoints in `outputs/train/ To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so101_test` policy: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --config_path=outputs/train/act_so101_test/checkpoints/last/pretrained_model/train_config.json \ --resume=true ``` @@ -490,7 +490,7 @@ You can use the `record` script from [`lerobot/record.py`](https://github.com/hu ```bash -python -m lerobot.record \ +lerobot-record \ --robot.type=so100_follower \ --robot.port=/dev/ttyACM1 \ --robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \ diff --git a/docs/source/il_sim.mdx b/docs/source/il_sim.mdx index 761e24e0f..3dd80dc4b 100644 --- a/docs/source/il_sim.mdx +++ b/docs/source/il_sim.mdx @@ -96,10 +96,10 @@ If you uploaded your dataset to the hub you can [visualize your dataset online]( ## Train a policy -To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: +To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --dataset.repo_id=${HF_USER}/il_gym \ --policy.type=act \ --output_dir=outputs/train/il_sim_test \ diff --git a/docs/source/koch.mdx b/docs/source/koch.mdx index d0b991e74..3e94899a8 100644 --- a/docs/source/koch.mdx +++ b/docs/source/koch.mdx @@ -31,7 +31,7 @@ pip install -e ".[dynamixel]" To find the port for each bus servo adapter, run this script: ```bash -python -m lerobot.find_port +lerobot-find-port ``` @@ -98,7 +98,7 @@ For a visual reference on how to set the motor ids please refer to [this video]( ```bash -python -m lerobot.setup_motors \ +lerobot-setup-motors \ --robot.type=koch_follower \ --robot.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step ``` @@ -174,7 +174,7 @@ Do the same steps for the leader arm but modify the command or script accordingl ```bash -python -m lerobot.setup_motors \ +lerobot-setup-motors \ --teleop.type=koch_leader \ --teleop.port=/dev/tty.usbmodem575E0031751 \ # <- paste here the port found at previous step ``` @@ -211,7 +211,7 @@ Run the following command or API example to calibrate the follower arm: ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --robot.type=koch_follower \ --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --robot.id=my_awesome_follower_arm # <- Give the robot a unique name @@ -249,7 +249,7 @@ Do the same steps to calibrate the leader arm, run the following command or API ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --teleop.type=koch_leader \ --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name diff --git a/docs/source/lekiwi.mdx b/docs/source/lekiwi.mdx index a5bdb19cf..14c06e444 100644 --- a/docs/source/lekiwi.mdx +++ b/docs/source/lekiwi.mdx @@ -60,7 +60,7 @@ First, we will assemble the two SO100/SO101 arms. One to attach to the mobile ba To find the port for each bus servo adapter, run this script: ```bash -python -m lerobot.find_port +lerobot-find-port ``` @@ -116,7 +116,7 @@ The instructions for configuring the motors can be found in the SO101 [docs](./s You can run this command to setup motors for LeKiwi. It will first setup the motors for arm (id 6..1) and then setup motors for wheels (9,8,7) ```bash -python -m lerobot.setup_motors \ +lerobot-setup-motors \ --robot.type=lekiwi \ --robot.port=/dev/tty.usbmodem58760431551 # <- paste here the port found at previous step ``` @@ -174,7 +174,7 @@ The calibration process is very important because it allows a neural network tra Make sure the arm is connected to the Raspberry Pi and run this script or API example (on the Raspberry Pi via SSH) to launch calibration of the follower arm: ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --robot.type=lekiwi \ --robot.id=my_awesome_kiwi # <- Give the robot a unique name ``` @@ -193,7 +193,7 @@ Then, to calibrate the leader arm (which is attached to the laptop/pc). Run the ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --teleop.type=so100_leader \ --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name diff --git a/docs/source/smolvla.mdx b/docs/source/smolvla.mdx index 880beaa1a..89c475a90 100644 --- a/docs/source/smolvla.mdx +++ b/docs/source/smolvla.mdx @@ -54,7 +54,7 @@ If you don't have a gpu device, you can train using our notebook on [![Google Co Pass your dataset to the training script using `--dataset.repo_id`. If you want to test your installation, run the following command where we use one of the datasets we collected for the [SmolVLA Paper](https://huggingface.co/papers/2506.01844). ```bash -cd lerobot && python -m lerobot.scripts.train \ +cd lerobot && lerobot-train \ --policy.path=lerobot/smolvla_base \ --dataset.repo_id=${HF_USER}/mydataset \ --batch_size=64 \ @@ -73,7 +73,7 @@ cd lerobot && python -m lerobot.scripts.train \ Fine-tuning is an art. For a complete overview of the options for finetuning, run ```bash -python -m lerobot.scripts.train --help +lerobot-train --help ```

@@ -97,7 +97,7 @@ Similarly for when recording an episode, it is recommended that you are logged i Once you are logged in, you can run inference in your setup by doing: ```bash -python -m lerobot.record \ +lerobot-record \ --robot.type=so101_follower \ --robot.port=/dev/ttyACM0 \ # <- Use your port --robot.id=my_blue_follower_arm \ # <- Use your robot id diff --git a/docs/source/so100.mdx b/docs/source/so100.mdx index d9ff922c5..8578e1e8d 100644 --- a/docs/source/so100.mdx +++ b/docs/source/so100.mdx @@ -26,7 +26,7 @@ Unlike the SO-101, the motor connectors are not easily accessible once the arm i To find the port for each bus servo adapter, run this script: ```bash -python -m lerobot.find_port +lerobot-find-port ``` @@ -93,7 +93,7 @@ For a visual reference on how to set the motor ids please refer to [this video]( ```bash -python -m lerobot.setup_motors \ +lerobot-setup-motors \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step ``` @@ -168,7 +168,7 @@ Do the same steps for the leader arm. ```bash -python -m lerobot.setup_motors \ +lerobot-setup-motors \ --teleop.type=so100_leader \ --teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step ``` @@ -568,7 +568,7 @@ Run the following command or API example to calibrate the follower arm: ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --robot.id=my_awesome_follower_arm # <- Give the robot a unique name @@ -606,7 +606,7 @@ Do the same steps to calibrate the leader arm, run the following command or API ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --teleop.type=so100_leader \ --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name diff --git a/docs/source/so101.mdx b/docs/source/so101.mdx index a20a3fa9f..b9fb9cab4 100644 --- a/docs/source/so101.mdx +++ b/docs/source/so101.mdx @@ -162,7 +162,7 @@ It is advisable to install one 3-pin cable in the motor after placing them befor To find the port for each bus servo adapter, connect MotorBus to your computer via USB and power. Run the following script and disconnect the MotorBus when prompted: ```bash -python -m lerobot.find_port +lerobot-find-port ``` @@ -240,7 +240,7 @@ Connect the usb cable from your computer and the power supply to the follower ar ```bash -python -m lerobot.setup_motors \ +lerobot-setup-motors \ --robot.type=so101_follower \ --robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step ``` @@ -316,7 +316,7 @@ Do the same steps for the leader arm. ```bash -python -m lerobot.setup_motors \ +lerobot-setup-motors \ --teleop.type=so101_leader \ --teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step ``` @@ -353,7 +353,7 @@ Run the following command or API example to calibrate the follower arm: ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --robot.type=so101_follower \ --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --robot.id=my_awesome_follower_arm # <- Give the robot a unique name @@ -402,7 +402,7 @@ Do the same steps to calibrate the leader arm, run the following command or API ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --teleop.type=so101_leader \ --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name diff --git a/examples/4_train_policy_with_script.md b/examples/4_train_policy_with_script.md index d6cd6cc23..ffa7de66e 100644 --- a/examples/4_train_policy_with_script.md +++ b/examples/4_train_policy_with_script.md @@ -62,7 +62,7 @@ By default, every field takes its default value specified in the dataclass. If a Let's say that we want to train [Diffusion Policy](../src/lerobot/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --dataset.repo_id=lerobot/pusht \ --policy.type=diffusion \ --env.type=pusht @@ -77,7 +77,7 @@ Let's break this down: Let's see another example. Let's say you've been training [ACT](../src/lerobot/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.type=act \ --dataset.repo_id=lerobot/aloha_sim_insertion_human \ --env.type=aloha \ @@ -90,7 +90,7 @@ We now want to train a different policy for aloha on another task. We'll change Looking at the [`AlohaEnv`](../src/lerobot/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.type=act \ --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \ --env.type=aloha \ @@ -127,7 +127,7 @@ Now, let's assume that we want to reproduce the run just above. That run has pro We can then simply load the config values from this file using: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \ --output_dir=outputs/train/act_aloha_transfer_2 ``` @@ -137,7 +137,7 @@ python -m lerobot.scripts.train \ Similarly to Hydra, we can still override some parameters in the CLI if we want to, e.g.: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \ --output_dir=outputs/train/act_aloha_transfer_2 --policy.n_action_steps=80 @@ -148,7 +148,7 @@ python -m lerobot.scripts.train \ `--config_path` can also accept the repo_id of a repo on the hub that contains a `train_config.json` file, e.g. running: ```bash -python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht +lerobot-train --config_path=lerobot/diffusion_pusht ``` will start a training run with the same configuration used for training [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht) @@ -160,7 +160,7 @@ Being able to resume a training run is important in case it crashed or aborted f Let's reuse the command from the previous run and add a few more options: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.type=act \ --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \ --env.type=aloha \ @@ -179,7 +179,7 @@ INFO 2025-01-24 16:10:56 ts/train.py:263 Checkpoint policy after step 100 Now let's simulate a crash by killing the process (hit `ctrl`+`c`). We can then simply resume this run from the last checkpoint available with: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \ --resume=true ``` @@ -190,7 +190,7 @@ Another reason for which you might want to resume a run is simply to extend trai You could double the number of steps of the previous run with: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \ --resume=true \ --steps=200000 @@ -224,7 +224,7 @@ In addition to the features currently in Draccus, we've added a special `.path` For example, we could fine-tune a [policy pre-trained on the aloha transfer task](https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human) on the aloha insertion task. We can achieve this with: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.path=lerobot/act_aloha_sim_transfer_cube_human \ --dataset.repo_id=lerobot/aloha_sim_insertion_human \ --env.type=aloha \ @@ -270,7 +270,7 @@ We'll summarize here the main use cases to remember from this tutorial. #### Train a policy from scratch – CLI ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.type=act \ # <- select 'act' policy --env.type=pusht \ # <- select 'pusht' environment --dataset.repo_id=lerobot/pusht # <- train on this dataset @@ -279,7 +279,7 @@ python -m lerobot.scripts.train \ #### Train a policy from scratch - config file + CLI ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --config_path=path/to/pretrained_model \ # <- can also be a repo_id --policy.n_action_steps=80 # <- you may still override values ``` @@ -287,7 +287,7 @@ python -m lerobot.scripts.train \ #### Resume/continue a training run ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --config_path=checkpoint/pretrained_model/ \ --resume=true \ --steps=200000 # <- you can change some training parameters @@ -296,7 +296,7 @@ python -m lerobot.scripts.train \ #### Fine-tuning ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.path=lerobot/act_aloha_sim_transfer_cube_human \ # <- can also be a local path to a checkpoint --dataset.repo_id=lerobot/aloha_sim_insertion_human \ --env.type=aloha \ diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py index cc3397543..6c680f204 100644 --- a/examples/backward_compatibility/replay.py +++ b/examples/backward_compatibility/replay.py @@ -18,7 +18,7 @@ Replays the actions of an episode from a dataset on a robot. Example: ```shell -python -m lerobot.replay \ +lerobot-replay \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=black \ diff --git a/src/lerobot/calibrate.py b/src/lerobot/calibrate.py index 0dda80ba2..0aa61a2f9 100644 --- a/src/lerobot/calibrate.py +++ b/src/lerobot/calibrate.py @@ -18,7 +18,7 @@ Helper to recalibrate your device (robot or teleoperator). Example: ```shell -python -m lerobot.calibrate \ +lerobot-calibrate \ --teleop.type=so100_leader \ --teleop.port=/dev/tty.usbmodem58760431551 \ --teleop.id=blue diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index aad19819a..3665a909f 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -60,7 +60,7 @@ class OpenCVCamera(Camera): or port changes, especially on Linux. Use the provided utility script to find available camera indices or paths: ```bash - python -m lerobot.find_cameras opencv + lerobot-find-cameras opencv ``` The camera's default settings (FPS, resolution, color mode) are used unless @@ -165,8 +165,7 @@ class OpenCVCamera(Camera): self.videocapture.release() self.videocapture = None raise ConnectionError( - f"Failed to open {self}." - f"Run `python -m lerobot.find_cameras opencv` to find available cameras." + f"Failed to open {self}.Run `lerobot-find-cameras opencv` to find available cameras." ) self._configure_capture_settings() diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index 918c5592e..12ce89c91 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -51,7 +51,7 @@ class RealSenseCamera(Camera): Use the provided utility script to find available camera indices and default profiles: ```bash - python -m lerobot.find_cameras realsense + lerobot-find-cameras realsense ``` A `RealSenseCamera` instance requires a configuration object specifying the @@ -176,8 +176,7 @@ class RealSenseCamera(Camera): self.rs_profile = None self.rs_pipeline = None raise ConnectionError( - f"Failed to open {self}." - "Run `python -m lerobot.find_cameras realsense` to find available cameras." + f"Failed to open {self}.Run `lerobot-find-cameras realsense` to find available cameras." ) from e self._configure_capture_settings() diff --git a/src/lerobot/find_cameras.py b/src/lerobot/find_cameras.py index 8f88d3107..ec8f5ff30 100644 --- a/src/lerobot/find_cameras.py +++ b/src/lerobot/find_cameras.py @@ -20,7 +20,7 @@ Helper to find the camera devices available in your system. Example: ```shell -python -m lerobot.find_cameras +lerobot-find-cameras ``` """ diff --git a/src/lerobot/find_port.py b/src/lerobot/find_port.py index babe0288e..e32b9cb99 100644 --- a/src/lerobot/find_port.py +++ b/src/lerobot/find_port.py @@ -18,7 +18,7 @@ Helper to find the USB port associated with your MotorsBus. Example: ```shell -python -m lerobot.find_port +lerobot-find-port ``` """ diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index 597bcd3c4..97830fc35 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -222,7 +222,7 @@ class MotorsBus(abc.ABC): A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)). To find the port, you can run our utility script: ```bash - python -m lerobot.find_port.py + lerobot-find-port.py >>> Finding all available ports for the MotorsBus. >>> ["/dev/tty.usbmodem575E0032081", "/dev/tty.usbmodem575E0031751"] >>> Remove the usb cable from your MotorsBus and press Enter when done. @@ -446,7 +446,7 @@ class MotorsBus(abc.ABC): except (FileNotFoundError, OSError, serial.SerialException) as e: raise ConnectionError( f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port." - "\nTry running `python -m lerobot.find_port`\n" + "\nTry running `lerobot-find-port`\n" ) from e @abc.abstractmethod diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index e56946ac8..de41e2bd4 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -30,7 +30,7 @@ pip install -e ".[pi0]" Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`): ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.path=lerobot/pi0 \ --dataset.repo_id=danaaubakirova/koch_test ``` @@ -38,7 +38,7 @@ python -m lerobot.scripts.train \ Example of finetuning the pi0 neural network with PaliGemma and expert Gemma pretrained with VLM default parameters before pi0 finetuning: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.type=pi0 \ --dataset.repo_id=danaaubakirova/koch_test ``` diff --git a/src/lerobot/policies/pi0fast/modeling_pi0fast.py b/src/lerobot/policies/pi0fast/modeling_pi0fast.py index 80e10bc02..88727b581 100644 --- a/src/lerobot/policies/pi0fast/modeling_pi0fast.py +++ b/src/lerobot/policies/pi0fast/modeling_pi0fast.py @@ -25,14 +25,14 @@ Disclaimer: It is not expected to perform as well as the original implementation Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`): ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.path=lerobot/pi0fast_base \ --dataset.repo_id=danaaubakirova/koch_test ``` Example of training the pi0+FAST neural network with from scratch: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.type=pi0fast \ --dataset.repo_id=danaaubakirova/koch_test ``` diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 469645e84..18f2fc58a 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -28,7 +28,7 @@ pip install -e ".[smolvla]" Example of finetuning the smolvla pretrained model (`smolvla_base`): ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.path=lerobot/smolvla_base \ --dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ --batch_size=64 \ @@ -38,7 +38,7 @@ python -m lerobot.scripts.train \ Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM, and an action expert. ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.type=smolvla \ --dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ --batch_size=64 \ diff --git a/src/lerobot/record.py b/src/lerobot/record.py index 575fcb94d..09fa33fe3 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -18,7 +18,7 @@ Records a dataset. Actions for the robot can be either generated by teleoperatio Example: ```shell -python -m lerobot.record \ +lerobot-record \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.cameras="{laptop: {type: opencv, camera_index: 0, width: 640, height: 480}}" \ @@ -36,7 +36,7 @@ python -m lerobot.record \ Example recording with bimanual so100: ```shell -python -m lerobot.record \ +lerobot-record \ --robot.type=bi_so100_follower \ --robot.left_arm_port=/dev/tty.usbmodem5A460851411 \ --robot.right_arm_port=/dev/tty.usbmodem5A460812391 \ diff --git a/src/lerobot/replay.py b/src/lerobot/replay.py index a9dceb741..2b62fd67f 100644 --- a/src/lerobot/replay.py +++ b/src/lerobot/replay.py @@ -18,7 +18,7 @@ Replays the actions of an episode from a dataset on a robot. Examples: ```shell -python -m lerobot.replay \ +lerobot-replay \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=black \ @@ -28,7 +28,7 @@ python -m lerobot.replay \ Example replay with bimanual so100: ```shell -python -m lerobot.replay \ +lerobot-replay \ --robot.type=bi_so100_follower \ --robot.left_arm_port=/dev/tty.usbmodem5A460851411 \ --robot.right_arm_port=/dev/tty.usbmodem5A460812391 \ diff --git a/src/lerobot/robots/viperx/README.md b/src/lerobot/robots/viperx/README.md index 5cdb152a2..bbc9f7223 100644 --- a/src/lerobot/robots/viperx/README.md +++ b/src/lerobot/robots/viperx/README.md @@ -141,10 +141,10 @@ python lerobot/scripts/control_robot.py \ ## Train a policy -To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: +To train a policy to control your robot, use the [`lerobot-train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --dataset.repo_id=${HF_USER}/aloha_test \ --policy.type=act \ --output_dir=outputs/train/act_aloha_test \ diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index 6a6c02a24..13d30c686 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -21,7 +21,7 @@ You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/di for 10 episodes. ``` -python -m lerobot.scripts.eval \ +lerobot-eval \ --policy.path=lerobot/diffusion_pusht \ --env.type=pusht \ --eval.batch_size=10 \ @@ -32,7 +32,7 @@ python -m lerobot.scripts.eval \ OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes. ``` -python -m lerobot.scripts.eval \ +lerobot-eval \ --policy.path=outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \ --env.type=pusht \ --eval.batch_size=10 \ diff --git a/src/lerobot/setup_motors.py b/src/lerobot/setup_motors.py index 76cdca56d..c1d256c21 100644 --- a/src/lerobot/setup_motors.py +++ b/src/lerobot/setup_motors.py @@ -18,7 +18,7 @@ Helper to set motor ids and baudrate. Example: ```shell -python -m lerobot.setup_motors \ +lerobot-setup-motors \ --teleop.type=so100_leader \ --teleop.port=/dev/tty.usbmodem575E0031751 ``` diff --git a/src/lerobot/teleoperate.py b/src/lerobot/teleoperate.py index 3c72caf79..e7be6967b 100644 --- a/src/lerobot/teleoperate.py +++ b/src/lerobot/teleoperate.py @@ -18,7 +18,7 @@ Simple script to control a robot from teleoperation. Example: ```shell -python -m lerobot.teleoperate \ +lerobot-teleoperate \ --robot.type=so101_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ @@ -32,7 +32,7 @@ python -m lerobot.teleoperate \ Example teleoperation with bimanual so100: ```shell -python -m lerobot.teleoperate \ +lerobot-teleoperate \ --robot.type=bi_so100_follower \ --robot.left_arm_port=/dev/tty.usbmodem5A460851411 \ --robot.right_arm_port=/dev/tty.usbmodem5A460812391 \ diff --git a/src/lerobot/templates/lerobot_modelcard_template.md b/src/lerobot/templates/lerobot_modelcard_template.md index 7b7aaa84a..9293d6ba7 100644 --- a/src/lerobot/templates/lerobot_modelcard_template.md +++ b/src/lerobot/templates/lerobot_modelcard_template.md @@ -44,7 +44,7 @@ Below is the short version on how to train and run inference/eval: ### Train from scratch ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --dataset.repo_id=${HF_USER}/ \ --policy.type=act \ --output_dir=outputs/train/ \ @@ -59,7 +59,7 @@ _Writes checkpoints to `outputs/train//checkpoints/`._ ### Evaluate the policy/run inference ```bash -python -m lerobot.record \ +lerobot-record \ --robot.type=so100_follower \ --dataset.repo_id=/eval_ \ --policy.path=/ \ From 11e6bd762a6fd558b470849d5d72253dceb784ef Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Fri, 8 Aug 2025 10:46:14 +0200 Subject: [PATCH 073/158] fix(busy_wait): fix busy_wait implementation for Windows platforms and removing erronous TODO (#1695) --- src/lerobot/utils/robot_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lerobot/utils/robot_utils.py b/src/lerobot/utils/robot_utils.py index e6c0cfe6d..8069b3662 100644 --- a/src/lerobot/utils/robot_utils.py +++ b/src/lerobot/utils/robot_utils.py @@ -17,10 +17,9 @@ import time def busy_wait(seconds): - if platform.system() == "Darwin": - # On Mac, `time.sleep` is not accurate and we need to use this while loop trick, + if platform.system() == "Darwin" or platform.system() == "Windows": + # On Mac and Windows, `time.sleep` is not accurate and we need to use this while loop trick, # but it consumes CPU cycles. - # TODO(rcadene): find an alternative: from python 11, time.sleep is precise end_time = time.perf_counter() + seconds while time.perf_counter() < end_time: pass From 0878c6880fa4fbadf0742751cf7b015f2d63a769 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sat, 9 Aug 2025 00:21:42 +0200 Subject: [PATCH 074/158] fix(ci): inverted names (#1705) --- .github/workflows/nightly.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index b42b92f6b..03f26a792 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -29,8 +29,8 @@ on: env: UV_VERSION: "0.8.0" PYTHON_VERSION: "3.10" - DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-gpu:latest - DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-cpu:latest + DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest + DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest # Ensures that only the latest commit is built, canceling older runs. concurrency: From 55198de096f46a8e0447a8795129dd9ee84c088c Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 14 Aug 2025 11:12:06 +0200 Subject: [PATCH 075/158] fix(ci): rename libegl1-mesa in deb13 trixie (#1735) --- docker/Dockerfile.user | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile.user b/docker/Dockerfile.user index 4cfbb437a..bcd067637 100644 --- a/docker/Dockerfile.user +++ b/docker/Dockerfile.user @@ -29,7 +29,7 @@ ENV DEBIAN_FRONTEND=noninteractive \ # Install system dependencies and uv (as root) RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential git curl libglib2.0-0 libegl1-mesa ffmpeg \ + build-essential git curl libglib2.0-0 libegl1-mesa-dev ffmpeg \ libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \ && curl -LsSf https://astral.sh/uv/install.sh | sh \ && mv /root/.local/bin/uv /usr/local/bin/uv \ From 7f70b78f3221a2fa64ae09795b5989a58a61931d Mon Sep 17 00:00:00 2001 From: Jack Vial Date: Wed, 20 Aug 2025 11:24:05 -0400 Subject: [PATCH 076/158] Add missing encoding table entries for Koch arm (#1534) --- src/lerobot/motors/dynamixel/tables.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lerobot/motors/dynamixel/tables.py b/src/lerobot/motors/dynamixel/tables.py index 8b67bbf38..5417d8cee 100644 --- a/src/lerobot/motors/dynamixel/tables.py +++ b/src/lerobot/motors/dynamixel/tables.py @@ -107,6 +107,8 @@ X_SERIES_ENCODINGS_TABLE = { "Goal_PWM": X_SERIES_CONTROL_TABLE["Goal_PWM"][1], "Goal_Current": X_SERIES_CONTROL_TABLE["Goal_Current"][1], "Goal_Velocity": X_SERIES_CONTROL_TABLE["Goal_Velocity"][1], + "Goal_Position": X_SERIES_CONTROL_TABLE["Goal_Position"][1], + "Present_Position": X_SERIES_CONTROL_TABLE["Present_Position"][1], "Present_PWM": X_SERIES_CONTROL_TABLE["Present_PWM"][1], "Present_Current": X_SERIES_CONTROL_TABLE["Present_Current"][1], "Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1], From b0923ab74b7fb7ed688ef2abbe79607f3dee390a Mon Sep 17 00:00:00 2001 From: lxk <53181378+lxk-221@users.noreply.github.com> Date: Fri, 22 Aug 2025 21:24:02 +0800 Subject: [PATCH 077/158] fix(dataset): Use provided episode_data in save_episode (#1740) The 'episode_data' parameter was previously ignored, causing an error if provided. This change ensures it is correctly used, which allows for asynchronous episode saving by passing a copy of the episode buffer, preventing conflicts with the main data collection loop. --- src/lerobot/datasets/lerobot_dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 617ac297f..a869cb920 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -825,6 +825,8 @@ class LeRobotDataset(torch.utils.data.Dataset): """ if not episode_data: episode_buffer = self.episode_buffer + else: + episode_buffer = episode_data validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features) From 577cd10974b84bea1f06b6472eb9e5e74e07f77a Mon Sep 17 00:00:00 2001 From: mgiac-hexagon Date: Mon, 25 Aug 2025 12:39:32 +0200 Subject: [PATCH 078/158] Removed dupicate lines of code (#1709) --- src/lerobot/scripts/server/robot_client.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/lerobot/scripts/server/robot_client.py b/src/lerobot/scripts/server/robot_client.py index 0599e068e..939d5cea8 100644 --- a/src/lerobot/scripts/server/robot_client.py +++ b/src/lerobot/scripts/server/robot_client.py @@ -302,11 +302,6 @@ class RobotClient: self.logger.debug(f"Current latest action: {latest_action}") - # Get queue state before changes - old_size, old_timesteps = self._inspect_action_queue() - if not old_timesteps: - old_timesteps = [latest_action] # queue was empty - # Get queue state before changes old_size, old_timesteps = self._inspect_action_queue() if not old_timesteps: From 61b0eeae4b41902fd1f09886412974861d800cfa Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Thu, 28 Aug 2025 11:18:54 +0200 Subject: [PATCH 079/158] Add feetech firmware update docs (#1793) * Add feetech firmware update docs * add bonus * formatting * adapt text * feedback pr --- docs/source/_toctree.yml | 2 ++ docs/source/feetech.mdx | 71 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 docs/source/feetech.mdx diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 1af96d79d..af44c512b 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -39,6 +39,8 @@ - sections: - local: notebooks title: Notebooks + - local: feetech + title: Updating Feetech Firmware title: "Resources" - sections: - local: contributing diff --git a/docs/source/feetech.mdx b/docs/source/feetech.mdx new file mode 100644 index 000000000..bba60e4cc --- /dev/null +++ b/docs/source/feetech.mdx @@ -0,0 +1,71 @@ +# Feetech Motor Firmware Update + +This tutorial guides you through updating the firmware of Feetech motors using the official Feetech software. + +## Prerequisites + +- Windows computer (Feetech software is only available for Windows) +- Feetech motor control board +- USB cable to connect the control board to your computer +- Feetech motors connected to the control board + +## Step 1: Download Feetech Software + +1. Visit the official Feetech software download page: [https://www.feetechrc.com/software.html](https://www.feetechrc.com/software.html) +2. Download the latest version of the Feetech debugging software (FD) +3. Install the software on your Windows computer + +## Step 2: Hardware Setup + +1. Connect your Feetech motors to the motor control board +2. Connect the motor control board to your Windows computer via USB cable +3. Ensure power is supplied to the motors + +## Step 3: Configure Connection + +1. Launch the Feetech debugging software +2. Select the correct COM port from the port dropdown menu + - If unsure which port to use, check Windows Device Manager under "Ports (COM & LPT)" +3. Set the appropriate baud rate (typically 1000000 for most Feetech motors) +4. Click "Open" to establish communication with the control board + +## Step 4: Scan for Motors + +1. Once connected, click the "Search" button to detect all connected motors +2. The software will automatically discover and list all motors on the bus +3. Each motor will appear with its ID number + +## Step 5: Update Firmware + +For each motor you want to update: + +1. **Select the motor** from the list by clicking on it +2. **Click on Upgrade tab**: +3. **Click on Online button**: + - If an potential firmware update is found, it will be displayed in the box +4. **Click on Upgrade button**: + - The update progress will be displayed + +## Step 6: Verify Update + +1. After the update completes, the software should automatically refresh the motor information +2. Verify that the firmware version has been updated to the expected version + +## Important Notes + +⚠️ **Warning**: Do not disconnect power or USB during firmware updates, it will potentially brick the motor. + +## Bonus: Motor Debugging on Linux/macOS + +For debugging purposes only, you can use the open-source Feetech Debug Tool: + +- **Repository**: [FT_SCServo_Debug_Qt](https://github.com/CarolinePascal/FT_SCServo_Debug_Qt/tree/fix/port-search-timer) + +### Installation Instructions + +Follow the instructions in the repository to install the tool, for Ubuntu you can directly install it, for MacOS you need to build it from source. + +**Limitations:** + +- This tool is for debugging and parameter adjustment only +- Firmware updates must still be done on Windows with official Feetech software From 882c80d446a63a44868c67ae535467af32ce0e80 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Fri, 29 Aug 2025 16:06:55 +0200 Subject: [PATCH 080/158] Lower limits by 50% for current and torque for gripper motor (#1809) Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- src/lerobot/robots/so100_follower/so100_follower.py | 5 +++++ src/lerobot/robots/so101_follower/so101_follower.py | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/src/lerobot/robots/so100_follower/so100_follower.py b/src/lerobot/robots/so100_follower/so100_follower.py index ac52293ff..1e117e80b 100644 --- a/src/lerobot/robots/so100_follower/so100_follower.py +++ b/src/lerobot/robots/so100_follower/so100_follower.py @@ -161,6 +161,11 @@ class SO100Follower(Robot): self.bus.write("I_Coefficient", motor, 0) self.bus.write("D_Coefficient", motor, 32) + if motor == "gripper": + self.bus.write("Max_Torque_Limit", motor, 500) # 50% of max torque to avoid burnout + self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout + self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded + def setup_motors(self) -> None: for motor in reversed(self.bus.motors): input(f"Connect the controller board to the '{motor}' motor only and press enter.") diff --git a/src/lerobot/robots/so101_follower/so101_follower.py b/src/lerobot/robots/so101_follower/so101_follower.py index 3ef66d702..31b06c2fd 100644 --- a/src/lerobot/robots/so101_follower/so101_follower.py +++ b/src/lerobot/robots/so101_follower/so101_follower.py @@ -157,6 +157,13 @@ class SO101Follower(Robot): self.bus.write("I_Coefficient", motor, 0) self.bus.write("D_Coefficient", motor, 32) + if motor == "gripper": + self.bus.write( + "Max_Torque_Limit", motor, 500 + ) # 50% of the max torque limit to avoid burnout + self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout + self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded + def setup_motors(self) -> None: for motor in reversed(self.bus.motors): input(f"Connect the controller board to the '{motor}' motor only and press enter.") From d74494d92b74e951ce2e92bce55b0462c6d3324c Mon Sep 17 00:00:00 2001 From: Justin Huang Date: Fri, 5 Sep 2025 00:58:47 -0700 Subject: [PATCH 081/158] Allow max_relative_target to be a float (#1837) * Remove unused max_relative_target for stretch3 * Fix type annotation and allow integer max_relative_target values * Configure max_relative_target to be floats instead of ints * Update docs and types to reflect that max_relative_target can be a dict * Remove unnecessary isinstance check for ints * Fix typo in name --------- Co-authored-by: Justin Huang --- .../robots/bi_so100_follower/config_bi_so100_follower.py | 4 ++-- src/lerobot/robots/hope_jr/config_hope_jr.py | 6 +++--- src/lerobot/robots/koch_follower/config_koch_follower.py | 6 +++--- src/lerobot/robots/lekiwi/config_lekiwi.py | 6 +++--- src/lerobot/robots/so100_follower/config_so100_follower.py | 6 +++--- src/lerobot/robots/so101_follower/config_so101_follower.py | 6 +++--- src/lerobot/robots/stretch3/configuration_stretch3.py | 5 ----- src/lerobot/robots/utils.py | 2 +- src/lerobot/robots/viperx/config_viperx.py | 6 +++--- 9 files changed, 21 insertions(+), 26 deletions(-) diff --git a/src/lerobot/robots/bi_so100_follower/config_bi_so100_follower.py b/src/lerobot/robots/bi_so100_follower/config_bi_so100_follower.py index 00643b85f..5806d7415 100644 --- a/src/lerobot/robots/bi_so100_follower/config_bi_so100_follower.py +++ b/src/lerobot/robots/bi_so100_follower/config_bi_so100_follower.py @@ -29,10 +29,10 @@ class BiSO100FollowerConfig(RobotConfig): # Optional left_arm_disable_torque_on_disconnect: bool = True - left_arm_max_relative_target: int | None = None + left_arm_max_relative_target: float | dict[str, float] | None = None left_arm_use_degrees: bool = False right_arm_disable_torque_on_disconnect: bool = True - right_arm_max_relative_target: int | None = None + right_arm_max_relative_target: float | dict[str, float] | None = None right_arm_use_degrees: bool = False # cameras (shared between both arms) diff --git a/src/lerobot/robots/hope_jr/config_hope_jr.py b/src/lerobot/robots/hope_jr/config_hope_jr.py index 747e98e01..f2af5f47c 100644 --- a/src/lerobot/robots/hope_jr/config_hope_jr.py +++ b/src/lerobot/robots/hope_jr/config_hope_jr.py @@ -44,8 +44,8 @@ class HopeJrArmConfig(RobotConfig): disable_torque_on_disconnect: bool = True # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None + # Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor + # names to the max_relative_target value for that motor. + max_relative_target: float | dict[str, float] | None = None cameras: dict[str, CameraConfig] = field(default_factory=dict) diff --git a/src/lerobot/robots/koch_follower/config_koch_follower.py b/src/lerobot/robots/koch_follower/config_koch_follower.py index a7c9249ae..02a95ef4e 100644 --- a/src/lerobot/robots/koch_follower/config_koch_follower.py +++ b/src/lerobot/robots/koch_follower/config_koch_follower.py @@ -28,9 +28,9 @@ class KochFollowerConfig(RobotConfig): disable_torque_on_disconnect: bool = True # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None + # Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor + # names to the max_relative_target value for that motor. + max_relative_target: float | dict[str, float] | None = None # cameras cameras: dict[str, CameraConfig] = field(default_factory=dict) diff --git a/src/lerobot/robots/lekiwi/config_lekiwi.py b/src/lerobot/robots/lekiwi/config_lekiwi.py index f0f8c24b3..acaf5f0ec 100644 --- a/src/lerobot/robots/lekiwi/config_lekiwi.py +++ b/src/lerobot/robots/lekiwi/config_lekiwi.py @@ -39,9 +39,9 @@ class LeKiwiConfig(RobotConfig): disable_torque_on_disconnect: bool = True # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None + # Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor + # names to the max_relative_target value for that motor. + max_relative_target: float | dict[str, float] | None = None cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config) diff --git a/src/lerobot/robots/so100_follower/config_so100_follower.py b/src/lerobot/robots/so100_follower/config_so100_follower.py index ea8b9f1c2..561790e77 100644 --- a/src/lerobot/robots/so100_follower/config_so100_follower.py +++ b/src/lerobot/robots/so100_follower/config_so100_follower.py @@ -30,9 +30,9 @@ class SO100FollowerConfig(RobotConfig): disable_torque_on_disconnect: bool = True # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None + # Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor + # names to the max_relative_target value for that motor. + max_relative_target: float | dict[str, float] | None = None # cameras cameras: dict[str, CameraConfig] = field(default_factory=dict) diff --git a/src/lerobot/robots/so101_follower/config_so101_follower.py b/src/lerobot/robots/so101_follower/config_so101_follower.py index be630e6ac..03c3530c2 100644 --- a/src/lerobot/robots/so101_follower/config_so101_follower.py +++ b/src/lerobot/robots/so101_follower/config_so101_follower.py @@ -30,9 +30,9 @@ class SO101FollowerConfig(RobotConfig): disable_torque_on_disconnect: bool = True # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None + # Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor + # names to the max_relative_target value for that motor. + max_relative_target: float | dict[str, float] | None = None # cameras cameras: dict[str, CameraConfig] = field(default_factory=dict) diff --git a/src/lerobot/robots/stretch3/configuration_stretch3.py b/src/lerobot/robots/stretch3/configuration_stretch3.py index 9fcf8f742..d4e217ca0 100644 --- a/src/lerobot/robots/stretch3/configuration_stretch3.py +++ b/src/lerobot/robots/stretch3/configuration_stretch3.py @@ -24,11 +24,6 @@ from ..config import RobotConfig @RobotConfig.register_subclass("stretch3") @dataclass class Stretch3RobotConfig(RobotConfig): - # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None - # cameras cameras: dict[str, CameraConfig] = field( default_factory=lambda: { diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index 7486ee499..befd96424 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -70,7 +70,7 @@ def make_robot_from_config(config: RobotConfig) -> Robot: def ensure_safe_goal_position( - goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float] + goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[str, float] ) -> dict[str, float]: """Caps relative action target magnitude for safety.""" diff --git a/src/lerobot/robots/viperx/config_viperx.py b/src/lerobot/robots/viperx/config_viperx.py index 4922f1d18..ed3876a9c 100644 --- a/src/lerobot/robots/viperx/config_viperx.py +++ b/src/lerobot/robots/viperx/config_viperx.py @@ -28,15 +28,15 @@ class ViperXConfig(RobotConfig): # /!\ FOR SAFETY, READ THIS /!\ # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. + # Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor + # names to the max_relative_target value for that motor. # For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default. # When you feel more confident with teleoperation or running the policy, you can extend # this safety limit and even removing it by setting it to `null`. # Also, everything is expected to work safely out-of-the-box, but we highly advise to # first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml), # then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully - max_relative_target: int | None = 5 + max_relative_target: float | dict[str, float] = 5.0 # cameras cameras: dict[str, CameraConfig] = field(default_factory=dict) From 6a3d57031aab37adff8eec2e11049510654cc5bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABlle=20Lannuzel?= Date: Fri, 5 Sep 2025 11:03:14 +0200 Subject: [PATCH 082/158] 2 add reachy 2 to updated lerobot (#1767) * Start adding Reachy 2 (no camera) * Fix joint shape * Remove print * Modify observation_features * Fix observation state * Try adding a fake Reachy teleoperator * Saving test scripts * Add reachy2camera to cameras * Add teleop_left camera to observation * Create test_reachy2_camera.py * Update utils.py * Add all rgb cameras * Future depth work * Try adding mobile_base velocity * Update tests * Update data_acquisition_server.py * Update with use_external_commands * Replay * Usable with or without mobile base * No need for new isntance * Use same ip for cameras * Remove useless imports * Add resume * Divide joints in multiple dicts * Divide joinits into several dicts in teleoperator * Fix forgotten method call * Create test_robot_client.py * Open gripper on start * Add arguments for cameras * Modify get_frame() requested size * Call generate_joints_dict on _init_ * black + isort * Add reachy2 in imports * Add reachy2 dependencies * Add documentation * Update reachy2.mdx * Update reachy2.mdx * Clean files and add types * Fix type in send_action * Remove print * Delete test files * Clean code * Update cameras * Disconnect from camera * Run pre-commit hooks * Update pyproject.toml * Create test_reachy2.py * Fix generate_joints * Update test_reachy2.py * Update send_action test * Update reachy2_cameras depth + CameraManager * Update reachy2_camera tests * Remove useless import and args * Rename reachy2_teleoperator * Create test_reachy2_teleoperator.py * Fix remainging fake_teleoperator * Remove useless elements * Mock cameras in test_reachy2 * Delete commented lines * Add use_present_position to teleoperator * Add cameras tests * Add check no part + test * Use disable_torque_on_disconnect * Use odometry for vel with present_position * Update documentation * Fix vel value type * Use ensure_safe_goal_position * Import joints dict from classes * Update reachy2.mdx * Update reachy2.mdx * Update minimal version * Update minimal version * fix(tests) fixes for reachy2 tests; removing reachy2 references from the script * Add reachy2_sdk fake as plugins --------- Co-authored-by: Michel Aractingi --- docs/source/_toctree.yml | 2 + docs/source/reachy2.mdx | 288 ++++++++++++++++ pyproject.toml | 2 + .../cameras/reachy2_camera/__init__.py | 16 + .../configuration_reachy2_camera.py | 78 +++++ .../cameras/reachy2_camera/reachy2_camera.py | 288 ++++++++++++++++ src/lerobot/cameras/utils.py | 8 +- src/lerobot/record.py | 10 +- src/lerobot/replay.py | 1 + src/lerobot/robots/reachy2/__init__.py | 25 ++ .../robots/reachy2/configuration_reachy2.py | 107 ++++++ src/lerobot/robots/reachy2/robot_reachy2.py | 230 ++++++++++++ src/lerobot/robots/utils.py | 4 + .../reachy2_teleoperator/__init__.py | 25 ++ .../config_reachy2_teleoperator.py | 51 +++ .../reachy2_teleoperator.py | 164 +++++++++ src/lerobot/teleoperators/utils.py | 4 + tests/cameras/test_reachy2_camera.py | 177 ++++++++++ tests/conftest.py | 1 + tests/plugins/reachy2_sdk.py | 30 ++ tests/robots/test_reachy2.py | 326 ++++++++++++++++++ .../test_reachy2_teleoperator.py | 150 ++++++++ 22 files changed, 1984 insertions(+), 3 deletions(-) create mode 100644 docs/source/reachy2.mdx create mode 100644 src/lerobot/cameras/reachy2_camera/__init__.py create mode 100644 src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py create mode 100644 src/lerobot/cameras/reachy2_camera/reachy2_camera.py create mode 100644 src/lerobot/robots/reachy2/__init__.py create mode 100644 src/lerobot/robots/reachy2/configuration_reachy2.py create mode 100644 src/lerobot/robots/reachy2/robot_reachy2.py create mode 100644 src/lerobot/teleoperators/reachy2_teleoperator/__init__.py create mode 100644 src/lerobot/teleoperators/reachy2_teleoperator/config_reachy2_teleoperator.py create mode 100644 src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py create mode 100644 tests/cameras/test_reachy2_camera.py create mode 100644 tests/plugins/reachy2_sdk.py create mode 100644 tests/robots/test_reachy2.py create mode 100644 tests/teleoperators/test_reachy2_teleoperator.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index af44c512b..1a4558f93 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -35,6 +35,8 @@ title: Koch v1.1 - local: lekiwi title: LeKiwi + - local: reachy2 + title: Reachy 2 title: "Robots" - sections: - local: notebooks diff --git a/docs/source/reachy2.mdx b/docs/source/reachy2.mdx new file mode 100644 index 000000000..7d3dc1b60 --- /dev/null +++ b/docs/source/reachy2.mdx @@ -0,0 +1,288 @@ +# Reachy 2 + +Reachy 2 is an open-source humanoid robot made by Pollen Robotics, specifically designed for the development of embodied AI and real-world applications. +Check out [Pollen Robotics website](https://www.pollen-robotics.com/reachy/), or access [Reachy 2 documentation](https://docs.pollen-robotics.com/) for more information on the platform! + +## Teleoperate Reachy 2 + +Currently, there are two ways to teleoperate Reachy 2: + +- Pollen Robotics’ VR teleoperation (not included in LeRobot). +- Robot-to-robot teleoperation (use one Reachy 2 to control another). + +## Reachy 2 Simulation + +**(Linux only)** You can run Reachy 2 in simulation (Gazebo or MuJoCo) using the provided [Docker image](https://hub.docker.com/r/pollenrobotics/reachy2_core). + +1. Install [Docker Engine](https://docs.docker.com/engine/). +2. Run (for MuJoCo): + +``` +docker run --rm -it \ + --name reachy \ + --privileged \ + --network host \ + --ipc host \ + --device-cgroup-rule='c 189:* rwm' \ + --group-add audio \ + -e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \ + -e DISPLAY="$DISPLAY" \ + -e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \ + -e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \ + -v /dev:/dev \ + -v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \ + -v "$HOME/.reachy.log":/home/reachy/.ros/log \ + -v /usr/lib/x86_64-linux-gnu:/opt/host-libs \ + --entrypoint /package/launch.sh \ + pollenrobotics/reachy2_core:1.7.5.9_deploy \ + start_rviz:=true start_sdk_server:=true mujoco:=true +``` + +> If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance: +> +> ``` +> docker run --rm -it \ +> --name reachy \ +> --privileged \ +> --network host \ +> --ipc host \ +> --device-cgroup-rule='c 189:* rwm' \ +> --group-add audio \ +> -e ROS_DOMAIN_ID="$ROS_DOMAIN_ID" \ +> -e DISPLAY="$DISPLAY" \ +> -e RCUTILS_CONSOLE_OUTPUT_FORMAT="[{severity}]: {message}" \ +> -e REACHY2_CORE_SERVICE_FAKE="${REACHY2_CORE_SERVICE_FAKE:-true}" \ +> -e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \ +> -v /dev:/dev \ +> -v "$HOME/.reachy_config":/home/reachy/.reachy_config_override \ +> -v "$HOME/.reachy.log":/home/reachy/.ros/log \ +> -v /usr/lib/x86_64-linux-gnu:/opt/host-libs \ +> --entrypoint /package/launch.sh \ +> pollenrobotics/reachy2_core:1.7.5.9_deploy \ +> start_rviz:=true start_sdk_server:=true mujoco:=true +> ``` + +## Setup + +### Prerequisites + +- On your robot, check the **service images** meet the minimum versions: + - **reachy2-core >= 1.7.5.2** + - **webrtc >= 2.0.1.1** + +Then, if you want to use VR teleoperation: + +- Install the [Reachy 2 teleoperation application](https://docs.pollen-robotics.com/teleoperation/teleoperation-introduction/discover-teleoperation/). + Use version **>=v1.2.0** + +We recommend using two computers: one for teleoperation (Windows required) and another for recording with LeRobot. + +### Install LeRobot + +Follow the [installation instructions](https://github.com/huggingface/lerobot#installation) to install LeRobot. + +Install LeRobot with Reachy 2 dependencies: + +```bash +pip install -e ".[reachy2]" +``` + +### (Optional but recommended) Install pollen_data_acquisition_server + +How you manage Reachy 2 recording sessions is up to you, but the **easiest** way is to use this server so you can control sessions directly from the VR teleoperation app. + +> **Note:** Currently, only the VR teleoperation application works as a client for this server, so this step primarily targets teleoperation. You’re free to develop custom clients to manage sessions to your needs. + +In your LeRobot environment, install the server from source: + +```bash +git clone https://github.com/pollen-robotics/pollen_data_acquisition_server.git +cd pollen_data_acquisition_server +pip install -e . +``` + +Find the [pollen_data_acquisition_server documentation here](https://github.com/pollen-robotics/pollen_data_acquisition_server). + +## Step 1: Recording + +### Get Reachy 2 IP address + +Before starting teleoperation and data recording, find the [robot's IP address](https://docs.pollen-robotics.com/getting-started/setup-reachy2/connect-reachy2/). +We strongly recommend connecting all devices (PC and robot) via **Ethernet**. + +### Launch recording + +There are two ways to manage recording sessions when using the Reachy 2 VR teleoperation application: + +- **Using the data acquisition server (recommended for VR teleop)**: The VR app orchestrates sessions (via the server it tells LeRobot when to create datasets, start/stop episodes) while also controlling the robot’s motions. +- **Using LeRobot’s record script**: LeRobot owns session control and decides when to start/stop episodes. If you also use the VR teleop app, it’s only for motion control. + +### Option 1: Using Pollen data acquisition server (recommended for VR teleop) + +Make sure you have installed pollen_data_acquisition_server, as explained in the Setup section. + +Launch the data acquisition server to be able to manage your session directly from the teleoperation application: + +```bash +python -m pollen_data_acquisition_server.server +``` + +Then get into the teleoperation application and choose "Data acquisition session". +You can finally setup your session by following the screens displayed. + +> Even without the VR app, you can use the `pollen_data_acquisition_server` with your own client implementation. + +### Option 2: Using lerobot.record + +Reachy 2 is fully supported by LeRobot’s recording features. +If you choose this option but still want to use the VR teleoperation application, select "Standard session" in the app. + +**Example: start a recording without the mobile base:** +First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command: + +```bash +python -m lerobot.record \ + --robot.type=reachy2 \ + --robot.ip_address=192.168.0.200 \ + --robot.id=r2-0000 \ + --robot.use_external_commands=true \ + --robot.with_mobile_base=false \ + --teleop.type=reachy2_teleoperator \ + --teleop.ip_address=192.168.0.200 \ + --teleop.with_mobile_base=false \ + --dataset.repo_id=pollen_robotics/record_test \ + --dataset.single_task="Reachy 2 recording test" \ + --dataset.num_episodes=1 \ + --dataset.episode_time_s=5 \ + --dataset.fps=15 \ + --dataset.push_to_hub=true \ + --dataset.private=true \ + --display_data=true +``` + +#### Specific Options + +**Extended setup overview (all options included):** + +```bash +python -m lerobot.record \ + --robot.type=reachy2 \ + --robot.ip_address=192.168.0.200 \ + --robot.use_external_commands=true \ + --robot.with_mobile_base=true \ + --robot.with_l_arm=true \ + --robot.with_r_arm=true \ + --robot.with_neck=true \ + --robot.with_antennas=true \ + --robot.with_left_teleop_camera=true \ + --robot.with_right_teleop_camera=true \ + --robot.with_torso_camera=false \ + --robot.disable_torque_on_disconnect=false \ + --robot.max_relative_target=5.0 \ + --teleop.type=reachy2_teleoperator \ + --teleop.ip_address=192.168.0.200 \ + --teleop.use_present_position=false \ + --teleop.with_mobile_base=false \ + --teleop.with_l_arm=true \ + --teleop.with_r_arm=true \ + --teleop.with_neck=true \ + --teleop.with_antennas=true \ + --dataset.repo_id=pollen_robotics/record_test \ + --dataset.single_task="Reachy 2 recording test" \ + --dataset.num_episodes=1 \ + --dataset.episode_time_s=5 \ + --dataset.fps=15 \ + --dataset.push_to_hub=true \ + --dataset.private=true \ + --display_data=true +``` + +##### `--robot.use_external_commands` + +Determine whether LeRobot robot.send_action() sends commands to the robot. +**Must** be set to false while using the VR teleoperation application, as the app already sends commands. + +##### `--teleop.use_present_position` + +Determine whether the teleoperator reads the goal or present position of the robot. +Must be set to true if a compliant Reachy 2 is used to control another one. + +##### Use the relevant parts + +From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies. +To avoid this, you can exclude specific parts from recording and replay using: + +```` +--robot.with_=false +```, +with `` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`. +It determine whether the corresponding part is recorded in the observations. True if not set. + +By default, **all parts are recorded**. + +The same per-part mechanism is available in `reachy2_teleoperator` as well. + +```` + +--teleop.with\_ + +``` +with `` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`. +Determine whether the corresponding part is recorded in the actions. True if not set. + +> **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator. +For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`. + +##### Use the relevant cameras + +You can do the same for **cameras**. By default, only the **teleoperation cameras** are recorded (both `left_teleop_camera` and `right_teleop_camera`). Enable or disable each camera with: + +``` + +--robot.with_left_teleop_camera= +--robot.with_right_teleop_camera= +--robot.with_torso_camera= + +```` + + +## Step 2: Replay + +Make sure the robot is configured with the same parts as the dataset: + +```bash +python -m lerobot.replay \ + --robot.type=reachy2 \ + --robot.ip_address=192.168.0.200 \ + --robot.use_external_commands=false \ + --robot.with_mobile_base=false \ + --dataset.repo_id=pollen_robotics/record_test \ + --dataset.episode=0 + --display_data=true +```` + +## Step 3: Train + +```bash +python -m lerobot.scripts.train \ + --dataset.repo_id=pollen_robotics/record_test \ + --policy.type=act \ + --output_dir=outputs/train/reachy2_test \ + --job_name=reachy2 \ + --policy.device=mps \ + --wandb.enable=true \ + --policy.repo_id=pollen_robotics/record_test_policy +``` + +## Step 4: Evaluate + +```bash +python -m lerobot.record \ + --robot.type=reachy2 \ + --robot.ip_address=192.168.0.200 \ + --display_data=false \ + --dataset.repo_id=pollen_robotics/eval_record_test \ + --dataset.single_task="Evaluate reachy2 policy" \ + --dataset.num_episodes=10 \ + --policy.path=outputs/train/reachy2_test/checkpoints/last/pretrained_model +``` diff --git a/pyproject.toml b/pyproject.toml index 2bc57c076..50cd207e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31"] gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0"] hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"] lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1"] +reachy2 = ["reachy2_sdk>=1.0.14"] kinematics = ["lerobot[placo-dep]"] intelrealsense = [ "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'", @@ -141,6 +142,7 @@ all = [ "lerobot[gamepad]", "lerobot[hopejr]", "lerobot[lekiwi]", + "lerobot[reachy2]", "lerobot[kinematics]", "lerobot[intelrealsense]", "lerobot[pi0]", diff --git a/src/lerobot/cameras/reachy2_camera/__init__.py b/src/lerobot/cameras/reachy2_camera/__init__.py new file mode 100644 index 000000000..72e45f32a --- /dev/null +++ b/src/lerobot/cameras/reachy2_camera/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_reachy2_camera import Reachy2CameraConfig +from .reachy2_camera import Reachy2Camera diff --git a/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py new file mode 100644 index 000000000..5b2303ff2 --- /dev/null +++ b/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py @@ -0,0 +1,78 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..configs import CameraConfig, ColorMode + + +@CameraConfig.register_subclass("reachy2_camera") +@dataclass +class Reachy2CameraConfig(CameraConfig): + """Configuration class for Reachy 2 camera devices. + + This class provides configuration options for Reachy 2 cameras, + supporting both the teleop and depth cameras. It includes settings + for resolution, frame rate, color mode, and the selection of the cameras. + + Example configurations: + ```python + # Basic configurations + Reachy2CameraConfig( + name="teleop", + image_type="left", + ip_address="192.168.0.200", # IP address of the robot + fps=15, + width=640, + height=480, + color_mode=ColorMode.RGB, + ) # Left teleop camera, 640x480 @ 15FPS + ``` + + Attributes: + name: Name of the camera device. Can be "teleop" or "depth". + image_type: Type of image stream. For "teleop" camera, can be "left" or "right". + For "depth" camera, can be "rgb" or "depth". (depth is not supported yet) + fps: Requested frames per second for the color stream. + width: Requested frame width in pixels for the color stream. + height: Requested frame height in pixels for the color stream. + color_mode: Color mode for image output (RGB or BGR). Defaults to RGB. + ip_address: IP address of the robot. Defaults to "localhost". + port: Port number for the camera server. Defaults to 50065. + + Note: + - Only 3-channel color output (RGB/BGR) is currently supported. + """ + + name: str + image_type: str + color_mode: ColorMode = ColorMode.RGB + ip_address: str | None = "localhost" + port: int = 50065 + # use_depth: bool = False + + def __post_init__(self): + if self.name not in ["teleop", "depth"]: + raise ValueError(f"`name` is expected to be 'teleop' or 'depth', but {self.name} is provided.") + if (self.name == "teleop" and self.image_type not in ["left", "right"]) or ( + self.name == "depth" and self.image_type not in ["rgb", "depth"] + ): + raise ValueError( + f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided." + ) + + if self.color_mode not in ["rgb", "bgr"]: + raise ValueError( + f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." + ) diff --git a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py new file mode 100644 index 000000000..0daeb6bbb --- /dev/null +++ b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py @@ -0,0 +1,288 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager. +""" + +import logging +import os +import platform +import time +from threading import Event, Lock, Thread +from typing import Any + +# Fix MSMF hardware transform compatibility for Windows before importing cv2 +if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" not in os.environ: + os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0" +import cv2 +import numpy as np +from reachy2_sdk.media.camera import CameraView +from reachy2_sdk.media.camera_manager import CameraManager + +from lerobot.errors import DeviceNotConnectedError + +from ..camera import Camera +from .configuration_reachy2_camera import ColorMode, Reachy2CameraConfig + +logger = logging.getLogger(__name__) + + +class Reachy2Camera(Camera): + """ + Manages Reachy 2 camera using Reachy 2 CameraManager. + + This class provides a high-level interface to connect to, configure, and read + frames from Reachy 2 cameras. It supports both synchronous and asynchronous + frame reading. + + An Reachy2Camera instance requires a camera name (e.g., "teleop") and an image + type (e.g., "left") to be specified in the configuration. + + The camera's default settings (FPS, resolution, color mode) are used unless + overridden in the configuration. + """ + + def __init__(self, config: Reachy2CameraConfig): + """ + Initializes the Reachy2Camera instance. + + Args: + config: The configuration settings for the camera. + """ + super().__init__(config) + + self.config = config + + self.fps = config.fps + self.color_mode = config.color_mode + + self.cam_manager: CameraManager | None = None + + self.thread: Thread | None = None + self.stop_event: Event | None = None + self.frame_lock: Lock = Lock() + self.latest_frame: np.ndarray | None = None + self.new_frame_event: Event = Event() + + def __str__(self) -> str: + return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})" + + @property + def is_connected(self) -> bool: + """Checks if the camera is currently connected and opened.""" + if self.config.name == "teleop": + return self.cam_manager._grpc_connected and self.cam_manager.teleop if self.cam_manager else False + elif self.config.name == "depth": + return self.cam_manager._grpc_connected and self.cam_manager.depth if self.cam_manager else False + else: + raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.") + + def connect(self, warmup: bool = True): + """ + Connects to the Reachy2 CameraManager as specified in the configuration. + """ + self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port) + self.cam_manager.initialize_cameras() + + logger.info(f"{self} connected.") + + @staticmethod + def find_cameras(ip_address: str = "localhost", port: int = 50065) -> list[dict[str, Any]]: + """ + Detects available Reachy 2 cameras. + + Returns: + List[Dict[str, Any]]: A list of dictionaries, + where each dictionary contains 'name', 'stereo', + and the default profile properties (width, height, fps). + """ + initialized_cameras = [] + camera_manager = CameraManager(host=ip_address, port=port) + + for camera in [camera_manager.teleop, camera_manager.depth]: + if camera is None: + continue + + height, width, _, _, _, _, _ = camera.get_parameters() + + camera_info = { + "name": camera._cam_info.name, + "stereo": camera._cam_info.stereo, + "default_profile": { + "width": width, + "height": height, + "fps": 30, + }, + } + initialized_cameras.append(camera_info) + + camera_manager.disconnect() + return initialized_cameras + + def read(self, color_mode: ColorMode | None = None) -> np.ndarray: + """ + Reads a single frame synchronously from the camera. + + This is a blocking call. + + Args: + color_mode (Optional[ColorMode]): If specified, overrides the default + color mode (`self.color_mode`) for this read operation (e.g., + request RGB even if default is BGR). + + Returns: + np.ndarray: The captured frame as a NumPy array in the format + (height, width, channels), using the specified or default + color mode and applying any configured rotation. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + start_time = time.perf_counter() + + frame = None + + if self.cam_manager is None: + raise DeviceNotConnectedError(f"{self} is not connected.") + else: + if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"): + if self.config.image_type == "left": + frame = self.cam_manager.teleop.get_frame(CameraView.LEFT, size=(640, 480))[0] + elif self.config.image_type == "right": + frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT, size=(640, 480))[0] + elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"): + if self.config.image_type == "depth": + frame = self.cam_manager.depth.get_depth_frame()[0] + elif self.config.image_type == "rgb": + frame = self.cam_manager.depth.get_frame(size=(640, 480))[0] + + if frame is None: + return np.empty((0, 0, 3), dtype=np.uint8) + + if self.config.color_mode == "rgb": + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + read_duration_ms = (time.perf_counter() - start_time) * 1e3 + logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") + + return frame + + def _read_loop(self): + """ + Internal loop run by the background thread for asynchronous reading. + + On each iteration: + 1. Reads a color frame + 2. Stores result in latest_frame (thread-safe) + 3. Sets new_frame_event to notify listeners + + Stops on DeviceNotConnectedError, logs other errors and continues. + """ + while not self.stop_event.is_set(): + try: + color_image = self.read() + + with self.frame_lock: + self.latest_frame = color_image + self.new_frame_event.set() + + except DeviceNotConnectedError: + break + except Exception as e: + logger.warning(f"Error reading frame in background thread for {self}: {e}") + + def _start_read_thread(self) -> None: + """Starts or restarts the background read thread if it's not running.""" + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=0.1) + if self.stop_event is not None: + self.stop_event.set() + + self.stop_event = Event() + self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop") + self.thread.daemon = True + self.thread.start() + + def _stop_read_thread(self) -> None: + """Signals the background read thread to stop and waits for it to join.""" + if self.stop_event is not None: + self.stop_event.set() + + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=2.0) + + self.thread = None + self.stop_event = None + + def async_read(self, timeout_ms: float = 200) -> np.ndarray: + """ + Reads the latest available frame asynchronously. + + This method retrieves the most recent frame captured by the background + read thread. It does not block waiting for the camera hardware directly, + but may wait up to timeout_ms for the background thread to provide a frame. + + Args: + timeout_ms (float): Maximum time in milliseconds to wait for a frame + to become available. Defaults to 200ms (0.2 seconds). + + Returns: + np.ndarray: The latest captured frame as a NumPy array in the format + (height, width, channels), processed according to configuration. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + TimeoutError: If no frame becomes available within the specified timeout. + RuntimeError: If an unexpected error occurs. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + self._start_read_thread() + + if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): + thread_alive = self.thread is not None and self.thread.is_alive() + raise TimeoutError( + f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. " + f"Read thread alive: {thread_alive}." + ) + + with self.frame_lock: + frame = self.latest_frame + self.new_frame_event.clear() + + if frame is None: + raise RuntimeError(f"Internal error: Event set but no frame available for {self}.") + + return frame + + def disconnect(self): + """ + Stops the background read thread (if running). + + Raises: + DeviceNotConnectedError: If the camera is already disconnected. + """ + if not self.is_connected and self.thread is None: + raise DeviceNotConnectedError(f"{self} not connected.") + + if self.thread is not None: + self._stop_read_thread() + + if self.cam_manager is not None: + self.cam_manager.disconnect() + + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/cameras/utils.py b/src/lerobot/cameras/utils.py index 1eb69840b..dfac33e17 100644 --- a/src/lerobot/cameras/utils.py +++ b/src/lerobot/cameras/utils.py @@ -37,8 +37,14 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s from .realsense.camera_realsense import RealSenseCamera cameras[key] = RealSenseCamera(cfg) + + elif cfg.type == "reachy2_camera": + from .reachy2_camera.reachy2_camera import Reachy2Camera + + cameras[key] = Reachy2Camera(cfg) + else: - raise ValueError(f"The motor type '{cfg.type}' is not valid.") + raise ValueError(f"The camera type '{cfg.type}' is not valid.") return cameras diff --git a/src/lerobot/record.py b/src/lerobot/record.py index 09fa33fe3..8eee59558 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -67,7 +67,6 @@ from lerobot.cameras import ( # noqa: F401 CameraConfig, # noqa: F401 ) from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 -from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 from lerobot.configs import parser from lerobot.configs.policies import PreTrainedConfig from lerobot.datasets.image_writer import safe_stop_image_writer @@ -209,7 +208,14 @@ def record_loop( ( t for t in teleop - if isinstance(t, (so100_leader.SO100Leader, so101_leader.SO101Leader, koch_leader.KochLeader)) + if isinstance( + t, + ( + so100_leader.SO100Leader, + so101_leader.SO101Leader, + koch_leader.KochLeader, + ), + ) ), None, ) diff --git a/src/lerobot/replay.py b/src/lerobot/replay.py index 2b62fd67f..603aa93ea 100644 --- a/src/lerobot/replay.py +++ b/src/lerobot/replay.py @@ -55,6 +55,7 @@ from lerobot.robots import ( # noqa: F401 hope_jr, koch_follower, make_robot_from_config, + reachy2, so100_follower, so101_follower, ) diff --git a/src/lerobot/robots/reachy2/__init__.py b/src/lerobot/robots/reachy2/__init__.py new file mode 100644 index 000000000..1a38fd03b --- /dev/null +++ b/src/lerobot/robots/reachy2/__init__.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_reachy2 import Reachy2RobotConfig +from .robot_reachy2 import ( + REACHY2_ANTENNAS_JOINTS, + REACHY2_L_ARM_JOINTS, + REACHY2_NECK_JOINTS, + REACHY2_R_ARM_JOINTS, + REACHY2_VEL, + Reachy2Robot, +) diff --git a/src/lerobot/robots/reachy2/configuration_reachy2.py b/src/lerobot/robots/reachy2/configuration_reachy2.py new file mode 100644 index 000000000..aa25351c6 --- /dev/null +++ b/src/lerobot/robots/reachy2/configuration_reachy2.py @@ -0,0 +1,107 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.cameras import CameraConfig +from lerobot.cameras.configs import ColorMode +from lerobot.cameras.reachy2_camera import Reachy2CameraConfig + +from ..config import RobotConfig + + +@RobotConfig.register_subclass("reachy2") +@dataclass +class Reachy2RobotConfig(RobotConfig): + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors. + max_relative_target: float | None = None + + # IP address of the Reachy 2 robot + ip_address: str | None = "localhost" + + # If True, turn_off_smoothly() will be sent to the robot before disconnecting. + disable_torque_on_disconnect: bool = False + + # Tag for external commands control + # Set to True if you use an external commands system to control the robot, + # such as the official teleoperation application: https://github.com/pollen-robotics/Reachy2Teleoperation + # If True, robot.send_action() will not send commands to the robot. + use_external_commands: bool = False + + # Robot parts + # Set to False to not add the corresponding joints part to the robot list of joints. + # By default, all parts are set to True. + with_mobile_base: bool = True + with_l_arm: bool = True + with_r_arm: bool = True + with_neck: bool = True + with_antennas: bool = True + + # Robot cameras + # Set to True if you want to use the corresponding cameras in the observations. + # By default, only the teleop cameras are used. + with_left_teleop_camera: bool = True + with_right_teleop_camera: bool = True + with_torso_camera: bool = False + + cameras: dict[str, CameraConfig] = field(default_factory=dict) + + def __post_init__(self) -> None: + # Add cameras with same ip_address as the robot + if self.with_left_teleop_camera: + self.cameras["teleop_left"] = Reachy2CameraConfig( + name="teleop", + image_type="left", + ip_address=self.ip_address, + fps=15, + width=640, + height=480, + color_mode=ColorMode.RGB, + ) + if self.with_right_teleop_camera: + self.cameras["teleop_right"] = Reachy2CameraConfig( + name="teleop", + image_type="right", + ip_address=self.ip_address, + fps=15, + width=640, + height=480, + color_mode=ColorMode.RGB, + ) + if self.with_torso_camera: + self.cameras["torso_rgb"] = Reachy2CameraConfig( + name="depth", + image_type="rgb", + ip_address=self.ip_address, + fps=15, + width=640, + height=480, + color_mode=ColorMode.RGB, + ) + + super().__post_init__() + + if not ( + self.with_mobile_base + or self.with_l_arm + or self.with_r_arm + or self.with_neck + or self.with_antennas + ): + raise ValueError( + "No Reachy2Robot part used.\n" + "At least one part of the robot must be set to True " + "(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)" + ) diff --git a/src/lerobot/robots/reachy2/robot_reachy2.py b/src/lerobot/robots/reachy2/robot_reachy2.py new file mode 100644 index 000000000..ecc488a79 --- /dev/null +++ b/src/lerobot/robots/reachy2/robot_reachy2.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import Any + +import numpy as np +from reachy2_sdk import ReachySDK + +from lerobot.cameras.utils import make_cameras_from_configs + +from ..robot import Robot +from ..utils import ensure_safe_goal_position +from .configuration_reachy2 import Reachy2RobotConfig + +# {lerobot_keys: reachy2_sdk_keys} +REACHY2_NECK_JOINTS = { + "neck_yaw.pos": "head.neck.yaw", + "neck_pitch.pos": "head.neck.pitch", + "neck_roll.pos": "head.neck.roll", +} + +REACHY2_ANTENNAS_JOINTS = { + "l_antenna.pos": "head.l_antenna", + "r_antenna.pos": "head.r_antenna", +} + +REACHY2_R_ARM_JOINTS = { + "r_shoulder_pitch.pos": "r_arm.shoulder.pitch", + "r_shoulder_roll.pos": "r_arm.shoulder.roll", + "r_elbow_yaw.pos": "r_arm.elbow.yaw", + "r_elbow_pitch.pos": "r_arm.elbow.pitch", + "r_wrist_roll.pos": "r_arm.wrist.roll", + "r_wrist_pitch.pos": "r_arm.wrist.pitch", + "r_wrist_yaw.pos": "r_arm.wrist.yaw", + "r_gripper.pos": "r_arm.gripper", +} + +REACHY2_L_ARM_JOINTS = { + "l_shoulder_pitch.pos": "l_arm.shoulder.pitch", + "l_shoulder_roll.pos": "l_arm.shoulder.roll", + "l_elbow_yaw.pos": "l_arm.elbow.yaw", + "l_elbow_pitch.pos": "l_arm.elbow.pitch", + "l_wrist_roll.pos": "l_arm.wrist.roll", + "l_wrist_pitch.pos": "l_arm.wrist.pitch", + "l_wrist_yaw.pos": "l_arm.wrist.yaw", + "l_gripper.pos": "l_arm.gripper", +} + +REACHY2_VEL = { + "mobile_base.vx": "vx", + "mobile_base.vy": "vy", + "mobile_base.vtheta": "vtheta", +} + + +class Reachy2Robot(Robot): + """ + [Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics. + """ + + config_class = Reachy2RobotConfig + name = "reachy2" + + def __init__(self, config: Reachy2RobotConfig): + super().__init__(config) + + self.config = config + self.robot_type = self.config.type + self.use_external_commands = self.config.use_external_commands + + self.reachy: None | ReachySDK = None + self.cameras = make_cameras_from_configs(config.cameras) + + self.logs: dict[str, float] = {} + + self.joints_dict: dict[str, str] = self._generate_joints_dict() + + @property + def observation_features(self) -> dict[str, Any]: + return {**self.motors_features, **self.camera_features} + + @property + def action_features(self) -> dict[str, type]: + return self.motors_features + + @property + def camera_features(self) -> dict[str, tuple[int | None, int | None, int]]: + return {cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras} + + @property + def motors_features(self) -> dict[str, type]: + if self.config.with_mobile_base: + return { + **dict.fromkeys( + self.joints_dict.keys(), + float, + ), + **dict.fromkeys( + REACHY2_VEL.keys(), + float, + ), + } + else: + return dict.fromkeys(self.joints_dict.keys(), float) + + @property + def is_connected(self) -> bool: + return self.reachy.is_connected() if self.reachy is not None else False + + def connect(self, calibrate: bool = False) -> None: + self.reachy = ReachySDK(self.config.ip_address) + if not self.is_connected: + raise ConnectionError() + + for cam in self.cameras.values(): + cam.connect() + + self.configure() + + def configure(self) -> None: + if self.reachy is not None: + self.reachy.turn_on() + self.reachy.reset_default_limits() + + @property + def is_calibrated(self) -> bool: + return True + + def calibrate(self) -> None: + pass + + def _generate_joints_dict(self) -> dict[str, str]: + joints = {} + if self.config.with_neck: + joints.update(REACHY2_NECK_JOINTS) + if self.config.with_l_arm: + joints.update(REACHY2_L_ARM_JOINTS) + if self.config.with_r_arm: + joints.update(REACHY2_R_ARM_JOINTS) + if self.config.with_antennas: + joints.update(REACHY2_ANTENNAS_JOINTS) + return joints + + def _get_state(self) -> dict[str, float]: + if self.reachy is not None: + pos_dict = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()} + if not self.config.with_mobile_base: + return pos_dict + vel_dict = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()} + return {**pos_dict, **vel_dict} + else: + return {} + + def get_observation(self) -> dict[str, np.ndarray]: + obs_dict: dict[str, Any] = {} + + # Read Reachy 2 state + before_read_t = time.perf_counter() + obs_dict.update(self._get_state()) + self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + obs_dict[cam_key] = cam.async_read() + + return obs_dict + + def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + if self.reachy is not None: + if not self.is_connected: + raise ConnectionError() + + before_write_t = time.perf_counter() + + vel = {} + goal_pos = {} + for key, val in action.items(): + if key not in self.joints_dict: + if key not in REACHY2_VEL: + raise KeyError(f"Key '{key}' is not a valid motor key in Reachy 2.") + else: + vel[REACHY2_VEL[key]] = float(val) + else: + if not self.use_external_commands and self.config.max_relative_target is not None: + goal_pos[key] = float(val) + goal_present_pos = { + key: ( + goal_pos[key], + self.reachy.joints[self.joints_dict[key]].present_position, + ) + } + safe_goal_pos = ensure_safe_goal_position( + goal_present_pos, float(self.config.max_relative_target) + ) + val = safe_goal_pos[key] + self.reachy.joints[self.joints_dict[key]].goal_position = float(val) + + if self.config.with_mobile_base: + self.reachy.mobile_base.set_goal_speed(vel["vx"], vel["vy"], vel["vtheta"]) + + # We don't send the goal positions if we control Reachy 2 externally + if not self.use_external_commands: + self.reachy.send_goal_positions() + if self.config.with_mobile_base: + self.reachy.mobile_base.send_speed_command() + + self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t + return action + + def disconnect(self) -> None: + if self.reachy is not None: + for cam in self.cameras.values(): + cam.disconnect() + if self.config.disable_torque_on_disconnect: + self.reachy.turn_off_smoothly() + self.reachy.disconnect() diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index befd96424..261e59a32 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -61,6 +61,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .bi_so100_follower import BiSO100Follower return BiSO100Follower(config) + elif config.type == "reachy2": + from .reachy2 import Reachy2Robot + + return Reachy2Robot(config) elif config.type == "mock_robot": from tests.mocks.mock_robot import MockRobot diff --git a/src/lerobot/teleoperators/reachy2_teleoperator/__init__.py b/src/lerobot/teleoperators/reachy2_teleoperator/__init__.py new file mode 100644 index 000000000..a07a4a6cd --- /dev/null +++ b/src/lerobot/teleoperators/reachy2_teleoperator/__init__.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig +from .reachy2_teleoperator import ( + REACHY2_ANTENNAS_JOINTS, + REACHY2_L_ARM_JOINTS, + REACHY2_NECK_JOINTS, + REACHY2_R_ARM_JOINTS, + REACHY2_VEL, + Reachy2Teleoperator, +) diff --git a/src/lerobot/teleoperators/reachy2_teleoperator/config_reachy2_teleoperator.py b/src/lerobot/teleoperators/reachy2_teleoperator/config_reachy2_teleoperator.py new file mode 100644 index 000000000..4e615d363 --- /dev/null +++ b/src/lerobot/teleoperators/reachy2_teleoperator/config_reachy2_teleoperator.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("reachy2_teleoperator") +@dataclass +class Reachy2TeleoperatorConfig(TeleoperatorConfig): + # IP address of the Reachy 2 robot used as teleoperator + ip_address: str | None = "localhost" + + # Whether to use the present position of the joints as actions + # if False, the goal position of the joints will be used + use_present_position: bool = False + + # Which parts of the robot to use + with_mobile_base: bool = True + with_l_arm: bool = True + with_r_arm: bool = True + with_neck: bool = True + with_antennas: bool = True + + def __post_init__(self): + if not ( + self.with_mobile_base + or self.with_l_arm + or self.with_r_arm + or self.with_neck + or self.with_antennas + ): + raise ValueError( + "No Reachy2Teleoperator part used.\n" + "At least one part of the robot must be set to True " + "(with_mobile_base, with_l_arm, with_r_arm, with_neck, with_antennas)" + ) diff --git a/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py b/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py new file mode 100644 index 000000000..5a427dd71 --- /dev/null +++ b/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +from reachy2_sdk import ReachySDK + +from ..teleoperator import Teleoperator +from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig + +logger = logging.getLogger(__name__) + +# {lerobot_keys: reachy2_sdk_keys} +REACHY2_NECK_JOINTS = { + "neck_yaw.pos": "head.neck.yaw", + "neck_pitch.pos": "head.neck.pitch", + "neck_roll.pos": "head.neck.roll", +} + +REACHY2_ANTENNAS_JOINTS = { + "l_antenna.pos": "head.l_antenna", + "r_antenna.pos": "head.r_antenna", +} + +REACHY2_R_ARM_JOINTS = { + "r_shoulder_pitch.pos": "r_arm.shoulder.pitch", + "r_shoulder_roll.pos": "r_arm.shoulder.roll", + "r_elbow_yaw.pos": "r_arm.elbow.yaw", + "r_elbow_pitch.pos": "r_arm.elbow.pitch", + "r_wrist_roll.pos": "r_arm.wrist.roll", + "r_wrist_pitch.pos": "r_arm.wrist.pitch", + "r_wrist_yaw.pos": "r_arm.wrist.yaw", + "r_gripper.pos": "r_arm.gripper", +} + +REACHY2_L_ARM_JOINTS = { + "l_shoulder_pitch.pos": "l_arm.shoulder.pitch", + "l_shoulder_roll.pos": "l_arm.shoulder.roll", + "l_elbow_yaw.pos": "l_arm.elbow.yaw", + "l_elbow_pitch.pos": "l_arm.elbow.pitch", + "l_wrist_roll.pos": "l_arm.wrist.roll", + "l_wrist_pitch.pos": "l_arm.wrist.pitch", + "l_wrist_yaw.pos": "l_arm.wrist.yaw", + "l_gripper.pos": "l_arm.gripper", +} + +REACHY2_VEL = { + "mobile_base.vx": "vx", + "mobile_base.vy": "vy", + "mobile_base.vtheta": "vtheta", +} + + +class Reachy2Teleoperator(Teleoperator): + """ + [Reachy 2](https://www.pollen-robotics.com/reachy/), by Pollen Robotics. + """ + + config_class = Reachy2TeleoperatorConfig + name = "reachy2_specific" + + def __init__(self, config: Reachy2TeleoperatorConfig): + super().__init__(config) + self.config = config + self.reachy: None | ReachySDK = None + + self.joints_dict: dict[str, str] = self._generate_joints_dict() + + def _generate_joints_dict(self) -> dict[str, str]: + joints = {} + if self.config.with_neck: + joints.update(REACHY2_NECK_JOINTS) + if self.config.with_l_arm: + joints.update(REACHY2_L_ARM_JOINTS) + if self.config.with_r_arm: + joints.update(REACHY2_R_ARM_JOINTS) + if self.config.with_antennas: + joints.update(REACHY2_ANTENNAS_JOINTS) + return joints + + @property + def action_features(self) -> dict[str, type]: + if self.config.with_mobile_base: + return { + **dict.fromkeys( + self.joints_dict.keys(), + float, + ), + **dict.fromkeys( + REACHY2_VEL.keys(), + float, + ), + } + else: + return dict.fromkeys(self.joints_dict.keys(), float) + + @property + def feedback_features(self) -> dict[str, type]: + return {} + + @property + def is_connected(self) -> bool: + return self.reachy.is_connected() if self.reachy is not None else False + + def connect(self, calibrate: bool = True) -> None: + self.reachy = ReachySDK(self.config.ip_address) + if not self.is_connected: + raise ConnectionError() + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return True + + def calibrate(self) -> None: + pass + + def configure(self) -> None: + pass + + def get_action(self) -> dict[str, float]: + start = time.perf_counter() + + if self.reachy and self.is_connected: + if self.config.use_present_position: + joint_action = { + k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items() + } + else: + joint_action = {k: self.reachy.joints[v].goal_position for k, v in self.joints_dict.items()} + + if not self.config.with_mobile_base: + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read action: {dt_ms:.1f}ms") + return joint_action + + if self.config.use_present_position: + vel_action = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()} + else: + vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()} + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read action: {dt_ms:.1f}ms") + return {**joint_action, **vel_action} + + def send_feedback(self, feedback: dict[str, float]) -> None: + raise NotImplementedError + + def disconnect(self) -> None: + if self.reachy and self.is_connected: + self.reachy.disconnect() diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index 344a95d72..02e6fd22c 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -65,5 +65,9 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: from .bi_so100_leader import BiSO100Leader return BiSO100Leader(config) + elif config.type == "reachy2_teleoperator": + from .reachy2_teleoperator import Reachy2Teleoperator + + return Reachy2Teleoperator(config) else: raise ValueError(config.type) diff --git a/tests/cameras/test_reachy2_camera.py b/tests/cameras/test_reachy2_camera.py new file mode 100644 index 000000000..66c7675a6 --- /dev/null +++ b/tests/cameras/test_reachy2_camera.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from lerobot.cameras.reachy2_camera import Reachy2Camera, Reachy2CameraConfig +from lerobot.errors import DeviceNotConnectedError + +PARAMS = [ + ("teleop", "left"), + ("teleop", "right"), + ("depth", "rgb"), + # ("depth", "depth"), # Depth camera is not available yet +] + + +def _make_cam_manager_mock(): + c = MagicMock(name="CameraManagerMock") + + teleop = MagicMock(name="TeleopCam") + teleop.width = 640 + teleop.height = 480 + teleop.get_frame = MagicMock( + side_effect=lambda *_, **__: ( + np.zeros((480, 640, 3), dtype=np.uint8), + time.time(), + ) + ) + + depth = MagicMock(name="DepthCam") + depth.width = 640 + depth.height = 480 + depth.get_frame = MagicMock( + side_effect=lambda *_, **__: ( + np.zeros((480, 640, 3), dtype=np.uint8), + time.time(), + ) + ) + + c.is_connected.return_value = True + c.teleop = teleop + c.depth = depth + + def _connect(): + c.teleop = teleop + c.depth = depth + c.is_connected.return_value = True + + def _disconnect(): + c.teleop = None + c.depth = None + c.is_connected.return_value = False + + c.connect = MagicMock(side_effect=_connect) + c.disconnect = MagicMock(side_effect=_disconnect) + + # Mock methods + c.initialize_cameras = MagicMock() + + return c + + +@pytest.fixture( + params=PARAMS, + # ids=["teleop-left", "teleop-right", "torso-rgb", "torso-depth"], + ids=["teleop-left", "teleop-right", "torso-rgb"], +) +def camera(request): + name, image_type = request.param + with ( + patch( + "lerobot.cameras.reachy2_camera.reachy2_camera.CameraManager", + side_effect=lambda *a, **k: _make_cam_manager_mock(), + ), + ): + config = Reachy2CameraConfig(name=name, image_type=image_type) + cam = Reachy2Camera(config) + yield cam + if cam.is_connected: + cam.disconnect() + + +def test_connect(camera): + camera.connect() + assert camera.is_connected + camera.cam_manager.initialize_cameras.assert_called_once() + + +def test_read(camera): + camera.connect() + + img = camera.read() + if camera.config.name == "teleop": + camera.cam_manager.teleop.get_frame.assert_called_once() + elif camera.config.name == "depth": + camera.cam_manager.depth.get_frame.assert_called_once() + assert isinstance(img, np.ndarray) + assert img.shape == (480, 640, 3) + + +def test_disconnect(camera): + camera.connect() + + camera.disconnect() + assert not camera.is_connected + + +def test_async_read(camera): + camera.connect() + try: + img = camera.async_read() + + assert camera.thread is not None + assert camera.thread.is_alive() + assert isinstance(img, np.ndarray) + finally: + if camera.is_connected: + camera.disconnect() + + +def test_async_read_timeout(camera): + camera.connect() + try: + with pytest.raises(TimeoutError): + camera.async_read(timeout_ms=0) + finally: + if camera.is_connected: + camera.disconnect() + + +def test_read_before_connect(camera): + with pytest.raises(DeviceNotConnectedError): + _ = camera.read() + + +def test_disconnect_before_connect(camera): + with pytest.raises(DeviceNotConnectedError): + camera.disconnect() + + +def test_async_read_before_connect(camera): + with pytest.raises(DeviceNotConnectedError): + _ = camera.async_read() + + +def test_wrong_camera_name(): + with pytest.raises(ValueError): + _ = Reachy2CameraConfig(name="wrong-name", image_type="left") + + +def test_wrong_image_type(): + with pytest.raises(ValueError): + _ = Reachy2CameraConfig(name="teleop", image_type="rgb") + with pytest.raises(ValueError): + _ = Reachy2CameraConfig(name="depth", image_type="left") + + +def test_wrong_color_mode(): + with pytest.raises(ValueError): + _ = Reachy2CameraConfig(name="teleop", image_type="left", color_mode="wrong-color") diff --git a/tests/conftest.py b/tests/conftest.py index 7940cc5ba..e273da50f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,6 +28,7 @@ pytest_plugins = [ "tests.fixtures.files", "tests.fixtures.hub", "tests.fixtures.optimizers", + "tests.plugins.reachy2_sdk", ] diff --git a/tests/plugins/reachy2_sdk.py b/tests/plugins/reachy2_sdk.py new file mode 100644 index 000000000..f56b59efb --- /dev/null +++ b/tests/plugins/reachy2_sdk.py @@ -0,0 +1,30 @@ +import sys +import types +from unittest.mock import MagicMock + + +def _install_reachy2_sdk_stub(): + sdk = types.ModuleType("reachy2_sdk") + sdk.__path__ = [] + sdk.ReachySDK = MagicMock(name="ReachySDK") + + media = types.ModuleType("reachy2_sdk.media") + media.__path__ = [] + camera = types.ModuleType("reachy2_sdk.media.camera") + camera.CameraView = MagicMock(name="CameraView") + camera_manager = types.ModuleType("reachy2_sdk.media.camera_manager") + camera_manager.CameraManager = MagicMock(name="CameraManager") + + sdk.media = media + media.camera = camera + media.camera_manager = camera_manager + + # Register in sys.modules + sys.modules.setdefault("reachy2_sdk", sdk) + sys.modules.setdefault("reachy2_sdk.media", media) + sys.modules.setdefault("reachy2_sdk.media.camera", camera) + sys.modules.setdefault("reachy2_sdk.media.camera_manager", camera_manager) + + +def pytest_sessionstart(session): + _install_reachy2_sdk_stub() diff --git a/tests/robots/test_reachy2.py b/tests/robots/test_reachy2.py new file mode 100644 index 000000000..c93fbeced --- /dev/null +++ b/tests/robots/test_reachy2.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from lerobot.robots.reachy2 import ( + REACHY2_ANTENNAS_JOINTS, + REACHY2_L_ARM_JOINTS, + REACHY2_NECK_JOINTS, + REACHY2_R_ARM_JOINTS, + REACHY2_VEL, + Reachy2Robot, + Reachy2RobotConfig, +) + +# {lerobot_keys: reachy2_sdk_keys} +REACHY2_JOINTS = { + **REACHY2_NECK_JOINTS, + **REACHY2_ANTENNAS_JOINTS, + **REACHY2_R_ARM_JOINTS, + **REACHY2_L_ARM_JOINTS, +} + +PARAMS = [ + {}, # default config + {"with_mobile_base": False}, + {"with_mobile_base": False, "with_l_arm": False, "with_antennas": False}, + {"with_r_arm": False, "with_neck": False, "with_antennas": False}, + {"use_external_commands": True, "disable_torque_on_disconnect": True}, + {"use_external_commands": True, "with_mobile_base": False, "with_neck": False}, + {"disable_torque_on_disconnect": False}, + {"max_relative_target": 5}, + {"with_right_teleop_camera": False}, + {"with_left_teleop_camera": False, "with_right_teleop_camera": False}, + {"with_left_teleop_camera": False, "with_torso_camera": True}, +] + + +def _make_reachy2_sdk_mock(): + class JointSpy: + __slots__ = ( + "present_position", + "_goal_position", + "_on_set", + ) + + def __init__(self, present_position=0.0, on_set=None): + self.present_position = present_position + self._goal_position = present_position + self._on_set = on_set + + @property + def goal_position(self): + return self._goal_position + + @goal_position.setter + def goal_position(self, v): + self._goal_position = v + if self._on_set: + self._on_set() + + r = MagicMock(name="ReachySDKMock") + r.is_connected.return_value = True + + def _connect(): + r.is_connected.return_value = True + + def _disconnect(): + r.is_connected.return_value = False + + # Global counter of goal_position sets + r._goal_position_set_total = 0 + + def _on_any_goal_set(): + r._goal_position_set_total += 1 + + # Mock joints with some dummy positions + joints = { + k: JointSpy( + present_position=float(i), + on_set=_on_any_goal_set, + ) + for i, k in enumerate(REACHY2_JOINTS.values()) + } + r.joints = joints + + # Mock mobile base with some dummy odometry + r.mobile_base = MagicMock() + r.mobile_base.odometry = { + "x": 0.1, + "y": -0.2, + "theta": 21.3, + "vx": 0.001, + "vy": 0.002, + "vtheta": 0.0, + } + + r.connect = MagicMock(side_effect=_connect) + r.disconnect = MagicMock(side_effect=_disconnect) + + # Mock methods + r.turn_on = MagicMock() + r.reset_default_limits = MagicMock() + r.send_goal_positions = MagicMock() + r.turn_off_smoothly = MagicMock() + r.mobile_base.set_goal_speed = MagicMock() + r.mobile_base.send_speed_command = MagicMock() + + return r + + +def _make_reachy2_camera_mock(*args, **kwargs): + cfg = args[0] if args else kwargs.get("config") + name = getattr(cfg, "name", kwargs.get("name", "cam")) + image_type = getattr(cfg, "image_type", kwargs.get("image_type", "cam")) + width = getattr(cfg, "width", kwargs.get("width", 640)) + height = getattr(cfg, "height", kwargs.get("height", 480)) + + cam = MagicMock(name=f"Reachy2CameraMock:{name}") + cam.name = name + cam.image_type = image_type + cam.width = width + cam.height = height + cam.connect = MagicMock() + cam.disconnect = MagicMock() + cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8)) + return cam + + +@pytest.fixture(params=PARAMS, ids=lambda p: "default" if not p else ",".join(p.keys())) +def reachy2(request): + with ( + patch( + "lerobot.robots.reachy2.robot_reachy2.ReachySDK", + side_effect=lambda *a, **k: _make_reachy2_sdk_mock(), + ), + patch( + "lerobot.cameras.reachy2_camera.reachy2_camera.Reachy2Camera", + side_effect=_make_reachy2_camera_mock, + ), + ): + overrides = request.param + cfg = Reachy2RobotConfig(ip_address="192.168.0.200", **overrides) + robot = Reachy2Robot(cfg) + yield robot + if robot.is_connected: + robot.disconnect() + + +def test_connect_disconnect(reachy2): + assert not reachy2.is_connected + + reachy2.connect() + assert reachy2.is_connected + + reachy2.reachy.turn_on.assert_called_once() + reachy2.reachy.reset_default_limits.assert_called_once() + + reachy2.disconnect() + assert not reachy2.is_connected + + if reachy2.config.disable_torque_on_disconnect: + reachy2.reachy.turn_off_smoothly.assert_called_once() + else: + reachy2.reachy.turn_off_smoothly.assert_not_called() + reachy2.reachy.disconnect.assert_called_once() + + +def test_get_joints_dict(reachy2): + reachy2.connect() + + if reachy2.config.with_neck: + assert "neck_yaw.pos" in reachy2.joints_dict + assert "neck_pitch.pos" in reachy2.joints_dict + assert "neck_roll.pos" in reachy2.joints_dict + else: + assert "neck_yaw.pos" not in reachy2.joints_dict + assert "neck_pitch.pos" not in reachy2.joints_dict + assert "neck_roll.pos" not in reachy2.joints_dict + + if reachy2.config.with_antennas: + assert "l_antenna.pos" in reachy2.joints_dict + assert "r_antenna.pos" in reachy2.joints_dict + else: + assert "l_antenna.pos" not in reachy2.joints_dict + assert "r_antenna.pos" not in reachy2.joints_dict + + if reachy2.config.with_r_arm: + assert "r_shoulder_pitch.pos" in reachy2.joints_dict + assert "r_shoulder_roll.pos" in reachy2.joints_dict + assert "r_elbow_yaw.pos" in reachy2.joints_dict + assert "r_elbow_pitch.pos" in reachy2.joints_dict + assert "r_wrist_roll.pos" in reachy2.joints_dict + assert "r_wrist_pitch.pos" in reachy2.joints_dict + assert "r_wrist_yaw.pos" in reachy2.joints_dict + assert "r_gripper.pos" in reachy2.joints_dict + else: + assert "r_shoulder_pitch.pos" not in reachy2.joints_dict + assert "r_shoulder_roll.pos" not in reachy2.joints_dict + assert "r_elbow_yaw.pos" not in reachy2.joints_dict + assert "r_elbow_pitch.pos" not in reachy2.joints_dict + assert "r_wrist_roll.pos" not in reachy2.joints_dict + assert "r_wrist_pitch.pos" not in reachy2.joints_dict + assert "r_wrist_yaw.pos" not in reachy2.joints_dict + assert "r_gripper.pos" not in reachy2.joints_dict + + if reachy2.config.with_l_arm: + assert "l_shoulder_pitch.pos" in reachy2.joints_dict + assert "l_shoulder_roll.pos" in reachy2.joints_dict + assert "l_elbow_yaw.pos" in reachy2.joints_dict + assert "l_elbow_pitch.pos" in reachy2.joints_dict + assert "l_wrist_roll.pos" in reachy2.joints_dict + assert "l_wrist_pitch.pos" in reachy2.joints_dict + assert "l_wrist_yaw.pos" in reachy2.joints_dict + assert "l_gripper.pos" in reachy2.joints_dict + else: + assert "l_shoulder_pitch.pos" not in reachy2.joints_dict + assert "l_shoulder_roll.pos" not in reachy2.joints_dict + assert "l_elbow_yaw.pos" not in reachy2.joints_dict + assert "l_elbow_pitch.pos" not in reachy2.joints_dict + assert "l_wrist_roll.pos" not in reachy2.joints_dict + assert "l_wrist_pitch.pos" not in reachy2.joints_dict + assert "l_wrist_yaw.pos" not in reachy2.joints_dict + assert "l_gripper.pos" not in reachy2.joints_dict + + +def test_get_observation(reachy2): + reachy2.connect() + obs = reachy2.get_observation() + + expected_keys = set(reachy2.joints_dict) + expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base) + expected_keys.update(reachy2.cameras.keys()) + assert set(obs.keys()) == expected_keys + + for motor in reachy2.joints_dict.keys(): + assert obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position + if reachy2.config.with_mobile_base: + for vel in REACHY2_VEL.keys(): + assert obs[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]] + if reachy2.config.with_left_teleop_camera: + assert obs["teleop_left"].shape == ( + reachy2.config.cameras["teleop_left"].height, + reachy2.config.cameras["teleop_left"].width, + 3, + ) + if reachy2.config.with_right_teleop_camera: + assert obs["teleop_right"].shape == ( + reachy2.config.cameras["teleop_right"].height, + reachy2.config.cameras["teleop_right"].width, + 3, + ) + if reachy2.config.with_torso_camera: + assert obs["torso_rgb"].shape == ( + reachy2.config.cameras["torso_rgb"].height, + reachy2.config.cameras["torso_rgb"].width, + 3, + ) + + +def test_send_action(reachy2): + reachy2.connect() + + action = {k: i * 10.0 for i, k in enumerate(reachy2.joints_dict.keys(), start=1)} + if reachy2.config.with_mobile_base: + action.update({k: i * 0.1 for i, k in enumerate(REACHY2_VEL.keys(), start=1)}) + + previous_present_position = { + k: reachy2.reachy.joints[REACHY2_JOINTS[k]].present_position for k in reachy2.joints_dict.keys() + } + returned = reachy2.send_action(action) + + if reachy2.config.max_relative_target is None: + assert returned == action + + assert reachy2.reachy._goal_position_set_total == len(reachy2.joints_dict) + for motor in reachy2.joints_dict.keys(): + expected_pos = action[motor] + real_pos = reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position + if reachy2.config.max_relative_target is None: + assert real_pos == expected_pos + else: + assert real_pos == previous_present_position[motor] + np.sign(expected_pos) * min( + abs(expected_pos - real_pos), reachy2.config.max_relative_target + ) + + if reachy2.config.with_mobile_base: + goal_speed = [i * 0.1 for i, _ in enumerate(REACHY2_VEL.keys(), start=1)] + reachy2.reachy.mobile_base.set_goal_speed.assert_called_once_with(*goal_speed) + + if reachy2.config.use_external_commands: + reachy2.reachy.send_goal_positions.assert_not_called() + if reachy2.config.with_mobile_base: + reachy2.reachy.mobile_base.send_speed_command.assert_not_called() + else: + reachy2.reachy.send_goal_positions.assert_called_once() + if reachy2.config.with_mobile_base: + reachy2.reachy.mobile_base.send_speed_command.assert_called_once() + + +def test_no_part_declared(): + with pytest.raises(ValueError): + _ = Reachy2RobotConfig( + ip_address="192.168.0.200", + with_mobile_base=False, + with_l_arm=False, + with_r_arm=False, + with_neck=False, + with_antennas=False, + ) diff --git a/tests/teleoperators/test_reachy2_teleoperator.py b/tests/teleoperators/test_reachy2_teleoperator.py new file mode 100644 index 000000000..5130de87d --- /dev/null +++ b/tests/teleoperators/test_reachy2_teleoperator.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import pytest + +from lerobot.teleoperators.reachy2_teleoperator import ( + REACHY2_ANTENNAS_JOINTS, + REACHY2_L_ARM_JOINTS, + REACHY2_NECK_JOINTS, + REACHY2_R_ARM_JOINTS, + REACHY2_VEL, + Reachy2Teleoperator, + Reachy2TeleoperatorConfig, +) + +# {lerobot_keys: reachy2_sdk_keys} +REACHY2_JOINTS = { + **REACHY2_NECK_JOINTS, + **REACHY2_ANTENNAS_JOINTS, + **REACHY2_R_ARM_JOINTS, + **REACHY2_L_ARM_JOINTS, +} + +PARAMS = [ + {}, # default config + {"with_mobile_base": False}, + {"with_mobile_base": False, "with_l_arm": False, "with_antennas": False}, + {"with_r_arm": False, "with_neck": False, "with_antennas": False}, + {"with_mobile_base": False, "with_neck": False}, + {"use_present_position": True}, +] + + +def _make_reachy2_sdk_mock(): + r = MagicMock(name="ReachySDKMock") + r.is_connected.return_value = True + + def _connect(): + r.is_connected.return_value = True + + def _disconnect(): + r.is_connected.return_value = False + + # Mock joints with some dummy positions + joints = { + k: MagicMock( + present_position=float(i), + goal_position=float(i) + 0.5, + ) + for i, k in enumerate(REACHY2_JOINTS.values()) + } + r.joints = joints + + # Mock mobile base with some dummy odometry + r.mobile_base = MagicMock() + r.mobile_base.last_cmd_vel = { + "vx": -0.2, + "vy": 0.2, + "vtheta": 11.0, + } + r.mobile_base.odometry = { + "x": 1.0, + "y": 2.0, + "theta": 20.0, + "vx": 0.1, + "vy": -0.1, + "vtheta": 8.0, + } + + r.connect = MagicMock(side_effect=_connect) + r.disconnect = MagicMock(side_effect=_disconnect) + + return r + + +@pytest.fixture(params=PARAMS, ids=lambda p: "default" if not p else ",".join(p.keys())) +def reachy2(request): + with ( + patch( + "lerobot.teleoperators.reachy2_teleoperator.reachy2_teleoperator.ReachySDK", + side_effect=lambda *a, **k: _make_reachy2_sdk_mock(), + ), + ): + overrides = request.param + cfg = Reachy2TeleoperatorConfig(ip_address="192.168.0.200", **overrides) + robot = Reachy2Teleoperator(cfg) + yield robot + if robot.is_connected: + robot.disconnect() + + +def test_connect_disconnect(reachy2): + assert not reachy2.is_connected + + reachy2.connect() + assert reachy2.is_connected + + reachy2.disconnect() + assert not reachy2.is_connected + + reachy2.reachy.disconnect.assert_called_once() + + +def test_get_action(reachy2): + reachy2.connect() + action = reachy2.get_action() + + expected_keys = set(reachy2.joints_dict) + expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base) + assert set(action.keys()) == expected_keys + + for motor in reachy2.joints_dict.keys(): + if reachy2.config.use_present_position: + assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position + else: + assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position + if reachy2.config.with_mobile_base: + if reachy2.config.use_present_position: + for vel in REACHY2_VEL.keys(): + assert action[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]] + else: + for vel in REACHY2_VEL.keys(): + assert action[vel] == reachy2.reachy.mobile_base.last_cmd_vel[REACHY2_VEL[vel]] + + +def test_no_part_declared(): + with pytest.raises(ValueError): + _ = Reachy2TeleoperatorConfig( + ip_address="192.168.0.200", + with_mobile_base=False, + with_l_arm=False, + with_r_arm=False, + with_neck=False, + with_antennas=False, + ) From 49baccdccb2433d0c8b4c4b8610b4372dc725340 Mon Sep 17 00:00:00 2001 From: Steven Gong Date: Mon, 8 Sep 2025 02:38:13 -0700 Subject: [PATCH 083/158] Disable torque before applying calibration logic (#1889) --- src/lerobot/robots/koch_follower/koch_follower.py | 2 +- src/lerobot/teleoperators/koch_leader/koch_leader.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py index b09b9b8e2..563325b88 100644 --- a/src/lerobot/robots/koch_follower/koch_follower.py +++ b/src/lerobot/robots/koch_follower/koch_follower.py @@ -110,6 +110,7 @@ class KochFollower(Robot): return self.bus.is_calibrated def calibrate(self) -> None: + self.bus.disable_torque() if self.calibration: # Calibration file exists, ask user whether to use it or run new calibration user_input = input( @@ -120,7 +121,6 @@ class KochFollower(Robot): self.bus.write_calibration(self.calibration) return logger.info(f"\nRunning calibration of {self}") - self.bus.disable_torque() for motor in self.bus.motors: self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value) diff --git a/src/lerobot/teleoperators/koch_leader/koch_leader.py b/src/lerobot/teleoperators/koch_leader/koch_leader.py index e0318cca5..f703d5b6e 100644 --- a/src/lerobot/teleoperators/koch_leader/koch_leader.py +++ b/src/lerobot/teleoperators/koch_leader/koch_leader.py @@ -88,6 +88,7 @@ class KochLeader(Teleoperator): return self.bus.is_calibrated def calibrate(self) -> None: + self.bus.disable_torque() if self.calibration: # Calibration file exists, ask user whether to use it or run new calibration user_input = input( @@ -98,7 +99,6 @@ class KochLeader(Teleoperator): self.bus.write_calibration(self.calibration) return logger.info(f"\nRunning calibration of {self}") - self.bus.disable_torque() for motor in self.bus.motors: self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value) From d602e8169cbad9e93a4a3b3ee1dd8b332af7ebf8 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 8 Sep 2025 18:29:39 +0200 Subject: [PATCH 084/158] fix(scripts): revert deletion of rs cam config import introduced by #1767 (#1876) --- src/lerobot/record.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lerobot/record.py b/src/lerobot/record.py index 8eee59558..de397bb84 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -67,6 +67,7 @@ from lerobot.cameras import ( # noqa: F401 CameraConfig, # noqa: F401 ) from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 from lerobot.configs import parser from lerobot.configs.policies import PreTrainedConfig from lerobot.datasets.image_writer import safe_stop_image_writer From f55c6e89f01ad30594f8815a41720e6460656447 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 15 Sep 2025 09:53:30 +0200 Subject: [PATCH 085/158] Dataset v3 (#1412) Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Remi Cadene Co-authored-by: Tavish Co-authored-by: fracapuano Co-authored-by: CarolinePascal --- README.md | 42 +- benchmarks/video/run_video_benchmark.py | 6 +- docs/source/_toctree.yml | 2 + docs/source/porting_datasets_v3.mdx | 321 +++++++ examples/1_load_lerobot_dataset.py | 6 +- examples/port_datasets/display_error_files.py | 85 ++ examples/port_datasets/port_droid.py | 430 +++++++++ .../port_datasets/slurm_aggregate_shards.py | 148 +++ examples/port_datasets/slurm_port_shards.py | 162 ++++ examples/port_datasets/slurm_upload.py | 281 ++++++ pyproject.toml | 1 - src/lerobot/datasets/aggregate.py | 502 ++++++++++ .../datasets/backward_compatibility.py | 35 +- src/lerobot/datasets/lerobot_dataset.py | 728 ++++++++++----- src/lerobot/datasets/online_buffer.py | 8 +- src/lerobot/datasets/sampler.py | 12 +- src/lerobot/datasets/utils.py | 403 ++++---- .../v2/batch_convert_dataset_v1_to_v2.py | 884 ------------------ .../datasets/v2/convert_dataset_v1_to_v2.py | 687 -------------- .../v21/_remove_language_instruction.py | 87 -- .../v21/batch_convert_dataset_v20_to_v21.py | 54 -- .../v21/convert_dataset_v20_to_v21.py | 114 --- src/lerobot/datasets/v21/convert_stats.py | 99 -- .../v30/convert_dataset_v21_to_v30.py | 500 ++++++++++ src/lerobot/datasets/video_utils.py | 120 ++- src/lerobot/record.py | 4 +- src/lerobot/replay.py | 8 +- src/lerobot/robots/viperx/README.md | 6 +- src/lerobot/scripts/rl/crop_dataset_roi.py | 3 +- src/lerobot/scripts/rl/gym_manipulator.py | 3 +- src/lerobot/scripts/train.py | 3 +- src/lerobot/scripts/visualize_dataset.py | 6 +- src/lerobot/scripts/visualize_dataset_html.py | 482 ---------- .../templates/visualize_dataset_homepage.html | 68 -- .../templates/visualize_dataset_template.html | 546 ----------- src/lerobot/utils/buffer.py | 14 +- src/lerobot/utils/utils.py | 10 + .../datasets/save_dataset_to_safetensors.py | 26 +- tests/datasets/test_aggregate.py | 292 ++++++ tests/datasets/test_datasets.py | 658 +++++++++++-- tests/datasets/test_delta_timestamps.py | 140 --- tests/datasets/test_sampler.py | 12 +- tests/datasets/test_utils.py | 33 +- tests/fixtures/constants.py | 4 +- tests/fixtures/dataset_factories.py | 304 ++++-- tests/fixtures/files.py | 231 +++-- tests/fixtures/hub.py | 134 +-- tests/policies/test_policies.py | 7 +- tests/test_control_robot.py | 21 +- tests/utils/test_replay_buffer.py | 2 +- 50 files changed, 4642 insertions(+), 4092 deletions(-) create mode 100644 docs/source/porting_datasets_v3.mdx create mode 100644 examples/port_datasets/display_error_files.py create mode 100644 examples/port_datasets/port_droid.py create mode 100644 examples/port_datasets/slurm_aggregate_shards.py create mode 100644 examples/port_datasets/slurm_port_shards.py create mode 100644 examples/port_datasets/slurm_upload.py create mode 100644 src/lerobot/datasets/aggregate.py delete mode 100644 src/lerobot/datasets/v2/batch_convert_dataset_v1_to_v2.py delete mode 100644 src/lerobot/datasets/v2/convert_dataset_v1_to_v2.py delete mode 100644 src/lerobot/datasets/v21/_remove_language_instruction.py delete mode 100644 src/lerobot/datasets/v21/batch_convert_dataset_v20_to_v21.py delete mode 100644 src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py delete mode 100644 src/lerobot/datasets/v21/convert_stats.py create mode 100644 src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py delete mode 100644 src/lerobot/scripts/visualize_dataset_html.py delete mode 100644 src/lerobot/templates/visualize_dataset_homepage.html delete mode 100644 src/lerobot/templates/visualize_dataset_template.html create mode 100644 tests/datasets/test_aggregate.py diff --git a/README.md b/README.md index b5e666aa8..9fd45a7b7 100644 --- a/README.md +++ b/README.md @@ -233,7 +233,7 @@ Under the hood, the `LeRobotDataset` format makes use of several ways to seriali Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects: -``` +```` dataset attributes: ├ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example: │ ├ observation.images.cam_high (VideoFrame): @@ -246,20 +246,30 @@ dataset attributes: │ ├ timestamp (float32): timestamp in the episode │ ├ next.done (bool): indicates the end of an episode ; True for the last frame in each episode │ └ index (int64): general index in the whole dataset - ├ episode_data_index: contains 2 tensors with the start and end indices of each episode - │ ├ from (1D int64 tensor): first frame index for each episode — shape (num episodes,) starts with 0 - │ └ to: (1D int64 tensor): last frame index for each episode — shape (num episodes,) - ├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance - │ ├ observation.images.cam_high: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.} - │ ... - ├ info: a dictionary of metadata on the dataset - │ ├ codebase_version (str): this is to keep track of the codebase version the dataset was created with - │ ├ fps (float): frame per second the dataset is recorded/synchronized to - │ ├ video (bool): indicates if frames are encoded in mp4 video files to save space or stored as png files - │ └ encoding (dict): if video, this documents the main options that were used with ffmpeg to encode the videos - ├ videos_dir (Path): where the mp4 videos or png images are stored/accessed - └ camera_keys (list of string): the keys to access camera features in the item returned by the dataset (e.g. `["observation.images.cam_high", ...]`) -``` + ├ meta: a LeRobotDatasetMetadata object containing: + │ ├ info: a dictionary of metadata on the dataset + │ │ ├ codebase_version (str): this is to keep track of the codebase version the dataset was created with + │ │ ├ fps (int): frame per second the dataset is recorded/synchronized to + │ │ ├ features (dict): all features contained in the dataset with their shapes and types + │ │ ├ total_episodes (int): total number of episodes in the dataset + │ │ ├ total_frames (int): total number of frames in the dataset + │ │ ├ robot_type (str): robot type used for recording + │ │ ├ data_path (str): formattable string for the parquet files + │ │ └ video_path (str): formattable string for the video files (if using videos) + │ ├ episodes: a DataFrame containing episode metadata with columns: + │ │ ├ episode_index (int): index of the episode + │ │ ├ tasks (list): list of tasks for this episode + │ │ ├ length (int): number of frames in this episode + │ │ ├ dataset_from_index (int): start index of this episode in the dataset + │ │ └ dataset_to_index (int): end index of this episode in the dataset + │ ├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance + │ │ ├ observation.images.front_cam: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.} + │ │ └ ... + │ └ tasks: a DataFrame containing task information with task names as index and task_index as values + ├ root (Path): local directory where the dataset is stored + ├ image_transforms (Callable): optional image transformations to apply to visual modalities + └ delta_timestamps (dict): optional delta timestamps for temporal queries +decoding videos (e.g., 'pyav', 'torchcodec') A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely: @@ -283,7 +293,7 @@ lerobot-eval \ --eval.n_episodes=10 \ --policy.use_amp=false \ --policy.device=cuda -``` +```` Note: After training your own policy, you can re-evaluate the checkpoints with: diff --git a/benchmarks/video/run_video_benchmark.py b/benchmarks/video/run_video_benchmark.py index bababf636..5472551f5 100644 --- a/benchmarks/video/run_video_benchmark.py +++ b/benchmarks/video/run_video_benchmark.py @@ -108,7 +108,8 @@ def save_decoded_frames( def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None: - ep_num_images = dataset.episode_data_index["to"][0].item() + episode_index = 0 + ep_num_images = dataset.meta.episodes["length"][episode_index] if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images: return @@ -265,7 +266,8 @@ def benchmark_encoding_decoding( overwrite=True, ) - ep_num_images = dataset.episode_data_index["to"][0].item() + episode_index = 0 + ep_num_images = dataset.meta.episodes["length"][episode_index] width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:]) num_pixels = width * height video_size_bytes = video_path.stat().st_size diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 1a4558f93..5f5a509c7 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -19,6 +19,8 @@ title: Train RL in Simulation - local: async title: Use Async Inference + - local: porting_datasets_v3 + title: Porting Large Datasets title: "Tutorials" - sections: - local: smolvla diff --git a/docs/source/porting_datasets_v3.mdx b/docs/source/porting_datasets_v3.mdx new file mode 100644 index 000000000..46793265e --- /dev/null +++ b/docs/source/porting_datasets_v3.mdx @@ -0,0 +1,321 @@ +# Porting Large Datasets to LeRobot Dataset v3.0 + +This tutorial explains how to port large-scale robotic datasets to the LeRobot Dataset v3.0 format. We'll use the **DROID 1.0.1** dataset as our primary example, which demonstrates handling multi-terabyte datasets with thousands of shards across SLURM clusters. + +## File Organization: v2.1 vs v3.0 + +Dataset v3.0 fundamentally changes how data is organized and stored: + +**v2.1 Structure (Episode-based)**: + +``` +dataset/ +├── data/chunk-000/episode_000000.parquet +├── data/chunk-000/episode_000001.parquet +├── videos/chunk-000/camera/episode_000000.mp4 +└── meta/episodes.jsonl +``` + +**v3.0 Structure (File-based)**: + +``` +dataset/ +├── data/chunk-000/file-000.parquet # Multiple episodes per file +├── videos/camera/chunk-000/file-000.mp4 # Consolidated video chunks +└── meta/episodes/chunk-000/file-000.parquet # Structured metadata +``` + +This transition from individual episode files to file-based chunks dramatically improves performance and reduces storage overhead. + +## What's New in Dataset v3.0 + +Dataset v3.0 introduces significant improvements for handling large datasets: + +### 🏗️ **Enhanced File Organization** + +- **File-based structure**: Episodes are now grouped into chunked files rather than individual episode files +- **Configurable file sizes**: for data and video files +- **Improved storage efficiency**: Better compression and reduced overhead + +### 📊 **Modern Metadata Management** + +- **Parquet-based metadata**: Replaced JSON Lines with efficient parquet format +- **Structured episode access**: Direct pandas DataFrame access via `dataset.meta.episodes` +- **Per-episode statistics**: Enhanced statistics tracking at episode level + +### 🚀 **Performance Enhancements** + +- **Memory-mapped access**: Improved RAM usage through PyArrow memory mapping +- **Faster loading**: Significantly reduced dataset initialization time +- **Better scalability**: Designed for datasets with millions of episodes + +## Prerequisites + +Before porting large datasets, ensure you have: + +- **LeRobot installed** with v3.0 support. Follow our [Installation Guide](./installation). +- **Sufficient storage**: Raw datasets can be very large (e.g., DROID requires 2TB) +- **Cluster access** (recommended for large datasets): SLURM or similar job scheduler +- **Dataset-specific dependencies**: For DROID, you'll need TensorFlow Dataset utilities + +## Understanding the DROID Dataset + +[DROID 1.0.1](https://droid-dataset.github.io/droid/the-droid-dataset) is an excellent example of a large-scale robotic dataset: + +- **Size**: 1.7TB (RLDS format), 8.7TB (raw data) +- **Structure**: 2048 pre-defined TensorFlow dataset shards +- **Content**: 76,000+ robot manipulation trajectories from Franka Emika Panda robots +- **Scope**: Real-world manipulation tasks across multiple environments and objects +- **Format**: Originally in TensorFlow Records/RLDS format, requiring conversion to LeRobot format +- **Hosting**: Google Cloud Storage with public access via `gsutil` + +The dataset contains diverse manipulation demonstrations with: + +- Multiple camera views (wrist camera, exterior cameras) +- Natural language task descriptions +- Robot proprioceptive state and actions +- Success/failure annotations + +### DROID Features Schema + +```python +DROID_FEATURES = { + # Episode markers + "is_first": {"dtype": "bool", "shape": (1,)}, + "is_last": {"dtype": "bool", "shape": (1,)}, + "is_terminal": {"dtype": "bool", "shape": (1,)}, + + # Language instructions + "language_instruction": {"dtype": "string", "shape": (1,)}, + "language_instruction_2": {"dtype": "string", "shape": (1,)}, + "language_instruction_3": {"dtype": "string", "shape": (1,)}, + + # Robot state + "observation.state.gripper_position": {"dtype": "float32", "shape": (1,)}, + "observation.state.cartesian_position": {"dtype": "float32", "shape": (6,)}, + "observation.state.joint_position": {"dtype": "float32", "shape": (7,)}, + + # Camera observations + "observation.images.wrist_left": {"dtype": "image"}, + "observation.images.exterior_1_left": {"dtype": "image"}, + "observation.images.exterior_2_left": {"dtype": "image"}, + + # Actions + "action.gripper_position": {"dtype": "float32", "shape": (1,)}, + "action.cartesian_position": {"dtype": "float32", "shape": (6,)}, + "action.joint_position": {"dtype": "float32", "shape": (7,)}, + + # Standard LeRobot format + "observation.state": {"dtype": "float32", "shape": (8,)}, # joints + gripper + "action": {"dtype": "float32", "shape": (8,)}, # joints + gripper +} +``` + +## Approach 1: Single Computer Porting + +### Step 1: Install Dependencies + +For DROID specifically: + +```bash +pip install tensorflow +pip install tensorflow_datasets +``` + +For other datasets, install the appropriate readers for your source format. + +### Step 2: Download Raw Data + +Download DROID from Google Cloud Storage using `gsutil`: + +```bash +# Install Google Cloud SDK if not already installed +# https://cloud.google.com/sdk/docs/install + +# Download the full RLDS dataset (1.7TB) +gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 /your/data/ + +# Or download just the 100-episode sample (2GB) for testing +gsutil -m cp -r gs://gresearch/robotics/droid_100 /your/data/ +``` + +> [!WARNING] +> Large datasets require substantial time and storage: +> +> - **Full DROID (1.7TB)**: Several days to download depending on bandwidth +> - **Processing time**: 7+ days for local porting of full dataset +> - **Upload time**: 3+ days to push to Hugging Face Hub +> - **Local storage**: ~400GB for processed LeRobot format + +### Step 3: Port the Dataset + +```bash +python examples/port_datasets/port_droid.py \ + --raw-dir /your/data/droid/1.0.1 \ + --repo-id your_id/droid_1.0.1 \ + --push-to-hub +``` + +### Development and Testing + +For development, you can port a single shard: + +```bash +python examples/port_datasets/port_droid.py \ + --raw-dir /your/data/droid/1.0.1 \ + --repo-id your_id/droid_1.0.1_test \ + --num-shards 2048 \ + --shard-index 0 +``` + +This approach works for smaller datasets or testing, but large datasets require cluster computing. + +## Approach 2: SLURM Cluster Porting (Recommended) + +For large datasets like DROID, parallel processing across multiple nodes dramatically reduces processing time. + +### Step 1: Install Cluster Dependencies + +```bash +pip install datatrove # Hugging Face's distributed processing library +``` + +### Step 2: Configure Your SLURM Environment + +Find your partition information: + +```bash +sinfo --format="%R" # List available partitions +sinfo -N -p your_partition -h -o "%N cpus=%c mem=%m" # Check resources +``` + +Choose a **CPU partition** - no GPU needed for dataset porting. + +### Step 3: Launch Parallel Porting Jobs + +```bash +python examples/port_datasets/slurm_port_shards.py \ + --raw-dir /your/data/droid/1.0.1 \ + --repo-id your_id/droid_1.0.1 \ + --logs-dir /your/logs \ + --job-name port_droid \ + --partition your_partition \ + --workers 2048 \ + --cpus-per-task 8 \ + --mem-per-cpu 1950M +``` + +#### Parameter Guidelines + +- **`--workers`**: Number of parallel jobs (max 2048 for DROID's shard count) +- **`--cpus-per-task`**: 8 CPUs recommended for frame encoding parallelization +- **`--mem-per-cpu`**: ~16GB total RAM (8×1950M) for loading raw frames + +> [!TIP] +> Start with fewer workers (e.g., 100) to test your cluster configuration before launching thousands of jobs. + +### Step 4: Monitor Progress + +Check running jobs: + +```bash +squeue -u $USER +``` + +Monitor overall progress: + +```bash +jobs_status /your/logs +``` + +Inspect individual job logs: + +```bash +less /your/logs/port_droid/slurm_jobs/JOB_ID_WORKER_ID.out +``` + +Debug failed jobs: + +```bash +failed_logs /your/logs/port_droid +``` + +### Step 5: Aggregate Shards + +Once all porting jobs complete: + +```bash +python examples/port_datasets/slurm_aggregate_shards.py \ + --repo-id your_id/droid_1.0.1 \ + --logs-dir /your/logs \ + --job-name aggr_droid \ + --partition your_partition \ + --workers 2048 \ + --cpus-per-task 8 \ + --mem-per-cpu 1950M +``` + +### Step 6: Upload to Hub + +```bash +python examples/port_datasets/slurm_upload.py \ + --repo-id your_id/droid_1.0.1 \ + --logs-dir /your/logs \ + --job-name upload_droid \ + --partition your_partition \ + --workers 50 \ + --cpus-per-task 4 \ + --mem-per-cpu 1950M +``` + +> [!NOTE] +> Upload uses fewer workers (50) since it's network-bound rather than compute-bound. + +## Dataset v3.0 File Structure + +Your completed dataset will have this modern structure: + +``` +dataset/ +├── meta/ +│ ├── episodes/ +│ │ └── chunk-000/ +│ │ └── file-000.parquet # Episode metadata +│ ├── tasks.parquet # Task definitions +│ ├── stats.json # Aggregated statistics +│ └── info.json # Dataset information +├── data/ +│ └── chunk-000/ +│ └── file-000.parquet # Consolidated episode data +└── videos/ + └── camera_key/ + └── chunk-000/ + └── file-000.mp4 # Consolidated video files +``` + +This replaces the old episode-per-file structure with efficient, optimally-sized chunks. + +## Migrating from Dataset v2.1 + +If you have existing datasets in v2.1 format, use the migration tool: + +```bash +python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \ + --repo-id your_id/existing_dataset +``` + +This automatically: + +- Converts file structure to v3.0 format +- Migrates metadata from JSON Lines to parquet +- Aggregates statistics and creates per-episode stats +- Updates version information + +## Performance Benefits + +Dataset v3.0 provides significant improvements for large datasets: + +- **Faster loading**: 3-5x reduction in initialization time +- **Memory efficiency**: Better RAM usage through memory mapping +- **Scalable processing**: Handles millions of episodes efficiently +- **Storage optimization**: Reduced file count and improved compression diff --git a/examples/1_load_lerobot_dataset.py b/examples/1_load_lerobot_dataset.py index 3d357dd19..ac4a843c7 100644 --- a/examples/1_load_lerobot_dataset.py +++ b/examples/1_load_lerobot_dataset.py @@ -92,11 +92,11 @@ print(dataset.hf_dataset) # LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working # with the latter, like iterating through the dataset. # The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by -# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access +# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access # frame indices associated to the first episode: episode_index = 0 -from_idx = dataset.episode_data_index["from"][episode_index].item() -to_idx = dataset.episode_data_index["to"][episode_index].item() +from_idx = dataset.meta.episodes["dataset_from_index"][episode_index] +to_idx = dataset.meta.episodes["dataset_to_index"][episode_index] # Then we grab all the image frames from the first camera: camera_key = dataset.meta.camera_keys[0] diff --git a/examples/port_datasets/display_error_files.py b/examples/port_datasets/display_error_files.py new file mode 100644 index 000000000..fffab5ff3 --- /dev/null +++ b/examples/port_datasets/display_error_files.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +from pathlib import Path + + +def find_missing_workers(completions_dir, world_size): + """Find workers that are not completed and returns their indices.""" + full = list(range(world_size)) + + completed = [] + for path in completions_dir.glob("*"): + if path.name in [".", ".."]: + continue + index = path.name.lstrip("0") + index = 0 if index == "" else int(index) + completed.append(index) + + missing_workers = set(full) - set(completed) + return missing_workers + + +def find_output_files(slurm_dir, worker_indices): + """Find output files associated to worker indices, and return tuples + of (worker index, output file path) + """ + out_files = [] + for path in slurm_dir.glob("*.out"): + _, worker_id = path.name.replace(".out", "").split("_") + worker_id = int(worker_id) + if worker_id in worker_indices: + out_files.append((worker_id, path)) + return out_files + + +def display_error_files(logs_dir, job_name): + executor_path = Path(logs_dir) / job_name / "executor.json" + completions_dir = Path(logs_dir) / job_name / "completions" + + with open(executor_path) as f: + executor = json.load(f) + + missing_workers = find_missing_workers(completions_dir, executor["world_size"]) + + for missing in sorted(missing_workers)[::-1]: + print(missing) + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--logs-dir", + type=str, + help="Path to logs directory for `datatrove`.", + ) + parser.add_argument( + "--job-name", + type=str, + default="port_droid", + help="Job name used in slurm, and name of the directory created inside the provided logs directory.", + ) + + args = parser.parse_args() + + display_error_files(**vars(args)) + + +if __name__ == "__main__": + main() diff --git a/examples/port_datasets/port_droid.py b/examples/port_datasets/port_droid.py new file mode 100644 index 000000000..4efb131e4 --- /dev/null +++ b/examples/port_datasets/port_droid.py @@ -0,0 +1,430 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import time +from pathlib import Path + +import numpy as np +import tensorflow_datasets as tfds + +from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds + +DROID_SHARDS = 2048 +DROID_FPS = 15 +DROID_ROBOT_TYPE = "Franka" + +# Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema +DROID_FEATURES = { + # true on first step of the episode + "is_first": { + "dtype": "bool", + "shape": (1,), + "names": None, + }, + # true on last step of the episode + "is_last": { + "dtype": "bool", + "shape": (1,), + "names": None, + }, + # true on last step of the episode if it is a terminal step, True for demos + "is_terminal": { + "dtype": "bool", + "shape": (1,), + "names": None, + }, + # language_instruction is also stored as "task" to follow LeRobot standard + "language_instruction": { + "dtype": "string", + "shape": (1,), + "names": None, + }, + "language_instruction_2": { + "dtype": "string", + "shape": (1,), + "names": None, + }, + "language_instruction_3": { + "dtype": "string", + "shape": (1,), + "names": None, + }, + "observation.state.gripper_position": { + "dtype": "float32", + "shape": (1,), + "names": { + "axes": ["gripper"], + }, + }, + "observation.state.cartesian_position": { + "dtype": "float32", + "shape": (6,), + "names": { + "axes": ["x", "y", "z", "roll", "pitch", "yaw"], + }, + }, + "observation.state.joint_position": { + "dtype": "float32", + "shape": (7,), + "names": { + "axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"], + }, + }, + # Add this new feature to follow LeRobot standard of using joint position + gripper + "observation.state": { + "dtype": "float32", + "shape": (8,), + "names": { + "axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "gripper"], + }, + }, + # Initially called wrist_image_left + "observation.images.wrist_left": { + "dtype": "video", + "shape": (180, 320, 3), + "names": [ + "height", + "width", + "channels", + ], + }, + # Initially called exterior_image_1_left + "observation.images.exterior_1_left": { + "dtype": "video", + "shape": (180, 320, 3), + "names": [ + "height", + "width", + "channels", + ], + }, + # Initially called exterior_image_2_left + "observation.images.exterior_2_left": { + "dtype": "video", + "shape": (180, 320, 3), + "names": [ + "height", + "width", + "channels", + ], + }, + "action.gripper_position": { + "dtype": "float32", + "shape": (1,), + "names": { + "axes": ["gripper"], + }, + }, + "action.gripper_velocity": { + "dtype": "float32", + "shape": (1,), + "names": { + "axes": ["gripper"], + }, + }, + "action.cartesian_position": { + "dtype": "float32", + "shape": (6,), + "names": { + "axes": ["x", "y", "z", "roll", "pitch", "yaw"], + }, + }, + "action.cartesian_velocity": { + "dtype": "float32", + "shape": (6,), + "names": { + "axes": ["x", "y", "z", "roll", "pitch", "yaw"], + }, + }, + "action.joint_position": { + "dtype": "float32", + "shape": (7,), + "names": { + "axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"], + }, + }, + "action.joint_velocity": { + "dtype": "float32", + "shape": (7,), + "names": { + "axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"], + }, + }, + # This feature was called "action" in RLDS dataset and consists of [6x joint velocities, 1x gripper position] + "action.original": { + "dtype": "float32", + "shape": (7,), + "names": { + "axes": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"], + }, + }, + # Add this new feature to follow LeRobot standard of using joint position + gripper + "action": { + "dtype": "float32", + "shape": (8,), + "names": { + "axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "gripper"], + }, + }, + "discount": { + "dtype": "float32", + "shape": (1,), + "names": None, + }, + "reward": { + "dtype": "float32", + "shape": (1,), + "names": None, + }, + # Meta data that are the same for all frames in the episode + "task_category": { + "dtype": "string", + "shape": (1,), + "names": None, + }, + "building": { + "dtype": "string", + "shape": (1,), + "names": None, + }, + "collector_id": { + "dtype": "string", + "shape": (1,), + "names": None, + }, + "date": { + "dtype": "string", + "shape": (1,), + "names": None, + }, + "camera_extrinsics.wrist_left": { + "dtype": "float32", + "shape": (6,), + "names": { + "axes": ["x", "y", "z", "roll", "pitch", "yaw"], + }, + }, + "camera_extrinsics.exterior_1_left": { + "dtype": "float32", + "shape": (6,), + "names": { + "axes": ["x", "y", "z", "roll", "pitch", "yaw"], + }, + }, + "camera_extrinsics.exterior_2_left": { + "dtype": "float32", + "shape": (6,), + "names": { + "axes": ["x", "y", "z", "roll", "pitch", "yaw"], + }, + }, + "is_episode_successful": { + "dtype": "bool", + "shape": (1,), + "names": None, + }, +} + + +def is_episode_successful(tf_episode_metadata): + # Adapted from: https://github.com/droid-dataset/droid_policy_learning/blob/dd1020eb20d981f90b5ff07dc80d80d5c0cb108b/robomimic/utils/rlds_utils.py#L8 + return "/success/" in tf_episode_metadata["file_path"].numpy().decode() + + +def generate_lerobot_frames(tf_episode): + m = tf_episode["episode_metadata"] + frame_meta = { + "task_category": m["building"].numpy().decode(), + "building": m["building"].numpy().decode(), + "collector_id": m["collector_id"].numpy().decode(), + "date": m["date"].numpy().decode(), + "camera_extrinsics.wrist_left": m["extrinsics_wrist_cam"].numpy(), + "camera_extrinsics.exterior_1_left": m["extrinsics_exterior_cam_1"].numpy(), + "camera_extrinsics.exterior_2_left": m["extrinsics_exterior_cam_2"].numpy(), + "is_episode_successful": np.array([is_episode_successful(m)]), + } + for f in tf_episode["steps"]: + # Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema + frame = { + "is_first": np.array([f["is_first"].numpy()]), + "is_last": np.array([f["is_last"].numpy()]), + "is_terminal": np.array([f["is_terminal"].numpy()]), + "language_instruction": f["language_instruction"].numpy().decode(), + "language_instruction_2": f["language_instruction_2"].numpy().decode(), + "language_instruction_3": f["language_instruction_3"].numpy().decode(), + "observation.state.gripper_position": f["observation"]["gripper_position"].numpy(), + "observation.state.cartesian_position": f["observation"]["cartesian_position"].numpy(), + "observation.state.joint_position": f["observation"]["joint_position"].numpy(), + "observation.images.wrist_left": f["observation"]["wrist_image_left"].numpy(), + "observation.images.exterior_1_left": f["observation"]["exterior_image_1_left"].numpy(), + "observation.images.exterior_2_left": f["observation"]["exterior_image_2_left"].numpy(), + "action.gripper_position": f["action_dict"]["gripper_position"].numpy(), + "action.gripper_velocity": f["action_dict"]["gripper_velocity"].numpy(), + "action.cartesian_position": f["action_dict"]["cartesian_position"].numpy(), + "action.cartesian_velocity": f["action_dict"]["cartesian_velocity"].numpy(), + "action.joint_position": f["action_dict"]["joint_position"].numpy(), + "action.joint_velocity": f["action_dict"]["joint_velocity"].numpy(), + "discount": np.array([f["discount"].numpy()]), + "reward": np.array([f["reward"].numpy()]), + "action.original": f["action"].numpy(), + } + + # language_instruction is also stored as "task" to follow LeRobot standard + frame["task"] = frame["language_instruction"] + + # Add this new feature to follow LeRobot standard of using joint position + gripper + frame["observation.state"] = np.concatenate( + [frame["observation.state.joint_position"], frame["observation.state.gripper_position"]] + ) + frame["action"] = np.concatenate([frame["action.joint_position"], frame["action.gripper_position"]]) + + # Meta data that are the same for all frames in the episode + frame.update(frame_meta) + + # Cast fp64 to fp32 + for key in frame: + if isinstance(frame[key], np.ndarray) and frame[key].dtype == np.float64: + frame[key] = frame[key].astype(np.float32) + + yield frame + + +def port_droid( + raw_dir: Path, + repo_id: str, + push_to_hub: bool = False, + num_shards: int | None = None, + shard_index: int | None = None, +): + dataset_name = raw_dir.parent.name + version = raw_dir.name + data_dir = raw_dir.parent.parent + + builder = tfds.builder(f"{dataset_name}/{version}", data_dir=data_dir, version="") + + if num_shards is not None: + tfds_num_shards = builder.info.splits["train"].num_shards + if tfds_num_shards != DROID_SHARDS: + raise ValueError( + f"Number of shards of Droid dataset is expected to be {DROID_SHARDS} but is {tfds_num_shards}." + ) + if num_shards != tfds_num_shards: + raise ValueError( + f"We only shard over the fixed number of shards provided by tensorflow dataset ({tfds_num_shards}), but {num_shards} shards provided instead." + ) + if shard_index >= tfds_num_shards: + raise ValueError( + f"Shard index is greater than the num of shards ({shard_index} >= {num_shards})." + ) + + raw_dataset = builder.as_dataset(split=f"train[{shard_index}shard]") + else: + raw_dataset = builder.as_dataset(split="train") + + lerobot_dataset = LeRobotDataset.create( + repo_id=repo_id, + robot_type=DROID_ROBOT_TYPE, + fps=DROID_FPS, + features=DROID_FEATURES, + ) + + start_time = time.time() + num_episodes = raw_dataset.cardinality().numpy().item() + logging.info(f"Number of episodes {num_episodes}") + + for episode_index, episode in enumerate(raw_dataset): + elapsed_time = time.time() - start_time + d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time) + + logging.info( + f"{episode_index} / {num_episodes} episodes processed (after {d} days, {h} hours, {m} minutes, {s:.3f} seconds)" + ) + + for frame in generate_lerobot_frames(episode): + lerobot_dataset.add_frame(frame) + + lerobot_dataset.save_episode() + logging.info("Save_episode") + + if push_to_hub: + lerobot_dataset.push_to_hub( + # Add openx tag, since it belongs to the openx collection of datasets + tags=["openx"], + private=False, + ) + + +def validate_dataset(repo_id): + """Sanity check that ensure meta data can be loaded and all files are present.""" + meta = LeRobotDatasetMetadata(repo_id) + + if meta.total_episodes == 0: + raise ValueError("Number of episodes is 0.") + + for ep_idx in range(meta.total_episodes): + data_path = meta.root / meta.get_data_file_path(ep_idx) + + if not data_path.exists(): + raise ValueError(f"Parquet file is missing in: {data_path}") + + for vid_key in meta.video_keys: + vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key) + if not vid_path.exists(): + raise ValueError(f"Video file is missing in: {vid_path}") + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--raw-dir", + type=Path, + required=True, + help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).", + ) + parser.add_argument( + "--repo-id", + type=str, + help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True", + ) + parser.add_argument( + "--push-to-hub", + action="store_true", + help="Upload to hub.", + ) + parser.add_argument( + "--num-shards", + type=int, + default=None, + help="Number of shards. Can be either None to load the full dataset, or 2048 to load one of the 2048 tensorflow dataset files.", + ) + parser.add_argument( + "--shard-index", + type=int, + default=None, + help="Index of the shard. Can be either None to load the full dataset, or in [0,2047] to load one of the 2048 tensorflow dataset files.", + ) + + args = parser.parse_args() + + port_droid(**vars(args)) + + +if __name__ == "__main__": + main() diff --git a/examples/port_datasets/slurm_aggregate_shards.py b/examples/port_datasets/slurm_aggregate_shards.py new file mode 100644 index 000000000..4e1b71a31 --- /dev/null +++ b/examples/port_datasets/slurm_aggregate_shards.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path + +from datatrove.executor import LocalPipelineExecutor +from datatrove.executor.slurm import SlurmPipelineExecutor +from datatrove.pipeline.base import PipelineStep +from port_datasets.droid_rlds.port_droid import DROID_SHARDS + +from lerobot.datasets.aggregate import aggregate_datasets +from lerobot.utils.utils import init_logging + + +class AggregateDatasets(PipelineStep): + def __init__( + self, + repo_ids: list[str], + aggregated_repo_id: str, + ): + super().__init__() + self.repo_ids = repo_ids + self.aggr_repo_id = aggregated_repo_id + + def run(self, data=None, rank: int = 0, world_size: int = 1): + init_logging() + + # Since aggregate_datasets already handles parallel processing internally, + # we only need one worker to run the entire aggregation + if rank == 0: + logging.info(f"Starting aggregation of {len(self.repo_ids)} datasets into {self.aggr_repo_id}") + aggregate_datasets(self.repo_ids, self.aggr_repo_id) + logging.info("Aggregation complete!") + else: + logging.info(f"Worker {rank} skipping - only worker 0 performs aggregation") + + +def make_aggregate_executor( + repo_ids, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True +): + kwargs = { + "pipeline": [ + AggregateDatasets(repo_ids, repo_id), + ], + "logging_dir": str(logs_dir / job_name), + } + + if slurm: + # For aggregation, we only need 1 task since aggregate_datasets handles everything + kwargs.update( + { + "job_name": job_name, + "tasks": 1, # Only need 1 task for aggregation + "workers": 1, # Only need 1 worker + "time": "08:00:00", + "partition": partition, + "cpus_per_task": cpus_per_task, + "sbatch_args": {"mem-per-cpu": mem_per_cpu}, + } + ) + executor = SlurmPipelineExecutor(**kwargs) + else: + kwargs.update( + { + "tasks": 1, + "workers": 1, + } + ) + executor = LocalPipelineExecutor(**kwargs) + + return executor + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--repo-id", + type=str, + help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.", + ) + parser.add_argument( + "--logs-dir", + type=Path, + help="Path to logs directory for `datatrove`.", + ) + parser.add_argument( + "--job-name", + type=str, + default="aggr_droid", + help="Job name used in slurm, and name of the directory created inside the provided logs directory.", + ) + parser.add_argument( + "--slurm", + type=int, + default=1, + help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).", + ) + parser.add_argument( + "--workers", + type=int, + default=1, # Changed default to 1 since aggregation doesn't need multiple workers + help="Number of slurm workers. For aggregation, this should be 1.", + ) + parser.add_argument( + "--partition", + type=str, + help="Slurm partition. Ideally a CPU partition. No need for GPU partition.", + ) + parser.add_argument( + "--cpus-per-task", + type=int, + default=8, + help="Number of cpus that each slurm worker will use.", + ) + parser.add_argument( + "--mem-per-cpu", + type=str, + default="1950M", + help="Memory per cpu that each worker will use.", + ) + + args = parser.parse_args() + kwargs = vars(args) + kwargs["slurm"] = kwargs.pop("slurm") == 1 + + repo_ids = [f"{args.repo_id}_world_{DROID_SHARDS}_rank_{rank}" for rank in range(DROID_SHARDS)] + aggregate_executor = make_aggregate_executor(repo_ids, **kwargs) + aggregate_executor.run() + + +if __name__ == "__main__": + main() diff --git a/examples/port_datasets/slurm_port_shards.py b/examples/port_datasets/slurm_port_shards.py new file mode 100644 index 000000000..3bb4c135c --- /dev/null +++ b/examples/port_datasets/slurm_port_shards.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path + +from datatrove.executor import LocalPipelineExecutor +from datatrove.executor.slurm import SlurmPipelineExecutor +from datatrove.pipeline.base import PipelineStep +from port_datasets.droid_rlds.port_droid import DROID_SHARDS + + +class PortDroidShards(PipelineStep): + def __init__( + self, + raw_dir: Path | str, + repo_id: str = None, + ): + super().__init__() + self.raw_dir = Path(raw_dir) + self.repo_id = repo_id + + def run(self, data=None, rank: int = 0, world_size: int = 1): + from datasets.utils.tqdm import disable_progress_bars + from port_datasets.droid_rlds.port_droid import port_droid, validate_dataset + + from lerobot.utils.utils import init_logging + + init_logging() + disable_progress_bars() + + shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}" + + try: + validate_dataset(shard_repo_id) + return + except Exception: + pass # nosec B110 - Dataset doesn't exist yet, continue with porting + + port_droid( + self.raw_dir, + shard_repo_id, + push_to_hub=False, + num_shards=world_size, + shard_index=rank, + ) + + validate_dataset(shard_repo_id) + + +def make_port_executor( + raw_dir, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True +): + kwargs = { + "pipeline": [ + PortDroidShards(raw_dir, repo_id), + ], + "logging_dir": str(logs_dir / job_name), + } + + if slurm: + kwargs.update( + { + "job_name": job_name, + "tasks": DROID_SHARDS, + "workers": workers, + "time": "08:00:00", + "partition": partition, + "cpus_per_task": cpus_per_task, + "sbatch_args": {"mem-per-cpu": mem_per_cpu}, + } + ) + executor = SlurmPipelineExecutor(**kwargs) + else: + kwargs.update( + { + "tasks": 1, + "workers": 1, + } + ) + executor = LocalPipelineExecutor(**kwargs) + + return executor + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--raw-dir", + type=Path, + required=True, + help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).", + ) + parser.add_argument( + "--repo-id", + type=str, + help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.", + ) + parser.add_argument( + "--logs-dir", + type=Path, + help="Path to logs directory for `datatrove`.", + ) + parser.add_argument( + "--job-name", + type=str, + default="port_droid", + help="Job name used in slurm, and name of the directory created inside the provided logs directory.", + ) + parser.add_argument( + "--slurm", + type=int, + default=1, + help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).", + ) + parser.add_argument( + "--workers", + type=int, + default=2048, + help="Number of slurm workers. It should be less than the maximum number of shards.", + ) + parser.add_argument( + "--partition", + type=str, + help="Slurm partition. Ideally a CPU partition. No need for GPU partition.", + ) + parser.add_argument( + "--cpus-per-task", + type=int, + default=8, + help="Number of cpus that each slurm worker will use.", + ) + parser.add_argument( + "--mem-per-cpu", + type=str, + default="1950M", + help="Memory per cpu that each worker will use.", + ) + + args = parser.parse_args() + kwargs = vars(args) + kwargs["slurm"] = kwargs.pop("slurm") == 1 + port_executor = make_port_executor(**kwargs) + port_executor.run() + + +if __name__ == "__main__": + main() diff --git a/examples/port_datasets/slurm_upload.py b/examples/port_datasets/slurm_upload.py new file mode 100644 index 000000000..ade1ef874 --- /dev/null +++ b/examples/port_datasets/slurm_upload.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +from pathlib import Path + +from datatrove.executor import LocalPipelineExecutor +from datatrove.executor.slurm import SlurmPipelineExecutor +from datatrove.pipeline.base import PipelineStep +from huggingface_hub import HfApi +from huggingface_hub.constants import REPOCARD_NAME +from port_datasets.droid_rlds.port_droid import DROID_SHARDS + +from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata +from lerobot.datasets.utils import create_lerobot_dataset_card +from lerobot.utils.utils import init_logging + + +class UploadDataset(PipelineStep): + def __init__( + self, + repo_id: str, + branch: str | None = None, + revision: str | None = None, + tags: list | None = None, + license: str | None = "apache-2.0", + private: bool = False, + distant_repo_id: str | None = None, + **card_kwargs, + ): + super().__init__() + self.repo_id = repo_id + self.distant_repo_id = self.repo_id if distant_repo_id is None else distant_repo_id + self.branch = branch + self.tags = tags + self.license = license + self.private = private + self.card_kwargs = card_kwargs + self.revision = revision if revision else CODEBASE_VERSION + + if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") != "1": + logging.warning( + 'HF_HUB_ENABLE_HF_TRANSFER is not set to "1". Install hf_transfer and set the env ' + "variable for faster uploads:\npip install hf-transfer\nexport HF_HUB_ENABLE_HF_TRANSFER=1" + ) + + self.create_repo() + + def create_repo(self): + logging.info(f"Loading meta data from {self.repo_id}...") + meta = LeRobotDatasetMetadata(self.repo_id) + + logging.info(f"Creating repo {self.distant_repo_id}...") + hub_api = HfApi() + hub_api.create_repo( + repo_id=self.distant_repo_id, + private=self.private, + repo_type="dataset", + exist_ok=True, + ) + if self.branch: + hub_api.create_branch( + repo_id=self.distant_repo_id, + branch=self.branch, + revision=self.revision, + repo_type="dataset", + exist_ok=True, + ) + + if not hub_api.file_exists( + self.distant_repo_id, REPOCARD_NAME, repo_type="dataset", revision=self.branch + ): + card = create_lerobot_dataset_card( + tags=self.tags, dataset_info=meta.info, license=self.license, **self.card_kwargs + ) + card.push_to_hub(repo_id=self.distant_repo_id, repo_type="dataset", revision=self.branch) + + hub_api.create_tag(self.distant_repo_id, tag=CODEBASE_VERSION, repo_type="dataset") + + def list_files_recursively(directory): + base_path = Path(directory) + return [str(file.relative_to(base_path)) for file in base_path.rglob("*") if file.is_file()] + + logging.info(f"Listing all local files from {self.repo_id}...") + self.file_paths = list_files_recursively(meta.root) + self.file_paths = sorted(self.file_paths) + + def create_chunks(self, lst, n): + from itertools import islice + + it = iter(lst) + return [list(islice(it, size)) for size in [len(lst) // n + (i < len(lst) % n) for i in range(n)]] + + def create_commits(self, additions): + import logging + import math + import random + import time + + from huggingface_hub import create_commit + from huggingface_hub.utils import HfHubHTTPError + + FILES_BETWEEN_COMMITS = 10 # noqa: N806 + BASE_DELAY = 0.1 # noqa: N806 + MAX_RETRIES = 12 # noqa: N806 + + # Split the files into smaller chunks for faster commit + # and avoiding "A commit has happened since" error + num_chunks = math.ceil(len(additions) / FILES_BETWEEN_COMMITS) + chunks = self.create_chunks(additions, num_chunks) + + for chunk in chunks: + retries = 0 + while True: + try: + create_commit( + self.distant_repo_id, + repo_type="dataset", + operations=chunk, + commit_message=f"DataTrove upload ({len(chunk)} files)", + revision=self.branch, + ) + # TODO: every 100 chunks super_squach_commits() + logging.info("create_commit completed!") + break + except HfHubHTTPError as e: + if "A commit has happened since" in e.server_message: + if retries >= MAX_RETRIES: + logging.error(f"Failed to create commit after {MAX_RETRIES=}. Giving up.") + raise e + logging.info("Commit creation race condition issue. Waiting...") + time.sleep(BASE_DELAY * 2**retries + random.uniform(0, 2)) + retries += 1 + else: + raise e + + def run(self, data=None, rank: int = 0, world_size: int = 1): + import logging + + from datasets.utils.tqdm import disable_progress_bars + from huggingface_hub import CommitOperationAdd, preupload_lfs_files + + from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata + from lerobot.utils.utils import init_logging + + init_logging() + disable_progress_bars() + + chunks = self.create_chunks(self.file_paths, world_size) + file_paths = chunks[rank] + + if len(file_paths) == 0: + raise ValueError(file_paths) + + logging.info("Pre-uploading LFS files...") + for i, path in enumerate(file_paths): + logging.info(f"{i}: {path}") + + meta = LeRobotDatasetMetadata(self.repo_id) + additions = [ + CommitOperationAdd(path_in_repo=path, path_or_fileobj=meta.root / path) for path in file_paths + ] + preupload_lfs_files( + repo_id=self.distant_repo_id, repo_type="dataset", additions=additions, revision=self.branch + ) + + logging.info("Creating commits...") + self.create_commits(additions) + logging.info("Done!") + + +def make_upload_executor( + repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True +): + kwargs = { + "pipeline": [ + UploadDataset(repo_id), + ], + "logging_dir": str(logs_dir / job_name), + } + + if slurm: + kwargs.update( + { + "job_name": job_name, + "tasks": DROID_SHARDS, + "workers": workers, + "time": "08:00:00", + "partition": partition, + "cpus_per_task": cpus_per_task, + "sbatch_args": {"mem-per-cpu": mem_per_cpu}, + } + ) + executor = SlurmPipelineExecutor(**kwargs) + else: + kwargs.update( + { + "tasks": DROID_SHARDS, + "workers": 1, + } + ) + executor = LocalPipelineExecutor(**kwargs) + + return executor + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--repo-id", + type=str, + help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.", + ) + parser.add_argument( + "--logs-dir", + type=Path, + help="Path to logs directory for `datatrove`.", + ) + parser.add_argument( + "--job-name", + type=str, + default="upload_droid", + help="Job name used in slurm, and name of the directory created inside the provided logs directory.", + ) + parser.add_argument( + "--slurm", + type=int, + default=1, + help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).", + ) + parser.add_argument( + "--workers", + type=int, + default=50, + help="Number of slurm workers. It should be less than the maximum number of shards.", + ) + parser.add_argument( + "--partition", + type=str, + help="Slurm partition. Ideally a CPU partition. No need for GPU partition.", + ) + parser.add_argument( + "--cpus-per-task", + type=int, + default=8, + help="Number of cpus that each slurm worker will use.", + ) + parser.add_argument( + "--mem-per-cpu", + type=str, + default="1950M", + help="Memory per cpu that each worker will use.", + ) + + init_logging() + + args = parser.parse_args() + kwargs = vars(args) + kwargs["slurm"] = kwargs.pop("slurm") == 1 + upload_executor = make_upload_executor(**kwargs) + upload_executor.run() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 50cd207e9..7241a78f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,6 @@ dependencies = [ # Support dependencies "deepdiff>=7.0.1,<9.0.0", - "flask>=3.0.3,<4.0.0", "imageio[ffmpeg]>=2.34.0,<3.0.0", "termcolor>=2.4.0,<4.0.0", ] diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py new file mode 100644 index 000000000..43d4ee233 --- /dev/null +++ b/src/lerobot/datasets/aggregate.py @@ -0,0 +1,502 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import shutil +from pathlib import Path + +import pandas as pd +import tqdm + +from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.datasets.utils import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_DATA_PATH, + DEFAULT_EPISODES_PATH, + DEFAULT_VIDEO_FILE_SIZE_IN_MB, + DEFAULT_VIDEO_PATH, + get_parquet_file_size_in_mb, + get_video_size_in_mb, + to_parquet_with_hf_images, + update_chunk_file_indices, + write_info, + write_stats, + write_tasks, +) +from lerobot.datasets.video_utils import concatenate_video_files + + +def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): + """Validates that all dataset metadata have consistent properties. + + Ensures all datasets have the same fps, robot_type, and features to guarantee + compatibility when aggregating them into a single dataset. + + Args: + all_metadata: List of LeRobotDatasetMetadata objects to validate. + + Returns: + tuple: A tuple containing (fps, robot_type, features) from the first metadata. + + Raises: + ValueError: If any metadata has different fps, robot_type, or features + than the first metadata in the list. + """ + + fps = all_metadata[0].fps + robot_type = all_metadata[0].robot_type + features = all_metadata[0].features + + for meta in tqdm.tqdm(all_metadata, desc="Validate all meta data"): + if fps != meta.fps: + raise ValueError(f"Same fps is expected, but got fps={meta.fps} instead of {fps}.") + if robot_type != meta.robot_type: + raise ValueError( + f"Same robot_type is expected, but got robot_type={meta.robot_type} instead of {robot_type}." + ) + if features != meta.features: + raise ValueError( + f"Same features is expected, but got features={meta.features} instead of {features}." + ) + + return fps, robot_type, features + + +def update_data_df(df, src_meta, dst_meta): + """Updates a data DataFrame with new indices and task mappings for aggregation. + + Adjusts episode indices, frame indices, and task indices to account for + previously aggregated data in the destination dataset. + + Args: + df: DataFrame containing the data to be updated. + src_meta: Source dataset metadata. + dst_meta: Destination dataset metadata. + + Returns: + pd.DataFrame: Updated DataFrame with adjusted indices. + """ + + def _update(row): + row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"] + row["index"] = row["index"] + dst_meta.info["total_frames"] + task = src_meta.tasks.iloc[row["task_index"]].name + row["task_index"] = dst_meta.tasks.loc[task].task_index.item() + return row + + return df.apply(_update, axis=1) + + +def update_meta_data( + df, + dst_meta, + meta_idx, + data_idx, + videos_idx, +): + """Updates metadata DataFrame with new chunk, file, and timestamp indices. + + Adjusts all indices and timestamps to account for previously aggregated + data and videos in the destination dataset. + + Args: + df: DataFrame containing the metadata to be updated. + dst_meta: Destination dataset metadata. + meta_idx: Dictionary containing current metadata chunk and file indices. + data_idx: Dictionary containing current data chunk and file indices. + videos_idx: Dictionary containing current video indices and timestamps. + + Returns: + pd.DataFrame: Updated DataFrame with adjusted indices and timestamps. + """ + + def _update(row): + row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"] + row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"] + row["data/chunk_index"] = row["data/chunk_index"] + data_idx["chunk"] + row["data/file_index"] = row["data/file_index"] + data_idx["file"] + for key, video_idx in videos_idx.items(): + row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + video_idx["chunk"] + row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + video_idx["file"] + row[f"videos/{key}/from_timestamp"] = ( + row[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"] + ) + row[f"videos/{key}/to_timestamp"] = ( + row[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"] + ) + + row["dataset_from_index"] = row["dataset_from_index"] + dst_meta.info["total_frames"] + row["dataset_to_index"] = row["dataset_to_index"] + dst_meta.info["total_frames"] + row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"] + return row + + return df.apply(_update, axis=1) + + +def aggregate_datasets( + repo_ids: list[str], + aggr_repo_id: str, + roots: list[Path] | None = None, + aggr_root: Path | None = None, + data_files_size_in_mb: float | None = None, + video_files_size_in_mb: float | None = None, + chunk_size: int | None = None, +): + """Aggregates multiple LeRobot datasets into a single unified dataset. + + This is the main function that orchestrates the aggregation process by: + 1. Loading and validating all source dataset metadata + 2. Creating a new destination dataset with unified tasks + 3. Aggregating videos, data, and metadata from all source datasets + 4. Finalizing the aggregated dataset with proper statistics + + Args: + repo_ids: List of repository IDs for the datasets to aggregate. + aggr_repo_id: Repository ID for the aggregated output dataset. + roots: Optional list of root paths for the source datasets. + aggr_root: Optional root path for the aggregated dataset. + data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB) + video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB) + chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE) + """ + logging.info("Start aggregate_datasets") + + if data_files_size_in_mb is None: + data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB + if video_files_size_in_mb is None: + video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB + if chunk_size is None: + chunk_size = DEFAULT_CHUNK_SIZE + + all_metadata = ( + [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids] + if roots is None + else [ + LeRobotDatasetMetadata(repo_id, root=root) for repo_id, root in zip(repo_ids, roots, strict=False) + ] + ) + fps, robot_type, features = validate_all_metadata(all_metadata) + video_keys = [key for key in features if features[key]["dtype"] == "video"] + + dst_meta = LeRobotDatasetMetadata.create( + repo_id=aggr_repo_id, + fps=fps, + robot_type=robot_type, + features=features, + root=aggr_root, + ) + + logging.info("Find all tasks") + unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique() + dst_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks) + + meta_idx = {"chunk": 0, "file": 0} + data_idx = {"chunk": 0, "file": 0} + videos_idx = { + key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys + } + + dst_meta.episodes = {} + + for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"): + videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size) + data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size) + + meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx) + + dst_meta.info["total_episodes"] += src_meta.total_episodes + dst_meta.info["total_frames"] += src_meta.total_frames + + finalize_aggregation(dst_meta, all_metadata) + logging.info("Aggregation complete.") + + +def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size): + """Aggregates video chunks from a source dataset into the destination dataset. + + Handles video file concatenation and rotation based on file size limits. + Creates new video files when size limits are exceeded. + + Args: + src_meta: Source dataset metadata. + dst_meta: Destination dataset metadata. + videos_idx: Dictionary tracking video chunk and file indices. + video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB) + chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE) + + Returns: + dict: Updated videos_idx with current chunk and file indices. + """ + for key, video_idx in videos_idx.items(): + unique_chunk_file_pairs = { + (chunk, file) + for chunk, file in zip( + src_meta.episodes[f"videos/{key}/chunk_index"], + src_meta.episodes[f"videos/{key}/file_index"], + strict=False, + ) + } + unique_chunk_file_pairs = sorted(unique_chunk_file_pairs) + + chunk_idx = video_idx["chunk"] + file_idx = video_idx["file"] + + for src_chunk_idx, src_file_idx in unique_chunk_file_pairs: + src_path = src_meta.root / DEFAULT_VIDEO_PATH.format( + video_key=key, + chunk_index=src_chunk_idx, + file_index=src_file_idx, + ) + + dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format( + video_key=key, + chunk_index=chunk_idx, + file_index=file_idx, + ) + + # If a new file is created, we don't want to increment the latest_duration + update_latest_duration = False + + if not dst_path.exists(): + # First write to this destination file + dst_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(str(src_path), str(dst_path)) + continue # not accumulating further, already copied the file in place + + # Check file sizes before appending + src_size = get_video_size_in_mb(src_path) + dst_size = get_video_size_in_mb(dst_path) + + if dst_size + src_size >= video_files_size_in_mb: + # Rotate to a new chunk/file + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size) + dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format( + video_key=key, + chunk_index=chunk_idx, + file_index=file_idx, + ) + dst_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(str(src_path), str(dst_path)) + else: + # Get the timestamps shift for this video + timestamps_shift_s = dst_meta.info["total_frames"] / dst_meta.info["fps"] + + # Append to existing video file + concatenate_video_files( + [dst_path, src_path], + dst_path, + ) + # Update the latest_duration when appending (shifts timestamps!) + update_latest_duration = not update_latest_duration + + # Update the videos_idx with the final chunk and file indices for this key + videos_idx[key]["chunk"] = chunk_idx + videos_idx[key]["file"] = file_idx + + if update_latest_duration: + videos_idx[key]["latest_duration"] += timestamps_shift_s + + return videos_idx + + +def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size): + """Aggregates data chunks from a source dataset into the destination dataset. + + Reads source data files, updates indices to match the aggregated dataset, + and writes them to the destination with proper file rotation. + + Args: + src_meta: Source dataset metadata. + dst_meta: Destination dataset metadata. + data_idx: Dictionary tracking data chunk and file indices. + + Returns: + dict: Updated data_idx with current chunk and file indices. + """ + unique_chunk_file_ids = { + (c, f) + for c, f in zip( + src_meta.episodes["data/chunk_index"], src_meta.episodes["data/file_index"], strict=False + ) + } + + unique_chunk_file_ids = sorted(unique_chunk_file_ids) + + for src_chunk_idx, src_file_idx in unique_chunk_file_ids: + src_path = src_meta.root / DEFAULT_DATA_PATH.format( + chunk_index=src_chunk_idx, file_index=src_file_idx + ) + df = pd.read_parquet(src_path) + df = update_data_df(df, src_meta, dst_meta) + + data_idx = append_or_create_parquet_file( + df, + src_path, + data_idx, + data_files_size_in_mb, + chunk_size, + DEFAULT_DATA_PATH, + contains_images=len(dst_meta.image_keys) > 0, + aggr_root=dst_meta.root, + ) + + return data_idx + + +def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): + """Aggregates metadata from a source dataset into the destination dataset. + + Reads source metadata files, updates all indices and timestamps, + and writes them to the destination with proper file rotation. + + Args: + src_meta: Source dataset metadata. + dst_meta: Destination dataset metadata. + meta_idx: Dictionary tracking metadata chunk and file indices. + data_idx: Dictionary tracking data chunk and file indices. + videos_idx: Dictionary tracking video indices and timestamps. + + Returns: + dict: Updated meta_idx with current chunk and file indices. + """ + chunk_file_ids = { + (c, f) + for c, f in zip( + src_meta.episodes["meta/episodes/chunk_index"], + src_meta.episodes["meta/episodes/file_index"], + strict=False, + ) + } + + chunk_file_ids = sorted(chunk_file_ids) + for chunk_idx, file_idx in chunk_file_ids: + src_path = src_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + df = pd.read_parquet(src_path) + df = update_meta_data( + df, + dst_meta, + meta_idx, + data_idx, + videos_idx, + ) + + for k in videos_idx: + videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"] + + meta_idx = append_or_create_parquet_file( + df, + src_path, + meta_idx, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_CHUNK_SIZE, + DEFAULT_EPISODES_PATH, + contains_images=False, + aggr_root=dst_meta.root, + ) + + return meta_idx + + +def append_or_create_parquet_file( + df: pd.DataFrame, + src_path: Path, + idx: dict[str, int], + max_mb: float, + chunk_size: int, + default_path: str, + contains_images: bool = False, + aggr_root: Path = None, +): + """Appends data to an existing parquet file or creates a new one based on size constraints. + + Manages file rotation when size limits are exceeded to prevent individual files + from becoming too large. Handles both regular parquet files and those containing images. + + Args: + df: DataFrame to write to the parquet file. + src_path: Path to the source file (used for size estimation). + idx: Dictionary containing current 'chunk' and 'file' indices. + max_mb: Maximum allowed file size in MB before rotation. + chunk_size: Maximum number of files per chunk before incrementing chunk index. + default_path: Format string for generating file paths. + contains_images: Whether the data contains images requiring special handling. + aggr_root: Root path for the aggregated dataset. + + Returns: + dict: Updated index dictionary with current chunk and file indices. + """ + dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) + + if not dst_path.exists(): + dst_path.parent.mkdir(parents=True, exist_ok=True) + if contains_images: + to_parquet_with_hf_images(df, dst_path) + else: + df.to_parquet(dst_path) + return idx + + src_size = get_parquet_file_size_in_mb(src_path) + dst_size = get_parquet_file_size_in_mb(dst_path) + + if dst_size + src_size >= max_mb: + idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size) + new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) + new_path.parent.mkdir(parents=True, exist_ok=True) + final_df = df + target_path = new_path + else: + existing_df = pd.read_parquet(dst_path) + final_df = pd.concat([existing_df, df], ignore_index=True) + target_path = dst_path + + if contains_images: + to_parquet_with_hf_images(final_df, target_path) + else: + final_df.to_parquet(target_path) + + return idx + + +def finalize_aggregation(aggr_meta, all_metadata): + """Finalizes the dataset aggregation by writing summary files and statistics. + + Writes the tasks file, info file with total counts and splits, and + aggregated statistics from all source datasets. + + Args: + aggr_meta: Aggregated dataset metadata. + all_metadata: List of all source dataset metadata objects. + """ + logging.info("write tasks") + write_tasks(aggr_meta.tasks, aggr_meta.root) + + logging.info("write info") + aggr_meta.info.update( + { + "total_tasks": len(aggr_meta.tasks), + "total_episodes": sum(m.total_episodes for m in all_metadata), + "total_frames": sum(m.total_frames for m in all_metadata), + "splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"}, + } + ) + write_info(aggr_meta.info, aggr_meta.root) + + logging.info("write stats") + aggr_meta.stats = aggregate_stats([m.stats for m in all_metadata]) + write_stats(aggr_meta.stats, aggr_meta.root) diff --git a/src/lerobot/datasets/backward_compatibility.py b/src/lerobot/datasets/backward_compatibility.py index fae485058..1d600434a 100644 --- a/src/lerobot/datasets/backward_compatibility.py +++ b/src/lerobot/datasets/backward_compatibility.py @@ -14,33 +14,13 @@ import packaging.version -V2_MESSAGE = """ +V30_MESSAGE = """ The dataset you requested ({repo_id}) is in {version} format. -We introduced a new format since v2.0 which is not backward compatible with v1.x. -Please, use our conversion script. Modify the following command with your own task description: +We introduced a new format since v3.0 which is not backward compatible with v2.1. +Please, update your dataset to the new format using this command: ``` -python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \\ - --repo-id {repo_id} \\ - --single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\ -``` - -A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the -peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top -cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped -target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the -sweatshirt.", ... - -If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb) -or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose). -""" - -V21_MESSAGE = """ -The dataset you requested ({repo_id}) is in {version} format. -While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global -stats instead of per-episode stats. Update your dataset stats to the new format using this command: -``` -python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 --repo-id={repo_id} +python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id={repo_id} ``` If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb) @@ -58,7 +38,12 @@ class CompatibilityError(Exception): ... class BackwardCompatibilityError(CompatibilityError): def __init__(self, repo_id: str, version: packaging.version.Version): - message = V2_MESSAGE.format(repo_id=repo_id, version=version) + if version.major == 2 and version.minor == 1: + message = V30_MESSAGE.format(repo_id=repo_id, version=version) + else: + raise NotImplementedError( + "Contact the maintainer on [Discord](https://discord.com/invite/s3KuuzsPFb)." + ) super().__init__(message) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index a869cb920..ceefcf05e 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -14,18 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +import gc import logging import shutil +import tempfile from collections.abc import Callable from pathlib import Path import datasets import numpy as np import packaging.version +import pandas as pd import PIL.Image import torch import torch.utils -from datasets import concatenate_datasets, load_dataset from huggingface_hub import HfApi, snapshot_download from huggingface_hub.constants import REPOCARD_NAME from huggingface_hub.errors import RevisionNotFoundError @@ -34,46 +36,52 @@ from lerobot.constants import HF_LEROBOT_HOME from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.datasets.image_writer import AsyncImageWriter, write_image from lerobot.datasets.utils import ( + DEFAULT_EPISODES_PATH, DEFAULT_FEATURES, DEFAULT_IMAGE_PATH, INFO_PATH, - TASKS_PATH, _validate_feature_names, - append_jsonlines, - backward_compatible_episodes_stats, check_delta_timestamps, - check_timestamps_sync, check_version_compatibility, create_empty_dataset_info, create_lerobot_dataset_card, embed_images, + flatten_dict, get_delta_indices, - get_episode_data_index, + get_hf_dataset_cache_dir, + get_hf_dataset_size_in_mb, get_hf_features_from_features, + get_parquet_file_size_in_mb, + get_parquet_num_frames, get_safe_version, + get_video_size_in_mb, hf_transform_to_torch, is_valid_version, load_episodes, - load_episodes_stats, load_info, + load_nested_dataset, load_stats, load_tasks, + to_parquet_with_hf_images, + update_chunk_file_indices, validate_episode_buffer, validate_frame, - write_episode, - write_episode_stats, write_info, write_json, + write_stats, + write_tasks, ) from lerobot.datasets.video_utils import ( VideoFrame, + concatenate_video_files, decode_video_frames, encode_video_frames, get_safe_default_codec, + get_video_duration_in_s, get_video_info, ) -CODEBASE_VERSION = "v2.1" +CODEBASE_VERSION = "v3.0" class LeRobotDatasetMetadata: @@ -103,14 +111,9 @@ class LeRobotDatasetMetadata: def load_metadata(self): self.info = load_info(self.root) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) - self.tasks, self.task_to_task_index = load_tasks(self.root) + self.tasks = load_tasks(self.root) self.episodes = load_episodes(self.root) - if self._version < packaging.version.parse("v2.1"): - self.stats = load_stats(self.root) - self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes) - else: - self.episodes_stats = load_episodes_stats(self.root) - self.stats = aggregate_stats(list(self.episodes_stats.values())) + self.stats = load_stats(self.root) def pull_from_repo( self, @@ -132,18 +135,19 @@ class LeRobotDatasetMetadata: return packaging.version.parse(self.info["codebase_version"]) def get_data_file_path(self, ep_index: int) -> Path: - ep_chunk = self.get_episode_chunk(ep_index) - fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index) + ep = self.episodes[ep_index] + chunk_idx = ep["data/chunk_index"] + file_idx = ep["data/file_index"] + fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx) return Path(fpath) def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: - ep_chunk = self.get_episode_chunk(ep_index) - fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) + ep = self.episodes[ep_index] + chunk_idx = ep[f"videos/{vid_key}/chunk_index"] + file_idx = ep[f"videos/{vid_key}/file_index"] + fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx) return Path(fpath) - def get_episode_chunk(self, ep_index: int) -> int: - return ep_index // self.chunks_size - @property def data_path(self) -> str: """Formattable string for the parquet files.""" @@ -210,39 +214,115 @@ class LeRobotDatasetMetadata: return self.info["total_tasks"] @property - def total_chunks(self) -> int: - """Total number of chunks (groups of episodes).""" - return self.info["total_chunks"] + def chunks_size(self) -> int: + """Max number of files per chunk.""" + return self.info["chunks_size"] @property - def chunks_size(self) -> int: - """Max number of episodes per chunk.""" - return self.info["chunks_size"] + def data_files_size_in_mb(self) -> int: + """Max size of data file in mega bytes.""" + return self.info["data_files_size_in_mb"] + + @property + def video_files_size_in_mb(self) -> int: + """Max size of video file in mega bytes.""" + return self.info["video_files_size_in_mb"] def get_task_index(self, task: str) -> int | None: """ Given a task in natural language, returns its task_index if the task already exists in the dataset, otherwise return None. """ - return self.task_to_task_index.get(task, None) + if task in self.tasks.index: + return int(self.tasks.loc[task].task_index) + else: + return None - def add_task(self, task: str): + def save_episode_tasks(self, tasks: list[str]): + if len(set(tasks)) != len(tasks): + raise ValueError(f"Tasks are not unique: {tasks}") + + if self.tasks is None: + new_tasks = tasks + task_indices = range(len(tasks)) + self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks) + else: + new_tasks = [task for task in tasks if task not in self.tasks.index] + new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks)) + for task_idx, task in zip(new_task_indices, new_tasks, strict=False): + self.tasks.loc[task] = task_idx + + if len(new_tasks) > 0: + # Update on disk + write_tasks(self.tasks, self.root) + + def _save_episode_metadata(self, episode_dict: dict) -> None: + """Save episode metadata to a parquet file and update the Hugging Face dataset of episodes metadata. + + This function processes episodes metadata from a dictionary, converts it into a Hugging Face dataset, + and saves it as a parquet file. It handles both the creation of new parquet files and the + updating of existing ones based on size constraints. After saving the metadata, it reloads + the Hugging Face dataset to ensure it is up-to-date. + + Notes: We both need to update parquet files and HF dataset: + - `pandas` loads parquet file in RAM + - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, + or loads directly from pyarrow cache. """ - Given a task in natural language, add it to the dictionary of tasks. - """ - if task in self.task_to_task_index: - raise ValueError(f"The task '{task}' already exists and can't be added twice.") + # Convert buffer into HF Dataset + episode_dict = {key: [value] for key, value in episode_dict.items()} + ep_dataset = datasets.Dataset.from_dict(episode_dict) + ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset) + df = pd.DataFrame(ep_dataset) + num_frames = episode_dict["length"][0] - task_index = self.info["total_tasks"] - self.task_to_task_index[task] = task_index - self.tasks[task_index] = task - self.info["total_tasks"] += 1 + if self.episodes is None: + # Initialize indices and frame count for a new dataset made of the first episode data + chunk_idx, file_idx = 0, 0 + df["meta/episodes/chunk_index"] = [chunk_idx] + df["meta/episodes/file_index"] = [file_idx] + df["dataset_from_index"] = [0] + df["dataset_to_index"] = [num_frames] + else: + # Retrieve information from the latest parquet file + latest_ep = self.episodes[-1] + chunk_idx = latest_ep["meta/episodes/chunk_index"] + file_idx = latest_ep["meta/episodes/file_index"] - task_dict = { - "task_index": task_index, - "task": task, - } - append_jsonlines(task_dict, self.root / TASKS_PATH) + latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + latest_size_in_mb = get_parquet_file_size_in_mb(latest_path) + + if latest_size_in_mb + ep_size_in_mb >= self.data_files_size_in_mb: + # Size limit is reached, prepare new parquet file + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + + # Update the existing pandas dataframe with new row + df["meta/episodes/chunk_index"] = [chunk_idx] + df["meta/episodes/file_index"] = [file_idx] + df["dataset_from_index"] = [latest_ep["dataset_to_index"]] + df["dataset_to_index"] = [latest_ep["dataset_to_index"] + num_frames] + + if latest_size_in_mb + ep_size_in_mb < self.data_files_size_in_mb: + # Size limit wasnt reached, concatenate latest dataframe with new one + latest_df = pd.read_parquet(latest_path) + df = pd.concat([latest_df, df], ignore_index=True) + + # Memort optimization + del latest_df + gc.collect() + + # Write the resulting dataframe from RAM to disk + path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(path, index=False) + + if self.episodes is not None: + # Remove the episodes cache directory, necessary to avoid cache bloat + cached_dir = get_hf_dataset_cache_dir(self.episodes) + if cached_dir is not None: + shutil.rmtree(cached_dir) + + self.episodes = load_episodes(self.root) def save_episode( self, @@ -250,41 +330,91 @@ class LeRobotDatasetMetadata: episode_length: int, episode_tasks: list[str], episode_stats: dict[str, dict], + episode_metadata: dict, ) -> None: - self.info["total_episodes"] += 1 - self.info["total_frames"] += episode_length - - chunk = self.get_episode_chunk(episode_index) - if chunk >= self.total_chunks: - self.info["total_chunks"] += 1 - - self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} - self.info["total_videos"] += len(self.video_keys) - - write_info(self.info, self.root) - episode_dict = { "episode_index": episode_index, "tasks": episode_tasks, "length": episode_length, } - self.episodes[episode_index] = episode_dict - write_episode(episode_dict, self.root) + episode_dict.update(episode_metadata) + episode_dict.update(flatten_dict({"stats": episode_stats})) + self._save_episode_metadata(episode_dict) - self.episodes_stats[episode_index] = episode_stats - self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats - write_episode_stats(episode_index, episode_stats, self.root) + # Update info + self.info["total_episodes"] += 1 + self.info["total_frames"] += episode_length + self.info["total_tasks"] = len(self.tasks) + self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} - def update_video_info(self) -> None: + write_info(self.info, self.root) + + self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats + write_stats(self.stats, self.root) + + def update_video_info(self, video_key: str | None = None) -> None: """ Warning: this function writes info from first episode videos, implicitly assuming that all videos have been encoded the same way. Also, this means it assumes the first episode exists. """ - for key in self.video_keys: + if video_key is not None and video_key not in self.video_keys: + raise ValueError(f"Video key {video_key} not found in dataset") + + video_keys = [video_key] if video_key is not None else self.video_keys + for key in video_keys: if not self.features[key].get("info", None): - video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) + video_path = self.root / self.video_path.format( + video_key=video_key, chunk_index=0, file_index=0 + ) self.info["features"][key]["info"] = get_video_info(video_path) + def update_chunk_settings( + self, + chunks_size: int | None = None, + data_files_size_in_mb: int | None = None, + video_files_size_in_mb: int | None = None, + ) -> None: + """Update chunk and file size settings after dataset creation. + + This allows users to customize storage organization without modifying the constructor. + These settings control how episodes are chunked and how large files can grow before + creating new ones. + + Args: + chunks_size: Maximum number of files per chunk directory. If None, keeps current value. + data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value. + video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value. + """ + if chunks_size is not None: + if chunks_size <= 0: + raise ValueError(f"chunks_size must be positive, got {chunks_size}") + self.info["chunks_size"] = chunks_size + + if data_files_size_in_mb is not None: + if data_files_size_in_mb <= 0: + raise ValueError(f"data_files_size_in_mb must be positive, got {data_files_size_in_mb}") + self.info["data_files_size_in_mb"] = data_files_size_in_mb + + if video_files_size_in_mb is not None: + if video_files_size_in_mb <= 0: + raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}") + self.info["video_files_size_in_mb"] = video_files_size_in_mb + + # Update the info file on disk + write_info(self.info, self.root) + + def get_chunk_settings(self) -> dict[str, int]: + """Get current chunk and file size settings. + + Returns: + Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb. + """ + return { + "chunks_size": self.chunks_size, + "data_files_size_in_mb": self.data_files_size_in_mb, + "video_files_size_in_mb": self.video_files_size_in_mb, + } + def __repr__(self): feature_keys = list(self.features) return ( @@ -313,12 +443,12 @@ class LeRobotDatasetMetadata: obj.root.mkdir(parents=True, exist_ok=False) - # TODO(aliberts, rcadene): implement sanity check for features features = {**features, **DEFAULT_FEATURES} _validate_feature_names(features) - obj.tasks, obj.task_to_task_index = {}, {} - obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {} + obj.tasks = None + obj.episodes = None + obj.stats = None obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type) if len(obj.video_keys) > 0 and not use_videos: raise ValueError() @@ -334,7 +464,7 @@ class LeRobotDataset(torch.utils.data.Dataset): root: str | Path | None = None, episodes: list[int] | None = None, image_transforms: Callable | None = None, - delta_timestamps: dict[list[float]] | None = None, + delta_timestamps: dict[str, list[float]] | None = None, tolerance_s: float = 1e-4, revision: str | None = None, force_cache_sync: bool = False, @@ -354,9 +484,9 @@ class LeRobotDataset(torch.utils.data.Dataset): - On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download the dataset from that address and load it, pending your dataset is compliant with - codebase_version v2.0. If your dataset has been created before this new format, you will be - prompted to convert it using our conversion script from v1.6 to v2.0, which you can find at - lerobot/datasets/v2/convert_dataset_v1_to_v2.py. + codebase_version v3.0. If your dataset has been created before this new format, you will be + prompted to convert it using our conversion script from v2.1 to v3.0, which you can find at + lerobot/datasets/v30/convert_dataset_v21_to_v30.py. 2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty @@ -377,38 +507,47 @@ class LeRobotDataset(torch.utils.data.Dataset): . ├── data │ ├── chunk-000 - │ │ ├── episode_000000.parquet - │ │ ├── episode_000001.parquet - │ │ ├── episode_000002.parquet + │ │ ├── file-000.parquet + │ │ ├── file-001.parquet │ │ └── ... │ ├── chunk-001 - │ │ ├── episode_001000.parquet - │ │ ├── episode_001001.parquet - │ │ ├── episode_001002.parquet + │ │ ├── file-000.parquet + │ │ ├── file-001.parquet │ │ └── ... │ └── ... ├── meta - │ ├── episodes.jsonl + │ ├── episodes + │ │ ├── chunk-000 + │ │ │ ├── file-000.parquet + │ │ │ ├── file-001.parquet + │ │ │ └── ... + │ │ ├── chunk-001 + │ │ │ └── ... + │ │ └── ... │ ├── info.json │ ├── stats.json - │ └── tasks.jsonl + │ └── tasks.parquet └── videos - ├── chunk-000 - │ ├── observation.images.laptop - │ │ ├── episode_000000.mp4 - │ │ ├── episode_000001.mp4 - │ │ ├── episode_000002.mp4 + ├── observation.images.laptop + │ ├── chunk-000 + │ │ ├── file-000.mp4 + │ │ ├── file-001.mp4 │ │ └── ... - │ ├── observation.images.phone - │ │ ├── episode_000000.mp4 - │ │ ├── episode_000001.mp4 - │ │ ├── episode_000002.mp4 + │ ├── chunk-001 │ │ └── ... - ├── chunk-001 + │ └── ... + ├── observation.images.phone + │ ├── chunk-000 + │ │ ├── file-000.mp4 + │ │ ├── file-001.mp4 + │ │ └── ... + │ ├── chunk-001 + │ │ └── ... + │ └── ... └── ... - Note that this file-based structure is designed to be as versatile as possible. The files are split by - episodes which allows a more granular control over which episodes one wants to use and download. The + Note that this file-based structure is designed to be as versatile as possible. Multiple episodes are + consolidated into chunked files which improves storage efficiency and loading performance. The structure of the dataset is entirely described in the info.json file, which can be easily downloaded or viewed directly on the hub before downloading any actual data. The type of files used are very simple and do not need complex tools to be read, it only uses .parquet, .json and .mp4 files (and .md @@ -468,29 +607,20 @@ class LeRobotDataset(torch.utils.data.Dataset): self.meta = LeRobotDatasetMetadata( self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync ) - if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"): - episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes] - self.stats = aggregate_stats(episodes_stats) # Load actual data try: if force_cache_sync: raise FileNotFoundError - assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths()) self.hf_dataset = self.load_hf_dataset() + # Check if cached dataset contains all requested episodes + if not self._check_cached_episodes_sufficient(): + raise FileNotFoundError("Cached dataset doesn't contain all requested episodes") except (AssertionError, FileNotFoundError, NotADirectoryError): self.revision = get_safe_version(self.repo_id, self.revision) - self.download_episodes(download_videos) + self.download(download_videos) self.hf_dataset = self.load_hf_dataset() - self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) - - # Check timestamps - timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy() - episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy() - ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()} - check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s) - # Setup delta_indices if self.delta_timestamps is not None: check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) @@ -566,7 +696,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ignore_patterns=ignore_patterns, ) - def download_episodes(self, download_videos: bool = True) -> None: + def download(self, download_videos: bool = True) -> None: """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present @@ -574,11 +704,10 @@ class LeRobotDataset(torch.utils.data.Dataset): """ # TODO(rcadene, aliberts): implement faster transfer # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads - files = None ignore_patterns = None if download_videos else "videos/" + files = None if self.episodes is not None: files = self.get_episodes_file_paths() - self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) def get_episodes_file_paths(self) -> list[Path]: @@ -591,28 +720,43 @@ class LeRobotDataset(torch.utils.data.Dataset): for ep_idx in episodes ] fpaths += video_files - + # episodes are stored in the same files, so we return unique paths only + fpaths = list(set(fpaths)) return fpaths def load_hf_dataset(self) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" - if self.episodes is None: - path = str(self.root / "data") - hf_dataset = load_dataset("parquet", data_dir=path, split="train") - else: - files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] - hf_dataset = load_dataset("parquet", data_files=files, split="train") - - # TODO(aliberts): hf_dataset.set_format("torch") + features = get_hf_features_from_features(self.features) + hf_dataset = load_nested_dataset(self.root / "data", features=features) hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset + def _check_cached_episodes_sufficient(self) -> bool: + """Check if the cached dataset contains all requested episodes.""" + if self.hf_dataset is None or len(self.hf_dataset) == 0: + return False + + # Get available episode indices from cached dataset + available_episodes = { + ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx + for ep_idx in self.hf_dataset["episode_index"] + } + + # Determine requested episodes + if self.episodes is None: + # Requesting all episodes - check if we have all episodes from metadata + requested_episodes = set(range(self.meta.total_episodes)) + else: + # Requesting specific episodes + requested_episodes = set(self.episodes) + + # Check if all requested episodes are available in cached data + return requested_episodes.issubset(available_episodes) + def create_hf_dataset(self) -> datasets.Dataset: features = get_hf_features_from_features(self.features) ft_dict = {col: [] for col in features} hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train") - - # TODO(aliberts): hf_dataset.set_format("torch") hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset @@ -644,15 +788,16 @@ class LeRobotDataset(torch.utils.data.Dataset): return get_hf_features_from_features(self.features) def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: - ep_start = self.episode_data_index["from"][ep_idx] - ep_end = self.episode_data_index["to"][ep_idx] + ep = self.meta.episodes[ep_idx] + ep_start = ep["dataset_from_index"] + ep_end = ep["dataset_to_index"] query_indices = { - key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx] + key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx] for key, delta_idx in self.delta_indices.items() } padding = { # Pad values outside of current episode range f"{key}_is_pad": torch.BoolTensor( - [(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx] + [(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx] ) for key, delta_idx in self.delta_indices.items() } @@ -666,7 +811,7 @@ class LeRobotDataset(torch.utils.data.Dataset): query_timestamps = {} for key in self.meta.video_keys: if query_indices is not None and key in query_indices: - timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] + timestamps = self.hf_dataset[query_indices[key]]["timestamp"] query_timestamps[key] = torch.stack(timestamps).tolist() else: query_timestamps[key] = [current_ts] @@ -675,7 +820,7 @@ class LeRobotDataset(torch.utils.data.Dataset): def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: return { - key: torch.stack(self.hf_dataset.select(q_idx)[key]) + key: torch.stack(self.hf_dataset[q_idx][key]) for key, q_idx in query_indices.items() if key not in self.meta.video_keys } @@ -686,10 +831,17 @@ class LeRobotDataset(torch.utils.data.Dataset): Segmentation Fault. This probably happens because a memory reference to the video loader is created in the main process and a subprocess fails to access it. """ + ep = self.meta.episodes[ep_idx] item = {} for vid_key, query_ts in query_timestamps.items(): + # Episodes are stored sequentially on a single mp4 to reduce the number of files. + # Thus we load the start timestamp of the episode on this mp4 and, + # shift the query timestamp accordingly. + from_timestamp = ep[f"videos/{vid_key}/from_timestamp"] + shifted_query_ts = [from_timestamp + ts for ts in query_ts] + video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) - frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend) + frames = decode_video_frames(video_path, shifted_query_ts, self.tolerance_s, self.video_backend) item[vid_key] = frames.squeeze(0) return item @@ -727,8 +879,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # Add task as a string task_idx = item["task_index"].item() - item["task"] = self.meta.tasks[task_idx] - + item["task"] = self.meta.tasks.iloc[task_idx].name return item def __repr__(self): @@ -758,6 +909,9 @@ class LeRobotDataset(torch.utils.data.Dataset): ) return self.root / fpath + def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path: + return self._get_image_file_path(episode_index, image_key, frame_index=0).parent + def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None: if self.image_writer is None: if isinstance(image, torch.Tensor): @@ -766,7 +920,7 @@ class LeRobotDataset(torch.utils.data.Dataset): else: self.image_writer.save_image(image=image, fpath=fpath) - def add_frame(self, frame: dict, task: str, timestamp: float | None = None) -> None: + def add_frame(self, frame: dict) -> None: """ This function only adds the frame to the episode_buffer. Apart from images — which are written in a temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method @@ -784,11 +938,10 @@ class LeRobotDataset(torch.utils.data.Dataset): # Automatically add frame_index and timestamp to episode buffer frame_index = self.episode_buffer["size"] - if timestamp is None: - timestamp = frame_index / self.fps + timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["timestamp"].append(timestamp) - self.episode_buffer["task"].append(task) + self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing # Add frame features to episode_buffer for key in frame: @@ -823,10 +976,7 @@ class LeRobotDataset(torch.utils.data.Dataset): save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to None. """ - if not episode_data: - episode_buffer = self.episode_buffer - else: - episode_buffer = episode_data + episode_buffer = episode_data if episode_data is not None else self.episode_buffer validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features) @@ -839,11 +989,8 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length) episode_buffer["episode_index"] = np.full((episode_length,), episode_index) - # Add new tasks to the tasks dictionary - for task in episode_tasks: - task_index = self.meta.get_task_index(task) - if task_index is None: - self.meta.add_task(task) + # Update tasks and task indices with new tasks if any + self.meta.save_episode_tasks(episode_tasks) # Given tasks in natural language, find their corresponding task indices episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks]) @@ -855,72 +1002,234 @@ class LeRobotDataset(torch.utils.data.Dataset): continue episode_buffer[key] = np.stack(episode_buffer[key]) + # Wait for image writer to end, so that episode stats over images can be computed self._wait_image_writer() - self._save_episode_table(episode_buffer, episode_index) ep_stats = compute_episode_stats(episode_buffer, self.features) + ep_metadata = self._save_episode_data(episode_buffer) has_video_keys = len(self.meta.video_keys) > 0 use_batched_encoding = self.batch_encoding_size > 1 if has_video_keys and not use_batched_encoding: - self.encode_episode_videos(episode_index) + for video_key in self.meta.video_keys: + ep_metadata.update(self._save_episode_video(video_key, episode_index)) - # `meta.save_episode` should be executed after encoding the videos - self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) + # `meta.save_episode` need to be executed after encoding the videos + self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata) - # Check if we should trigger batch encoding if has_video_keys and use_batched_encoding: + # Check if we should trigger batch encoding self.episodes_since_last_encoding += 1 if self.episodes_since_last_encoding == self.batch_encoding_size: start_ep = self.num_episodes - self.batch_encoding_size end_ep = self.num_episodes - logging.info( - f"Batch encoding {self.batch_encoding_size} videos for episodes {start_ep} to {end_ep - 1}" - ) - self.batch_encode_videos(start_ep, end_ep) + self._batch_save_episode_video(start_ep, end_ep) self.episodes_since_last_encoding = 0 - # Episode data index and timestamp checking - ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index]) - ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()} - check_timestamps_sync( - episode_buffer["timestamp"], - episode_buffer["episode_index"], - ep_data_index_np, - self.fps, - self.tolerance_s, + if not episode_data: + # Reset episode buffer and clean up temporary images (if not already deleted during video encoding) + self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0) + + def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None): + """ + Batch save videos for multiple episodes. + + Args: + start_episode: Starting episode index (inclusive) + end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode to the current episode. + """ + if end_episode is None: + end_episode = self.num_episodes + + logging.info( + f"Batch encoding {self.batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}" ) - # Verify that we have one parquet file per episode and the number of video files matches the number of encoded episodes - parquet_files = list(self.root.rglob("*.parquet")) - assert len(parquet_files) == self.num_episodes - video_files = list(self.root.rglob("*.mp4")) - assert len(video_files) == (self.num_episodes - self.episodes_since_last_encoding) * len( - self.meta.video_keys - ) + chunk_idx = self.meta.episodes[start_episode]["data/chunk_index"] + file_idx = self.meta.episodes[start_episode]["data/file_index"] + episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + episode_df = pd.read_parquet(episode_df_path) - if not episode_data: # Reset the buffer - self.episode_buffer = self.create_episode_buffer() + for ep_idx in range(start_episode, end_episode): + logging.info(f"Encoding videos for episode {ep_idx}") - def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None: - episode_dict = {key: episode_buffer[key] for key in self.hf_features} - ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train") + if ( + self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx + or self.meta.episodes[ep_idx]["data/file_index"] != file_idx + ): + # The current episode is in a new chunk or file. + # Save previous episode dataframe and update the Hugging Face dataset by reloading it. + episode_df.to_parquet(episode_df_path) + self.meta.episodes = load_episodes(self.root) + + # Load new episode dataframe + chunk_idx = self.meta.episodes[ep_idx]["data/chunk_index"] + file_idx = self.meta.episodes[ep_idx]["data/file_index"] + episode_df_path = self.root / DEFAULT_EPISODES_PATH.format( + chunk_index=chunk_idx, file_index=file_idx + ) + episode_df = pd.read_parquet(episode_df_path) + + # Save the current episode's video metadata to the dataframe + video_ep_metadata = {} + for video_key in self.meta.video_keys: + video_ep_metadata.update(self._save_episode_video(video_key, ep_idx)) + video_ep_metadata.pop("episode_index") + video_ep_df = pd.DataFrame(video_ep_metadata, index=[ep_idx]).convert_dtypes( + dtype_backend="pyarrow" + ) # allows NaN values along with integers + + episode_df = episode_df.combine_first(video_ep_df) + episode_df.to_parquet(episode_df_path) + self.meta.episodes = load_episodes(self.root) + + def _save_episode_data(self, episode_buffer: dict) -> dict: + """Save episode data to a parquet file and update the Hugging Face dataset of frames data. + + This function processes episodes data from a buffer, converts it into a Hugging Face dataset, + and saves it as a parquet file. It handles both the creation of new parquet files and the + updating of existing ones based on size constraints. After saving the data, it reloads + the Hugging Face dataset to ensure it is up-to-date. + + Notes: We both need to update parquet files and HF dataset: + - `pandas` loads parquet file in RAM + - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, + or loads directly from pyarrow cache. + """ + # Convert buffer into HF Dataset + ep_dict = {key: episode_buffer[key] for key in self.hf_features} + ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train") ep_dataset = embed_images(ep_dataset) - self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset]) - self.hf_dataset.set_transform(hf_transform_to_torch) - ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index) - ep_data_path.parent.mkdir(parents=True, exist_ok=True) - ep_dataset.to_parquet(ep_data_path) + ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset) + ep_num_frames = len(ep_dataset) + df = pd.DataFrame(ep_dataset) - def clear_episode_buffer(self) -> None: - episode_index = self.episode_buffer["episode_index"] + if self.meta.episodes is None: + # Initialize indices and frame count for a new dataset made of the first episode data + chunk_idx, file_idx = 0, 0 + latest_num_frames = 0 + else: + # Retrieve information from the latest parquet file + latest_ep = self.meta.episodes[-1] + chunk_idx = latest_ep["data/chunk_index"] + file_idx = latest_ep["data/file_index"] + latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) + latest_size_in_mb = get_parquet_file_size_in_mb(latest_path) + latest_num_frames = get_parquet_num_frames(latest_path) + + # Determine if a new parquet file is needed + if latest_size_in_mb + ep_size_in_mb >= self.meta.data_files_size_in_mb: + # Size limit is reached, prepare new parquet file + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) + latest_num_frames = 0 + else: + # Update the existing parquet file with new rows + latest_df = pd.read_parquet(latest_path) + df = pd.concat([latest_df, df], ignore_index=True) + + # Memort optimization + del latest_df + gc.collect() + + # Write the resulting dataframe from RAM to disk + path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) + path.parent.mkdir(parents=True, exist_ok=True) + if len(self.meta.image_keys) > 0: + to_parquet_with_hf_images(df, path) + else: + df.to_parquet(path) + + if self.hf_dataset is not None: + # Remove hf dataset cache directory, necessary to avoid cache bloat + cached_dir = get_hf_dataset_cache_dir(self.hf_dataset) + if cached_dir is not None: + shutil.rmtree(cached_dir) + + self.hf_dataset = self.load_hf_dataset() + + metadata = { + "data/chunk_index": chunk_idx, + "data/file_index": file_idx, + "dataset_from_index": latest_num_frames, + "dataset_to_index": latest_num_frames + ep_num_frames, + } + return metadata + + def _save_episode_video(self, video_key: str, episode_index: int): + # Encode episode frames into a temporary video + ep_path = self._encode_temporary_episode_video(video_key, episode_index) + ep_size_in_mb = get_video_size_in_mb(ep_path) + ep_duration_in_s = get_video_duration_in_s(ep_path) + + if self.meta.episodes is None or ( + f"videos/{video_key}/chunk_index" not in self.meta.episodes.column_names + or f"videos/{video_key}/file_index" not in self.meta.episodes.column_names + ): + # Initialize indices for a new dataset made of the first episode data + chunk_idx, file_idx = 0, 0 + latest_duration_in_s = 0.0 + new_path = self.root / self.meta.video_path.format( + video_key=video_key, chunk_index=chunk_idx, file_index=file_idx + ) + new_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(ep_path), str(new_path)) + else: + # Retrieve information from the latest updated video file (possibly several episodes ago) + latest_ep = self.meta.episodes[episode_index - 1] + chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"] + file_idx = latest_ep[f"videos/{video_key}/file_index"] + + latest_path = self.root / self.meta.video_path.format( + video_key=video_key, chunk_index=chunk_idx, file_index=file_idx + ) + latest_size_in_mb = get_video_size_in_mb(latest_path) + latest_duration_in_s = get_video_duration_in_s(latest_path) + + if latest_size_in_mb + ep_size_in_mb >= self.meta.video_files_size_in_mb: + # Move temporary episode video to a new video file in the dataset + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) + new_path = self.root / self.meta.video_path.format( + video_key=video_key, chunk_index=chunk_idx, file_index=file_idx + ) + new_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(ep_path), str(new_path)) + latest_duration_in_s = 0.0 + else: + # Update latest video file + concatenate_video_files( + [latest_path, ep_path], + latest_path, + ) + + # Remove temporary directory + shutil.rmtree(str(ep_path.parent)) + + # Update video info (only needed when first episode is encoded since it reads from episode 0) + if episode_index == 0: + self.meta.update_video_info(video_key) + write_info(self.meta.info, self.meta.root) # ensure video info always written properly + + metadata = { + "episode_index": episode_index, + f"videos/{video_key}/chunk_index": chunk_idx, + f"videos/{video_key}/file_index": file_idx, + f"videos/{video_key}/from_timestamp": latest_duration_in_s, + f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s, + } + return metadata + + def clear_episode_buffer(self, delete_images: bool = True) -> None: # Clean up image files for the current episode buffer - if self.image_writer is not None: + if delete_images: + # Wait for the async image writer to finish + if self.image_writer is not None: + self._wait_image_writer() + episode_index = self.episode_buffer["episode_index"] + if isinstance(episode_index, np.ndarray): + episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0] for cam_key in self.meta.camera_keys: - img_dir = self._get_image_file_path( - episode_index=episode_index, image_key=cam_key, frame_index=0 - ).parent + img_dir = self._get_image_file_dir(episode_index, cam_key) if img_dir.is_dir(): shutil.rmtree(img_dir) @@ -941,7 +1250,7 @@ class LeRobotDataset(torch.utils.data.Dataset): def stop_image_writer(self) -> None: """ Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to - remove the image_writer in order for the LeRobotDataset object to be picklable and parallelized. + remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized. """ if self.image_writer is not None: self.image_writer.stop() @@ -952,55 +1261,17 @@ class LeRobotDataset(torch.utils.data.Dataset): if self.image_writer is not None: self.image_writer.wait_until_done() - def encode_episode_videos(self, episode_index: int) -> None: + def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> dict: """ Use ffmpeg to convert frames stored as png into mp4 videos. Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, since video encoding with ffmpeg is already using multithreading. - - This method handles video encoding steps: - - Video encoding via ffmpeg - - Video info updating in metadata - - Raw image cleanup - - Args: - episode_index (int): Index of the episode to encode. """ - for key in self.meta.video_keys: - video_path = self.root / self.meta.get_video_file_path(episode_index, key) - if video_path.is_file(): - # Skip if video is already encoded. Could be the case when resuming data recording. - continue - img_dir = self._get_image_file_path( - episode_index=episode_index, image_key=key, frame_index=0 - ).parent - encode_video_frames(img_dir, video_path, self.fps, overwrite=True) - shutil.rmtree(img_dir) - - # Update video info (only needed when first episode is encoded since it reads from episode 0) - if len(self.meta.video_keys) > 0 and episode_index == 0: - self.meta.update_video_info() - write_info(self.meta.info, self.meta.root) # ensure video info always written properly - - def batch_encode_videos(self, start_episode: int = 0, end_episode: int | None = None) -> None: - """ - Batch encode videos for multiple episodes. - - Args: - start_episode: Starting episode index (inclusive) - end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode - """ - if end_episode is None: - end_episode = self.meta.total_episodes - - logging.info(f"Starting batch video encoding for episodes {start_episode} to {end_episode - 1}") - - # Encode all episodes with cleanup enabled for individual episodes - for ep_idx in range(start_episode, end_episode): - logging.info(f"Encoding videos for episode {ep_idx}") - self.encode_episode_videos(ep_idx) - - logging.info("Batch video encoding completed") + temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4" + img_dir = self._get_image_file_dir(episode_index, video_key) + encode_video_frames(img_dir, temp_path, self.fps, overwrite=True) + shutil.rmtree(img_dir) + return temp_path @classmethod def create( @@ -1046,7 +1317,6 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.image_transforms = None obj.delta_timestamps = None obj.delta_indices = None - obj.episode_data_index = None obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() return obj @@ -1064,7 +1334,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): root: str | Path | None = None, episodes: dict | None = None, image_transforms: Callable | None = None, - delta_timestamps: dict[list[float]] | None = None, + delta_timestamps: dict[str, list[float]] | None = None, tolerances_s: dict | None = None, download_videos: bool = True, video_backend: str | None = None, diff --git a/src/lerobot/datasets/online_buffer.py b/src/lerobot/datasets/online_buffer.py index 79f48f49d..563d800b9 100644 --- a/src/lerobot/datasets/online_buffer.py +++ b/src/lerobot/datasets/online_buffer.py @@ -337,13 +337,11 @@ def compute_sampler_weights( if len(offline_dataset) > 0: offline_data_mask_indices = [] for start_index, end_index in zip( - offline_dataset.episode_data_index["from"], - offline_dataset.episode_data_index["to"], + offline_dataset.meta.episodes["dataset_from_index"], + offline_dataset.meta.episodes["dataset_to_index"], strict=True, ): - offline_data_mask_indices.extend( - range(start_index.item(), end_index.item() - offline_drop_n_last_frames) - ) + offline_data_mask_indices.extend(range(start_index, end_index - offline_drop_n_last_frames)) offline_data_mask = torch.zeros(len(offline_dataset), dtype=torch.bool) offline_data_mask[torch.tensor(offline_data_mask_indices)] = True weights.append( diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index 79ac7a4b2..d0bb20c27 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -21,7 +21,8 @@ import torch class EpisodeAwareSampler: def __init__( self, - episode_data_index: dict, + dataset_from_indices: list[int], + dataset_to_indices: list[int], episode_indices_to_use: list | None = None, drop_n_first_frames: int = 0, drop_n_last_frames: int = 0, @@ -30,7 +31,8 @@ class EpisodeAwareSampler: """Sampler that optionally incorporates episode boundary information. Args: - episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode. + dataset_from_indices: List of indices containing the start of each episode in the dataset. + dataset_to_indices: List of indices containing the end of each episode in the dataset. episode_indices_to_use: List of episode indices to use. If None, all episodes are used. Assumes that episodes are indexed from 0 to N-1. drop_n_first_frames: Number of frames to drop from the start of each episode. @@ -39,12 +41,10 @@ class EpisodeAwareSampler: """ indices = [] for episode_idx, (start_index, end_index) in enumerate( - zip(episode_data_index["from"], episode_data_index["to"], strict=True) + zip(dataset_from_indices, dataset_to_indices, strict=True) ): if episode_indices_to_use is None or episode_idx in episode_indices_to_use: - indices.extend( - range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames) - ) + indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames)) self.indices = indices self.shuffle = shuffle diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 078c5351d..2b0d95e17 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -18,42 +18,55 @@ import importlib.resources import json import logging from collections.abc import Iterator -from itertools import accumulate from pathlib import Path from pprint import pformat -from types import SimpleNamespace from typing import Any import datasets -import jsonlines import numpy as np import packaging.version +import pandas +import pandas as pd +import pyarrow.parquet as pq import torch +from datasets import Dataset, concatenate_datasets from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub.errors import RevisionNotFoundError from PIL import Image as PILImage from torchvision import transforms -from lerobot.configs.types import DictLike, FeatureType, PolicyFeature +from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.datasets.backward_compatibility import ( - V21_MESSAGE, + FUTURE_MESSAGE, BackwardCompatibilityError, ForwardCompatibilityError, ) from lerobot.utils.utils import is_valid_numpy_dtype_string -DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk +DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk +DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file +DEFAULT_VIDEO_FILE_SIZE_IN_MB = 500 # Max size per file INFO_PATH = "meta/info.json" -EPISODES_PATH = "meta/episodes.jsonl" STATS_PATH = "meta/stats.json" -EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" -TASKS_PATH = "meta/tasks.jsonl" -DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" -DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" -DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" +EPISODES_DIR = "meta/episodes" +DATA_DIR = "data" +VIDEO_DIR = "videos" + +CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}" +DEFAULT_TASKS_PATH = "meta/tasks.parquet" +DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" +DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" +DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4" +DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png" + +LEGACY_EPISODES_PATH = "meta/episodes.jsonl" +LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" +LEGACY_TASKS_PATH = "meta/tasks.jsonl" +LEGACY_DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" +LEGACY_DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" DATASET_CARD_TEMPLATE = """ --- @@ -74,6 +87,65 @@ DEFAULT_FEATURES = { } +def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float: + metadata = pq.read_metadata(parquet_path) + total_uncompressed_size = 0 + for row_group in range(metadata.num_row_groups): + rg_metadata = metadata.row_group(row_group) + for column in range(rg_metadata.num_columns): + col_metadata = rg_metadata.column(column) + total_uncompressed_size += col_metadata.total_uncompressed_size + return total_uncompressed_size / (1024**2) + + +def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int: + return hf_ds.data.nbytes // (1024**2) + + +def get_hf_dataset_cache_dir(hf_ds: Dataset) -> Path | None: + if hf_ds.cache_files is None or len(hf_ds.cache_files) == 0: + return None + return Path(hf_ds.cache_files[0]["filename"]).parents[2] + + +def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -> tuple[int, int]: + if file_idx == chunks_size - 1: + file_idx = 0 + chunk_idx += 1 + else: + file_idx += 1 + return chunk_idx, file_idx + + +def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset: + """Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet + Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage + Concatenate all pyarrow references to return HF Dataset format + + Args: + pq_dir: Directory containing parquet files + features: Optional features schema to ensure consistent loading of complex types like images + """ + paths = sorted(pq_dir.glob("*/*.parquet")) + if len(paths) == 0: + raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") + + # TODO(rcadene): set num_proc to accelerate conversion to pyarrow + datasets = [Dataset.from_parquet(str(path), features=features) for path in paths] + return concatenate_datasets(datasets) + + +def get_parquet_num_frames(parquet_path: str | Path) -> int: + metadata = pq.read_metadata(parquet_path) + return metadata.num_rows + + +def get_video_size_in_mb(mp4_path: Path) -> float: + file_size_bytes = mp4_path.stat().st_size + file_size_mb = file_size_bytes / (1024**2) + return file_size_mb + + def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. @@ -82,6 +154,7 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}` >>> print(flatten_dict(dct)) {"a/b": 1, "a/c/d": 2, "e": 3} + ``` """ items = [] for k, v in d.items(): @@ -106,23 +179,13 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict: return outdict -def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any: - split_keys = flattened_key.split(sep) - getter = obj[split_keys[0]] - if len(split_keys) == 1: - return getter - - for key in split_keys[1:]: - getter = getter[key] - - return getter - - def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: serialized_dict = {} for key, value in flatten_dict(stats).items(): if isinstance(value, (torch.Tensor, np.ndarray)): serialized_dict[key] = value.tolist() + elif isinstance(value, list) and isinstance(value[0], (int, float, list)): + serialized_dict[key] = value elif isinstance(value, np.generic): serialized_dict[key] = value.item() elif isinstance(value, (int, float)): @@ -152,24 +215,7 @@ def write_json(data: dict, fpath: Path) -> None: json.dump(data, f, indent=4, ensure_ascii=False) -def load_jsonlines(fpath: Path) -> list[Any]: - with jsonlines.open(fpath, "r") as reader: - return list(reader) - - -def write_jsonlines(data: dict, fpath: Path) -> None: - fpath.parent.mkdir(exist_ok=True, parents=True) - with jsonlines.open(fpath, "w") as writer: - writer.write_all(data) - - -def append_jsonlines(data: dict, fpath: Path) -> None: - fpath.parent.mkdir(exist_ok=True, parents=True) - with jsonlines.open(fpath, "a") as writer: - writer.write(data) - - -def write_info(info: dict, local_dir: Path): +def write_info(info: dict, local_dir: Path) -> None: write_json(info, local_dir / INFO_PATH) @@ -180,65 +226,68 @@ def load_info(local_dir: Path) -> dict: return info -def write_stats(stats: dict, local_dir: Path): +def write_stats(stats: dict, local_dir: Path) -> None: serialized_stats = serialize_dict(stats) write_json(serialized_stats, local_dir / STATS_PATH) -def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]: +def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]: stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} return unflatten_dict(stats) -def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]: +def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None: if not (local_dir / STATS_PATH).exists(): return None stats = load_json(local_dir / STATS_PATH) return cast_stats_to_numpy(stats) -def write_task(task_index: int, task: dict, local_dir: Path): - task_dict = { - "task_index": task_index, - "task": task, - } - append_jsonlines(task_dict, local_dir / TASKS_PATH) +def write_tasks(tasks: pandas.DataFrame, local_dir: Path) -> None: + path = local_dir / DEFAULT_TASKS_PATH + path.parent.mkdir(parents=True, exist_ok=True) + tasks.to_parquet(path) -def load_tasks(local_dir: Path) -> tuple[dict, dict]: - tasks = load_jsonlines(local_dir / TASKS_PATH) - tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} - task_to_task_index = {task: task_index for task_index, task in tasks.items()} - return tasks, task_to_task_index +def load_tasks(local_dir: Path) -> pandas.DataFrame: + tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH) + return tasks -def write_episode(episode: dict, local_dir: Path): - append_jsonlines(episode, local_dir / EPISODES_PATH) +def write_episodes(episodes: Dataset, local_dir: Path) -> None: + """Write episode metadata to a parquet file in the LeRobot v3.0 format. + This function writes episode-level metadata to a single parquet file. + Used primarily during dataset conversion (v2.1 → v3.0) and in test fixtures. + + Args: + episodes: HuggingFace Dataset containing episode metadata + local_dir: Root directory where the dataset will be stored + """ + episode_size_mb = get_hf_dataset_size_in_mb(episodes) + if episode_size_mb > DEFAULT_DATA_FILE_SIZE_IN_MB: + raise NotImplementedError( + f"Episodes dataset is too large ({episode_size_mb} MB) to write to a single file. " + f"The current limit is {DEFAULT_DATA_FILE_SIZE_IN_MB} MB. " + "This function only supports single-file episode metadata. " + ) + + fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0) + fpath.parent.mkdir(parents=True, exist_ok=True) + episodes.to_parquet(fpath) -def load_episodes(local_dir: Path) -> dict: - episodes = load_jsonlines(local_dir / EPISODES_PATH) - return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])} - - -def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path): - # We wrap episode_stats in a dictionary since `episode_stats["episode_index"]` - # is a dictionary of stats and not an integer. - episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)} - append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH) - - -def load_episodes_stats(local_dir: Path) -> dict: - episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH) - return { - item["episode_index"]: cast_stats_to_numpy(item["stats"]) - for item in sorted(episodes_stats, key=lambda x: x["episode_index"]) - } +def load_episodes(local_dir: Path) -> datasets.Dataset: + episodes = load_nested_dataset(local_dir / EPISODES_DIR) + # Select episode features/columns containing references to episode data and videos + # (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.) + # This is to speedup access to these data, instead of having to load episode stats. + episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")]) + return episodes def backward_compatible_episodes_stats( stats: dict[str, dict[str, np.ndarray]], episodes: list[int] -) -> dict[str, dict[str, np.ndarray]]: +) -> dict[int, dict[str, dict[str, np.ndarray]]]: return dict.fromkeys(episodes, stats) @@ -254,7 +303,7 @@ def load_image_as_numpy( return img_array -def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): +def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]: """Get a transform function that convert items from Hugging Face dataset (pyarrow) to torch tensors. Importantly, images are converted from PIL, which corresponds to a channel last representation (h w c) of uint8 type, to a torch image representation @@ -299,7 +348,7 @@ def check_version_compatibility( if v_check.major < v_current.major and enforce_breaking_major: raise BackwardCompatibilityError(repo_id, v_check) elif v_check.minor < v_current.minor: - logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check)) + logging.warning(FUTURE_MESSAGE.format(repo_id=repo_id, version=v_check)) def get_repo_versions(repo_id: str) -> list[packaging.version.Version]: @@ -476,6 +525,9 @@ def create_empty_dataset_info( features: dict, use_videos: bool, robot_type: str | None = None, + chunks_size: int | None = None, + data_files_size_in_mb: int | None = None, + video_files_size_in_mb: int | None = None, ) -> dict: return { "codebase_version": codebase_version, @@ -483,104 +535,17 @@ def create_empty_dataset_info( "total_episodes": 0, "total_frames": 0, "total_tasks": 0, - "total_videos": 0, - "total_chunks": 0, - "chunks_size": DEFAULT_CHUNK_SIZE, + "chunks_size": chunks_size or DEFAULT_CHUNK_SIZE, + "data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB, + "video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB, "fps": fps, "splits": {}, - "data_path": DEFAULT_PARQUET_PATH, + "data_path": DEFAULT_DATA_PATH, "video_path": DEFAULT_VIDEO_PATH if use_videos else None, "features": features, } -def get_episode_data_index( - episode_dicts: dict[dict], episodes: list[int] | None = None -) -> dict[str, torch.Tensor]: - episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()} - if episodes is not None: - episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes} - - cumulative_lengths = list(accumulate(episode_lengths.values())) - return { - "from": torch.LongTensor([0] + cumulative_lengths[:-1]), - "to": torch.LongTensor(cumulative_lengths), - } - - -def check_timestamps_sync( - timestamps: np.ndarray, - episode_indices: np.ndarray, - episode_data_index: dict[str, np.ndarray], - fps: int, - tolerance_s: float, - raise_value_error: bool = True, -) -> bool: - """ - This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance - to account for possible numerical error. - - Args: - timestamps (np.ndarray): Array of timestamps in seconds. - episode_indices (np.ndarray): Array indicating the episode index for each timestamp. - episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to', - which identifies indices for the end of each episode. - fps (int): Frames per second. Used to check the expected difference between consecutive timestamps. - tolerance_s (float): Allowed deviation from the expected (1/fps) difference. - raise_value_error (bool): Whether to raise a ValueError if the check fails. - - Returns: - bool: True if all checked timestamp differences lie within tolerance, False otherwise. - - Raises: - ValueError: If the check fails and `raise_value_error` is True. - """ - if timestamps.shape != episode_indices.shape: - raise ValueError( - "timestamps and episode_indices should have the same shape. " - f"Found {timestamps.shape=} and {episode_indices.shape=}." - ) - - # Consecutive differences - diffs = np.diff(timestamps) - within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s - - # Mask to ignore differences at the boundaries between episodes - mask = np.ones(len(diffs), dtype=bool) - ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode - mask[ignored_diffs] = False - filtered_within_tolerance = within_tolerance[mask] - - # Check if all remaining diffs are within tolerance - if not np.all(filtered_within_tolerance): - # Track original indices before masking - original_indices = np.arange(len(diffs)) - filtered_indices = original_indices[mask] - outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0] - outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices] - - outside_tolerances = [] - for idx in outside_tolerance_indices: - entry = { - "timestamps": [timestamps[idx], timestamps[idx + 1]], - "diff": diffs[idx], - "episode_index": episode_indices[idx].item() - if hasattr(episode_indices[idx], "item") - else episode_indices[idx], - } - outside_tolerances.append(entry) - - if raise_value_error: - raise ValueError( - f"""One or several timestamps unexpectedly violate the tolerance inside episode range. - This might be due to synchronization issues during data collection. - \n{pformat(outside_tolerances)}""" - ) - return False - - return True - - def check_delta_timestamps( delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True ) -> bool: @@ -619,7 +584,7 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic return delta_indices -def cycle(iterable): +def cycle(iterable: Any) -> Iterator[Any]: """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe. @@ -632,7 +597,7 @@ def cycle(iterable): iterator = iter(iterable) -def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None: +def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) -> None: """Create a branch on a existing Hugging Face repo. Delete the branch if it already exists before creating it. """ @@ -685,76 +650,28 @@ def create_lerobot_dataset_card( ) -class IterableNamespace(SimpleNamespace): - """ - A namespace object that supports both dictionary-like iteration and dot notation access. - Automatically converts nested dictionaries into IterableNamespaces. - - This class extends SimpleNamespace to provide: - - Dictionary-style iteration over keys - - Access to items via both dot notation (obj.key) and brackets (obj["key"]) - - Dictionary-like methods: items(), keys(), values() - - Recursive conversion of nested dictionaries - - Args: - dictionary: Optional dictionary to initialize the namespace - **kwargs: Additional keyword arguments passed to SimpleNamespace - - Examples: - >>> data = {"name": "Alice", "details": {"age": 25}} - >>> ns = IterableNamespace(data) - >>> ns.name - 'Alice' - >>> ns.details.age - 25 - >>> list(ns.keys()) - ['name', 'details'] - >>> for key, value in ns.items(): - ... print(f"{key}: {value}") - name: Alice - details: IterableNamespace(age=25) - """ - - def __init__(self, dictionary: dict[str, Any] = None, **kwargs): - super().__init__(**kwargs) - if dictionary is not None: - for key, value in dictionary.items(): - if isinstance(value, dict): - setattr(self, key, IterableNamespace(value)) - else: - setattr(self, key, value) - - def __iter__(self) -> Iterator[str]: - return iter(vars(self)) - - def __getitem__(self, key: str) -> Any: - return vars(self)[key] - - def items(self): - return vars(self).items() - - def values(self): - return vars(self).values() - - def keys(self): - return vars(self).keys() - - -def validate_frame(frame: dict, features: dict): +def validate_frame(frame: dict, features: dict) -> None: expected_features = set(features) - set(DEFAULT_FEATURES) actual_features = set(frame) - error_message = validate_features_presence(actual_features, expected_features) + # task is a special required field that's not part of regular features + if "task" not in actual_features: + raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n") - common_features = actual_features & expected_features - for name in common_features - {"task"}: + # Remove task from actual_features for regular feature validation + actual_features_for_validation = actual_features - {"task"} + + error_message = validate_features_presence(actual_features_for_validation, expected_features) + + common_features = actual_features_for_validation & expected_features + for name in common_features: error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) if error_message: raise ValueError(error_message) -def validate_features_presence(actual_features: set[str], expected_features: set[str]): +def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str: error_message = "" missing_features = expected_features - actual_features extra_features = actual_features - expected_features @@ -769,7 +686,9 @@ def validate_features_presence(actual_features: set[str], expected_features: set return error_message -def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str): +def validate_feature_dtype_and_shape( + name: str, feature: dict, value: np.ndarray | PILImage.Image | str +) -> str: expected_dtype = feature["dtype"] expected_shape = feature["shape"] if is_valid_numpy_dtype_string(expected_dtype): @@ -784,7 +703,7 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray def validate_feature_numpy_array( name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray -): +) -> str: error_message = "" if isinstance(value, np.ndarray): actual_dtype = value.dtype @@ -801,7 +720,9 @@ def validate_feature_numpy_array( return error_message -def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image): +def validate_feature_image_or_video( + name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image +) -> str: # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. error_message = "" if isinstance(value, np.ndarray): @@ -817,13 +738,13 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value: return error_message -def validate_feature_string(name: str, value: str): +def validate_feature_string(name: str, value: str) -> str: if not isinstance(value, str): return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" return "" -def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict): +def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None: if "size" not in episode_buffer: raise ValueError("size key not found in episode_buffer") @@ -847,3 +768,11 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: f"In episode_buffer not in features: {buffer_keys - set(features)}" f"In features not in episode_buffer: {set(features) - buffer_keys}" ) + + +def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None: + """This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset. + This way, it can be loaded by HF dataset and correctly formatted images are returned. + """ + # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only + datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path) diff --git a/src/lerobot/datasets/v2/batch_convert_dataset_v1_to_v2.py b/src/lerobot/datasets/v2/batch_convert_dataset_v1_to_v2.py deleted file mode 100644 index fa99c725e..000000000 --- a/src/lerobot/datasets/v2/batch_convert_dataset_v1_to_v2.py +++ /dev/null @@ -1,884 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2. - -Note: Since the original Aloha datasets don't use shadow motors, you need to comment those out in -lerobot/configs/robot/aloha.yaml before running this script. -""" - -import traceback -from pathlib import Path -from textwrap import dedent - -from lerobot import available_datasets -from lerobot.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset -from lerobot.robots.aloha.configuration_aloha import AlohaRobotConfig - -LOCAL_DIR = Path("data/") - -# spellchecker:off -ALOHA_MOBILE_INFO = { - "robot_config": AlohaRobotConfig(), - "license": "mit", - "url": "https://mobile-aloha.github.io/", - "paper": "https://huggingface.co/papers/2401.02117", - "citation_bibtex": dedent(r""" - @inproceedings{fu2024mobile, - author = {Fu, Zipeng and Zhao, Tony Z. and Finn, Chelsea}, - title = {Mobile ALOHA: Learning Bimanual Mobile Manipulation with Low-Cost Whole-Body Teleoperation}, - booktitle = {arXiv}, - year = {2024}, - }""").lstrip(), -} -ALOHA_STATIC_INFO = { - "robot_config": AlohaRobotConfig(), - "license": "mit", - "url": "https://tonyzhaozh.github.io/aloha/", - "paper": "https://huggingface.co/papers/2304.13705", - "citation_bibtex": dedent(r""" - @article{Zhao2023LearningFB, - title={Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware}, - author={Tony Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn}, - journal={RSS}, - year={2023}, - volume={abs/2304.13705}, - url={https://huggingface.co/papers/2304.13705} - }""").lstrip(), -} -PUSHT_INFO = { - "license": "mit", - "url": "https://diffusion-policy.cs.columbia.edu/", - "paper": "https://huggingface.co/papers/2303.04137", - "citation_bibtex": dedent(r""" - @article{chi2024diffusionpolicy, - author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song}, - title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion}, - journal = {The International Journal of Robotics Research}, - year = {2024}, - }""").lstrip(), -} -XARM_INFO = { - "license": "mit", - "url": "https://www.nicklashansen.com/td-mpc/", - "paper": "https://huggingface.co/papers/2203.04955", - "citation_bibtex": dedent(r""" - @inproceedings{Hansen2022tdmpc, - title={Temporal Difference Learning for Model Predictive Control}, - author={Nicklas Hansen and Xiaolong Wang and Hao Su}, - booktitle={ICML}, - year={2022} - } - """), -} -UNITREEH_INFO = { - "license": "apache-2.0", -} - -DATASETS = { - "aloha_mobile_cabinet": { - "single_task": "Open the top cabinet, store the pot inside it then close the cabinet.", - **ALOHA_MOBILE_INFO, - }, - "aloha_mobile_chair": { - "single_task": "Push the chairs in front of the desk to place them against it.", - **ALOHA_MOBILE_INFO, - }, - "aloha_mobile_elevator": { - "single_task": "Take the elevator to the 1st floor.", - **ALOHA_MOBILE_INFO, - }, - "aloha_mobile_shrimp": { - "single_task": "Sauté the raw shrimp on both sides, then serve it in the bowl.", - **ALOHA_MOBILE_INFO, - }, - "aloha_mobile_wash_pan": { - "single_task": "Pick up the pan, rinse it in the sink and then place it in the drying rack.", - **ALOHA_MOBILE_INFO, - }, - "aloha_mobile_wipe_wine": { - "single_task": "Pick up the wet cloth on the faucet and use it to clean the spilled wine on the table and underneath the glass.", - **ALOHA_MOBILE_INFO, - }, - "aloha_static_battery": { - "single_task": "Place the battery into the slot of the remote controller.", - **ALOHA_STATIC_INFO, - }, - "aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO}, - "aloha_static_coffee": { - "single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.", - **ALOHA_STATIC_INFO, - }, - "aloha_static_coffee_new": { - "single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray.", - **ALOHA_STATIC_INFO, - }, - "aloha_static_cups_open": { - "single_task": "Pick up the plastic cup and open its lid.", - **ALOHA_STATIC_INFO, - }, - "aloha_static_fork_pick_up": { - "single_task": "Pick up the fork and place it on the plate.", - **ALOHA_STATIC_INFO, - }, - "aloha_static_pingpong_test": { - "single_task": "Transfer one of the two balls in the right glass into the left glass, then transfer it back to the right glass.", - **ALOHA_STATIC_INFO, - }, - "aloha_static_pro_pencil": { - "single_task": "Pick up the pencil with the right arm, hand it over to the left arm then place it back onto the table.", - **ALOHA_STATIC_INFO, - }, - "aloha_static_screw_driver": { - "single_task": "Pick up the screwdriver with the right arm, hand it over to the left arm then place it into the cup.", - **ALOHA_STATIC_INFO, - }, - "aloha_static_tape": { - "single_task": "Cut a small piece of tape from the tape dispenser then place it on the cardboard box's edge.", - **ALOHA_STATIC_INFO, - }, - "aloha_static_thread_velcro": { - "single_task": "Pick up the velcro cable tie with the left arm, then insert the end of the velcro tie into the other end's loop with the right arm.", - **ALOHA_STATIC_INFO, - }, - "aloha_static_towel": { - "single_task": "Pick up a piece of paper towel and place it on the spilled liquid.", - **ALOHA_STATIC_INFO, - }, - "aloha_static_vinh_cup": { - "single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.", - **ALOHA_STATIC_INFO, - }, - "aloha_static_vinh_cup_left": { - "single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.", - **ALOHA_STATIC_INFO, - }, - "aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO}, - "aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO}, - "aloha_sim_insertion_scripted_image": { - "single_task": "Insert the peg into the socket.", - **ALOHA_STATIC_INFO, - }, - "aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO}, - "aloha_sim_insertion_human_image": { - "single_task": "Insert the peg into the socket.", - **ALOHA_STATIC_INFO, - }, - "aloha_sim_transfer_cube_scripted": { - "single_task": "Pick up the cube with the right arm and transfer it to the left arm.", - **ALOHA_STATIC_INFO, - }, - "aloha_sim_transfer_cube_scripted_image": { - "single_task": "Pick up the cube with the right arm and transfer it to the left arm.", - **ALOHA_STATIC_INFO, - }, - "aloha_sim_transfer_cube_human": { - "single_task": "Pick up the cube with the right arm and transfer it to the left arm.", - **ALOHA_STATIC_INFO, - }, - "aloha_sim_transfer_cube_human_image": { - "single_task": "Pick up the cube with the right arm and transfer it to the left arm.", - **ALOHA_STATIC_INFO, - }, - "pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO}, - "pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO}, - "unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO}, - "unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO}, - "unitreeh1_two_robot_greeting": { - "single_task": "Greet the other robot with a high five.", - **UNITREEH_INFO, - }, - "unitreeh1_warehouse": { - "single_task": "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", - **UNITREEH_INFO, - }, - "xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, - "xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, - "xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, - "xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, - "xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO}, - "xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO}, - "xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO}, - "xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO}, - "umi_cup_in_the_wild": { - "single_task": "Put the cup on the plate.", - "license": "apache-2.0", - }, - "asu_table_top": { - "tasks_col": "language_instruction", - "license": "mit", - "paper": "https://link.springer.com/article/10.1007/s10514-023-10129-1", - "citation_bibtex": dedent(r""" - @inproceedings{zhou2023modularity, - title={Modularity through Attention: Efficient Training and Transfer of Language-Conditioned Policies for Robot Manipulation}, - author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Stepputtis, Simon and Amor, Heni}, - booktitle={Conference on Robot Learning}, - pages={1684--1695}, - year={2023}, - organization={PMLR} - } - @article{zhou2023learning, - title={Learning modular language-conditioned robot policies through attention}, - author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Ben Amor, Heni and Stepputtis, Simon}, - journal={Autonomous Robots}, - pages={1--21}, - year={2023}, - publisher={Springer} - }""").lstrip(), - }, - "austin_buds_dataset": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://ut-austin-rpl.github.io/BUDS-website/", - "paper": "https://huggingface.co/papers/2109.13841", - "citation_bibtex": dedent(r""" - @article{zhu2022bottom, - title={Bottom-Up Skill Discovery From Unsegmented Demonstrations for Long-Horizon Robot Manipulation}, - author={Zhu, Yifeng and Stone, Peter and Zhu, Yuke}, - journal={IEEE Robotics and Automation Letters}, - volume={7}, - number={2}, - pages={4126--4133}, - year={2022}, - publisher={IEEE} - }""").lstrip(), - }, - "austin_sailor_dataset": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://ut-austin-rpl.github.io/sailor/", - "paper": "https://huggingface.co/papers/2210.11435", - "citation_bibtex": dedent(r""" - @inproceedings{nasiriany2022sailor, - title={Learning and Retrieval from Prior Data for Skill-based Imitation Learning}, - author={Soroush Nasiriany and Tian Gao and Ajay Mandlekar and Yuke Zhu}, - booktitle={Conference on Robot Learning (CoRL)}, - year={2022} - }""").lstrip(), - }, - "austin_sirius_dataset": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://ut-austin-rpl.github.io/sirius/", - "paper": "https://huggingface.co/papers/2211.08416", - "citation_bibtex": dedent(r""" - @inproceedings{liu2022robot, - title = {Robot Learning on the Job: Human-in-the-Loop Autonomy and Learning During Deployment}, - author = {Huihan Liu and Soroush Nasiriany and Lance Zhang and Zhiyao Bao and Yuke Zhu}, - booktitle = {Robotics: Science and Systems (RSS)}, - year = {2023} - }""").lstrip(), - }, - "berkeley_autolab_ur5": { - "tasks_col": "language_instruction", - "license": "cc-by-4.0", - "url": "https://sites.google.com/view/berkeley-ur5/home", - "citation_bibtex": dedent(r""" - @misc{BerkeleyUR5Website, - title = {Berkeley {UR5} Demonstration Dataset}, - author = {Lawrence Yunliang Chen and Simeon Adebola and Ken Goldberg}, - howpublished = {https://sites.google.com/view/berkeley-ur5/home}, - }""").lstrip(), - }, - "berkeley_cable_routing": { - "tasks_col": "language_instruction", - "license": "cc-by-4.0", - "url": "https://sites.google.com/view/cablerouting/home", - "paper": "https://huggingface.co/papers/2307.08927", - "citation_bibtex": dedent(r""" - @article{luo2023multistage, - author = {Jianlan Luo and Charles Xu and Xinyang Geng and Gilbert Feng and Kuan Fang and Liam Tan and Stefan Schaal and Sergey Levine}, - title = {Multi-Stage Cable Routing through Hierarchical Imitation Learning}, - journal = {arXiv pre-print}, - year = {2023}, - url = {https://huggingface.co/papers/2307.08927}, - }""").lstrip(), - }, - "berkeley_fanuc_manipulation": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://sites.google.com/berkeley.edu/fanuc-manipulation", - "citation_bibtex": dedent(r""" - @article{fanuc_manipulation2023, - title={Fanuc Manipulation: A Dataset for Learning-based Manipulation with FANUC Mate 200iD Robot}, - author={Zhu, Xinghao and Tian, Ran and Xu, Chenfeng and Ding, Mingyu and Zhan, Wei and Tomizuka, Masayoshi}, - year={2023}, - }""").lstrip(), - }, - "berkeley_gnm_cory_hall": { - "tasks_col": "language_instruction", - "license": "mit", - "paper": "https://huggingface.co/papers/1709.10489", - "citation_bibtex": dedent(r""" - @inproceedings{kahn2018self, - title={Self-supervised deep reinforcement learning with generalized computation graphs for robot navigation}, - author={Kahn, Gregory and Villaflor, Adam and Ding, Bosen and Abbeel, Pieter and Levine, Sergey}, - booktitle={2018 IEEE international conference on robotics and automation (ICRA)}, - pages={5129--5136}, - year={2018}, - organization={IEEE} - }""").lstrip(), - }, - "berkeley_gnm_recon": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://sites.google.com/view/recon-robot", - "paper": "https://huggingface.co/papers/2104.05859", - "citation_bibtex": dedent(r""" - @inproceedings{shah2021rapid, - title={Rapid Exploration for Open-World Navigation with Latent Goal Models}, - author={Dhruv Shah and Benjamin Eysenbach and Nicholas Rhinehart and Sergey Levine}, - booktitle={5th Annual Conference on Robot Learning }, - year={2021}, - url={https://openreview.net/forum?id=d_SWJhyKfVw} - }""").lstrip(), - }, - "berkeley_gnm_sac_son": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://sites.google.com/view/SACSoN-review", - "paper": "https://huggingface.co/papers/2306.01874", - "citation_bibtex": dedent(r""" - @article{hirose2023sacson, - title={SACSoN: Scalable Autonomous Data Collection for Social Navigation}, - author={Hirose, Noriaki and Shah, Dhruv and Sridhar, Ajay and Levine, Sergey}, - journal={arXiv preprint arXiv:2306.01874}, - year={2023} - }""").lstrip(), - }, - "berkeley_mvp": { - "tasks_col": "language_instruction", - "license": "mit", - "paper": "https://huggingface.co/papers/2203.06173", - "citation_bibtex": dedent(r""" - @InProceedings{Radosavovic2022, - title = {Real-World Robot Learning with Masked Visual Pre-training}, - author = {Ilija Radosavovic and Tete Xiao and Stephen James and Pieter Abbeel and Jitendra Malik and Trevor Darrell}, - booktitle = {CoRL}, - year = {2022} - }""").lstrip(), - }, - "berkeley_rpt": { - "tasks_col": "language_instruction", - "license": "mit", - "paper": "https://huggingface.co/papers/2306.10007", - "citation_bibtex": dedent(r""" - @article{Radosavovic2023, - title={Robot Learning with Sensorimotor Pre-training}, - author={Ilija Radosavovic and Baifeng Shi and Letian Fu and Ken Goldberg and Trevor Darrell and Jitendra Malik}, - year={2023}, - journal={arXiv:2306.10007} - }""").lstrip(), - }, - "cmu_franka_exploration_dataset": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://human-world-model.github.io/", - "paper": "https://huggingface.co/papers/2308.10901", - "citation_bibtex": dedent(r""" - @inproceedings{mendonca2023structured, - title={Structured World Models from Human Videos}, - author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak}, - journal={RSS}, - year={2023} - }""").lstrip(), - }, - "cmu_play_fusion": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://play-fusion.github.io/", - "paper": "https://huggingface.co/papers/2312.04549", - "citation_bibtex": dedent(r""" - @inproceedings{chen2023playfusion, - title={PlayFusion: Skill Acquisition via Diffusion from Language-Annotated Play}, - author={Chen, Lili and Bahl, Shikhar and Pathak, Deepak}, - booktitle={CoRL}, - year={2023} - }""").lstrip(), - }, - "cmu_stretch": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://robo-affordances.github.io/", - "paper": "https://huggingface.co/papers/2304.08488", - "citation_bibtex": dedent(r""" - @inproceedings{bahl2023affordances, - title={Affordances from Human Videos as a Versatile Representation for Robotics}, - author={Bahl, Shikhar and Mendonca, Russell and Chen, Lili and Jain, Unnat and Pathak, Deepak}, - booktitle={CVPR}, - year={2023} - } - @article{mendonca2023structured, - title={Structured World Models from Human Videos}, - author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak}, - journal={CoRL}, - year={2023} - }""").lstrip(), - }, - "columbia_cairlab_pusht_real": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://diffusion-policy.cs.columbia.edu/", - "paper": "https://huggingface.co/papers/2303.04137", - "citation_bibtex": dedent(r""" - @inproceedings{chi2023diffusionpolicy, - title={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion}, - author={Chi, Cheng and Feng, Siyuan and Du, Yilun and Xu, Zhenjia and Cousineau, Eric and Burchfiel, Benjamin and Song, Shuran}, - booktitle={Proceedings of Robotics: Science and Systems (RSS)}, - year={2023} - }""").lstrip(), - }, - "conq_hose_manipulation": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://sites.google.com/view/conq-hose-manipulation-dataset/home", - "citation_bibtex": dedent(r""" - @misc{ConqHoseManipData, - author={Peter Mitrano and Dmitry Berenson}, - title={Conq Hose Manipulation Dataset, v1.15.0}, - year={2024}, - howpublished={https://sites.google.com/view/conq-hose-manipulation-dataset} - }""").lstrip(), - }, - "dlr_edan_shared_control": { - "tasks_col": "language_instruction", - "license": "mit", - "paper": "https://ieeexplore.ieee.org/document/9341156", - "citation_bibtex": dedent(r""" - @inproceedings{vogel_edan_2020, - title = {EDAN - an EMG-Controlled Daily Assistant to Help People with Physical Disabilities}, - language = {en}, - booktitle = {2020 {IEEE}/{RSJ} {International} {Conference} on {Intelligent} {Robots} and {Systems} ({IROS})}, - author = {Vogel, Jörn and Hagengruber, Annette and Iskandar, Maged and Quere, Gabriel and Leipscher, Ulrike and Bustamante, Samuel and Dietrich, Alexander and Hoeppner, Hannes and Leidner, Daniel and Albu-Schäffer, Alin}, - year = {2020} - } - @inproceedings{quere_shared_2020, - address = {Paris, France}, - title = {Shared {Control} {Templates} for {Assistive} {Robotics}}, - language = {en}, - booktitle = {2020 {IEEE} {International} {Conference} on {Robotics} and {Automation} ({ICRA})}, - author = {Quere, Gabriel and Hagengruber, Annette and Iskandar, Maged and Bustamante, Samuel and Leidner, Daniel and Stulp, Freek and Vogel, Joern}, - year = {2020}, - pages = {7}, - }""").lstrip(), - }, - "dlr_sara_grid_clamp": { - "tasks_col": "language_instruction", - "license": "mit", - "paper": "https://www.researchsquare.com/article/rs-3289569/v1", - "citation_bibtex": dedent(r""" - @article{padalkar2023guided, - title={A guided reinforcement learning approach using shared control templates for learning manipulation skills in the real world}, - author={Padalkar, Abhishek and Quere, Gabriel and Raffin, Antonin and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek}, - journal={Research square preprint rs-3289569/v1}, - year={2023} - }""").lstrip(), - }, - "dlr_sara_pour": { - "tasks_col": "language_instruction", - "license": "mit", - "paper": "https://elib.dlr.de/193739/1/padalkar2023rlsct.pdf", - "citation_bibtex": dedent(r""" - @inproceedings{padalkar2023guiding, - title={Guiding Reinforcement Learning with Shared Control Templates}, - author={Padalkar, Abhishek and Quere, Gabriel and Steinmetz, Franz and Raffin, Antonin and Nieuwenhuisen, Matthias and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek}, - booktitle={40th IEEE International Conference on Robotics and Automation, ICRA 2023}, - year={2023}, - organization={IEEE} - }""").lstrip(), - }, - "droid_100": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://droid-dataset.github.io/", - "paper": "https://huggingface.co/papers/2403.12945", - "citation_bibtex": dedent(r""" - @article{khazatsky2024droid, - title = {DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset}, - author = {Alexander Khazatsky and Karl Pertsch and Suraj Nair and Ashwin Balakrishna and Sudeep Dasari and Siddharth Karamcheti and Soroush Nasiriany and Mohan Kumar Srirama and Lawrence Yunliang Chen and Kirsty Ellis and Peter David Fagan and Joey Hejna and Masha Itkina and Marion Lepert and Yecheng Jason Ma and Patrick Tree Miller and Jimmy Wu and Suneel Belkhale and Shivin Dass and Huy Ha and Arhan Jain and Abraham Lee and Youngwoon Lee and Marius Memmel and Sungjae Park and Ilija Radosavovic and Kaiyuan Wang and Albert Zhan and Kevin Black and Cheng Chi and Kyle Beltran Hatch and Shan Lin and Jingpei Lu and Jean Mercat and Abdul Rehman and Pannag R Sanketi and Archit Sharma and Cody Simpson and Quan Vuong and Homer Rich Walke and Blake Wulfe and Ted Xiao and Jonathan Heewon Yang and Arefeh Yavary and Tony Z. Zhao and Christopher Agia and Rohan Baijal and Mateo Guaman Castro and Daphne Chen and Qiuyu Chen and Trinity Chung and Jaimyn Drake and Ethan Paul Foster and Jensen Gao and David Antonio Herrera and Minho Heo and Kyle Hsu and Jiaheng Hu and Donovon Jackson and Charlotte Le and Yunshuang Li and Kevin Lin and Roy Lin and Zehan Ma and Abhiram Maddukuri and Suvir Mirchandani and Daniel Morton and Tony Nguyen and Abigail O'Neill and Rosario Scalise and Derick Seale and Victor Son and Stephen Tian and Emi Tran and Andrew E. Wang and Yilin Wu and Annie Xie and Jingyun Yang and Patrick Yin and Yunchu Zhang and Osbert Bastani and Glen Berseth and Jeannette Bohg and Ken Goldberg and Abhinav Gupta and Abhishek Gupta and Dinesh Jayaraman and Joseph J Lim and Jitendra Malik and Roberto Martín-Martín and Subramanian Ramamoorthy and Dorsa Sadigh and Shuran Song and Jiajun Wu and Michael C. Yip and Yuke Zhu and Thomas Kollar and Sergey Levine and Chelsea Finn}, - year = {2024}, - }""").lstrip(), - }, - "fmb": { - "tasks_col": "language_instruction", - "license": "cc-by-4.0", - "url": "https://functional-manipulation-benchmark.github.io/", - "paper": "https://huggingface.co/papers/2401.08553", - "citation_bibtex": dedent(r""" - @article{luo2024fmb, - title={FMB: a Functional Manipulation Benchmark for Generalizable Robotic Learning}, - author={Luo, Jianlan and Xu, Charles and Liu, Fangchen and Tan, Liam and Lin, Zipeng and Wu, Jeffrey and Abbeel, Pieter and Levine, Sergey}, - journal={arXiv preprint arXiv:2401.08553}, - year={2024} - }""").lstrip(), - }, - "iamlab_cmu_pickup_insert": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://openreview.net/forum?id=WuBv9-IGDUA", - "paper": "https://huggingface.co/papers/2401.14502", - "citation_bibtex": dedent(r""" - @inproceedings{saxena2023multiresolution, - title={Multi-Resolution Sensing for Real-Time Control with Vision-Language Models}, - author={Saumya Saxena and Mohit Sharma and Oliver Kroemer}, - booktitle={7th Annual Conference on Robot Learning}, - year={2023}, - url={https://openreview.net/forum?id=WuBv9-IGDUA} - }""").lstrip(), - }, - "imperialcollege_sawyer_wrist_cam": { - "tasks_col": "language_instruction", - "license": "mit", - }, - "jaco_play": { - "tasks_col": "language_instruction", - "license": "cc-by-4.0", - "url": "https://github.com/clvrai/clvr_jaco_play_dataset", - "citation_bibtex": dedent(r""" - @software{dass2023jacoplay, - author = {Dass, Shivin and Yapeter, Jullian and Zhang, Jesse and Zhang, Jiahui - and Pertsch, Karl and Nikolaidis, Stefanos and Lim, Joseph J.}, - title = {CLVR Jaco Play Dataset}, - url = {https://github.com/clvrai/clvr_jaco_play_dataset}, - version = {1.0.0}, - year = {2023} - }""").lstrip(), - }, - "kaist_nonprehensile": { - "tasks_col": "language_instruction", - "license": "cc-by-4.0", - "url": "https://github.com/JaeHyung-Kim/rlds_dataset_builder", - "citation_bibtex": dedent(r""" - @article{kimpre, - title={Pre-and post-contact policy decomposition for non-prehensile manipulation with zero-shot sim-to-real transfer}, - author={Kim, Minchan and Han, Junhyek and Kim, Jaehyung and Kim, Beomjoon}, - booktitle={2023 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)}, - year={2023}, - organization={IEEE} - }""").lstrip(), - }, - "nyu_door_opening_surprising_effectiveness": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://jyopari.github.io/VINN/", - "paper": "https://huggingface.co/papers/2112.01511", - "citation_bibtex": dedent(r""" - @misc{pari2021surprising, - title={The Surprising Effectiveness of Representation Learning for Visual Imitation}, - author={Jyothish Pari and Nur Muhammad Shafiullah and Sridhar Pandian Arunachalam and Lerrel Pinto}, - year={2021}, - eprint={2112.01511}, - archivePrefix={arXiv}, - primaryClass={cs.RO} - }""").lstrip(), - }, - "nyu_franka_play_dataset": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://play-to-policy.github.io/", - "paper": "https://huggingface.co/papers/2210.10047", - "citation_bibtex": dedent(r""" - @article{cui2022play, - title = {From Play to Policy: Conditional Behavior Generation from Uncurated Robot Data}, - author = {Cui, Zichen Jeff and Wang, Yibin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel}, - journal = {arXiv preprint arXiv:2210.10047}, - year = {2022} - }""").lstrip(), - }, - "nyu_rot_dataset": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://rot-robot.github.io/", - "paper": "https://huggingface.co/papers/2206.15469", - "citation_bibtex": dedent(r""" - @inproceedings{haldar2023watch, - title={Watch and match: Supercharging imitation with regularized optimal transport}, - author={Haldar, Siddhant and Mathur, Vaibhav and Yarats, Denis and Pinto, Lerrel}, - booktitle={Conference on Robot Learning}, - pages={32--43}, - year={2023}, - organization={PMLR} - }""").lstrip(), - }, - "roboturk": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://roboturk.stanford.edu/dataset_real.html", - "paper": "PAPER", - "citation_bibtex": dedent(r""" - @inproceedings{mandlekar2019scaling, - title={Scaling robot supervision to hundreds of hours with roboturk: Robotic manipulation dataset through human reasoning and dexterity}, - author={Mandlekar, Ajay and Booher, Jonathan and Spero, Max and Tung, Albert and Gupta, Anchit and Zhu, Yuke and Garg, Animesh and Savarese, Silvio and Fei-Fei, Li}, - booktitle={2019 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)}, - pages={1048--1055}, - year={2019}, - organization={IEEE} - }""").lstrip(), - }, - "stanford_hydra_dataset": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://sites.google.com/view/hydra-il-2023", - "paper": "https://huggingface.co/papers/2306.17237", - "citation_bibtex": dedent(r""" - @article{belkhale2023hydra, - title={HYDRA: Hybrid Robot Actions for Imitation Learning}, - author={Belkhale, Suneel and Cui, Yuchen and Sadigh, Dorsa}, - journal={arxiv}, - year={2023} - }""").lstrip(), - }, - "stanford_kuka_multimodal_dataset": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://sites.google.com/view/visionandtouch", - "paper": "https://huggingface.co/papers/1810.10191", - "citation_bibtex": dedent(r""" - @inproceedings{lee2019icra, - title={Making sense of vision and touch: Self-supervised learning of multimodal representations for contact-rich tasks}, - author={Lee, Michelle A and Zhu, Yuke and Srinivasan, Krishnan and Shah, Parth and Savarese, Silvio and Fei-Fei, Li and Garg, Animesh and Bohg, Jeannette}, - booktitle={2019 IEEE International Conference on Robotics and Automation (ICRA)}, - year={2019}, - url={https://huggingface.co/papers/1810.10191} - }""").lstrip(), - }, - "stanford_robocook": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://hshi74.github.io/robocook/", - "paper": "https://huggingface.co/papers/2306.14447", - "citation_bibtex": dedent(r""" - @article{shi2023robocook, - title={RoboCook: Long-Horizon Elasto-Plastic Object Manipulation with Diverse Tools}, - author={Shi, Haochen and Xu, Huazhe and Clarke, Samuel and Li, Yunzhu and Wu, Jiajun}, - journal={arXiv preprint arXiv:2306.14447}, - year={2023} - }""").lstrip(), - }, - "taco_play": { - "tasks_col": "language_instruction", - "license": "cc-by-4.0", - "url": "https://www.kaggle.com/datasets/oiermees/taco-robot", - "paper": "https://huggingface.co/papers/2209.08959, https://huggingface.co/papers/2210.01911", - "citation_bibtex": dedent(r""" - @inproceedings{rosete2022tacorl, - author = {Erick Rosete-Beas and Oier Mees and Gabriel Kalweit and Joschka Boedecker and Wolfram Burgard}, - title = {Latent Plans for Task Agnostic Offline Reinforcement Learning}, - journal = {Proceedings of the 6th Conference on Robot Learning (CoRL)}, - year = {2022} - } - @inproceedings{mees23hulc2, - title={Grounding Language with Visual Affordances over Unstructured Data}, - author={Oier Mees and Jessica Borja-Diaz and Wolfram Burgard}, - booktitle = {Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)}, - year={2023}, - address = {London, UK} - }""").lstrip(), - }, - "tokyo_u_lsmo": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "URL", - "paper": "https://huggingface.co/papers/2107.05842", - "citation_bibtex": dedent(r""" - @Article{Osa22, - author = {Takayuki Osa}, - journal = {The International Journal of Robotics Research}, - title = {Motion Planning by Learning the Solution Manifold in Trajectory Optimization}, - year = {2022}, - number = {3}, - pages = {291--311}, - volume = {41}, - }""").lstrip(), - }, - "toto": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://toto-benchmark.org/", - "paper": "https://huggingface.co/papers/2306.00942", - "citation_bibtex": dedent(r""" - @inproceedings{zhou2023train, - author={Zhou, Gaoyue and Dean, Victoria and Srirama, Mohan Kumar and Rajeswaran, Aravind and Pari, Jyothish and Hatch, Kyle and Jain, Aryan and Yu, Tianhe and Abbeel, Pieter and Pinto, Lerrel and Finn, Chelsea and Gupta, Abhinav}, - booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)}, - title={Train Offline, Test Online: A Real Robot Learning Benchmark}, - year={2023}, - }""").lstrip(), - }, - "ucsd_kitchen_dataset": { - "tasks_col": "language_instruction", - "license": "mit", - "citation_bibtex": dedent(r""" - @ARTICLE{ucsd_kitchens, - author = {Ge Yan, Kris Wu, and Xiaolong Wang}, - title = {{ucsd kitchens Dataset}}, - year = {2023}, - month = {August} - }""").lstrip(), - }, - "ucsd_pick_and_place_dataset": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://owmcorl.github.io/#", - "paper": "https://huggingface.co/papers/2310.16029", - "citation_bibtex": dedent(r""" - @preprint{Feng2023Finetuning, - title={Finetuning Offline World Models in the Real World}, - author={Yunhai Feng, Nicklas Hansen, Ziyan Xiong, Chandramouli Rajagopalan, Xiaolong Wang}, - year={2023} - }""").lstrip(), - }, - "uiuc_d3field": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://robopil.github.io/d3fields/", - "paper": "https://huggingface.co/papers/2309.16118", - "citation_bibtex": dedent(r""" - @article{wang2023d3field, - title={D^3Field: Dynamic 3D Descriptor Fields for Generalizable Robotic Manipulation}, - author={Wang, Yixuan and Li, Zhuoran and Zhang, Mingtong and Driggs-Campbell, Katherine and Wu, Jiajun and Fei-Fei, Li and Li, Yunzhu}, - journal={arXiv preprint arXiv:}, - year={2023}, - }""").lstrip(), - }, - "usc_cloth_sim": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://uscresl.github.io/dmfd/", - "paper": "https://huggingface.co/papers/2207.10148", - "citation_bibtex": dedent(r""" - @article{salhotra2022dmfd, - author={Salhotra, Gautam and Liu, I-Chun Arthur and Dominguez-Kuhne, Marcus and Sukhatme, Gaurav S.}, - journal={IEEE Robotics and Automation Letters}, - title={Learning Deformable Object Manipulation From Expert Demonstrations}, - year={2022}, - volume={7}, - number={4}, - pages={8775-8782}, - doi={10.1109/LRA.2022.3187843} - }""").lstrip(), - }, - "utaustin_mutex": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://ut-austin-rpl.github.io/MUTEX/", - "paper": "https://huggingface.co/papers/2309.14320", - "citation_bibtex": dedent(r""" - @inproceedings{shah2023mutex, - title={{MUTEX}: Learning Unified Policies from Multimodal Task Specifications}, - author={Rutav Shah and Roberto Mart{\'\i}n-Mart{\'\i}n and Yuke Zhu}, - booktitle={7th Annual Conference on Robot Learning}, - year={2023}, - url={https://openreview.net/forum?id=PwqiqaaEzJ} - }""").lstrip(), - }, - "utokyo_pr2_opening_fridge": { - "tasks_col": "language_instruction", - "license": "mit", - "citation_bibtex": dedent(r""" - @misc{oh2023pr2utokyodatasets, - author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka}, - title={X-Embodiment U-Tokyo PR2 Datasets}, - year={2023}, - url={https://github.com/ojh6404/rlds_dataset_builder}, - }""").lstrip(), - }, - "utokyo_pr2_tabletop_manipulation": { - "tasks_col": "language_instruction", - "license": "mit", - "citation_bibtex": dedent(r""" - @misc{oh2023pr2utokyodatasets, - author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka}, - title={X-Embodiment U-Tokyo PR2 Datasets}, - year={2023}, - url={https://github.com/ojh6404/rlds_dataset_builder}, - }""").lstrip(), - }, - "utokyo_saytap": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://saytap.github.io/", - "paper": "https://huggingface.co/papers/2306.07580", - "citation_bibtex": dedent(r""" - @article{saytap2023, - author = {Yujin Tang and Wenhao Yu and Jie Tan and Heiga Zen and Aleksandra Faust and - Tatsuya Harada}, - title = {SayTap: Language to Quadrupedal Locomotion}, - eprint = {arXiv:2306.07580}, - url = {https://saytap.github.io}, - note = {https://saytap.github.io}, - year = {2023} - }""").lstrip(), - }, - "utokyo_xarm_bimanual": { - "tasks_col": "language_instruction", - "license": "cc-by-4.0", - "citation_bibtex": dedent(r""" - @misc{matsushima2023weblab, - title={Weblab xArm Dataset}, - author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo}, - year={2023}, - }""").lstrip(), - }, - "utokyo_xarm_pick_and_place": { - "tasks_col": "language_instruction", - "license": "cc-by-4.0", - "citation_bibtex": dedent(r""" - @misc{matsushima2023weblab, - title={Weblab xArm Dataset}, - author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo}, - year={2023}, - }""").lstrip(), - }, - "viola": { - "tasks_col": "language_instruction", - "license": "mit", - "url": "https://ut-austin-rpl.github.io/VIOLA/", - "paper": "https://huggingface.co/papers/2210.11339", - "citation_bibtex": dedent(r""" - @article{zhu2022viola, - title={VIOLA: Imitation Learning for Vision-Based Manipulation with Object Proposal Priors}, - author={Zhu, Yifeng and Joshi, Abhishek and Stone, Peter and Zhu, Yuke}, - journal={6th Annual Conference on Robot Learning (CoRL)}, - year={2022} - }""").lstrip(), - }, -} -# spellchecker:on - - -def batch_convert(): - status = {} - logfile = LOCAL_DIR / "conversion_log.txt" - assert set(DATASETS) == {id_.split("/")[1] for id_ in available_datasets} - for num, (name, kwargs) in enumerate(DATASETS.items()): - repo_id = f"lerobot/{name}" - print(f"\nConverting {repo_id} ({num}/{len(DATASETS)})") - print("---------------------------------------------------------") - try: - convert_dataset(repo_id, LOCAL_DIR, **kwargs) - status = f"{repo_id}: success." - with open(logfile, "a") as file: - file.write(status + "\n") - except Exception: - status = f"{repo_id}: failed\n {traceback.format_exc()}" - with open(logfile, "a") as file: - file.write(status + "\n") - continue - - -if __name__ == "__main__": - batch_convert() diff --git a/src/lerobot/datasets/v2/convert_dataset_v1_to_v2.py b/src/lerobot/datasets/v2/convert_dataset_v1_to_v2.py deleted file mode 100644 index cddfc4c18..000000000 --- a/src/lerobot/datasets/v2/convert_dataset_v1_to_v2.py +++ /dev/null @@ -1,687 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to -2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English -for each of the task performed in the dataset. This will allow to easily train models with task-conditioning. - -We support 3 different scenarios for these tasks (see instructions below): - 1. Single task dataset: all episodes of your dataset have the same single task. - 2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from - one episode to the next. - 3. Multi task episodes: episodes of your dataset may each contain several different tasks. - - -Can you can also provide a robot config .yaml file (not mandatory) to this script via the option -'--robot-config' so that it writes information about the robot (robot type, motors names) this dataset was -recorded with. For now, only Aloha/Koch type robots are supported with this option. - - -# 1. Single task dataset -If your dataset contains a single task, you can simply provide it directly via the CLI with the -'--single-task' option. - -Examples: - -```bash -python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \ - --repo-id lerobot/aloha_sim_insertion_human_image \ - --single-task "Insert the peg into the socket." \ - --robot-config lerobot/configs/robot/aloha.yaml \ - --local-dir data -``` - -```bash -python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \ - --repo-id aliberts/koch_tutorial \ - --single-task "Pick the Lego block and drop it in the box on the right." \ - --robot-config lerobot/configs/robot/koch.yaml \ - --local-dir data -``` - - -# 2. Single task episodes -If your dataset is a multi-task dataset, you have two options to provide the tasks to this script: - -- If your dataset already contains a language instruction column in its parquet file, you can simply provide - this column's name with the '--tasks-col' arg. - - Example: - - ```bash - python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \ - --repo-id lerobot/stanford_kuka_multimodal_dataset \ - --tasks-col "language_instruction" \ - --local-dir data - ``` - -- If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the - '--tasks-path' arg. This file should have the following structure where keys correspond to each - episode_index in the dataset, and values are the language instruction for that episode. - - Example: - - ```json - { - "0": "Do something", - "1": "Do something else", - "2": "Do something", - "3": "Go there", - ... - } - ``` - -# 3. Multi task episodes -If you have multiple tasks per episodes, your dataset should contain a language instruction column in its -parquet file, and you must provide this column's name with the '--tasks-col' arg. - -Example: - -```bash -python -m lerobot.datasets.v2.convert_dataset_v1_to_v2 \ - --repo-id lerobot/stanford_kuka_multimodal_dataset \ - --tasks-col "language_instruction" \ - --local-dir data -``` -""" - -import argparse -import contextlib -import filecmp -import json -import logging -import math -import shutil -import subprocess -import tempfile -from pathlib import Path - -import datasets -import pyarrow.compute as pc -import pyarrow.parquet as pq -import torch -from datasets import Dataset -from huggingface_hub import HfApi -from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError -from safetensors.torch import load_file - -from lerobot.datasets.utils import ( - DEFAULT_CHUNK_SIZE, - DEFAULT_PARQUET_PATH, - DEFAULT_VIDEO_PATH, - EPISODES_PATH, - INFO_PATH, - STATS_PATH, - TASKS_PATH, - create_branch, - create_lerobot_dataset_card, - flatten_dict, - get_safe_version, - load_json, - unflatten_dict, - write_json, - write_jsonlines, -) -from lerobot.datasets.video_utils import ( - VideoFrame, # noqa: F401 - get_image_pixel_channels, - get_video_info, -) -from lerobot.robots import RobotConfig - -V16 = "v1.6" -V20 = "v2.0" - -GITATTRIBUTES_REF = "aliberts/gitattributes_reference" -V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4" -V1_INFO_PATH = "meta_data/info.json" -V1_STATS_PATH = "meta_data/stats.safetensors" - - -def parse_robot_config(robot_cfg: RobotConfig) -> tuple[str, dict]: - if robot_cfg.type in ["aloha", "koch"]: - state_names = [ - f"{arm}_{motor}" if len(robot_cfg.follower_arms) > 1 else motor - for arm in robot_cfg.follower_arms - for motor in robot_cfg.follower_arms[arm].motors - ] - action_names = [ - # f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"] - f"{arm}_{motor}" if len(robot_cfg.leader_arms) > 1 else motor - for arm in robot_cfg.leader_arms - for motor in robot_cfg.leader_arms[arm].motors - ] - # elif robot_cfg["robot_type"] == "stretch3": TODO - else: - raise NotImplementedError( - "Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()." - ) - - return { - "robot_type": robot_cfg.type, - "names": { - "observation.state": state_names, - "observation.effort": state_names, - "action": action_names, - }, - } - - -def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None: - safetensor_path = v1_dir / V1_STATS_PATH - stats = load_file(safetensor_path) - serialized_stats = {key: value.tolist() for key, value in stats.items()} - serialized_stats = unflatten_dict(serialized_stats) - - json_path = v2_dir / STATS_PATH - json_path.parent.mkdir(exist_ok=True, parents=True) - with open(json_path, "w") as f: - json.dump(serialized_stats, f, indent=4) - - # Sanity check - with open(json_path) as f: - stats_json = json.load(f) - - stats_json = flatten_dict(stats_json) - stats_json = {key: torch.tensor(value) for key, value in stats_json.items()} - for key in stats: - torch.testing.assert_close(stats_json[key], stats[key]) - - -def get_features_from_hf_dataset( - dataset: Dataset, robot_config: RobotConfig | None = None -) -> dict[str, list]: - robot_config = parse_robot_config(robot_config) - features = {} - for key, ft in dataset.features.items(): - if isinstance(ft, datasets.Value): - dtype = ft.dtype - shape = (1,) - names = None - if isinstance(ft, datasets.Sequence): - assert isinstance(ft.feature, datasets.Value) - dtype = ft.feature.dtype - shape = (ft.length,) - motor_names = ( - robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)] - ) - assert len(motor_names) == shape[0] - names = {"motors": motor_names} - elif isinstance(ft, datasets.Image): - dtype = "image" - image = dataset[0][key] # Assuming first row - channels = get_image_pixel_channels(image) - shape = (image.height, image.width, channels) - names = ["height", "width", "channels"] - elif ft._type == "VideoFrame": - dtype = "video" - shape = None # Add shape later - names = ["height", "width", "channels"] - - features[key] = { - "dtype": dtype, - "shape": shape, - "names": names, - } - - return features - - -def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]: - df = dataset.to_pandas() - tasks = list(set(tasks_by_episodes.values())) - tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)} - episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()} - df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int) - - features = dataset.features - features["task_index"] = datasets.Value(dtype="int64") - dataset = Dataset.from_pandas(df, features=features, split="train") - return dataset, tasks - - -def add_task_index_from_tasks_col( - dataset: Dataset, tasks_col: str -) -> tuple[Dataset, dict[str, list[str]], list[str]]: - df = dataset.to_pandas() - - # HACK: This is to clean some of the instructions in our version of Open X datasets - prefix_to_clean = "tf.Tensor(b'" - suffix_to_clean = "', shape=(), dtype=string)" - df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean) - - # Create task_index col - tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict() - tasks = df[tasks_col].unique().tolist() - tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)} - df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int) - - # Build the dataset back from df - features = dataset.features - features["task_index"] = datasets.Value(dtype="int64") - dataset = Dataset.from_pandas(df, features=features, split="train") - dataset = dataset.remove_columns(tasks_col) - - return dataset, tasks, tasks_by_episode - - -def split_parquet_by_episodes( - dataset: Dataset, - total_episodes: int, - total_chunks: int, - output_dir: Path, -) -> list: - table = dataset.data.table - episode_lengths = [] - for ep_chunk in range(total_chunks): - ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk - ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes) - chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk) - (output_dir / chunk_dir).mkdir(parents=True, exist_ok=True) - for ep_idx in range(ep_chunk_start, ep_chunk_end): - ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) - episode_lengths.insert(ep_idx, len(ep_table)) - output_file = output_dir / DEFAULT_PARQUET_PATH.format( - episode_chunk=ep_chunk, episode_index=ep_idx - ) - pq.write_table(ep_table, output_file) - - return episode_lengths - - -def move_videos( - repo_id: str, - video_keys: list[str], - total_episodes: int, - total_chunks: int, - work_dir: Path, - clean_gittatributes: Path, - branch: str = "main", -) -> None: - """ - HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git - commands to fetch git lfs video files references to move them into subdirectories without having to - actually download them. - """ - _lfs_clone(repo_id, work_dir, branch) - - videos_moved = False - video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")] - if len(video_files) == 0: - video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")] - videos_moved = True # Videos have already been moved - - assert len(video_files) == total_episodes * len(video_keys) - - lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files) - - current_gittatributes = work_dir / ".gitattributes" - if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False): - fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes) - - if lfs_untracked_videos: - fix_lfs_video_files_tracking(work_dir, video_files) - - if videos_moved: - return - - video_dirs = sorted(work_dir.glob("videos*/")) - for ep_chunk in range(total_chunks): - ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk - ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes) - for vid_key in video_keys: - chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format( - episode_chunk=ep_chunk, video_key=vid_key - ) - (work_dir / chunk_dir).mkdir(parents=True, exist_ok=True) - - for ep_idx in range(ep_chunk_start, ep_chunk_end): - target_path = DEFAULT_VIDEO_PATH.format( - episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx - ) - video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx) - if len(video_dirs) == 1: - video_path = video_dirs[0] / video_file - else: - for dir in video_dirs: - if (dir / video_file).is_file(): - video_path = dir / video_file - break - - video_path.rename(work_dir / target_path) - - commit_message = "Move video files into chunk subdirectories" - subprocess.run(["git", "add", "."], cwd=work_dir, check=True) - subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True) - subprocess.run(["git", "push"], cwd=work_dir, check=True) - - -def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None: - """ - HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case, - there's no other option than to download the actual files and reupload them with lfs tracking. - """ - for i in range(0, len(lfs_untracked_videos), 100): - files = lfs_untracked_videos[i : i + 100] - try: - subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True) - except subprocess.CalledProcessError as e: - print("git rm --cached ERROR:") - print(e.stderr) - subprocess.run(["git", "add", *files], cwd=work_dir, check=True) - - commit_message = "Track video files with git lfs" - subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True) - subprocess.run(["git", "push"], cwd=work_dir, check=True) - - -def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None: - shutil.copyfile(clean_gittatributes, current_gittatributes) - subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True) - subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True) - subprocess.run(["git", "push"], cwd=work_dir, check=True) - - -def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None: - subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True) - repo_url = f"https://huggingface.co/datasets/{repo_id}" - env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files - subprocess.run( - ["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)], - check=True, - env=env, - ) - - -def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]: - lfs_tracked_files = subprocess.run( - ["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True - ) - lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines()) - return [f for f in video_files if f not in lfs_tracked_files] - - -def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict: - # Assumes first episode - video_files = [ - DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) - for vid_key in video_keys - ] - hub_api = HfApi() - hub_api.snapshot_download( - repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files - ) - videos_info_dict = {} - for vid_key, vid_path in zip(video_keys, video_files, strict=True): - videos_info_dict[vid_key] = get_video_info(local_dir / vid_path) - - return videos_info_dict - - -def convert_dataset( - repo_id: str, - local_dir: Path, - single_task: str | None = None, - tasks_path: Path | None = None, - tasks_col: Path | None = None, - robot_config: RobotConfig | None = None, - test_branch: str | None = None, - **card_kwargs, -): - v1 = get_safe_version(repo_id, V16) - v1x_dir = local_dir / V16 / repo_id - v20_dir = local_dir / V20 / repo_id - v1x_dir.mkdir(parents=True, exist_ok=True) - v20_dir.mkdir(parents=True, exist_ok=True) - - hub_api = HfApi() - hub_api.snapshot_download( - repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/" - ) - branch = "main" - if test_branch: - branch = test_branch - create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset") - - metadata_v1 = load_json(v1x_dir / V1_INFO_PATH) - dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train") - features = get_features_from_hf_dataset(dataset, robot_config) - video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"] - - if single_task and "language_instruction" in dataset.column_names: - logging.warning( - "'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.", - ) - single_task = None - tasks_col = "language_instruction" - - # Episodes & chunks - episode_indices = sorted(dataset.unique("episode_index")) - total_episodes = len(episode_indices) - assert episode_indices == list(range(total_episodes)) - total_videos = total_episodes * len(video_keys) - total_chunks = total_episodes // DEFAULT_CHUNK_SIZE - if total_episodes % DEFAULT_CHUNK_SIZE != 0: - total_chunks += 1 - - # Tasks - if single_task: - tasks_by_episodes = dict.fromkeys(episode_indices, single_task) - dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes) - tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()} - elif tasks_path: - tasks_by_episodes = load_json(tasks_path) - tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()} - dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes) - tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()} - elif tasks_col: - dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col) - else: - raise ValueError - - assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks} - tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)] - write_jsonlines(tasks, v20_dir / TASKS_PATH) - features["task_index"] = { - "dtype": "int64", - "shape": (1,), - "names": None, - } - - # Videos - if video_keys: - assert metadata_v1.get("video", False) - dataset = dataset.remove_columns(video_keys) - clean_gitattr = Path( - hub_api.hf_hub_download( - repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes" - ) - ).absolute() - with tempfile.TemporaryDirectory() as tmp_video_dir: - move_videos( - repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch - ) - videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch) - for key in video_keys: - features[key]["shape"] = ( - videos_info[key].pop("video.height"), - videos_info[key].pop("video.width"), - videos_info[key].pop("video.channels"), - ) - features[key]["video_info"] = videos_info[key] - assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3) - if "encoding" in metadata_v1: - assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"] - else: - assert metadata_v1.get("video", 0) == 0 - videos_info = None - - # Split data into 1 parquet file by episode - episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir) - - if robot_config is not None: - robot_type = robot_config.type - repo_tags = [robot_type] - else: - robot_type = "unknown" - repo_tags = None - - # Episodes - episodes = [ - {"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]} - for ep_idx in episode_indices - ] - write_jsonlines(episodes, v20_dir / EPISODES_PATH) - - # Assemble metadata v2.0 - metadata_v2_0 = { - "codebase_version": V20, - "robot_type": robot_type, - "total_episodes": total_episodes, - "total_frames": len(dataset), - "total_tasks": len(tasks), - "total_videos": total_videos, - "total_chunks": total_chunks, - "chunks_size": DEFAULT_CHUNK_SIZE, - "fps": metadata_v1["fps"], - "splits": {"train": f"0:{total_episodes}"}, - "data_path": DEFAULT_PARQUET_PATH, - "video_path": DEFAULT_VIDEO_PATH if video_keys else None, - "features": features, - } - write_json(metadata_v2_0, v20_dir / INFO_PATH) - convert_stats_to_json(v1x_dir, v20_dir) - card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs) - - with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): - hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch) - - with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): - hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch) - - with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): - hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch) - - hub_api.upload_folder( - repo_id=repo_id, - path_in_repo="data", - folder_path=v20_dir / "data", - repo_type="dataset", - revision=branch, - ) - hub_api.upload_folder( - repo_id=repo_id, - path_in_repo="meta", - folder_path=v20_dir / "meta", - repo_type="dataset", - revision=branch, - ) - - card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch) - - if not test_branch: - create_branch(repo_id=repo_id, branch=V20, repo_type="dataset") - - -def make_robot_config(robot_type: str, **kwargs) -> RobotConfig: - if robot_type == "aloha": - raise NotImplementedError # TODO - - elif robot_type == "koch_follower": - from lerobot.robots.koch_follower import KochFollowerConfig - - return KochFollowerConfig(**kwargs) - elif robot_type == "so100_follower": - from lerobot.robots.so100_follower import SO100FollowerConfig - - return SO100FollowerConfig(**kwargs) - elif robot_type == "stretch": - from lerobot.robots.stretch3 import Stretch3RobotConfig - - return Stretch3RobotConfig(**kwargs) - elif robot_type == "lekiwi": - from lerobot.robots.lekiwi import LeKiwiConfig - - return LeKiwiConfig(**kwargs) - else: - raise ValueError(f"Robot type '{robot_type}' is not available.") - - -def main(): - parser = argparse.ArgumentParser() - task_args = parser.add_mutually_exclusive_group(required=True) - - parser.add_argument( - "--repo-id", - type=str, - required=True, - help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", - ) - task_args.add_argument( - "--single-task", - type=str, - help="A short but accurate description of the single task performed in the dataset.", - ) - task_args.add_argument( - "--tasks-col", - type=str, - help="The name of the column containing language instructions", - ) - task_args.add_argument( - "--tasks-path", - type=Path, - help="The path to a .json file containing one language instruction for each episode_index", - ) - parser.add_argument( - "--robot", - type=str, - default=None, - help="Robot config used for the dataset during conversion (e.g. 'koch', 'aloha', 'so100', etc.)", - ) - parser.add_argument( - "--local-dir", - type=Path, - default=None, - help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2", - ) - parser.add_argument( - "--license", - type=str, - default="apache-2.0", - help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.", - ) - parser.add_argument( - "--test-branch", - type=str, - default=None, - help="Repo branch to test your conversion first (e.g. 'v2.0.test')", - ) - - args = parser.parse_args() - if not args.local_dir: - args.local_dir = Path("/tmp/lerobot_dataset_v2") - - if args.robot is not None: - robot_config = make_robot_config(args.robot) - - del args.robot - - convert_dataset(**vars(args), robot_config=robot_config) - - -if __name__ == "__main__": - main() diff --git a/src/lerobot/datasets/v21/_remove_language_instruction.py b/src/lerobot/datasets/v21/_remove_language_instruction.py deleted file mode 100644 index 1f1cb1855..000000000 --- a/src/lerobot/datasets/v21/_remove_language_instruction.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import traceback -from pathlib import Path - -from datasets import get_dataset_config_info -from huggingface_hub import HfApi - -from lerobot import available_datasets -from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata -from lerobot.datasets.utils import INFO_PATH, write_info -from lerobot.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings - -LOCAL_DIR = Path("data/") - -hub_api = HfApi() - - -def fix_dataset(repo_id: str) -> str: - if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"): - return f"{repo_id}: skipped (not in {V20})." - - dataset_info = get_dataset_config_info(repo_id, "default") - with SuppressWarnings(): - lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True) - - meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"} - parquet_features = set(dataset_info.features) - - diff_parquet_meta = parquet_features - meta_features - diff_meta_parquet = meta_features - parquet_features - - if diff_parquet_meta: - raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}") - - if not diff_meta_parquet: - return f"{repo_id}: skipped (no diff)" - - if diff_meta_parquet: - logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}") - assert diff_meta_parquet == {"language_instruction"} - lerobot_metadata.features.pop("language_instruction") - write_info(lerobot_metadata.info, lerobot_metadata.root) - commit_info = hub_api.upload_file( - path_or_fileobj=lerobot_metadata.root / INFO_PATH, - path_in_repo=INFO_PATH, - repo_id=repo_id, - repo_type="dataset", - revision=V20, - commit_message="Remove 'language_instruction'", - create_pr=True, - ) - return f"{repo_id}: success - PR: {commit_info.pr_url}" - - -def batch_fix(): - status = {} - LOCAL_DIR.mkdir(parents=True, exist_ok=True) - logfile = LOCAL_DIR / "fix_features_v20.txt" - for num, repo_id in enumerate(available_datasets): - print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})") - print("---------------------------------------------------------") - try: - status = fix_dataset(repo_id) - except Exception: - status = f"{repo_id}: failed\n {traceback.format_exc()}" - - logging.info(status) - with open(logfile, "a") as file: - file.write(status + "\n") - - -if __name__ == "__main__": - batch_fix() diff --git a/src/lerobot/datasets/v21/batch_convert_dataset_v20_to_v21.py b/src/lerobot/datasets/v21/batch_convert_dataset_v20_to_v21.py deleted file mode 100644 index b4f1c36c4..000000000 --- a/src/lerobot/datasets/v21/batch_convert_dataset_v20_to_v21.py +++ /dev/null @@ -1,54 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1. -""" - -import traceback -from pathlib import Path - -from huggingface_hub import HfApi - -from lerobot import available_datasets -from lerobot.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset - -LOCAL_DIR = Path("data/") - - -def batch_convert(): - status = {} - LOCAL_DIR.mkdir(parents=True, exist_ok=True) - logfile = LOCAL_DIR / "conversion_log_v21.txt" - hub_api = HfApi() - for num, repo_id in enumerate(available_datasets): - print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})") - print("---------------------------------------------------------") - try: - if hub_api.revision_exists(repo_id, V21, repo_type="dataset"): - status = f"{repo_id}: success (already in {V21})." - else: - convert_dataset(repo_id) - status = f"{repo_id}: success." - except Exception: - status = f"{repo_id}: failed\n {traceback.format_exc()}" - - with open(logfile, "a") as file: - file.write(status + "\n") - - -if __name__ == "__main__": - batch_convert() diff --git a/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py b/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py deleted file mode 100644 index 4ebc1086a..000000000 --- a/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to -2.1. It will: - -- Generate per-episodes stats and writes them in `episodes_stats.jsonl` -- Check consistency between these new stats and the old ones. -- Remove the deprecated `stats.json`. -- Update codebase_version in `info.json`. -- Push this new version to the hub on the 'main' branch and tags it with "v2.1". - -Usage: - -```bash -python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 \ - --repo-id=aliberts/koch_tutorial -``` - -""" - -import argparse -import logging - -from huggingface_hub import HfApi - -from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset -from lerobot.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info -from lerobot.datasets.v21.convert_stats import check_aggregate_stats, convert_stats - -V20 = "v2.0" -V21 = "v2.1" - - -class SuppressWarnings: - def __enter__(self): - self.previous_level = logging.getLogger().getEffectiveLevel() - logging.getLogger().setLevel(logging.ERROR) - - def __exit__(self, exc_type, exc_val, exc_tb): - logging.getLogger().setLevel(self.previous_level) - - -def convert_dataset( - repo_id: str, - branch: str | None = None, - num_workers: int = 4, -): - with SuppressWarnings(): - dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True) - - if (dataset.root / EPISODES_STATS_PATH).is_file(): - (dataset.root / EPISODES_STATS_PATH).unlink() - - convert_stats(dataset, num_workers=num_workers) - ref_stats = load_stats(dataset.root) - check_aggregate_stats(dataset, ref_stats) - - dataset.meta.info["codebase_version"] = CODEBASE_VERSION - write_info(dataset.meta.info, dataset.root) - - dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/") - - # delete old stats.json file - if (dataset.root / STATS_PATH).is_file: - (dataset.root / STATS_PATH).unlink() - - hub_api = HfApi() - if hub_api.file_exists( - repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset" - ): - hub_api.delete_file( - path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset" - ) - - hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--repo-id", - type=str, - required=True, - help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset " - "(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", - ) - parser.add_argument( - "--branch", - type=str, - default=None, - help="Repo branch to push your dataset. Defaults to the main branch.", - ) - parser.add_argument( - "--num-workers", - type=int, - default=4, - help="Number of workers for parallelizing stats compute. Defaults to 4.", - ) - - args = parser.parse_args() - convert_dataset(**vars(args)) diff --git a/src/lerobot/datasets/v21/convert_stats.py b/src/lerobot/datasets/v21/convert_stats.py deleted file mode 100644 index 462781c15..000000000 --- a/src/lerobot/datasets/v21/convert_stats.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from concurrent.futures import ThreadPoolExecutor, as_completed - -import numpy as np -from tqdm import tqdm - -from lerobot.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import write_episode_stats - - -def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray: - ep_len = dataset.meta.episodes[episode_index]["length"] - sampled_indices = sample_indices(ep_len) - query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices}) - video_frames = dataset._query_videos(query_timestamps, episode_index) - return video_frames[ft_key].numpy() - - -def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int): - ep_start_idx = dataset.episode_data_index["from"][ep_idx] - ep_end_idx = dataset.episode_data_index["to"][ep_idx] - ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx)) - - ep_stats = {} - for key, ft in dataset.features.items(): - if ft["dtype"] == "video": - # We sample only for videos - ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key) - else: - ep_ft_data = np.array(ep_data[key]) - - axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0 - keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1 - ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims) - - if ft["dtype"] in ["image", "video"]: # remove batch dim - ep_stats[key] = { - k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items() - } - - dataset.meta.episodes_stats[ep_idx] = ep_stats - - -def convert_stats(dataset: LeRobotDataset, num_workers: int = 0): - assert dataset.episodes is None - print("Computing episodes stats") - total_episodes = dataset.meta.total_episodes - if num_workers > 0: - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = { - executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx - for ep_idx in range(total_episodes) - } - for future in tqdm(as_completed(futures), total=total_episodes): - future.result() - else: - for ep_idx in tqdm(range(total_episodes)): - convert_episode_stats(dataset, ep_idx) - - for ep_idx in tqdm(range(total_episodes)): - write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root) - - -def check_aggregate_stats( - dataset: LeRobotDataset, - reference_stats: dict[str, dict[str, np.ndarray]], - video_rtol_atol: tuple[float] = (1e-2, 1e-2), - default_rtol_atol: tuple[float] = (5e-6, 6e-5), -): - """Verifies that the aggregated stats from episodes_stats are close to reference stats.""" - agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values())) - for key, ft in dataset.features.items(): - # These values might need some fine-tuning - if ft["dtype"] == "video": - # to account for image sub-sampling - rtol, atol = video_rtol_atol - else: - rtol, atol = default_rtol_atol - - for stat, val in agg_stats[key].items(): - if key in reference_stats and stat in reference_stats[key]: - err_msg = f"feature='{key}' stats='{stat}'" - np.testing.assert_allclose( - val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg - ) diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py new file mode 100644 index 000000000..96bdc1897 --- /dev/null +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -0,0 +1,500 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.1 to +3.0. It will: + +- Generate per-episodes stats and writes them in `episodes_stats.jsonl` +- Check consistency between these new stats and the old ones. +- Remove the deprecated `stats.json`. +- Update codebase_version in `info.json`. +- Push this new version to the hub on the 'main' branch and tags it with "v3.0". + +Usage: + +```bash +python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \ + --repo-id=lerobot/pusht +``` + +""" + +import argparse +import shutil +from pathlib import Path +from typing import Any + +import jsonlines +import pandas as pd +import pyarrow as pa +import tqdm +from datasets import Dataset, Features, Image +from huggingface_hub import HfApi, snapshot_download +from requests import HTTPError + +from lerobot.constants import HF_LEROBOT_HOME +from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset +from lerobot.datasets.utils import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_DATA_PATH, + DEFAULT_VIDEO_FILE_SIZE_IN_MB, + DEFAULT_VIDEO_PATH, + LEGACY_EPISODES_PATH, + LEGACY_EPISODES_STATS_PATH, + LEGACY_TASKS_PATH, + cast_stats_to_numpy, + flatten_dict, + get_parquet_file_size_in_mb, + get_parquet_num_frames, + get_video_size_in_mb, + load_info, + update_chunk_file_indices, + write_episodes, + write_info, + write_stats, + write_tasks, +) +from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s + +V21 = "v2.1" + + +""" +------------------------- +OLD +data/chunk-000/episode_000000.parquet + +NEW +data/chunk-000/file_000.parquet +------------------------- +OLD +videos/chunk-000/CAMERA/episode_000000.mp4 + +NEW +videos/chunk-000/file_000.mp4 +------------------------- +OLD +episodes.jsonl +{"episode_index": 1, "tasks": ["Put the blue block in the green bowl"], "length": 266} + +NEW +meta/episodes/chunk-000/episodes_000.parquet +episode_index | video_chunk_index | video_file_index | data_chunk_index | data_file_index | tasks | length +------------------------- +OLD +tasks.jsonl +{"task_index": 1, "task": "Put the blue block in the green bowl"} + +NEW +meta/tasks/chunk-000/file_000.parquet +task_index | task +------------------------- +OLD +episodes_stats.jsonl + +NEW +meta/episodes_stats/chunk-000/file_000.parquet +episode_index | mean | std | min | max +------------------------- +UPDATE +meta/info.json +------------------------- +""" + + +def load_jsonlines(fpath: Path) -> list[Any]: + with jsonlines.open(fpath, "r") as reader: + return list(reader) + + +def legacy_load_episodes(local_dir: Path) -> dict: + episodes = load_jsonlines(local_dir / LEGACY_EPISODES_PATH) + return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])} + + +def legacy_load_episodes_stats(local_dir: Path) -> dict: + episodes_stats = load_jsonlines(local_dir / LEGACY_EPISODES_STATS_PATH) + return { + item["episode_index"]: cast_stats_to_numpy(item["stats"]) + for item in sorted(episodes_stats, key=lambda x: x["episode_index"]) + } + + +def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]: + tasks = load_jsonlines(local_dir / LEGACY_TASKS_PATH) + tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} + task_to_task_index = {task: task_index for task_index, task in tasks.items()} + return tasks, task_to_task_index + + +def convert_tasks(root, new_root): + tasks, _ = legacy_load_tasks(root) + task_indices = tasks.keys() + task_strings = tasks.values() + df_tasks = pd.DataFrame({"task_index": task_indices}, index=task_strings) + write_tasks(df_tasks, new_root) + + +def concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys): + # TODO(rcadene): to save RAM use Dataset.from_parquet(file) and concatenate_datasets + dataframes = [pd.read_parquet(file) for file in paths_to_cat] + # Concatenate all DataFrames along rows + concatenated_df = pd.concat(dataframes, ignore_index=True) + + path = new_root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + path.parent.mkdir(parents=True, exist_ok=True) + + if len(image_keys) > 0: + schema = pa.Schema.from_pandas(concatenated_df) + features = Features.from_arrow_schema(schema) + for key in image_keys: + features[key] = Image() + schema = features.arrow_schema + else: + schema = None + + concatenated_df.to_parquet(path, index=False, schema=schema) + + +def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int): + data_dir = root / "data" + ep_paths = sorted(data_dir.glob("*/*.parquet")) + + image_keys = get_image_keys(root) + + ep_idx = 0 + chunk_idx = 0 + file_idx = 0 + size_in_mb = 0 + num_frames = 0 + paths_to_cat = [] + episodes_metadata = [] + for ep_path in ep_paths: + ep_size_in_mb = get_parquet_file_size_in_mb(ep_path) + ep_num_frames = get_parquet_num_frames(ep_path) + ep_metadata = { + "episode_index": ep_idx, + "data/chunk_index": chunk_idx, + "data/file_index": file_idx, + "dataset_from_index": num_frames, + "dataset_to_index": num_frames + ep_num_frames, + } + size_in_mb += ep_size_in_mb + num_frames += ep_num_frames + episodes_metadata.append(ep_metadata) + ep_idx += 1 + + if size_in_mb < data_file_size_in_mb: + paths_to_cat.append(ep_path) + continue + + if paths_to_cat: + concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys) + + # Reset for the next file + size_in_mb = ep_size_in_mb + num_frames = ep_num_frames + paths_to_cat = [ep_path] + + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE) + + # Write remaining data if any + if paths_to_cat: + concat_data_files(paths_to_cat, new_root, chunk_idx, file_idx, image_keys) + + return episodes_metadata + + +def get_video_keys(root): + info = load_info(root) + features = info["features"] + video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"] + return video_keys + + +def get_image_keys(root): + info = load_info(root) + features = info["features"] + image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"] + return image_keys + + +def convert_videos(root: Path, new_root: Path, video_file_size_in_mb: int): + video_keys = get_video_keys(root) + if len(video_keys) == 0: + return None + + video_keys = sorted(video_keys) + + eps_metadata_per_cam = [] + for camera in video_keys: + eps_metadata = convert_videos_of_camera(root, new_root, camera, video_file_size_in_mb) + eps_metadata_per_cam.append(eps_metadata) + + num_eps_per_cam = [len(eps_cam_map) for eps_cam_map in eps_metadata_per_cam] + if len(set(num_eps_per_cam)) != 1: + raise ValueError(f"All cams dont have same number of episodes ({num_eps_per_cam}).") + + episods_metadata = [] + num_cameras = len(video_keys) + num_episodes = num_eps_per_cam[0] + for ep_idx in range(num_episodes): + # Sanity check + ep_ids = [eps_metadata_per_cam[cam_idx][ep_idx]["episode_index"] for cam_idx in range(num_cameras)] + ep_ids += [ep_idx] + if len(set(ep_ids)) != 1: + raise ValueError(f"All episode indices need to match ({ep_ids}).") + + ep_dict = {} + for cam_idx in range(num_cameras): + ep_dict.update(eps_metadata_per_cam[cam_idx][ep_idx]) + episods_metadata.append(ep_dict) + + return episods_metadata + + +def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_file_size_in_mb: int): + # Access old paths to mp4 + videos_dir = root / "videos" + ep_paths = sorted(videos_dir.glob(f"*/{video_key}/*.mp4")) + + ep_idx = 0 + chunk_idx = 0 + file_idx = 0 + size_in_mb = 0 + duration_in_s = 0.0 + paths_to_cat = [] + episodes_metadata = [] + for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"): + ep_size_in_mb = get_video_size_in_mb(ep_path) + ep_duration_in_s = get_video_duration_in_s(ep_path) + + # Check if adding this episode would exceed the limit + if size_in_mb + ep_size_in_mb >= video_file_size_in_mb and len(paths_to_cat) > 0: + # Size limit would be exceeded, save current accumulation WITHOUT this episode + concatenate_video_files( + paths_to_cat, + new_root + / DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx), + ) + + # Update episodes metadata for the file we just saved + for i, _ in enumerate(paths_to_cat): + past_ep_idx = ep_idx - len(paths_to_cat) + i + episodes_metadata[past_ep_idx][f"videos/{video_key}/chunk_index"] = chunk_idx + episodes_metadata[past_ep_idx][f"videos/{video_key}/file_index"] = file_idx + + # Move to next file and start fresh with current episode + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE) + size_in_mb = 0 + duration_in_s = 0.0 + paths_to_cat = [] + + # Add current episode metadata + ep_metadata = { + "episode_index": ep_idx, + f"videos/{video_key}/chunk_index": chunk_idx, # Will be updated when file is saved + f"videos/{video_key}/file_index": file_idx, # Will be updated when file is saved + f"videos/{video_key}/from_timestamp": duration_in_s, + f"videos/{video_key}/to_timestamp": duration_in_s + ep_duration_in_s, + } + episodes_metadata.append(ep_metadata) + + # Add current episode to accumulation + paths_to_cat.append(ep_path) + size_in_mb += ep_size_in_mb + duration_in_s += ep_duration_in_s + ep_idx += 1 + + # Write remaining videos if any + if paths_to_cat: + concatenate_video_files( + paths_to_cat, + new_root + / DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx), + ) + + # Update episodes metadata for the final file + for i, _ in enumerate(paths_to_cat): + past_ep_idx = ep_idx - len(paths_to_cat) + i + episodes_metadata[past_ep_idx][f"videos/{video_key}/chunk_index"] = chunk_idx + episodes_metadata[past_ep_idx][f"videos/{video_key}/file_index"] = file_idx + + return episodes_metadata + + +def generate_episode_metadata_dict( + episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None +): + num_episodes = len(episodes_metadata) + episodes_legacy_metadata_vals = list(episodes_legacy_metadata.values()) + episodes_stats_vals = list(episodes_stats.values()) + episodes_stats_keys = list(episodes_stats.keys()) + + for i in range(num_episodes): + ep_legacy_metadata = episodes_legacy_metadata_vals[i] + ep_metadata = episodes_metadata[i] + ep_stats = episodes_stats_vals[i] + + ep_ids_set = { + ep_legacy_metadata["episode_index"], + ep_metadata["episode_index"], + episodes_stats_keys[i], + } + + if episodes_videos is None: + ep_video = {} + else: + ep_video = episodes_videos[i] + ep_ids_set.add(ep_video["episode_index"]) + + if len(ep_ids_set) != 1: + raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).") + + ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})} + ep_dict["meta/episodes/chunk_index"] = 0 + ep_dict["meta/episodes/file_index"] = 0 + yield ep_dict + + +def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None): + episodes_legacy_metadata = legacy_load_episodes(root) + episodes_stats = legacy_load_episodes_stats(root) + + num_eps_set = {len(episodes_legacy_metadata), len(episodes_metadata)} + if episodes_video_metadata is not None: + num_eps_set.add(len(episodes_video_metadata)) + + if len(num_eps_set) != 1: + raise ValueError(f"Number of episodes is not the same ({num_eps_set}).") + + ds_episodes = Dataset.from_generator( + lambda: generate_episode_metadata_dict( + episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata + ) + ) + write_episodes(ds_episodes, new_root) + + stats = aggregate_stats(list(episodes_stats.values())) + write_stats(stats, new_root) + + +def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb): + info = load_info(root) + info["codebase_version"] = "v3.0" + del info["total_chunks"] + del info["total_videos"] + info["data_files_size_in_mb"] = data_file_size_in_mb + info["video_files_size_in_mb"] = video_file_size_in_mb + info["data_path"] = DEFAULT_DATA_PATH + info["video_path"] = DEFAULT_VIDEO_PATH + info["fps"] = float(info["fps"]) + for key in info["features"]: + if info["features"][key]["dtype"] == "video": + # already has fps in video_info + continue + info["features"][key]["fps"] = info["fps"] + write_info(info, new_root) + + +def convert_dataset( + repo_id: str, + branch: str | None = None, + data_file_size_in_mb: int | None = None, + video_file_size_in_mb: int | None = None, +): + root = HF_LEROBOT_HOME / repo_id + old_root = HF_LEROBOT_HOME / f"{repo_id}_old" + new_root = HF_LEROBOT_HOME / f"{repo_id}_v30" + + if data_file_size_in_mb is None: + data_file_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB + if video_file_size_in_mb is None: + video_file_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB + + if old_root.is_dir() and root.is_dir(): + shutil.rmtree(str(root)) + shutil.move(str(old_root), str(root)) + + if new_root.is_dir(): + shutil.rmtree(new_root) + + snapshot_download( + repo_id, + repo_type="dataset", + revision=V21, + local_dir=root, + ) + + convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb) + convert_tasks(root, new_root) + episodes_metadata = convert_data(root, new_root, data_file_size_in_mb) + episodes_videos_metadata = convert_videos(root, new_root, video_file_size_in_mb) + convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata) + + shutil.move(str(root), str(old_root)) + shutil.move(str(new_root), str(root)) + + hub_api = HfApi() + try: + hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset") + except HTTPError as e: + print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})") + pass + hub_api.delete_files( + delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"], + repo_id=repo_id, + revision=branch, + repo_type="dataset", + ) + hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") + + LeRobotDataset(repo_id).push_to_hub() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset " + "(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", + ) + parser.add_argument( + "--branch", + type=str, + default=None, + help="Repo branch to push your dataset. Defaults to the main branch.", + ) + parser.add_argument( + "--data-file-size-in-mb", + type=int, + default=None, + help="File size in MB. Defaults to 100 for data and 500 for videos.", + ) + parser.add_argument( + "--video-file-size-in-mb", + type=int, + default=None, + help="File size in MB. Defaults to 100 for data and 500 for videos.", + ) + + args = parser.parse_args() + convert_dataset(**vars(args)) diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index b05edf6bd..9d7df8d61 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -17,6 +17,7 @@ import glob import importlib import logging import shutil +import tempfile import warnings from dataclasses import dataclass, field from pathlib import Path @@ -263,7 +264,11 @@ def encode_video_frames( video_path = Path(video_path) imgs_dir = Path(imgs_dir) - video_path.parent.mkdir(parents=True, exist_ok=overwrite) + if video_path.exists() and not overwrite: + logging.warning(f"Video file already exists: {video_path}. Skipping encoding.") + return + + video_path.parent.mkdir(parents=True, exist_ok=True) # Encoders/pixel formats incompatibility check if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p": @@ -273,9 +278,9 @@ def encode_video_frames( pix_fmt = "yuv420p" # Get input frames - template = "frame_" + ("[0-9]" * 6) + ".png" + template = "frame-" + ("[0-9]" * 6) + ".png" input_list = sorted( - glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("_")[-1].split(".")[0]) + glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0]) ) # Define video output frame size (assuming all input frames are the same size) @@ -300,7 +305,7 @@ def encode_video_frames( # Set logging level if log_level is not None: - # "While less efficient, it is generally preferable to modify logging with Python’s logging" + # "While less efficient, it is generally preferable to modify logging with Python's logging" logging.getLogger("libav").setLevel(log_level) # Create and open output file (overwrite by default) @@ -331,6 +336,89 @@ def encode_video_frames( raise OSError(f"Video encoding did not work. File not found: {video_path}.") +def concatenate_video_files( + input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True +): + """ + Concatenate multiple video files into a single video file using pyav. + + This function takes a list of video input file paths and concatenates them into a single + output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast + concatenation without re-encoding. + + Args: + input_video_paths: Ordered list of input video file paths to concatenate. + output_video_path: Path to the output video file. + overwrite: Whether to overwrite the output video file if it already exists. Default is True. + + Note: + - Creates a temporary directory for intermediate files that is cleaned up after use. + - Uses ffmpeg's concat demuxer which requires all input videos to have the same + codec, resolution, and frame rate for proper concatenation. + """ + + output_video_path = Path(output_video_path) + + if output_video_path.exists() and not overwrite: + logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.") + return + + output_video_path.parent.mkdir(parents=True, exist_ok=True) + + if len(input_video_paths) == 0: + raise FileNotFoundError("No input video paths provided.") + + # Create a temporary .ffconcat file to list the input video paths + with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file: + tmp_concatenate_file.write("ffconcat version 1.0\n") + for input_path in input_video_paths: + tmp_concatenate_file.write(f"file '{str(input_path)}'\n") + tmp_concatenate_file.flush() + tmp_concatenate_path = tmp_concatenate_file.name + + # Create input and output containers + input_container = av.open( + tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"} + ) # safe = 0 allows absolute paths as well as relative paths + + tmp_output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name + output_container = av.open( + tmp_output_video_path, mode="w", options={"movflags": "faststart"} + ) # faststart is to move the metadata to the beginning of the file to speed up loading + + # Replicate input streams in output container + stream_map = {} + for input_stream in input_container.streams: + if input_stream.type in ("video", "audio", "subtitle"): # only copy compatible streams + stream_map[input_stream.index] = output_container.add_stream_from_template( + template=input_stream, opaque=True + ) + stream_map[ + input_stream.index + ].time_base = ( + input_stream.time_base + ) # set the time base to the input stream time base (missing in the codec context) + + # Demux + remux packets (no re-encode) + for packet in input_container.demux(): + # Skip packets from un-mapped streams + if packet.stream.index not in stream_map: + continue + + # Skip demux flushing packets + if packet.dts is None: + continue + + output_stream = stream_map[packet.stream.index] + packet.stream = output_stream + output_container.mux(packet) + + input_container.close() + output_container.close() + shutil.move(tmp_output_video_path, output_video_path) + Path(tmp_concatenate_path).unlink() + + @dataclass class VideoFrame: # TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo @@ -454,6 +542,28 @@ def get_image_pixel_channels(image: Image): raise ValueError("Unknown format") +def get_video_duration_in_s(video_path: Path | str) -> float: + """ + Get the duration of a video file in seconds using PyAV. + + Args: + video_path: Path to the video file. + + Returns: + Duration of the video in seconds. + """ + with av.open(str(video_path)) as container: + # Get the first video stream + video_stream = container.streams.video[0] + # Calculate duration: stream.duration * stream.time_base gives duration in seconds + if video_stream.duration is not None: + duration = float(video_stream.duration * video_stream.time_base) + else: + # Fallback to container duration if stream duration is not available + duration = float(container.duration / av.time_base) + return duration + + class VideoEncodingManager: """ Context manager that ensures proper video encoding and data cleanup even if exceptions occur. @@ -487,7 +597,7 @@ class VideoEncodingManager: f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, " f"from episode {start_ep} to {end_ep - 1}" ) - self.dataset.batch_encode_videos(start_ep, end_ep) + self.dataset._batch_save_episode_video(start_ep, end_ep) # Clean up episode images if recording was interrupted if exc_type is not None: diff --git a/src/lerobot/record.py b/src/lerobot/record.py index de397bb84..f39a05fb5 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -279,8 +279,8 @@ def record_loop( if dataset is not None: action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action") - frame = {**observation_frame, **action_frame} - dataset.add_frame(frame, task=single_task) + frame = {**observation_frame, **action_frame, "task": single_task} + dataset.add_frame(frame) if display_data: log_rerun_data(observation, action) diff --git a/src/lerobot/replay.py b/src/lerobot/replay.py index 603aa93ea..cd76d114e 100644 --- a/src/lerobot/replay.py +++ b/src/lerobot/replay.py @@ -93,11 +93,15 @@ def replay(cfg: ReplayConfig): robot = make_robot_from_config(cfg.robot) dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode]) - actions = dataset.hf_dataset.select_columns("action") + + # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 + episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.episode) + actions = episode_frames.select_columns("action") + robot.connect() log_say("Replaying episode", cfg.play_sounds, blocking=True) - for idx in range(dataset.num_frames): + for idx in range(len(episode_frames)): start_episode_t = time.perf_counter() action_array = actions[idx]["action"] diff --git a/src/lerobot/robots/viperx/README.md b/src/lerobot/robots/viperx/README.md index bbc9f7223..5b57d61f5 100644 --- a/src/lerobot/robots/viperx/README.md +++ b/src/lerobot/robots/viperx/README.md @@ -115,11 +115,11 @@ If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you c echo ${HF_USER}/aloha_test ``` -If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with: +If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with [Rerun](https://github.com/rerun-io/rerun): ```bash -python -m lerobot.scripts.visualize_dataset_html \ - --repo-id ${HF_USER}/aloha_test +python -m lerobot.scripts.visualize_dataset \ + --repo-id ${HF_USER}/aloha_test --episode 0 ``` ## Replay an episode diff --git a/src/lerobot/scripts/rl/crop_dataset_roi.py b/src/lerobot/scripts/rl/crop_dataset_roi.py index 69904b740..c4318c415 100644 --- a/src/lerobot/scripts/rl/crop_dataset_roi.py +++ b/src/lerobot/scripts/rl/crop_dataset_roi.py @@ -226,7 +226,8 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset( value = value.unsqueeze(0) new_frame[key] = value - new_dataset.add_frame(new_frame, task=task) + new_frame["task"] = task + new_dataset.add_frame(new_frame) if frame["episode_index"].item() != prev_episode_index: # Save the episode diff --git a/src/lerobot/scripts/rl/gym_manipulator.py b/src/lerobot/scripts/rl/gym_manipulator.py index c8be6b7dd..046be03e8 100644 --- a/src/lerobot/scripts/rl/gym_manipulator.py +++ b/src/lerobot/scripts/rl/gym_manipulator.py @@ -2129,7 +2129,8 @@ def record_dataset(env, policy, cfg): frame["complementary_info.discrete_penalty"] = torch.tensor( [info.get("discrete_penalty", 0.0)], dtype=torch.float32 ) - dataset.add_frame(frame, task=cfg.task) + frame["task"] = cfg.task + dataset.add_frame(frame) # Maintain consistent timing if cfg.fps: diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 235352cd8..ba3db6075 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -166,7 +166,8 @@ def train(cfg: TrainPipelineConfig): if hasattr(cfg.policy, "drop_n_last_frames"): shuffle = False sampler = EpisodeAwareSampler( - dataset.episode_data_index, + dataset.meta.episodes["dataset_from_index"], + dataset.meta.episodes["dataset_to_index"], drop_n_last_frames=cfg.policy.drop_n_last_frames, shuffle=True, ) diff --git a/src/lerobot/scripts/visualize_dataset.py b/src/lerobot/scripts/visualize_dataset.py index 51ead0dd1..dda12594a 100644 --- a/src/lerobot/scripts/visualize_dataset.py +++ b/src/lerobot/scripts/visualize_dataset.py @@ -79,8 +79,8 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset class EpisodeSampler(torch.utils.data.Sampler): def __init__(self, dataset: LeRobotDataset, episode_index: int): - from_idx = dataset.episode_data_index["from"][episode_index].item() - to_idx = dataset.episode_data_index["to"][episode_index].item() + from_idx = dataset.meta.episodes["dataset_from_index"][episode_index] + to_idx = dataset.meta.episodes["dataset_to_index"][episode_index] self.frame_ids = range(from_idx, to_idx) def __iter__(self) -> Iterator: @@ -283,7 +283,7 @@ def main(): tolerance_s = kwargs.pop("tolerance_s") logging.info("Loading dataset") - dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s) + dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s) visualize_dataset(dataset, **vars(args)) diff --git a/src/lerobot/scripts/visualize_dataset_html.py b/src/lerobot/scripts/visualize_dataset_html.py deleted file mode 100644 index a722da603..000000000 --- a/src/lerobot/scripts/visualize_dataset_html.py +++ /dev/null @@ -1,482 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Visualize data of **all** frames of any episode of a dataset of type LeRobotDataset. - -Note: The last frame of the episode doesnt always correspond to a final state. -That's because our datasets are composed of transition from state to state up to -the antepenultimate state associated to the ultimate action to arrive in the final state. -However, there might not be a transition from a final state to another state. - -Note: This script aims to visualize the data used to train the neural networks. -~What you see is what you get~. When visualizing image modality, it is often expected to observe -lossly compression artifacts since these images have been decoded from compressed mp4 videos to -save disk space. The compression factor applied has been tuned to not affect success rate. - -Example of usage: - -- Visualize data stored on a local machine: -```bash -local$ python -m lerobot.scripts.visualize_dataset_html \ - --repo-id lerobot/pusht - -local$ open http://localhost:9090 -``` - -- Visualize data stored on a distant machine with a local viewer: -```bash -distant$ python -m lerobot.scripts.visualize_dataset_html \ - --repo-id lerobot/pusht - -local$ ssh -L 9090:localhost:9090 distant # create a ssh tunnel -local$ open http://localhost:9090 -``` - -- Select episodes to visualize: -```bash -python -m lerobot.scripts.visualize_dataset_html \ - --repo-id lerobot/pusht \ - --episodes 7 3 5 1 4 -``` -""" - -import argparse -import csv -import json -import logging -import re -import shutil -import tempfile -from io import StringIO -from pathlib import Path - -import numpy as np -import pandas as pd -import requests -from flask import Flask, redirect, render_template, request, url_for - -from lerobot import available_datasets -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import IterableNamespace -from lerobot.utils.utils import init_logging - - -def run_server( - dataset: LeRobotDataset | IterableNamespace | None, - episodes: list[int] | None, - host: str, - port: str, - static_folder: Path, - template_folder: Path, -): - app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve()) - app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache - - @app.route("/") - def hommepage(dataset=dataset): - if dataset: - dataset_namespace, dataset_name = dataset.repo_id.split("/") - return redirect( - url_for( - "show_episode", - dataset_namespace=dataset_namespace, - dataset_name=dataset_name, - episode_id=0, - ) - ) - - dataset_param, episode_param = None, None - all_params = request.args - if "dataset" in all_params: - dataset_param = all_params["dataset"] - if "episode" in all_params: - episode_param = int(all_params["episode"]) - - if dataset_param: - dataset_namespace, dataset_name = dataset_param.split("/") - return redirect( - url_for( - "show_episode", - dataset_namespace=dataset_namespace, - dataset_name=dataset_name, - episode_id=episode_param if episode_param is not None else 0, - ) - ) - - featured_datasets = [ - "lerobot/aloha_static_cups_open", - "lerobot/columbia_cairlab_pusht_real", - "lerobot/taco_play", - ] - return render_template( - "visualize_dataset_homepage.html", - featured_datasets=featured_datasets, - lerobot_datasets=available_datasets, - ) - - @app.route("//") - def show_first_episode(dataset_namespace, dataset_name): - first_episode_id = 0 - return redirect( - url_for( - "show_episode", - dataset_namespace=dataset_namespace, - dataset_name=dataset_name, - episode_id=first_episode_id, - ) - ) - - @app.route("///episode_") - def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes): - repo_id = f"{dataset_namespace}/{dataset_name}" - try: - if dataset is None: - dataset = get_dataset_info(repo_id) - except FileNotFoundError: - return ( - "Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461", - 400, - ) - dataset_version = ( - str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version - ) - match = re.search(r"v(\d+)\.", dataset_version) - if match: - major_version = int(match.group(1)) - if major_version < 2: - return "Make sure to convert your LeRobotDataset to v2 & above." - - episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id) - dataset_info = { - "repo_id": f"{dataset_namespace}/{dataset_name}", - "num_samples": dataset.num_frames - if isinstance(dataset, LeRobotDataset) - else dataset.total_frames, - "num_episodes": dataset.num_episodes - if isinstance(dataset, LeRobotDataset) - else dataset.total_episodes, - "fps": dataset.fps, - } - if isinstance(dataset, LeRobotDataset): - video_paths = [ - dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys - ] - videos_info = [ - { - "url": url_for("static", filename=str(video_path).replace("\\", "/")), - "filename": video_path.parent.name, - } - for video_path in video_paths - ] - tasks = dataset.meta.episodes[episode_id]["tasks"] - else: - video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"] - videos_info = [ - { - "url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/" - + dataset.video_path.format( - episode_chunk=int(episode_id) // dataset.chunks_size, - video_key=video_key, - episode_index=episode_id, - ), - "filename": video_key, - } - for video_key in video_keys - ] - - response = requests.get( - f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5 - ) - response.raise_for_status() - # Split into lines and parse each line as JSON - tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()] - - filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id] - tasks = filtered_tasks_jsonl[0]["tasks"] - - videos_info[0]["language_instruction"] = tasks - - if episodes is None: - episodes = list( - range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes) - ) - - return render_template( - "visualize_dataset_template.html", - episode_id=episode_id, - episodes=episodes, - dataset_info=dataset_info, - videos_info=videos_info, - episode_data_csv_str=episode_data_csv_str, - columns=columns, - ignored_columns=ignored_columns, - ) - - app.run(host=host, port=port) - - -def get_ep_csv_fname(episode_id: int): - ep_csv_fname = f"episode_{episode_id}.csv" - return ep_csv_fname - - -def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index): - """Get a csv str containing timeseries data of an episode (e.g. state and action). - This file will be loaded by Dygraph javascript to plot data in real time.""" - columns = [] - - selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]] - selected_columns.remove("timestamp") - - ignored_columns = [] - for column_name in selected_columns: - shape = dataset.features[column_name]["shape"] - shape_dim = len(shape) - if shape_dim > 1: - selected_columns.remove(column_name) - ignored_columns.append(column_name) - - # init header of csv with state and action names - header = ["timestamp"] - - for column_name in selected_columns: - dim_state = ( - dataset.meta.shapes[column_name][0] - if isinstance(dataset, LeRobotDataset) - else dataset.features[column_name].shape[0] - ) - - if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]: - column_names = dataset.features[column_name]["names"] - while not isinstance(column_names, list): - column_names = list(column_names.values())[0] - else: - column_names = [f"{column_name}_{i}" for i in range(dim_state)] - columns.append({"key": column_name, "value": column_names}) - - header += column_names - - selected_columns.insert(0, "timestamp") - - if isinstance(dataset, LeRobotDataset): - from_idx = dataset.episode_data_index["from"][episode_index] - to_idx = dataset.episode_data_index["to"][episode_index] - data = ( - dataset.hf_dataset.select(range(from_idx, to_idx)) - .select_columns(selected_columns) - .with_format("pandas") - ) - else: - repo_id = dataset.repo_id - - url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format( - episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index - ) - df = pd.read_parquet(url) - data = df[selected_columns] # Select specific columns - - rows = np.hstack( - ( - np.expand_dims(data["timestamp"], axis=1), - *[np.vstack(data[col]) for col in selected_columns[1:]], - ) - ).tolist() - - # Convert data to CSV string - csv_buffer = StringIO() - csv_writer = csv.writer(csv_buffer) - # Write header - csv_writer.writerow(header) - # Write data rows - csv_writer.writerows(rows) - csv_string = csv_buffer.getvalue() - - return csv_string, columns, ignored_columns - - -def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]: - # get first frame of episode (hack to get video_path of the episode) - first_frame_idx = dataset.episode_data_index["from"][ep_index].item() - return [ - dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] - for key in dataset.meta.video_keys - ] - - -def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]: - # check if the dataset has language instructions - if "language_instruction" not in dataset.features: - return None - - # get first frame index - first_frame_idx = dataset.episode_data_index["from"][ep_index].item() - - language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"] - # TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored - # with the tf.tensor appearing in the string - return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)") - - -def get_dataset_info(repo_id: str) -> IterableNamespace: - response = requests.get( - f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5 - ) - response.raise_for_status() # Raises an HTTPError for bad responses - dataset_info = response.json() - dataset_info["repo_id"] = repo_id - return IterableNamespace(dataset_info) - - -def visualize_dataset_html( - dataset: LeRobotDataset | None, - episodes: list[int] | None = None, - output_dir: Path | None = None, - serve: bool = True, - host: str = "127.0.0.1", - port: int = 9090, - force_override: bool = False, -) -> Path | None: - init_logging() - - template_dir = Path(__file__).resolve().parent.parent / "templates" - - if output_dir is None: - # Create a temporary directory that will be automatically cleaned up - output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_") - - output_dir = Path(output_dir) - if output_dir.exists(): - if force_override: - shutil.rmtree(output_dir) - else: - logging.info(f"Output directory already exists. Loading from it: '{output_dir}'") - - output_dir.mkdir(parents=True, exist_ok=True) - - static_dir = output_dir / "static" - static_dir.mkdir(parents=True, exist_ok=True) - - if dataset is None: - if serve: - run_server( - dataset=None, - episodes=None, - host=host, - port=port, - static_folder=static_dir, - template_folder=template_dir, - ) - else: - # Create a simlink from the dataset video folder containing mp4 files to the output directory - # so that the http server can get access to the mp4 files. - if isinstance(dataset, LeRobotDataset): - ln_videos_dir = static_dir / "videos" - if not ln_videos_dir.exists(): - ln_videos_dir.symlink_to((dataset.root / "videos").resolve().as_posix()) - - if serve: - run_server(dataset, episodes, host, port, static_dir, template_dir) - - -def main(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--repo-id", - type=str, - default=None, - help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).", - ) - parser.add_argument( - "--root", - type=Path, - default=None, - help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.", - ) - parser.add_argument( - "--load-from-hf-hub", - type=int, - default=0, - help="Load videos and parquet files from HF Hub rather than local system.", - ) - parser.add_argument( - "--episodes", - type=int, - nargs="*", - default=None, - help="Episode indices to visualize (e.g. `0 1 5 6` to load episodes of index 0, 1, 5 and 6). By default loads all episodes.", - ) - parser.add_argument( - "--output-dir", - type=Path, - default=None, - help="Directory path to write html files and kickoff a web server. By default write them to 'outputs/visualize_dataset/REPO_ID'.", - ) - parser.add_argument( - "--serve", - type=int, - default=1, - help="Launch web server.", - ) - parser.add_argument( - "--host", - type=str, - default="127.0.0.1", - help="Web host used by the http server.", - ) - parser.add_argument( - "--port", - type=int, - default=9090, - help="Web port used by the http server.", - ) - parser.add_argument( - "--force-override", - type=int, - default=0, - help="Delete the output directory if it exists already.", - ) - - parser.add_argument( - "--tolerance-s", - type=float, - default=1e-4, - help=( - "Tolerance in seconds used to ensure data timestamps respect the dataset fps value" - "This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument" - "If not given, defaults to 1e-4." - ), - ) - - args = parser.parse_args() - kwargs = vars(args) - repo_id = kwargs.pop("repo_id") - load_from_hf_hub = kwargs.pop("load_from_hf_hub") - root = kwargs.pop("root") - tolerance_s = kwargs.pop("tolerance_s") - - dataset = None - if repo_id: - dataset = ( - LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s) - if not load_from_hf_hub - else get_dataset_info(repo_id) - ) - - visualize_dataset_html(dataset, **vars(args)) - - -if __name__ == "__main__": - main() diff --git a/src/lerobot/templates/visualize_dataset_homepage.html b/src/lerobot/templates/visualize_dataset_homepage.html deleted file mode 100644 index 19613afb5..000000000 --- a/src/lerobot/templates/visualize_dataset_homepage.html +++ /dev/null @@ -1,68 +0,0 @@ - - - - - - Interactive Video Background Page - - - - -

- -
-
-
-
-

LeRobot Dataset Visualizer

- - create & train your own robots - -

-
-

Example Datasets:

-
    - {% for dataset in featured_datasets %} -
  • {{ dataset }}
  • - {% endfor %} -
-
-
-
- - -
- -
- More example datasets -
    - {% for dataset in lerobot_datasets %} -
  • {{ dataset }}
  • - {% endfor %} -
-
-
- - diff --git a/src/lerobot/templates/visualize_dataset_template.html b/src/lerobot/templates/visualize_dataset_template.html deleted file mode 100644 index cf9d40f1d..000000000 --- a/src/lerobot/templates/visualize_dataset_template.html +++ /dev/null @@ -1,546 +0,0 @@ - - - - - - - - - - - {{ dataset_info.repo_id }} episode {{ episode_id }} - - - - - - - -
- - -

{{ dataset_info.repo_id }}

-
- -
    -
  • - Number of samples/frames: {{ dataset_info.num_samples }} -
  • -
  • - Number of episodes: {{ dataset_info.num_episodes }} -
  • -
  • - Frames per second: {{ dataset_info.fps }} -
  • -
- -

Episodes:

- - - - -
- -
- -
- -
- -
- - - - - -
-

- Episode {{ episode_id }} -

- - - - - -
-
- filter videos -
🔽
-
- -
-
- -
-
-
- -
- {% for video_info in videos_info %} -
-

{{ video_info.filename }}

- -
- {% endfor %} -
- - - {% if videos_info[0].language_instruction %} -

- Language Instruction: {{ videos_info[0].language_instruction }} -

- {% endif %} - - - - - -
- - - - - - -
0:00 / - 0:00 -
-
- - -
-
-
-
-

- Time: 0.00s -

-
- -
- - - - - - - - - - -
- - - - {% if ignored_columns|length > 0 %} -
- Columns {{ ignored_columns }} are NOT shown since the visualizer currently does not support 2D or 3D data. -
- {% endif %} -
- -
-
- - - - - - - - - diff --git a/src/lerobot/utils/buffer.py b/src/lerobot/utils/buffer.py index d9ffa899c..c65801896 100644 --- a/src/lerobot/utils/buffer.py +++ b/src/lerobot/utils/buffer.py @@ -565,10 +565,7 @@ class ReplayBuffer: lerobot_dataset.start_image_writer(num_processes=0, num_threads=3) # Convert transitions into episodes and frames - episode_index = 0 - lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index) - frame_idx_in_episode = 0 for idx in range(self.size): actual_idx = (self.position - self.size + idx) % self.capacity @@ -582,6 +579,7 @@ class ReplayBuffer: frame_dict["action"] = self.actions[actual_idx].cpu() frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() + frame_dict["task"] = task_name # Add complementary_info if available if self.has_complementary_info: @@ -597,19 +595,11 @@ class ReplayBuffer: frame_dict[f"complementary_info.{key}"] = val # Add to the dataset's buffer - lerobot_dataset.add_frame(frame_dict, task=task_name) - - # Move to next frame - frame_idx_in_episode += 1 + lerobot_dataset.add_frame(frame_dict) # If we reached an episode boundary, call save_episode, reset counters if self.dones[actual_idx] or self.truncateds[actual_idx]: lerobot_dataset.save_episode() - episode_index += 1 - frame_idx_in_episode = 0 - lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer( - episode_index=episode_index - ) # Save any remaining frames in the buffer if lerobot_dataset.episode_buffer["size"] > 0: diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 6e13646b0..107606fda 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -274,6 +274,16 @@ def move_cursor_up(lines): print(f"\033[{lines}A", end="") +def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float): + days = int(elapsed_time_s // (24 * 3600)) + elapsed_time_s %= 24 * 3600 + hours = int(elapsed_time_s // 3600) + elapsed_time_s %= 3600 + minutes = int(elapsed_time_s // 60) + seconds = elapsed_time_s % 60 + return days, hours, minutes, seconds + + class TimerManager: """ Lightweight utility to measure elapsed time. diff --git a/tests/artifacts/datasets/save_dataset_to_safetensors.py b/tests/artifacts/datasets/save_dataset_to_safetensors.py index 419961b20..3df42f35c 100644 --- a/tests/artifacts/datasets/save_dataset_to_safetensors.py +++ b/tests/artifacts/datasets/save_dataset_to_safetensors.py @@ -47,38 +47,22 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"): ) # save 2 first frames of first episode - i = dataset.episode_data_index["from"][0].item() + i = dataset.meta.episodes["dataset_from_index"][0] save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") # save 2 frames at the middle of first episode - i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) + i = int( + (dataset.meta.episodes["dataset_to_index"][0] - dataset.meta.episodes["dataset_from_index"][0]) / 2 + ) save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors") # save 2 last frames of first episode - i = dataset.episode_data_index["to"][0].item() + i = dataset.meta.episodes["dataset_to_index"][0] save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors") save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors") - # TODO(rcadene): Enable testing on second and last episode - # We currently cant because our test dataset only contains the first episode - - # # save 2 first frames of second episode - # i = dataset.episode_data_index["from"][1].item() - # save_file(dataset[i], repo_dir / f"frame_{i}.safetensors") - # save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors") - - # # save 2 last frames of second episode - # i = dataset.episode_data_index["to"][1].item() - # save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors") - # save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors") - - # # save 2 last frames of last episode - # i = dataset.episode_data_index["to"][-1].item() - # save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors") - # save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors") - if __name__ == "__main__": for dataset in [ diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py new file mode 100644 index 000000000..4f316f80e --- /dev/null +++ b/tests/datasets/test_aggregate.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import torch + +from lerobot.datasets.aggregate import aggregate_datasets +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from tests.fixtures.constants import DUMMY_REPO_ID + + +def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames): + """Test that total number of episodes and frames are correctly aggregated.""" + assert aggr_ds.num_episodes == expected_episodes, ( + f"Expected {expected_episodes} episodes, got {aggr_ds.num_episodes}" + ) + assert aggr_ds.num_frames == expected_frames, ( + f"Expected {expected_frames} frames, got {aggr_ds.num_frames}" + ) + + +def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1): + """Test that the content of both datasets is preserved correctly in the aggregated dataset.""" + keys_to_ignore = ["episode_index", "index", "timestamp"] + + # Test first part of dataset corresponds to ds_0, check first item (index 0) matches ds_0[0] + aggr_first_item = aggr_ds[0] + ds_0_first_item = ds_0[0] + + # Compare all keys except episode_index and index which should be updated + for key in ds_0_first_item: + if key not in keys_to_ignore: + # Handle both tensor and non-tensor data + if torch.is_tensor(aggr_first_item[key]) and torch.is_tensor(ds_0_first_item[key]): + assert torch.allclose(aggr_first_item[key], ds_0_first_item[key], atol=1e-6), ( + f"First item key '{key}' doesn't match between aggregated and ds_0" + ) + else: + assert aggr_first_item[key] == ds_0_first_item[key], ( + f"First item key '{key}' doesn't match between aggregated and ds_0" + ) + + # Check last item of ds_0 part (index len(ds_0)-1) matches ds_0[-1] + aggr_ds_0_last_item = aggr_ds[len(ds_0) - 1] + ds_0_last_item = ds_0[-1] + + for key in ds_0_last_item: + if key not in keys_to_ignore: + # Handle both tensor and non-tensor data + if torch.is_tensor(aggr_ds_0_last_item[key]) and torch.is_tensor(ds_0_last_item[key]): + assert torch.allclose(aggr_ds_0_last_item[key], ds_0_last_item[key], atol=1e-6), ( + f"Last ds_0 item key '{key}' doesn't match between aggregated and ds_0" + ) + else: + assert aggr_ds_0_last_item[key] == ds_0_last_item[key], ( + f"Last ds_0 item key '{key}' doesn't match between aggregated and ds_0" + ) + + # Test second part of dataset corresponds to ds_1 + # Check first item of ds_1 part (index len(ds_0)) matches ds_1[0] + aggr_ds_1_first_item = aggr_ds[len(ds_0)] + ds_1_first_item = ds_1[0] + + for key in ds_1_first_item: + if key not in keys_to_ignore: + # Handle both tensor and non-tensor data + if torch.is_tensor(aggr_ds_1_first_item[key]) and torch.is_tensor(ds_1_first_item[key]): + assert torch.allclose(aggr_ds_1_first_item[key], ds_1_first_item[key], atol=1e-6), ( + f"First ds_1 item key '{key}' doesn't match between aggregated and ds_1" + ) + else: + assert aggr_ds_1_first_item[key] == ds_1_first_item[key], ( + f"First ds_1 item key '{key}' doesn't match between aggregated and ds_1" + ) + + # Check last item matches ds_1[-1] + aggr_last_item = aggr_ds[-1] + ds_1_last_item = ds_1[-1] + + for key in ds_1_last_item: + if key not in keys_to_ignore: + # Handle both tensor and non-tensor data + if torch.is_tensor(aggr_last_item[key]) and torch.is_tensor(ds_1_last_item[key]): + assert torch.allclose(aggr_last_item[key], ds_1_last_item[key], atol=1e-6), ( + f"Last item key '{key}' doesn't match between aggregated and ds_1" + ) + else: + assert aggr_last_item[key] == ds_1_last_item[key], ( + f"Last item key '{key}' doesn't match between aggregated and ds_1" + ) + + +def assert_metadata_consistency(aggr_ds, ds_0, ds_1): + """Test that metadata is correctly aggregated.""" + # Test basic info + assert aggr_ds.fps == ds_0.fps == ds_1.fps, "FPS should be the same across all datasets" + assert aggr_ds.meta.info["robot_type"] == ds_0.meta.info["robot_type"] == ds_1.meta.info["robot_type"], ( + "Robot type should be the same" + ) + + # Test features are the same + assert aggr_ds.features == ds_0.features == ds_1.features, "Features should be the same" + + # Test tasks aggregation + expected_tasks = set(ds_0.meta.tasks.index) | set(ds_1.meta.tasks.index) + actual_tasks = set(aggr_ds.meta.tasks.index) + assert actual_tasks == expected_tasks, f"Expected tasks {expected_tasks}, got {actual_tasks}" + + +def assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1): + """Test that episode indices are correctly updated after aggregation.""" + # ds_0 episodes should have episode_index 0 to ds_0.num_episodes-1 + for i in range(len(ds_0)): + assert aggr_ds[i]["episode_index"] < ds_0.num_episodes, ( + f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be < {ds_0.num_episodes}" + ) + + def ds1_episodes_condition(ep_idx): + return (ep_idx >= ds_0.num_episodes) and (ep_idx < ds_0.num_episodes + ds_1.num_episodes) + + # ds_1 episodes should have episode_index ds_0.num_episodes to total_episodes-1 + for i in range(len(ds_0), len(ds_0) + len(ds_1)): + expected_min_episode_idx = ds_0.num_episodes + assert ds1_episodes_condition(aggr_ds[i]["episode_index"]), ( + f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be >= {expected_min_episode_idx}" + ) + + +def assert_video_frames_integrity(aggr_ds, ds_0, ds_1): + """Test that video frames are correctly preserved and frame indices are updated.""" + + def visual_frames_equal(frame1, frame2): + return torch.allclose(frame1, frame2) + + video_keys = list( + filter( + lambda key: aggr_ds.meta.info["features"][key]["dtype"] == "video", + aggr_ds.meta.info["features"].keys(), + ) + ) + + # Test the section corresponding to the first dataset (ds_0) + for i in range(len(ds_0)): + assert aggr_ds[i]["index"] == i, ( + f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}" + ) + for key in video_keys: + assert visual_frames_equal(aggr_ds[i][key], ds_0[i][key]), ( + f"Visual frames at position {i} should be equal between aggregated and ds_0" + ) + + # Test the section corresponding to the second dataset (ds_1) + for i in range(len(ds_0), len(ds_0) + len(ds_1)): + # The frame index in the aggregated dataset should also match its position. + assert aggr_ds[i]["index"] == i, ( + f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}" + ) + for key in video_keys: + assert visual_frames_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), ( + f"Visual frames at position {i} should be equal between aggregated and ds_1" + ) + + +def assert_dataset_iteration_works(aggr_ds): + """Test that we can iterate through the entire dataset without errors.""" + for _ in aggr_ds: + pass + + +def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): + """Test basic aggregation functionality with standard parameters.""" + ds_0_num_frames = 400 + ds_1_num_frames = 800 + ds_0_num_episodes = 10 + ds_1_num_episodes = 25 + + # Create two datasets with different number of frames and episodes + ds_0 = lerobot_dataset_factory( + root=tmp_path / "test_0", + repo_id=f"{DUMMY_REPO_ID}_0", + total_episodes=ds_0_num_episodes, + total_frames=ds_0_num_frames, + ) + ds_1 = lerobot_dataset_factory( + root=tmp_path / "test_1", + repo_id=f"{DUMMY_REPO_ID}_1", + total_episodes=ds_1_num_episodes, + total_frames=ds_1_num_frames, + ) + + aggregate_datasets( + repo_ids=[ds_0.repo_id, ds_1.repo_id], + roots=[ds_0.root, ds_1.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_aggr", + aggr_root=tmp_path / "test_aggr", + ) + + # Mock the revision to prevent Hub calls during dataset loading + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "test_aggr") + aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr") + + # Run all assertion functions + expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes + expected_total_frames = ds_0.num_frames + ds_1.num_frames + + assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames) + assert_dataset_content_integrity(aggr_ds, ds_0, ds_1) + assert_metadata_consistency(aggr_ds, ds_0, ds_1) + assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) + assert_video_frames_integrity(aggr_ds, ds_0, ds_1) + assert_dataset_iteration_works(aggr_ds) + + +def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): + """Test aggregation with small file size limits to force file rotation/sharding.""" + ds_0_num_episodes = ds_1_num_episodes = 10 + ds_0_num_frames = ds_1_num_frames = 400 + + ds_0 = lerobot_dataset_factory( + root=tmp_path / "small_0", + repo_id=f"{DUMMY_REPO_ID}_small_0", + total_episodes=ds_0_num_episodes, + total_frames=ds_0_num_frames, + ) + ds_1 = lerobot_dataset_factory( + root=tmp_path / "small_1", + repo_id=f"{DUMMY_REPO_ID}_small_1", + total_episodes=ds_1_num_episodes, + total_frames=ds_1_num_frames, + ) + + # Use the new configurable parameters to force file rotation + aggregate_datasets( + repo_ids=[ds_0.repo_id, ds_1.repo_id], + roots=[ds_0.root, ds_1.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_small_aggr", + aggr_root=tmp_path / "small_aggr", + # Tiny file size to trigger new file instantiation + data_files_size_in_mb=0.01, + video_files_size_in_mb=0.1, + ) + + # Mock the revision to prevent Hub calls during dataset loading + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "small_aggr") + aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr") + + # Verify aggregation worked correctly despite file size constraints + expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes + expected_total_frames = ds_0_num_frames + ds_1_num_frames + + assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames) + assert_dataset_content_integrity(aggr_ds, ds_0, ds_1) + assert_metadata_consistency(aggr_ds, ds_0, ds_1) + assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) + assert_video_frames_integrity(aggr_ds, ds_0, ds_1) + assert_dataset_iteration_works(aggr_ds) + + # Check that multiple files were actually created due to small size limits + data_dir = tmp_path / "small_aggr" / "data" + video_dir = tmp_path / "small_aggr" / "videos" + + if data_dir.exists(): + parquet_files = list(data_dir.rglob("*.parquet")) + assert len(parquet_files) > 1, "Small file size limits should create multiple parquet files" + + if video_dir.exists(): + video_files = list(video_dir.rglob("*.mp4")) + assert len(video_files) > 1, "Small file size limits should create multiple video files" diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index d3b78ddcc..2eca82346 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -13,10 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import logging import re -from copy import deepcopy from itertools import chain from pathlib import Path @@ -37,13 +35,19 @@ from lerobot.datasets.lerobot_dataset import ( MultiLeRobotDataset, ) from lerobot.datasets.utils import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_VIDEO_FILE_SIZE_IN_MB, create_branch, - flatten_dict, - unflatten_dict, + get_hf_features_from_features, + hf_transform_to_torch, + hw_to_dataset_features, ) from lerobot.envs.factory import make_env_config from lerobot.policies.factory import make_policy_config +from lerobot.robots import make_robot_from_config from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID +from tests.mocks.mock_robot import MockRobotConfig from tests.utils import require_x86_64_kernel @@ -69,12 +73,17 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): objects have the same sets of attributes defined. """ # Instantiate both ways - features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + robot = make_robot_from_config(MockRobotConfig()) + action_features = hw_to_dataset_features(robot.action_features, "action", True) + obs_features = hw_to_dataset_features(robot.observation_features, "observation", True) + dataset_features = {**action_features, **obs_features} root_create = tmp_path / "create" - dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=features, root=root_create) + dataset_create = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=30, features=dataset_features, root=root_create + ) root_init = tmp_path / "init" - dataset_init = lerobot_dataset_factory(root=root_init) + dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1) init_attr = set(vars(dataset_init).keys()) create_attr = set(vars(dataset_create).keys()) @@ -99,13 +108,41 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory): assert dataset.num_frames == len(dataset) +# TODO(rcadene, aliberts): do not run LeRobotDataset.create, instead refactor LeRobotDatasetMetadata.create +# and test the small resulting function that validates the features +def test_dataset_feature_with_forward_slash_raises_error(): + # make sure dir does not exist + from lerobot.constants import HF_LEROBOT_HOME + + dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash" + # make sure does not exist + if dataset_dir.exists(): + dataset_dir.rmdir() + + with pytest.raises(ValueError): + LeRobotDataset.create( + repo_id="lerobot/test/with/slash", + fps=30, + features={"a/b": {"dtype": "float32", "shape": 2, "names": None}}, + ) + + +def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n" + ): + dataset.add_frame({"state": torch.randn(1)}) + + def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) with pytest.raises( ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n" ): - dataset.add_frame({"wrong_feature": torch.randn(1)}, task="Dummy task") + dataset.add_frame({"task": "Dummy task"}) def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory): @@ -114,7 +151,7 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory): with pytest.raises( ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n" ): - dataset.add_frame({"state": torch.randn(1), "extra": "dummy_extra"}, task="Dummy task") + dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"}) def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory): @@ -123,7 +160,7 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory): with pytest.raises( ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n" ): - dataset.add_frame({"state": torch.randn(1, dtype=torch.float16)}, task="Dummy task") + dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"}) def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory): @@ -133,7 +170,7 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory): ValueError, match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"), ): - dataset.add_frame({"state": torch.randn(1)}, task="Dummy task") + dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory): @@ -145,7 +182,7 @@ def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_fact "The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '' provided instead.\n" ), ): - dataset.add_frame({"state": 1.0}, task="Dummy task") + dataset.add_frame({"state": 1.0, "task": "Dummy task"}) def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory): @@ -155,7 +192,7 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact ValueError, match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"), ): - dataset.add_frame({"state": torch.tensor(1.0)}, task="Dummy task") + dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"}) def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory): @@ -167,13 +204,13 @@ def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_fact "The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '' provided instead.\n" ), ): - dataset.add_frame({"state": np.float32(1.0)}, task="Dummy task") + dataset.add_frame({"state": np.float32(1.0), "task": "Dummy task"}) def test_add_frame(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(1)}, task="Dummy task") + dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) dataset.save_episode() assert len(dataset) == 1 @@ -185,7 +222,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2)}, task="Dummy task") + dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"}) dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2]) @@ -194,7 +231,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4)}, task="Dummy task") + dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"}) dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4]) @@ -203,7 +240,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4, 3)}, task="Dummy task") + dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"}) dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4, 3]) @@ -212,7 +249,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4, 3, 5)}, task="Dummy task") + dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"}) dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5]) @@ -221,7 +258,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1)}, task="Dummy task") + dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"}) dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1]) @@ -230,7 +267,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": np.array([1], dtype=np.float32)}, task="Dummy task") + dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"}) dataset.save_episode() assert dataset[0]["state"].ndim == 0 @@ -239,7 +276,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory): features = {"caption": {"dtype": "string", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"caption": "Dummy caption"}, task="Dummy task") + dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"}) dataset.save_episode() assert dataset[0]["caption"] == "Dummy caption" @@ -254,7 +291,7 @@ def test_add_frame_image_wrong_shape(image_dataset): ), ): c, h, w = DUMMY_CHW - dataset.add_frame({"image": torch.randn(c, w, h)}, task="Dummy task") + dataset.add_frame({"image": torch.randn(c, w, h), "task": "Dummy task"}) def test_add_frame_image_wrong_range(image_dataset): @@ -267,14 +304,14 @@ def test_add_frame_image_wrong_range(image_dataset): Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`. """ dataset = image_dataset - dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255}, task="Dummy task") + dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"}) with pytest.raises(FileNotFoundError): dataset.save_episode() def test_add_frame_image(image_dataset): dataset = image_dataset - dataset.add_frame({"image": np.random.rand(*DUMMY_CHW)}, task="Dummy task") + dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -282,7 +319,7 @@ def test_add_frame_image(image_dataset): def test_add_frame_image_h_w_c(image_dataset): dataset = image_dataset - dataset.add_frame({"image": np.random.rand(*DUMMY_HWC)}, task="Dummy task") + dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"}) dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -291,7 +328,7 @@ def test_add_frame_image_h_w_c(image_dataset): def test_add_frame_image_uint8(image_dataset): dataset = image_dataset image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) - dataset.add_frame({"image": image}, task="Dummy task") + dataset.add_frame({"image": image, "task": "Dummy task"}) dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -300,7 +337,7 @@ def test_add_frame_image_uint8(image_dataset): def test_add_frame_image_pil(image_dataset): dataset = image_dataset image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) - dataset.add_frame({"image": Image.fromarray(image)}, task="Dummy task") + dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"}) dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -319,6 +356,13 @@ def test_image_array_to_pil_image_wrong_range_float_0_255(): # - [ ] test push_to_hub # - [ ] test smaller methods +# TODO(rcadene): +# - [ ] fix code so that old test_factory + backward pass +# - [ ] write new unit tests to test save_episode + getitem +# - [ ] save_episode : case where new dataset, concatenate same file, write new file (meta/episodes, data, videos) +# - [ ] +# - [ ] remove old tests + @pytest.mark.parametrize( "env_name, repo_id, policy_name", @@ -338,9 +382,8 @@ def test_factory(env_name, repo_id, policy_name): # TODO(rcadene, aliberts): remove dataset download dataset=DatasetConfig(repo_id=repo_id, episodes=[0]), env=make_env_config(env_name), - policy=make_policy_config(policy_name, push_to_hub=False), + policy=make_policy_config(policy_name), ) - cfg.validate() dataset = make_dataset(cfg) delta_timestamps = dataset.delta_timestamps @@ -427,30 +470,6 @@ def test_multidataset_frames(): assert torch.equal(sub_dataset_item[k], dataset_item[k]) -# TODO(aliberts): Move to more appropriate location -def test_flatten_unflatten_dict(): - d = { - "obs": { - "min": 0, - "max": 1, - "mean": 2, - "std": 3, - }, - "action": { - "min": 4, - "max": 5, - "mean": 6, - "std": 7, - }, - } - - original_d = deepcopy(d) - d = unflatten_dict(flatten_dict(d)) - - # test equality between nested dicts - assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}" - - @pytest.mark.parametrize( "repo_id", [ @@ -497,38 +516,22 @@ def test_backward_compatibility(repo_id): ) # test2 first frames of first episode - i = dataset.episode_data_index["from"][0].item() + i = dataset.meta.episodes[0]["dataset_from_index"] load_and_compare(i) load_and_compare(i + 1) # test 2 frames at the middle of first episode - i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2) + i = int( + (dataset.meta.episodes[0]["dataset_to_index"] - dataset.meta.episodes[0]["dataset_from_index"]) / 2 + ) load_and_compare(i) load_and_compare(i + 1) # test 2 last frames of first episode - i = dataset.episode_data_index["to"][0].item() + i = dataset.meta.episodes[0]["dataset_to_index"] load_and_compare(i - 2) load_and_compare(i - 1) - # TODO(rcadene): Enable testing on second and last episode - # We currently cant because our test dataset only contains the first episode - - # # test 2 first frames of second episode - # i = dataset.episode_data_index["from"][1].item() - # load_and_compare(i) - # load_and_compare(i + 1) - - # # test 2 last frames of second episode - # i = dataset.episode_data_index["to"][1].item() - # load_and_compare(i - 2) - # load_and_compare(i - 1) - - # # test 2 last frames of last episode - # i = dataset.episode_data_index["to"][-1].item() - # load_and_compare(i - 2) - # load_and_compare(i - 1) - @pytest.mark.skip("Requires internet access") def test_create_branch(): @@ -556,18 +559,499 @@ def test_create_branch(): api.delete_repo(repo_id, repo_type=repo_type) -def test_dataset_feature_with_forward_slash_raises_error(): - # make sure dir does not exist - from lerobot.constants import HF_LEROBOT_HOME +def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory): + """Test the _check_cached_episodes_sufficient method of LeRobotDataset.""" + # Create a dataset with 5 episodes (0-4) + dataset = lerobot_dataset_factory( + root=tmp_path / "test", + total_episodes=5, + total_frames=200, + use_videos=False, + ) - dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash" - # make sure does not exist - if dataset_dir.exists(): - dataset_dir.rmdir() + # Test hf_dataset is None + dataset.hf_dataset = None + assert dataset._check_cached_episodes_sufficient() is False - with pytest.raises(ValueError): - LeRobotDataset.create( - repo_id="lerobot/test/with/slash", - fps=30, - features={"a/b": {"dtype": "float32", "shape": 2, "names": None}}, + # Test hf_dataset is empty + import datasets + + empty_features = get_hf_features_from_features(dataset.features) + dataset.hf_dataset = datasets.Dataset.from_dict( + {key: [] for key in empty_features}, features=empty_features + ) + dataset.hf_dataset.set_transform(hf_transform_to_torch) + assert dataset._check_cached_episodes_sufficient() is False + + # Restore the original dataset for remaining tests + dataset.hf_dataset = dataset.load_hf_dataset() + + # Test all episodes requested (self.episodes = None) and all are available + dataset.episodes = None + assert dataset._check_cached_episodes_sufficient() is True + + # Test specific episodes requested that are all available + dataset.episodes = [0, 2, 4] + assert dataset._check_cached_episodes_sufficient() is True + + # Test request episodes that don't exist in the cached dataset + # Create a dataset with only episodes 0, 1, 2 + limited_dataset = lerobot_dataset_factory( + root=tmp_path / "limited", + total_episodes=3, + total_frames=120, + use_videos=False, + ) + + # Request episodes that include non-existent ones + limited_dataset.episodes = [0, 1, 2, 3, 4] + assert limited_dataset._check_cached_episodes_sufficient() is False + + # Test create a dataset with sparse episodes (e.g., only episodes 0, 2, 4) + # First create the full dataset structure + sparse_dataset = lerobot_dataset_factory( + root=tmp_path / "sparse", + total_episodes=5, + total_frames=200, + use_videos=False, + ) + + # Manually filter hf_dataset to only include episodes 0, 2, 4 + episode_indices = sparse_dataset.hf_dataset["episode_index"] + mask = torch.zeros(len(episode_indices), dtype=torch.bool) + for ep in [0, 2, 4]: + mask |= torch.tensor(episode_indices) == ep + + # Create a filtered dataset + filtered_data = {} + # Find image keys by checking features + image_keys = [key for key, ft in sparse_dataset.features.items() if ft.get("dtype") == "image"] + + for key in sparse_dataset.hf_dataset.column_names: + values = sparse_dataset.hf_dataset[key] + # Filter values based on mask + filtered_values = [val for i, val in enumerate(values) if mask[i]] + + # Convert float32 image tensors back to uint8 numpy arrays for HuggingFace dataset + if key in image_keys and len(filtered_values) > 0: + # Convert torch tensors (float32, [0, 1], CHW) back to numpy arrays (uint8, [0, 255], HWC) + filtered_values = [ + (val.permute(1, 2, 0).numpy() * 255).astype(np.uint8) for val in filtered_values + ] + + filtered_data[key] = filtered_values + + sparse_dataset.hf_dataset = datasets.Dataset.from_dict( + filtered_data, features=get_hf_features_from_features(sparse_dataset.features) + ) + sparse_dataset.hf_dataset.set_transform(hf_transform_to_torch) + + # Test requesting all episodes when only some are cached + sparse_dataset.episodes = None + assert sparse_dataset._check_cached_episodes_sufficient() is False + + # Test requesting only the available episodes + sparse_dataset.episodes = [0, 2, 4] + assert sparse_dataset._check_cached_episodes_sufficient() is True + + # Test requesting a mix of available and unavailable episodes + sparse_dataset.episodes = [0, 1, 2] + assert sparse_dataset._check_cached_episodes_sufficient() is False + + +def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory): + """Test the update_chunk_settings functionality for both LeRobotDataset and LeRobotDatasetMetadata.""" + features = { + "observation.state": { + "dtype": "float32", + "shape": (6,), + "names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"], + }, + "action": { + "dtype": "float32", + "shape": (6,), + "names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"], + }, + } + + # Create dataset with default chunk settings + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + + # Test initial default values + initial_settings = dataset.meta.get_chunk_settings() + assert initial_settings["chunks_size"] == DEFAULT_CHUNK_SIZE + assert initial_settings["data_files_size_in_mb"] == DEFAULT_DATA_FILE_SIZE_IN_MB + assert initial_settings["video_files_size_in_mb"] == DEFAULT_VIDEO_FILE_SIZE_IN_MB + + # Test updating all settings at once + new_chunks_size = 2000 + new_data_size = 200 + new_video_size = 1000 + + dataset.meta.update_chunk_settings( + chunks_size=new_chunks_size, + data_files_size_in_mb=new_data_size, + video_files_size_in_mb=new_video_size, + ) + + # Verify settings were updated + updated_settings = dataset.meta.get_chunk_settings() + assert updated_settings["chunks_size"] == new_chunks_size + assert updated_settings["data_files_size_in_mb"] == new_data_size + assert updated_settings["video_files_size_in_mb"] == new_video_size + + # Test updating individual settings + dataset.meta.update_chunk_settings(chunks_size=1500) + settings_after_partial = dataset.meta.get_chunk_settings() + assert settings_after_partial["chunks_size"] == 1500 + assert settings_after_partial["data_files_size_in_mb"] == new_data_size + assert settings_after_partial["video_files_size_in_mb"] == new_video_size + + # Test updating only data file size + dataset.meta.update_chunk_settings(data_files_size_in_mb=150) + settings_after_data = dataset.meta.get_chunk_settings() + assert settings_after_data["chunks_size"] == 1500 + assert settings_after_data["data_files_size_in_mb"] == 150 + assert settings_after_data["video_files_size_in_mb"] == new_video_size + + # Test updating only video file size + dataset.meta.update_chunk_settings(video_files_size_in_mb=800) + settings_after_video = dataset.meta.get_chunk_settings() + assert settings_after_video["chunks_size"] == 1500 + assert settings_after_video["data_files_size_in_mb"] == 150 + assert settings_after_video["video_files_size_in_mb"] == 800 + + # Test that settings persist in the info file + info_path = dataset.root / "meta" / "info.json" + assert info_path.exists() + + # Verify the underlying metadata properties + assert dataset.meta.chunks_size == 1500 + assert dataset.meta.data_files_size_in_mb == 150 + assert dataset.meta.video_files_size_in_mb == 800 + + # Test error handling for invalid values + with pytest.raises(ValueError, match="chunks_size must be positive"): + dataset.meta.update_chunk_settings(chunks_size=0) + + with pytest.raises(ValueError, match="chunks_size must be positive"): + dataset.meta.update_chunk_settings(chunks_size=-100) + + with pytest.raises(ValueError, match="data_files_size_in_mb must be positive"): + dataset.meta.update_chunk_settings(data_files_size_in_mb=0) + + with pytest.raises(ValueError, match="data_files_size_in_mb must be positive"): + dataset.meta.update_chunk_settings(data_files_size_in_mb=-50) + + with pytest.raises(ValueError, match="video_files_size_in_mb must be positive"): + dataset.meta.update_chunk_settings(video_files_size_in_mb=0) + + with pytest.raises(ValueError, match="video_files_size_in_mb must be positive"): + dataset.meta.update_chunk_settings(video_files_size_in_mb=-200) + + # Test calling with None values (should not change anything) + settings_before_none = dataset.meta.get_chunk_settings() + dataset.meta.update_chunk_settings( + chunks_size=None, data_files_size_in_mb=None, video_files_size_in_mb=None + ) + settings_after_none = dataset.meta.get_chunk_settings() + assert settings_before_none == settings_after_none + + # Test metadata direct access + meta_settings = dataset.meta.get_chunk_settings() + assert meta_settings == dataset.meta.get_chunk_settings() + + # Test updating via metadata directly + dataset.meta.update_chunk_settings(chunks_size=3000) + assert dataset.meta.get_chunk_settings()["chunks_size"] == 3000 + + +def test_update_chunk_settings_video_dataset(tmp_path): + """Test update_chunk_settings with a video dataset to ensure video-specific logic works.""" + features = { + "observation.images.cam": { + "dtype": "video", + "shape": (480, 640, 3), + "names": ["height", "width", "channels"], + }, + "action": {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]}, + } + + # Create video dataset + dataset = LeRobotDataset.create( + repo_id=DUMMY_REPO_ID, fps=30, features=features, root=tmp_path / "video_test", use_videos=True + ) + + # Test that video-specific settings work + original_video_size = dataset.meta.get_chunk_settings()["video_files_size_in_mb"] + new_video_size = original_video_size * 2 + + dataset.meta.update_chunk_settings(video_files_size_in_mb=new_video_size) + assert dataset.meta.get_chunk_settings()["video_files_size_in_mb"] == new_video_size + assert dataset.meta.video_files_size_in_mb == new_video_size + + +def test_episode_index_distribution(tmp_path, empty_lerobot_dataset_factory): + """Test that all frames have correct episode indices across multiple episodes.""" + features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) + + # Create 3 episodes with different lengths + num_episodes = 3 + frames_per_episode = [10, 15, 8] + + for episode_idx in range(num_episodes): + for _ in range(frames_per_episode[episode_idx]): + dataset.add_frame({"state": torch.randn(2), "task": f"task_{episode_idx}"}) + dataset.save_episode() + + # Load the dataset and check episode indices + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) + + # Check specific frames across episode boundaries + cumulative = 0 + for ep_idx, ep_length in enumerate(frames_per_episode): + # Check start, middle, and end of each episode + start_frame = cumulative + middle_frame = cumulative + ep_length // 2 + end_frame = cumulative + ep_length - 1 + + for frame_idx in [start_frame, middle_frame, end_frame]: + frame_data = loaded_dataset[frame_idx] + actual_ep_idx = frame_data["episode_index"].item() + assert actual_ep_idx == ep_idx, ( + f"Frame {frame_idx} has episode_index {actual_ep_idx}, should be {ep_idx}" + ) + + cumulative += ep_length + + # Check episode index distribution + all_episode_indices = [loaded_dataset[i]["episode_index"].item() for i in range(len(loaded_dataset))] + from collections import Counter + + distribution = Counter(all_episode_indices) + expected_dist = {i: frames_per_episode[i] for i in range(num_episodes)} + + assert dict(distribution) == expected_dist, ( + f"Episode distribution {dict(distribution)} != expected {expected_dist}" + ) + + +def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_factory): + """Test episode metadata consistency across multiple episodes.""" + features = { + "state": {"dtype": "float32", "shape": (3,), "names": ["x", "y", "z"]}, + "action": {"dtype": "float32", "shape": (2,), "names": ["v", "w"]}, + } + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) + + num_episodes = 4 + frames_per_episode = [20, 35, 10, 25] + tasks = ["pick", "place", "pick", "place"] + + for episode_idx in range(num_episodes): + for _ in range(frames_per_episode[episode_idx]): + dataset.add_frame({"state": torch.randn(3), "action": torch.randn(2), "task": tasks[episode_idx]}) + dataset.save_episode() + + # Load and validate episode metadata + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) + + assert loaded_dataset.meta.total_episodes == num_episodes + assert loaded_dataset.meta.total_frames == sum(frames_per_episode) + + cumulative_frames = 0 + for episode_idx in range(num_episodes): + episode_metadata = loaded_dataset.meta.episodes[episode_idx] + + # Check basic episode properties + assert episode_metadata["episode_index"] == episode_idx + assert episode_metadata["length"] == frames_per_episode[episode_idx] + assert episode_metadata["tasks"] == [tasks[episode_idx]] + + # Check dataset indices + expected_from = cumulative_frames + expected_to = cumulative_frames + frames_per_episode[episode_idx] + + assert episode_metadata["dataset_from_index"] == expected_from + assert episode_metadata["dataset_to_index"] == expected_to + + cumulative_frames += frames_per_episode[episode_idx] + + +def test_data_consistency_across_episodes(tmp_path, empty_lerobot_dataset_factory): + """Test that episodes have no gaps or overlaps in their data indices.""" + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) + + num_episodes = 5 + frames_per_episode = [12, 8, 20, 15, 5] + + for episode_idx in range(num_episodes): + for _ in range(frames_per_episode[episode_idx]): + dataset.add_frame({"state": torch.randn(1), "task": "consistency_test"}) + dataset.save_episode() + + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) + + # Check data consistency - no gaps or overlaps + cumulative_check = 0 + for episode_idx in range(num_episodes): + episode_metadata = loaded_dataset.meta.episodes[episode_idx] + from_idx = episode_metadata["dataset_from_index"] + to_idx = episode_metadata["dataset_to_index"] + + # Check that episode starts exactly where previous ended + assert from_idx == cumulative_check, ( + f"Episode {episode_idx} starts at {from_idx}, expected {cumulative_check}" ) + + # Check that episode length matches expected + actual_length = to_idx - from_idx + expected_length = frames_per_episode[episode_idx] + assert actual_length == expected_length, ( + f"Episode {episode_idx} length {actual_length} != expected {expected_length}" + ) + + cumulative_check = to_idx + + # Final check: last episode should end at total frames + expected_total_frames = sum(frames_per_episode) + assert cumulative_check == expected_total_frames, ( + f"Final frame count {cumulative_check} != expected {expected_total_frames}" + ) + + +def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory): + """Test that statistics are properly computed and stored for all features.""" + features = { + "state": {"dtype": "float32", "shape": (2,), "names": ["pos", "vel"]}, + "action": {"dtype": "float32", "shape": (1,), "names": ["force"]}, + } + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) + + # Create controlled data to verify statistics + num_episodes = 2 + frames_per_episode = [10, 10] + + # Use deterministic data for predictable statistics + torch.manual_seed(42) + for episode_idx in range(num_episodes): + for frame_idx in range(frames_per_episode[episode_idx]): + state_data = torch.tensor([frame_idx * 0.1, frame_idx * 0.2], dtype=torch.float32) + action_data = torch.tensor([frame_idx * 0.05], dtype=torch.float32) + dataset.add_frame({"state": state_data, "action": action_data, "task": "stats_test"}) + dataset.save_episode() + + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) + + # Check that statistics exist for all features + assert loaded_dataset.meta.stats is not None, "No statistics found" + + for feature_name in features.keys(): + assert feature_name in loaded_dataset.meta.stats, f"No statistics for feature '{feature_name}'" + + feature_stats = loaded_dataset.meta.stats[feature_name] + expected_stats = ["min", "max", "mean", "std", "count"] + + for stat_key in expected_stats: + assert stat_key in feature_stats, f"Missing '{stat_key}' statistic for '{feature_name}'" + + stat_value = feature_stats[stat_key] + # Basic sanity checks + if stat_key == "count": + assert stat_value == sum(frames_per_episode), f"Wrong count for '{feature_name}'" + elif stat_key in ["min", "max", "mean", "std"]: + # Check that statistics are reasonable (not NaN, proper shapes) + if hasattr(stat_value, "shape"): + expected_shape = features[feature_name]["shape"] + assert stat_value.shape == expected_shape or len(stat_value) == expected_shape[0], ( + f"Wrong shape for {stat_key} of '{feature_name}'" + ) + # Check no NaN values + if hasattr(stat_value, "__iter__"): + assert not any(np.isnan(v) for v in stat_value), f"NaN in {stat_key} for '{feature_name}'" + else: + assert not np.isnan(stat_value), f"NaN in {stat_key} for '{feature_name}'" + + +def test_episode_boundary_integrity(tmp_path, empty_lerobot_dataset_factory): + """Test frame indices and episode transitions at episode boundaries.""" + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) + + num_episodes = 3 + frames_per_episode = [7, 12, 5] + + for episode_idx in range(num_episodes): + for frame_idx in range(frames_per_episode[episode_idx]): + dataset.add_frame({"state": torch.tensor([float(frame_idx)]), "task": f"episode_{episode_idx}"}) + dataset.save_episode() + + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) + + # Test episode boundaries + cumulative = 0 + for ep_idx, ep_length in enumerate(frames_per_episode): + if ep_idx > 0: + # Check last frame of previous episode + prev_frame = loaded_dataset[cumulative - 1] + assert prev_frame["episode_index"].item() == ep_idx - 1 + + # Check first frame of current episode + if cumulative < len(loaded_dataset): + curr_frame = loaded_dataset[cumulative] + assert curr_frame["episode_index"].item() == ep_idx + + # Check frame_index within episode + for i in range(ep_length): + if cumulative + i < len(loaded_dataset): + frame = loaded_dataset[cumulative + i] + assert frame["frame_index"].item() == i, f"Frame {cumulative + i} has wrong frame_index" + assert frame["episode_index"].item() == ep_idx, ( + f"Frame {cumulative + i} has wrong episode_index" + ) + + cumulative += ep_length + + +def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory): + """Test that tasks are properly indexed and retrievable.""" + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) + + # Use multiple tasks, including repeated ones + tasks = ["pick", "place", "pick", "navigate", "place"] + unique_tasks = list(set(tasks)) # ["pick", "place", "navigate"] + frames_per_episode = [5, 8, 3, 10, 6] + + for episode_idx, task in enumerate(tasks): + for _ in range(frames_per_episode[episode_idx]): + dataset.add_frame({"state": torch.randn(1), "task": task}) + dataset.save_episode() + + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) + + # Check that all unique tasks are in the tasks metadata + stored_tasks = set(loaded_dataset.meta.tasks.index) + assert stored_tasks == set(unique_tasks), f"Stored tasks {stored_tasks} != expected {set(unique_tasks)}" + + # Check that task indices are consistent + cumulative = 0 + for episode_idx, expected_task in enumerate(tasks): + episode_metadata = loaded_dataset.meta.episodes[episode_idx] + assert episode_metadata["tasks"] == [expected_task] + + # Check frames in this episode have correct task + for i in range(frames_per_episode[episode_idx]): + frame = loaded_dataset[cumulative + i] + assert frame["task"] == expected_task, f"Frame {cumulative + i} has wrong task" + + # Check task_index consistency + expected_task_index = loaded_dataset.meta.get_task_index(expected_task) + assert frame["task_index"].item() == expected_task_index + + cumulative += frames_per_episode[episode_idx] + + # Check total number of tasks + assert loaded_dataset.meta.total_tasks == len(unique_tasks) diff --git a/tests/datasets/test_delta_timestamps.py b/tests/datasets/test_delta_timestamps.py index 786b90ce2..72f69bc72 100644 --- a/tests/datasets/test_delta_timestamps.py +++ b/tests/datasets/test_delta_timestamps.py @@ -11,83 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from itertools import accumulate - -import datasets -import numpy as np -import pyarrow.compute as pc import pytest -import torch from lerobot.datasets.utils import ( check_delta_timestamps, - check_timestamps_sync, get_delta_indices, ) from tests.fixtures.constants import DUMMY_MOTOR_FEATURES -def calculate_total_episode( - hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True -) -> dict[str, torch.Tensor]: - episode_indices = sorted(hf_dataset.unique("episode_index")) - total_episodes = len(episode_indices) - if raise_if_not_contiguous and episode_indices != list(range(total_episodes)): - raise ValueError("episode_index values are not sorted and contiguous.") - return total_episodes - - -def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.ndarray]: - episode_lengths = [] - table = hf_dataset.data.table - total_episodes = calculate_total_episode(hf_dataset) - for ep_idx in range(total_episodes): - ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) - episode_lengths.insert(ep_idx, len(ep_table)) - - cumulative_lengths = list(accumulate(episode_lengths)) - return { - "from": np.array([0] + cumulative_lengths[:-1], dtype=np.int64), - "to": np.array(cumulative_lengths, dtype=np.int64), - } - - -@pytest.fixture(scope="module") -def synced_timestamps_factory(hf_dataset_factory): - def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - hf_dataset = hf_dataset_factory(fps=fps) - timestamps = torch.stack(hf_dataset["timestamp"]).numpy() - episode_indices = torch.stack(hf_dataset["episode_index"]).numpy() - episode_data_index = calculate_episode_data_index(hf_dataset) - return timestamps, episode_indices, episode_data_index - - return _create_synced_timestamps - - -@pytest.fixture(scope="module") -def unsynced_timestamps_factory(synced_timestamps_factory): - def _create_unsynced_timestamps( - fps: int = 30, tolerance_s: float = 1e-4 - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps) - timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance - return timestamps, episode_indices, episode_data_index - - return _create_unsynced_timestamps - - -@pytest.fixture(scope="module") -def slightly_off_timestamps_factory(synced_timestamps_factory): - def _create_slightly_off_timestamps( - fps: int = 30, tolerance_s: float = 1e-4 - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps) - timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance - return timestamps, episode_indices, episode_data_index - - return _create_slightly_off_timestamps - - @pytest.fixture(scope="module") def valid_delta_timestamps_factory(): def _create_valid_delta_timestamps( @@ -136,78 +68,6 @@ def delta_indices_factory(): return _delta_indices -def test_check_timestamps_sync_synced(synced_timestamps_factory): - fps = 30 - tolerance_s = 1e-4 - timestamps, ep_idx, ep_data_index = synced_timestamps_factory(fps) - result = check_timestamps_sync( - timestamps=timestamps, - episode_indices=ep_idx, - episode_data_index=ep_data_index, - fps=fps, - tolerance_s=tolerance_s, - ) - assert result is True - - -def test_check_timestamps_sync_unsynced(unsynced_timestamps_factory): - fps = 30 - tolerance_s = 1e-4 - timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s) - with pytest.raises(ValueError): - check_timestamps_sync( - timestamps=timestamps, - episode_indices=ep_idx, - episode_data_index=ep_data_index, - fps=fps, - tolerance_s=tolerance_s, - ) - - -def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory): - fps = 30 - tolerance_s = 1e-4 - timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s) - result = check_timestamps_sync( - timestamps=timestamps, - episode_indices=ep_idx, - episode_data_index=ep_data_index, - fps=fps, - tolerance_s=tolerance_s, - raise_value_error=False, - ) - assert result is False - - -def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory): - fps = 30 - tolerance_s = 1e-4 - timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s) - result = check_timestamps_sync( - timestamps=timestamps, - episode_indices=ep_idx, - episode_data_index=ep_data_index, - fps=fps, - tolerance_s=tolerance_s, - ) - assert result is True - - -def test_check_timestamps_sync_single_timestamp(): - fps = 30 - tolerance_s = 1e-4 - timestamps, ep_idx = np.array([0.0]), np.array([0]) - episode_data_index = {"to": np.array([1]), "from": np.array([0])} - result = check_timestamps_sync( - timestamps=timestamps, - episode_indices=ep_idx, - episode_data_index=episode_data_index, - fps=fps, - tolerance_s=tolerance_s, - ) - assert result is True - - def test_check_delta_timestamps_valid(valid_delta_timestamps_factory): fps = 30 tolerance_s = 1e-4 diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index 94576a3e2..fd7a6e380 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -32,7 +32,7 @@ def test_drop_n_first_frames(): ) dataset.set_transform(hf_transform_to_torch) episode_data_index = calculate_episode_data_index(dataset) - sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1) + sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_first_frames=1) assert sampler.indices == [1, 4, 5] assert len(sampler) == 3 assert list(sampler) == [1, 4, 5] @@ -48,7 +48,7 @@ def test_drop_n_last_frames(): ) dataset.set_transform(hf_transform_to_torch) episode_data_index = calculate_episode_data_index(dataset) - sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1) + sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_last_frames=1) assert sampler.indices == [0, 3, 4] assert len(sampler) == 3 assert list(sampler) == [0, 3, 4] @@ -64,7 +64,9 @@ def test_episode_indices_to_use(): ) dataset.set_transform(hf_transform_to_torch) episode_data_index = calculate_episode_data_index(dataset) - sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2]) + sampler = EpisodeAwareSampler( + episode_data_index["from"], episode_data_index["to"], episode_indices_to_use=[0, 2] + ) assert sampler.indices == [0, 1, 3, 4, 5] assert len(sampler) == 5 assert list(sampler) == [0, 1, 3, 4, 5] @@ -80,11 +82,11 @@ def test_shuffle(): ) dataset.set_transform(hf_transform_to_torch) episode_data_index = calculate_episode_data_index(dataset) - sampler = EpisodeAwareSampler(episode_data_index, shuffle=False) + sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=False) assert sampler.indices == [0, 1, 2, 3, 4, 5] assert len(sampler) == 6 assert list(sampler) == [0, 1, 2, 3, 4, 5] - sampler = EpisodeAwareSampler(episode_data_index, shuffle=True) + sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=True) assert sampler.indices == [0, 1, 2, 3, 4, 5] assert len(sampler) == 6 assert set(sampler) == {0, 1, 2, 3, 4, 5} diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index ba16874d0..91d661b3c 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -14,12 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +from copy import deepcopy + import torch from datasets import Dataset from huggingface_hub import DatasetCard from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index -from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch +from lerobot.datasets.utils import ( + create_lerobot_dataset_card, + flatten_dict, + hf_transform_to_torch, + unflatten_dict, +) def test_default_parameters(): @@ -53,3 +61,26 @@ def test_calculate_episode_data_index(): episode_data_index = calculate_episode_data_index(dataset) assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3])) assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6])) + + +def test_flatten_unflatten_dict(): + d = { + "obs": { + "min": 0, + "max": 1, + "mean": 2, + "std": 3, + }, + "action": { + "min": 4, + "max": 5, + "mean": 6, + "std": 7, + }, + } + + original_d = deepcopy(d) + d = unflatten_dict(flatten_dict(d)) + + # test equality between nested dicts + assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}" diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index d69a4634f..0af499364 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -29,8 +29,8 @@ DUMMY_MOTOR_FEATURES = { }, } DUMMY_CAMERA_FEATURES = { - "laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, - "phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, + "laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None}, + "phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None}, } DEFAULT_FPS = 30 DUMMY_VIDEO_INFO = { diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 047db3393..c33fdcb72 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import random +import shutil from functools import partial from pathlib import Path from typing import Protocol @@ -19,19 +20,25 @@ from unittest.mock import patch import datasets import numpy as np +import pandas as pd import PIL.Image import pytest import torch +from datasets import Dataset from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_DATA_PATH, DEFAULT_FEATURES, - DEFAULT_PARQUET_PATH, + DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, + flatten_dict, get_hf_features_from_features, hf_transform_to_torch, ) +from lerobot.datasets.video_utils import encode_video_frames from tests.fixtures.constants import ( DEFAULT_FPS, DUMMY_CAMERA_FEATURES, @@ -46,10 +53,9 @@ class LeRobotDatasetFactory(Protocol): def __call__(self, *args, **kwargs) -> LeRobotDataset: ... -def get_task_index(task_dicts: dict, task: str) -> int: - tasks = {d["task_index"]: d["task"] for d in task_dicts.values()} - task_to_task_index = {task: task_idx for task_idx, task in tasks.items()} - return task_to_task_index[task] +def get_task_index(tasks: datasets.Dataset, task: str) -> int: + task_idx = tasks.loc[task].task_index.item() + return task_idx @pytest.fixture(scope="session") @@ -62,15 +68,49 @@ def img_tensor_factory(): @pytest.fixture(scope="session") def img_array_factory(): - def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray: - if np.issubdtype(dtype, np.unsignedinteger): - # Int array in [0, 255] range - img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype) - elif np.issubdtype(dtype, np.floating): - # Float array in [0, 1] range - img_array = np.random.rand(height, width, channels).astype(dtype) + def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8, content=None) -> np.ndarray: + if content is None: + # Original random noise behavior + if np.issubdtype(dtype, np.unsignedinteger): + # Int array in [0, 255] range + img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype) + elif np.issubdtype(dtype, np.floating): + # Float array in [0, 1] range + img_array = np.random.rand(height, width, channels).astype(dtype) + else: + raise ValueError(dtype) else: - raise ValueError(dtype) + # Create image with text content using OpenCV + import cv2 + + # Create white background + img_array = np.ones((height, width, channels), dtype=np.uint8) * 255 + + # Font settings + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = max(0.5, height / 200) # Scale font with image size + font_color = (0, 0, 0) # Black text + thickness = max(1, int(height / 100)) + + # Get text size to center it + text_size = cv2.getTextSize(content, font, font_scale, thickness)[0] + text_x = (width - text_size[0]) // 2 + text_y = (height + text_size[1]) // 2 + + # Put text on image + cv2.putText(img_array, content, (text_x, text_y), font, font_scale, font_color, thickness) + + # Handle single channel case + if channels == 1: + img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY) + img_array = img_array[:, :, np.newaxis] + + # Convert to target dtype + if np.issubdtype(dtype, np.floating): + img_array = img_array.astype(dtype) / 255.0 + else: + img_array = img_array.astype(dtype) + return img_array return _create_img_array @@ -117,9 +157,10 @@ def info_factory(features_factory): total_frames: int = 0, total_tasks: int = 0, total_videos: int = 0, - total_chunks: int = 0, chunks_size: int = DEFAULT_CHUNK_SIZE, - data_path: str = DEFAULT_PARQUET_PATH, + data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB, + video_files_size_in_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB, + data_path: str = DEFAULT_DATA_PATH, video_path: str = DEFAULT_VIDEO_PATH, motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, @@ -133,8 +174,9 @@ def info_factory(features_factory): "total_frames": total_frames, "total_tasks": total_tasks, "total_videos": total_videos, - "total_chunks": total_chunks, "chunks_size": chunks_size, + "data_files_size_in_mb": data_files_size_in_mb, + "video_files_size_in_mb": video_files_size_in_mb, "fps": fps, "splits": {}, "data_path": data_path, @@ -175,41 +217,26 @@ def stats_factory(): return _create_stats -@pytest.fixture(scope="session") -def episodes_stats_factory(stats_factory): - def _create_episodes_stats( - features: dict[str], - total_episodes: int = 3, - ) -> dict: - episodes_stats = {} - for episode_index in range(total_episodes): - episodes_stats[episode_index] = { - "episode_index": episode_index, - "stats": stats_factory(features), - } - return episodes_stats - - return _create_episodes_stats - - @pytest.fixture(scope="session") def tasks_factory(): - def _create_tasks(total_tasks: int = 3) -> int: - tasks = {} - for task_index in range(total_tasks): - task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."} - tasks[task_index] = task_dict - return tasks + def _create_tasks(total_tasks: int = 3) -> pd.DataFrame: + ids = list(range(total_tasks)) + tasks = [f"Perform action {i}." for i in ids] + df = pd.DataFrame({"task_index": ids}, index=tasks) + return df return _create_tasks @pytest.fixture(scope="session") -def episodes_factory(tasks_factory): +def episodes_factory(tasks_factory, stats_factory): def _create_episodes( + features: dict[str], + fps: int = DEFAULT_FPS, total_episodes: int = 3, total_frames: int = 400, - tasks: dict | None = None, + video_keys: list[str] | None = None, + tasks: pd.DataFrame | None = None, multi_task: bool = False, ): if total_episodes <= 0 or total_frames <= 0: @@ -217,66 +244,142 @@ def episodes_factory(tasks_factory): if total_frames < total_episodes: raise ValueError("total_length must be greater than or equal to num_episodes.") - if not tasks: + if tasks is None: min_tasks = 2 if multi_task else 1 total_tasks = random.randint(min_tasks, total_episodes) tasks = tasks_factory(total_tasks) - if total_episodes < len(tasks) and not multi_task: + num_tasks_available = len(tasks) + + if total_episodes < num_tasks_available and not multi_task: raise ValueError("The number of tasks should be less than the number of episodes.") # Generate random lengths that sum up to total_length lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist() - tasks_list = [task_dict["task"] for task_dict in tasks.values()] - num_tasks_available = len(tasks_list) + # Create empty dictionaries with all keys + d = { + "episode_index": [], + "meta/episodes/chunk_index": [], + "meta/episodes/file_index": [], + "data/chunk_index": [], + "data/file_index": [], + "dataset_from_index": [], + "dataset_to_index": [], + "tasks": [], + "length": [], + } + if video_keys is not None: + for video_key in video_keys: + d[f"videos/{video_key}/chunk_index"] = [] + d[f"videos/{video_key}/file_index"] = [] + d[f"videos/{video_key}/from_timestamp"] = [] + d[f"videos/{video_key}/to_timestamp"] = [] - episodes = {} - remaining_tasks = tasks_list.copy() + for stats_key in flatten_dict({"stats": stats_factory(features)}): + d[stats_key] = [] + + num_frames = 0 + remaining_tasks = list(tasks.index) for ep_idx in range(total_episodes): num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1 - tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list + tasks_to_sample = remaining_tasks if len(remaining_tasks) > 0 else list(tasks.index) episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample))) if remaining_tasks: for task in episode_tasks: remaining_tasks.remove(task) - episodes[ep_idx] = { - "episode_index": ep_idx, - "tasks": episode_tasks, - "length": lengths[ep_idx], - } + d["episode_index"].append(ep_idx) + # TODO(rcadene): remove heuristic of only one file + d["meta/episodes/chunk_index"].append(0) + d["meta/episodes/file_index"].append(0) + d["data/chunk_index"].append(0) + d["data/file_index"].append(0) + d["dataset_from_index"].append(num_frames) + d["dataset_to_index"].append(num_frames + lengths[ep_idx]) + d["tasks"].append(episode_tasks) + d["length"].append(lengths[ep_idx]) - return episodes + if video_keys is not None: + for video_key in video_keys: + d[f"videos/{video_key}/chunk_index"].append(0) + d[f"videos/{video_key}/file_index"].append(0) + d[f"videos/{video_key}/from_timestamp"].append(num_frames / fps) + d[f"videos/{video_key}/to_timestamp"].append((num_frames + lengths[ep_idx]) / fps) + + # Add stats columns like "stats/action/max" + for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items(): + d[stats_key].append(stats) + + num_frames += lengths[ep_idx] + + return Dataset.from_dict(d) return _create_episodes +@pytest.fixture(scope="session") +def create_videos(info_factory, img_array_factory): + def _create_video_directory( + root: Path, + info: dict | None = None, + total_episodes: int = 3, + total_frames: int = 150, + total_tasks: int = 1, + ): + if info is None: + info = info_factory( + total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks + ) + + video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"} + for key, ft in video_feats.items(): + # create and save images with identifiable content + tmp_dir = root / "tmp_images" + tmp_dir.mkdir(parents=True, exist_ok=True) + for frame_index in range(info["total_frames"]): + content = f"{key}-{frame_index}" + img = img_array_factory(height=ft["shape"][0], width=ft["shape"][1], content=content) + pil_img = PIL.Image.fromarray(img) + path = tmp_dir / f"frame-{frame_index:06d}.png" + pil_img.save(path) + + video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0) + video_path.parent.mkdir(parents=True, exist_ok=True) + # Use the global fps from info, not video-specific fps which might not exist + encode_video_frames(tmp_dir, video_path, fps=info["fps"]) + shutil.rmtree(tmp_dir) + + return _create_video_directory + + @pytest.fixture(scope="session") def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory): def _create_hf_dataset( features: dict | None = None, - tasks: list[dict] | None = None, - episodes: list[dict] | None = None, + tasks: pd.DataFrame | None = None, + episodes: datasets.Dataset | None = None, fps: int = DEFAULT_FPS, ) -> datasets.Dataset: - if not tasks: + if tasks is None: tasks = tasks_factory() - if not episodes: - episodes = episodes_factory() - if not features: + if features is None: features = features_factory() + if episodes is None: + episodes = episodes_factory(features, fps) timestamp_col = np.array([], dtype=np.float32) frame_index_col = np.array([], dtype=np.int64) episode_index_col = np.array([], dtype=np.int64) task_index = np.array([], dtype=np.int64) - for ep_dict in episodes.values(): + for ep_dict in episodes: timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps)) frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) episode_index_col = np.concatenate( (episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int)) ) + # Slightly incorrect, but for simplicity, we assign to all frames the first task defined in the episode metadata. + # TODO(rcadene): assign the tasks of the episode per chunks of frames ep_task_index = get_task_index(tasks, ep_dict["tasks"][0]) task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))) @@ -286,8 +389,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar for key, ft in features.items(): if ft["dtype"] == "image": robot_cols[key] = [ - img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0]) - for _ in range(len(index_col)) + img_array_factory(height=ft["shape"][1], width=ft["shape"][0], content=f"{key}-{i}") + for i in range(len(index_col)) ] elif ft["shape"][0] > 1 and ft["dtype"] != "video": robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"]) @@ -314,7 +417,6 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar def lerobot_dataset_metadata_factory( info_factory, stats_factory, - episodes_stats_factory, tasks_factory, episodes_factory, mock_snapshot_download_factory, @@ -324,29 +426,29 @@ def lerobot_dataset_metadata_factory( repo_id: str = DUMMY_REPO_ID, info: dict | None = None, stats: dict | None = None, - episodes_stats: list[dict] | None = None, - tasks: list[dict] | None = None, - episodes: list[dict] | None = None, + tasks: pd.DataFrame | None = None, + episodes: datasets.Dataset | None = None, ) -> LeRobotDatasetMetadata: - if not info: + if info is None: info = info_factory() - if not stats: + if stats is None: stats = stats_factory(features=info["features"]) - if not episodes_stats: - episodes_stats = episodes_stats_factory( - features=info["features"], total_episodes=info["total_episodes"] - ) - if not tasks: + if tasks is None: tasks = tasks_factory(total_tasks=info["total_tasks"]) - if not episodes: + if episodes is None: + video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"] episodes = episodes_factory( - total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks + features=info["features"], + fps=info["fps"], + total_episodes=info["total_episodes"], + total_frames=info["total_frames"], + video_keys=video_keys, + tasks=tasks, ) mock_snapshot_download = mock_snapshot_download_factory( info=info, stats=stats, - episodes_stats=episodes_stats, tasks=tasks, episodes=episodes, ) @@ -366,7 +468,6 @@ def lerobot_dataset_metadata_factory( def lerobot_dataset_factory( info_factory, stats_factory, - episodes_stats_factory, tasks_factory, episodes_factory, hf_dataset_factory, @@ -380,50 +481,63 @@ def lerobot_dataset_factory( total_frames: int = 150, total_tasks: int = 1, multi_task: bool = False, + use_videos: bool = True, info: dict | None = None, stats: dict | None = None, - episodes_stats: list[dict] | None = None, - tasks: list[dict] | None = None, - episode_dicts: list[dict] | None = None, + tasks: pd.DataFrame | None = None, + episodes_metadata: datasets.Dataset | None = None, hf_dataset: datasets.Dataset | None = None, + data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB, + chunks_size: int = DEFAULT_CHUNK_SIZE, **kwargs, ) -> LeRobotDataset: - if not info: + # Instantiate objects + if info is None: info = info_factory( - total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks + total_episodes=total_episodes, + total_frames=total_frames, + total_tasks=total_tasks, + use_videos=use_videos, + data_files_size_in_mb=data_files_size_in_mb, + chunks_size=chunks_size, ) - if not stats: + if stats is None: stats = stats_factory(features=info["features"]) - if not episodes_stats: - episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes) - if not tasks: + if tasks is None: tasks = tasks_factory(total_tasks=info["total_tasks"]) - if not episode_dicts: - episode_dicts = episodes_factory( + if episodes_metadata is None: + video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"] + episodes_metadata = episodes_factory( + features=info["features"], + fps=info["fps"], total_episodes=info["total_episodes"], total_frames=info["total_frames"], + video_keys=video_keys, tasks=tasks, multi_task=multi_task, ) - if not hf_dataset: - hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"]) + if hf_dataset is None: + hf_dataset = hf_dataset_factory( + features=info["features"], tasks=tasks, episodes=episodes_metadata, fps=info["fps"] + ) + # Write data on disk mock_snapshot_download = mock_snapshot_download_factory( info=info, stats=stats, - episodes_stats=episodes_stats, tasks=tasks, - episodes=episode_dicts, + episodes=episodes_metadata, hf_dataset=hf_dataset, + data_files_size_in_mb=data_files_size_in_mb, + chunks_size=chunks_size, ) mock_metadata = lerobot_dataset_metadata_factory( root=root, repo_id=repo_id, info=info, stats=stats, - episodes_stats=episodes_stats, tasks=tasks, - episodes=episode_dicts, + episodes=episodes_metadata, ) with ( patch("lerobot.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch, diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index e0553f77e..11f3fa94a 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -11,137 +11,166 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json +import logging from pathlib import Path import datasets -import jsonlines -import pyarrow.compute as pc -import pyarrow.parquet as pq +import numpy as np +import pandas as pd import pytest +from datasets import Dataset from lerobot.datasets.utils import ( - EPISODES_PATH, - EPISODES_STATS_PATH, - INFO_PATH, - STATS_PATH, - TASKS_PATH, + DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_DATA_PATH, + get_hf_dataset_size_in_mb, + update_chunk_file_indices, + write_episodes, + write_info, + write_stats, + write_tasks, ) +def write_hf_dataset( + hf_dataset: Dataset, + local_dir: Path, + data_file_size_mb: float | None = None, + chunk_size: int | None = None, +): + """ + Writes a Hugging Face Dataset to one or more Parquet files in a structured directory format. + + If the dataset size is within `DEFAULT_DATA_FILE_SIZE_IN_MB`, it's saved as a single file. + Otherwise, the dataset is split into multiple smaller Parquet files, each not exceeding the size limit. + The file and chunk indices are managed to organize the output files in a hierarchical structure, + e.g., `data/chunk-000/file-000.parquet`, `data/chunk-000/file-001.parquet`, etc. + This function ensures that episodes are not split across multiple files. + + Args: + hf_dataset (Dataset): The Hugging Face Dataset to be written to disk. + local_dir (Path): The root directory where the dataset files will be stored. + data_file_size_mb (float, optional): Maximal size for the parquet data file, in MB. Defaults to DEFAULT_DATA_FILE_SIZE_IN_MB. + chunk_size (int, optional): Maximal number of files within a chunk folder before creating another one. Defaults to DEFAULT_CHUNK_SIZE. + """ + if data_file_size_mb is None: + data_file_size_mb = DEFAULT_DATA_FILE_SIZE_IN_MB + if chunk_size is None: + chunk_size = DEFAULT_CHUNK_SIZE + + dataset_size_in_mb = get_hf_dataset_size_in_mb(hf_dataset) + + if dataset_size_in_mb <= data_file_size_mb: + # If the dataset is small enough, write it to a single file. + path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0) + path.parent.mkdir(parents=True, exist_ok=True) + hf_dataset.to_parquet(path) + return + + # If the dataset is too large, split it into smaller chunks, keeping episodes whole. + episode_indices = np.array(hf_dataset["episode_index"]) + episode_boundaries = np.where(np.diff(episode_indices) != 0)[0] + 1 + episode_starts = np.concatenate(([0], episode_boundaries)) + episode_ends = np.concatenate((episode_boundaries, [len(hf_dataset)])) + + num_episodes = len(episode_starts) + current_episode_idx = 0 + chunk_idx, file_idx = 0, 0 + + while current_episode_idx < num_episodes: + shard_start_row = episode_starts[current_episode_idx] + shard_end_row = episode_ends[current_episode_idx] + next_episode_to_try_idx = current_episode_idx + 1 + + while next_episode_to_try_idx < num_episodes: + potential_shard_end_row = episode_ends[next_episode_to_try_idx] + dataset_shard_candidate = hf_dataset.select(range(shard_start_row, potential_shard_end_row)) + shard_size_mb = get_hf_dataset_size_in_mb(dataset_shard_candidate) + + if shard_size_mb > data_file_size_mb: + break + else: + shard_end_row = potential_shard_end_row + next_episode_to_try_idx += 1 + + dataset_shard = hf_dataset.select(range(shard_start_row, shard_end_row)) + + if ( + shard_start_row == episode_starts[current_episode_idx] + and shard_end_row == episode_ends[current_episode_idx] + ): + shard_size_mb = get_hf_dataset_size_in_mb(dataset_shard) + if shard_size_mb > data_file_size_mb: + logging.warning( + f"Episode with index {hf_dataset[shard_start_row.item()]['episode_index']} has size {shard_size_mb:.2f}MB, " + f"which is larger than data_file_size_mb ({data_file_size_mb}MB). " + "Writing it to a separate shard anyway to preserve episode integrity." + ) + + # Define the path for the current shard and ensure the directory exists. + path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + path.parent.mkdir(parents=True, exist_ok=True) + + # Write the shard to a Parquet file. + dataset_shard.to_parquet(path) + + # Update chunk and file indices for the next iteration. + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size) + current_episode_idx = next_episode_to_try_idx + + @pytest.fixture(scope="session") -def info_path(info_factory): - def _create_info_json_file(dir: Path, info: dict | None = None) -> Path: - if not info: +def create_info(info_factory): + def _create_info(dir: Path, info: dict | None = None): + if info is None: info = info_factory() - fpath = dir / INFO_PATH - fpath.parent.mkdir(parents=True, exist_ok=True) - with open(fpath, "w") as f: - json.dump(info, f, indent=4, ensure_ascii=False) - return fpath + write_info(info, dir) - return _create_info_json_file + return _create_info @pytest.fixture(scope="session") -def stats_path(stats_factory): - def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path: - if not stats: +def create_stats(stats_factory): + def _create_stats(dir: Path, stats: dict | None = None): + if stats is None: stats = stats_factory() - fpath = dir / STATS_PATH - fpath.parent.mkdir(parents=True, exist_ok=True) - with open(fpath, "w") as f: - json.dump(stats, f, indent=4, ensure_ascii=False) - return fpath + write_stats(stats, dir) - return _create_stats_json_file + return _create_stats @pytest.fixture(scope="session") -def episodes_stats_path(episodes_stats_factory): - def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path: - if not episodes_stats: - episodes_stats = episodes_stats_factory() - fpath = dir / EPISODES_STATS_PATH - fpath.parent.mkdir(parents=True, exist_ok=True) - with jsonlines.open(fpath, "w") as writer: - writer.write_all(episodes_stats.values()) - return fpath - - return _create_episodes_stats_jsonl_file - - -@pytest.fixture(scope="session") -def tasks_path(tasks_factory): - def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path: - if not tasks: +def create_tasks(tasks_factory): + def _create_tasks(dir: Path, tasks: pd.DataFrame | None = None): + if tasks is None: tasks = tasks_factory() - fpath = dir / TASKS_PATH - fpath.parent.mkdir(parents=True, exist_ok=True) - with jsonlines.open(fpath, "w") as writer: - writer.write_all(tasks.values()) - return fpath + write_tasks(tasks, dir) - return _create_tasks_jsonl_file + return _create_tasks @pytest.fixture(scope="session") -def episode_path(episodes_factory): - def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path: - if not episodes: +def create_episodes(episodes_factory): + def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None): + if episodes is None: + # TODO(rcadene): add features, fps as arguments episodes = episodes_factory() - fpath = dir / EPISODES_PATH - fpath.parent.mkdir(parents=True, exist_ok=True) - with jsonlines.open(fpath, "w") as writer: - writer.write_all(episodes.values()) - return fpath + write_episodes(episodes, dir) - return _create_episodes_jsonl_file + return _create_episodes @pytest.fixture(scope="session") -def single_episode_parquet_path(hf_dataset_factory, info_factory): - def _create_single_episode_parquet( - dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None - ) -> Path: - if not info: - info = info_factory() +def create_hf_dataset(hf_dataset_factory): + def _create_hf_dataset( + dir: Path, + hf_dataset: datasets.Dataset | None = None, + data_file_size_in_mb: float | None = None, + chunk_size: int | None = None, + ): if hf_dataset is None: hf_dataset = hf_dataset_factory() + write_hf_dataset(hf_dataset, dir, data_file_size_in_mb, chunk_size) - data_path = info["data_path"] - chunks_size = info["chunks_size"] - ep_chunk = ep_idx // chunks_size - fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx) - fpath.parent.mkdir(parents=True, exist_ok=True) - table = hf_dataset.data.table - ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) - pq.write_table(ep_table, fpath) - return fpath - - return _create_single_episode_parquet - - -@pytest.fixture(scope="session") -def multi_episode_parquet_path(hf_dataset_factory, info_factory): - def _create_multi_episode_parquet( - dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None - ) -> Path: - if not info: - info = info_factory() - if hf_dataset is None: - hf_dataset = hf_dataset_factory() - - data_path = info["data_path"] - chunks_size = info["chunks_size"] - total_episodes = info["total_episodes"] - for ep_idx in range(total_episodes): - ep_chunk = ep_idx // chunks_size - fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx) - fpath.parent.mkdir(parents=True, exist_ok=True) - table = hf_dataset.data.table - ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) - pq.write_table(ep_table, fpath) - return dir / "data" - - return _create_multi_episode_parquet + return _create_hf_dataset diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index f7c5f5b04..4333b91a3 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -14,15 +14,19 @@ from pathlib import Path import datasets +import pandas as pd import pytest from huggingface_hub.utils import filter_repo_objects from lerobot.datasets.utils import ( - EPISODES_PATH, - EPISODES_STATS_PATH, + DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_DATA_PATH, + DEFAULT_EPISODES_PATH, + DEFAULT_TASKS_PATH, + DEFAULT_VIDEO_PATH, INFO_PATH, STATS_PATH, - TASKS_PATH, ) from tests.fixtures.constants import LEROBOT_TEST_DIR @@ -30,17 +34,16 @@ from tests.fixtures.constants import LEROBOT_TEST_DIR @pytest.fixture(scope="session") def mock_snapshot_download_factory( info_factory, - info_path, + create_info, stats_factory, - stats_path, - episodes_stats_factory, - episodes_stats_path, + create_stats, tasks_factory, - tasks_path, + create_tasks, episodes_factory, - episode_path, - single_episode_parquet_path, + create_episodes, hf_dataset_factory, + create_hf_dataset, + create_videos, ): """ This factory allows to patch snapshot_download such that when called, it will create expected files rather @@ -50,82 +53,93 @@ def mock_snapshot_download_factory( def _mock_snapshot_download_func( info: dict | None = None, stats: dict | None = None, - episodes_stats: list[dict] | None = None, - tasks: list[dict] | None = None, - episodes: list[dict] | None = None, + tasks: pd.DataFrame | None = None, + episodes: datasets.Dataset | None = None, hf_dataset: datasets.Dataset | None = None, + data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB, + chunks_size: int = DEFAULT_CHUNK_SIZE, ): - if not info: - info = info_factory() - if not stats: + if info is None: + info = info_factory(data_files_size_in_mb=data_files_size_in_mb, chunks_size=chunks_size) + if stats is None: stats = stats_factory(features=info["features"]) - if not episodes_stats: - episodes_stats = episodes_stats_factory( - features=info["features"], total_episodes=info["total_episodes"] - ) - if not tasks: + if tasks is None: tasks = tasks_factory(total_tasks=info["total_tasks"]) - if not episodes: + if episodes is None: episodes = episodes_factory( - total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks + features=info["features"], + fps=info["fps"], + total_episodes=info["total_episodes"], + total_frames=info["total_frames"], + tasks=tasks, ) - if not hf_dataset: + if hf_dataset is None: hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"]) - def _extract_episode_index_from_path(fpath: str) -> int: - path = Path(fpath) - if path.suffix == ".parquet" and path.stem.startswith("episode_"): - episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0 - return episode_index - else: - return None - def _mock_snapshot_download( - repo_id: str, + repo_id: str, # TODO(rcadene): repo_id should be used no? local_dir: str | Path | None = None, allow_patterns: str | list[str] | None = None, ignore_patterns: str | list[str] | None = None, *args, **kwargs, ) -> str: - if not local_dir: + if local_dir is None: local_dir = LEROBOT_TEST_DIR # List all possible files - all_files = [] - meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH] - all_files.extend(meta_files) + all_files = [ + INFO_PATH, + STATS_PATH, + # TODO(rcadene): remove naive chunk 0 file 0 ? + DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0), + DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0), + DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0), + ] - data_files = [] - for episode_dict in episodes.values(): - ep_idx = episode_dict["episode_index"] - ep_chunk = ep_idx // info["chunks_size"] - data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx) - data_files.append(data_path) - all_files.extend(data_files) + video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"] + for key in video_keys: + all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)) allowed_files = filter_repo_objects( all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns ) - # Create allowed files + request_info = False + request_tasks = False + request_episodes = False + request_stats = False + request_data = False + request_videos = False for rel_path in allowed_files: - if rel_path.startswith("data/"): - episode_index = _extract_episode_index_from_path(rel_path) - if episode_index is not None: - _ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info) - if rel_path == INFO_PATH: - _ = info_path(local_dir, info) - elif rel_path == STATS_PATH: - _ = stats_path(local_dir, stats) - elif rel_path == EPISODES_STATS_PATH: - _ = episodes_stats_path(local_dir, episodes_stats) - elif rel_path == TASKS_PATH: - _ = tasks_path(local_dir, tasks) - elif rel_path == EPISODES_PATH: - _ = episode_path(local_dir, episodes) + if rel_path.startswith("meta/info.json"): + request_info = True + elif rel_path.startswith("meta/stats"): + request_stats = True + elif rel_path.startswith("meta/tasks"): + request_tasks = True + elif rel_path.startswith("meta/episodes"): + request_episodes = True + elif rel_path.startswith("data/"): + request_data = True + elif rel_path.startswith("videos/"): + request_videos = True else: - pass + raise ValueError(f"{rel_path} not supported.") + + if request_info: + create_info(local_dir, info) + if request_stats: + create_stats(local_dir, stats) + if request_tasks: + create_tasks(local_dir, tasks) + if request_episodes: + create_episodes(local_dir, episodes) + if request_data: + create_hf_dataset(local_dir, hf_dataset, data_files_size_in_mb, chunks_size) + if request_videos: + create_videos(root=local_dir, info=info) + return str(local_dir) return _mock_snapshot_download diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index da7573d7c..ef2d4ecd8 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -71,7 +71,11 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p }, } info = info_factory( - total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features + total_episodes=1, + total_frames=1, + total_tasks=1, + camera_features=camera_features, + motor_features=motor_features, ) ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info) return ds_meta @@ -140,7 +144,6 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive, and for now we add tests as we see fit. """ - train_cfg = TrainPipelineConfig( # TODO(rcadene, aliberts): remove dataset download dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]), diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index e45688c14..374f98129 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import patch + from lerobot.calibrate import CalibrateConfig, calibrate from lerobot.record import DatasetRecordConfig, RecordConfig, record from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay @@ -67,7 +69,14 @@ def test_record_and_resume(tmp_path): assert dataset.meta.total_tasks == 1 cfg.resume = True - dataset = record(cfg) + # Mock the revision to prevent Hub calls during resume + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "record") + dataset = record(cfg) assert dataset.meta.total_episodes == dataset.num_episodes == 2 assert dataset.meta.total_frames == dataset.num_frames == 6 @@ -103,4 +112,12 @@ def test_record_and_replay(tmp_path): ) record(record_cfg) - replay(replay_cfg) + + # Mock the revision to prevent Hub calls during replay + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "record_and_replay") + replay(replay_cfg) diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index a53d7ba8c..8781c5c0d 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -384,7 +384,7 @@ def test_to_lerobot_dataset(tmp_path): elif feature == "next.done": assert torch.equal(value, buffer.dones[i]) elif feature == "observation.image": - # Tenssor -> numpy is not precise, so we have some diff there + # Tensor -> numpy is not precise, so we have some diff there # TODO: Check and fix it torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003) elif feature == "observation.state": From 33cad37054c2b594ceba57463e8f11ee374fa93c Mon Sep 17 00:00:00 2001 From: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Date: Mon, 15 Sep 2025 14:08:01 +0200 Subject: [PATCH 086/158] Add Streaming Dataset (#1613) Co-authored-by: Michel Aractingi --- examples/5_train_with_streaming.py | 116 +++++ src/lerobot/configs/default.py | 1 + src/lerobot/constants.py | 5 + src/lerobot/datasets/factory.py | 30 +- src/lerobot/datasets/lerobot_dataset.py | 4 + src/lerobot/datasets/streaming_dataset.py | 535 ++++++++++++++++++++++ src/lerobot/datasets/utils.py | 234 +++++++++- src/lerobot/datasets/video_utils.py | 84 +++- src/lerobot/scripts/train.py | 6 +- tests/datasets/test_streaming.py | 391 ++++++++++++++++ 10 files changed, 1380 insertions(+), 26 deletions(-) create mode 100644 examples/5_train_with_streaming.py create mode 100644 src/lerobot/datasets/streaming_dataset.py create mode 100644 tests/datasets/test_streaming.py diff --git a/examples/5_train_with_streaming.py b/examples/5_train_with_streaming.py new file mode 100644 index 000000000..17818410d --- /dev/null +++ b/examples/5_train_with_streaming.py @@ -0,0 +1,116 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This script demonstrates how to train a Diffusion Policy on the PushT environment, +using a dataset processed in streaming mode. + +Once you have trained a model with this script, you can try to evaluate it on +examples/2_evaluate_pretrained_policy.py +""" + +from pathlib import Path + +import torch + +from lerobot.configs.types import FeatureType +from lerobot.constants import ACTION +from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset +from lerobot.datasets.utils import dataset_to_policy_features +from lerobot.policies.act.configuration_act import ACTConfig +from lerobot.policies.act.modeling_act import ACTPolicy + + +def main(): + # Create a directory to store the training checkpoint. + output_directory = Path("outputs/train/example_streaming_dataset") + output_directory.mkdir(parents=True, exist_ok=True) + + # Selects the "best" device available + device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("mps") + if torch.backends.mps.is_available() + else torch.device("cpu") + ) + print(f"Using device: {device}") + + training_steps = 10 + log_freq = 1 + + dataset_id = ( + "aractingi/droid_1.0.1" # 26M frames! Would require 4TB of disk space if installed locally (: + ) + dataset_metadata = LeRobotDatasetMetadata(dataset_id) + features = dataset_to_policy_features(dataset_metadata.features) + output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} + input_features = {key: ft for key, ft in features.items() if key not in output_features} + + # We can now instantiate our policy with this config and the dataset stats. + cfg = ACTConfig(input_features=input_features, output_features=output_features) + policy = ACTPolicy(cfg, dataset_stats=dataset_metadata.stats) + policy.train() + policy.to(device) + + # Delta timestamps are used to (1) augment frames used during training and (2) supervise the policy. + # Here, we use delta-timestamps to only provide ground truth actions for supervision + delta_timestamps = { + ACTION: [t / dataset_metadata.fps for t in range(cfg.n_action_steps)], + } + + # Instantiating the training dataset in streaming mode allows to not consume up memory as the data is fetched + # iteratively rather than being load into memory all at once. Retrieved frames are shuffled across epochs + dataset = StreamingLeRobotDataset(dataset_id, delta_timestamps=delta_timestamps, tolerance_s=1e-3) + + optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4) + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=4, + batch_size=16, + pin_memory=device.type != "cpu", + drop_last=True, + prefetch_factor=2, # loads batches with multiprocessing while policy trains + ) + + # Run training loop. + step = 0 + done = False + while not done: + for batch in dataloader: + batch = { + k: (v.type(torch.float32) if isinstance(v, torch.Tensor) and v.dtype != torch.bool else v) + for k, v in batch.items() + } + batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + + # batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + loss, _ = policy.forward(batch) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if step % log_freq == 0: + print(f"step: {step} loss: {loss.item():.3f}") + step += 1 + if step >= training_steps: + done = True + break + + # Save a policy checkpoint. + policy.save_pretrained(output_directory) + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index 53cfe58e7..1bc2b8d16 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -37,6 +37,7 @@ class DatasetConfig: revision: str | None = None use_imagenet_stats: bool = True video_backend: str = field(default_factory=get_safe_default_codec) + streaming: bool = False @dataclass diff --git a/src/lerobot/constants.py b/src/lerobot/constants.py index 30777239e..382435a9f 100644 --- a/src/lerobot/constants.py +++ b/src/lerobot/constants.py @@ -52,3 +52,8 @@ HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expandu # calibration dir default_calibration_path = HF_LEROBOT_HOME / "calibration" HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser() + + +# streaming datasets +LOOKBACK_BACKTRACKTABLE = 100 +LOOKAHEAD_BACKTRACKTABLE = 100 diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index e06650bc9..a71e978bc 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -25,6 +25,7 @@ from lerobot.datasets.lerobot_dataset import ( LeRobotDatasetMetadata, MultiLeRobotDataset, ) +from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.transforms import ImageTransforms IMAGENET_STATS = { @@ -87,15 +88,26 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision ) delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta) - dataset = LeRobotDataset( - cfg.dataset.repo_id, - root=cfg.dataset.root, - episodes=cfg.dataset.episodes, - delta_timestamps=delta_timestamps, - image_transforms=image_transforms, - revision=cfg.dataset.revision, - video_backend=cfg.dataset.video_backend, - ) + if not cfg.dataset.streaming: + dataset = LeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + episodes=cfg.dataset.episodes, + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + revision=cfg.dataset.revision, + video_backend=cfg.dataset.video_backend, + ) + else: + dataset = StreamingLeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + episodes=cfg.dataset.episodes, + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + revision=cfg.dataset.revision, + max_num_shards=cfg.num_workers, + ) else: raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.") dataset = MultiLeRobotDataset( diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index ceefcf05e..9cd4b6bff 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -129,6 +129,10 @@ class LeRobotDatasetMetadata: ignore_patterns=ignore_patterns, ) + @property + def url_root(self) -> str: + return f"hf://datasets/{self.repo_id}" + @property def _version(self) -> packaging.version.Version: """Codebase version used to create this dataset.""" diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py new file mode 100644 index 000000000..e354c4060 --- /dev/null +++ b/src/lerobot/datasets/streaming_dataset.py @@ -0,0 +1,535 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable, Generator, Iterator +from pathlib import Path + +import datasets +import numpy as np +import torch +from datasets import load_dataset + +from lerobot.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE +from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata +from lerobot.datasets.utils import ( + Backtrackable, + LookAheadError, + LookBackError, + check_version_compatibility, + find_float_index, + get_delta_indices, + is_float_in_list, + item_to_torch, + safe_shard, +) +from lerobot.datasets.video_utils import ( + VideoDecoderCache, + decode_video_frames_torchcodec, +) + + +class StreamingLeRobotDataset(torch.utils.data.IterableDataset): + """LeRobotDataset with streaming capabilities. + + This class extends LeRobotDataset to add streaming functionality, allowing data to be streamed + rather than loaded entirely into memory. This is especially useful for large datasets that may + not fit in memory or when you want to quickly explore a dataset without downloading it completely. + + The key innovation is using a Backtrackable iterator that maintains a bounded buffer of recent + items, allowing us to access previous frames for delta timestamps without loading the entire + dataset into memory. + + Example: + Basic usage: + ```python + from lerobot.common.datasets.streaming_dataset import StreamingLeRobotDataset + + # Create a streaming dataset with delta timestamps + delta_timestamps = { + "observation.image": [-1.0, -0.5, 0.0], # 1 sec ago, 0.5 sec ago, current + "action": [0.0, 0.1, 0.2], # current, 0.1 sec future, 0.2 sec future + } + + dataset = StreamingLeRobotDataset( + repo_id="your-dataset-repo-id", + delta_timestamps=delta_timestamps, + streaming=True, + buffer_size=1000, + ) + + # Iterate over the dataset + for i, item in enumerate(dataset): + print(f"Sample {i}: Episode {item['episode_index']} Frame {item['frame_index']}") + # item will contain stacked frames according to delta_timestamps + if i >= 10: + break + ``` + """ + + def __init__( + self, + repo_id: str, + root: str | Path | None = None, + episodes: list[int] | None = None, + image_transforms: Callable | None = None, + delta_timestamps: dict[list[float]] | None = None, + tolerance_s: float = 1e-4, + revision: str | None = None, + force_cache_sync: bool = False, + streaming: bool = True, + buffer_size: int = 1000, + max_num_shards: int = 16, + seed: int = 42, + rng: np.random.Generator | None = None, + shuffle: bool = True, + ): + """Initialize a StreamingLeRobotDataset. + + Args: + repo_id (str): This is the repo id that will be used to fetch the dataset. + root (Path | None, optional): Local directory to use for downloading/writing files. + episodes (list[int] | None, optional): If specified, this will only load episodes specified by + their episode_index in this list. + image_transforms (Callable | None, optional): Transform to apply to image data. + tolerance_s (float, optional): Tolerance in seconds for timestamp matching. + revision (str, optional): Git revision id (branch name, tag, or commit hash). + force_cache_sync (bool, optional): Flag to sync and refresh local files first. + streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True. + buffer_size (int, optional): Buffer size for shuffling when streaming. Defaults to 1000. + max_num_shards (int, optional): Number of shards to re-shard the input dataset into. Defaults to 16. + seed (int, optional): Reproducibility random seed. + rng (np.random.Generator | None, optional): Random number generator. + shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True. + """ + super().__init__() + self.repo_id = repo_id + self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + self.streaming_from_local = root is not None + + self.image_transforms = image_transforms + self.episodes = episodes + self.tolerance_s = tolerance_s + self.revision = revision if revision else CODEBASE_VERSION + self.seed = seed + self.rng = rng if rng is not None else np.random.default_rng(seed) + self.shuffle = shuffle + + self.streaming = streaming + self.buffer_size = buffer_size + + # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) + self.video_decoder_cache = None + + self.root.mkdir(exist_ok=True, parents=True) + + # Load metadata + self.meta = LeRobotDatasetMetadata( + self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync + ) + # Check version + check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) + + self.delta_timestamps = None + self.delta_indices = None + + if delta_timestamps is not None: + self._validate_delta_timestamp_keys(delta_timestamps) # raises ValueError if invalid + self.delta_timestamps = delta_timestamps + self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) + + self.hf_dataset: datasets.IterableDataset = load_dataset( + self.repo_id if not self.streaming_from_local else str(self.root), + split="train", + streaming=self.streaming, + data_files="data/*/*.parquet", + revision=self.revision, + ) + + self.num_shards = min(self.hf_dataset.num_shards, max_num_shards) + + @property + def num_frames(self): + return self.meta.total_frames + + @property + def num_episodes(self): + return self.meta.total_episodes + + @property + def fps(self): + return self.meta.fps + + @staticmethod + def _iter_random_indices( + rng: np.random.Generator, buffer_size: int, random_batch_size=100 + ) -> Iterator[int]: + while True: + yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size)) + + @staticmethod + def _infinite_generator_over_elements(rng: np.random.Generator, elements: list[int]) -> Iterator[int]: + while True: + yield rng.choice(elements) + + # TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading. + # The current sequential iteration is a bottleneck. A producer-consumer pattern + # could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding) + # in parallel, feeding a queue from which this iterator will yield processed items. + def __iter__(self) -> Iterator[dict[str, torch.Tensor]]: + if self.video_decoder_cache is None: + self.video_decoder_cache = VideoDecoderCache() + + # keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions + rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng + + buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size) + + idx_to_backtrack_dataset = { + idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards)) + for idx in range(self.num_shards) + } + + # This buffer is populated while iterating on the dataset's shards + # the logic is to add 2 levels of randomness: + # (1) sample one shard at random from the ones available, and + # (2) sample one frame from the shard sampled at (1) + frames_buffer = [] + while available_shards := list(idx_to_backtrack_dataset.keys()): + shard_key = next(self._infinite_generator_over_elements(rng, available_shards)) + backtrack_dataset = idx_to_backtrack_dataset[shard_key] # selects which shard to iterate on + + try: + for frame in self.make_frame(backtrack_dataset): + if len(frames_buffer) == self.buffer_size: + i = next(buffer_indices_generator) # samples a element from the buffer + yield frames_buffer[i] + frames_buffer[i] = frame + else: + frames_buffer.append(frame) + break # random shard sampled, switch shard + except ( + RuntimeError, + StopIteration, + ): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7 + del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard + + # Once shards are all exhausted, shuffle the buffer and yield the remaining frames + rng.shuffle(frames_buffer) + yield from frames_buffer + + def _get_window_steps( + self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False + ) -> tuple[int, int]: + if delta_timestamps is None: + return 1, 1 + + if not dynamic_bounds: + # Fix the windows + lookback = LOOKBACK_BACKTRACKTABLE + lookahead = LOOKAHEAD_BACKTRACKTABLE + else: + # Dynamically adjust the windows based on the given delta_timesteps + all_timestamps = sum(delta_timestamps.values(), []) + lookback = min(all_timestamps) * self.fps + lookahead = max(all_timestamps) * self.fps + + # When lookback is >=0 it means no negative timesteps have been provided + lookback = 0 if lookback >= 0 else (lookback * -1) + + return lookback, lookahead + + def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable: + lookback, lookahead = self._get_window_steps(self.delta_timestamps) + return Backtrackable(dataset, history=lookback, lookahead=lookahead) + + def _make_timestamps_from_indices( + self, start_ts: float, indices: dict[str, list[int]] | None = None + ) -> dict[str, list[float]]: + if indices is not None: + return { + key: ( + start_ts + torch.tensor(indices[key]) / self.fps + ).tolist() # NOTE: why not delta_timestamps directly? + for key in self.delta_timestamps + } + else: + return dict.fromkeys(self.meta.video_keys, [start_ts]) + + def _make_padding_camera_frame(self, camera_key: str): + """Variable-shape padding frame for given camera keys, given in (H, W, C)""" + return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1) + + def _get_video_frame_padding_mask( + self, + video_frames: dict[str, torch.Tensor], + query_timestamps: dict[str, list[float]], + original_timestamps: dict[str, list[float]], + ) -> dict[str, torch.BoolTensor]: + padding_mask = {} + + for video_key, timestamps in original_timestamps.items(): + if video_key not in video_frames: + continue # only padding on video keys that are available + frames = [] + mask = [] + padding_frame = self._make_padding_camera_frame(video_key) + for ts in timestamps: + if is_float_in_list(ts, query_timestamps[video_key]): + idx = find_float_index(ts, query_timestamps[video_key]) + frames.append(video_frames[video_key][idx, :]) + mask.append(False) + else: + frames.append(padding_frame) + mask.append(True) + + padding_mask[f"{video_key}_is_pad"] = torch.BoolTensor(mask) + + return padding_mask + + def make_frame( + self, dataset_iterator: Backtrackable, previous_dataset_iterator: Backtrackable | None = None + ) -> Generator: + """Makes a frame starting from a dataset iterator""" + item = next(dataset_iterator) + item = item_to_torch(item) + + updates = [] # list of "updates" to apply to the item retrieved from hf_dataset (w/o camera features) + + # Get episode index from the item + ep_idx = item["episode_index"] + + # "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps) + current_ts = item["index"] / self.fps + + episode_boundaries_ts = { + key: ( + self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"], + self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"], + ) + for key in self.meta.video_keys + } + + # Apply delta querying logic if necessary + if self.delta_indices is not None: + query_result, padding = self._get_delta_frames(dataset_iterator, item) + updates.append(query_result) + updates.append(padding) + + # Load video frames, when needed + if len(self.meta.video_keys) > 0: + original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices) + + # Some timestamps might not result available considering the episode's boundaries + query_timestamps = self._get_query_timestamps( + current_ts, self.delta_indices, episode_boundaries_ts + ) + video_frames = self._query_videos(query_timestamps, ep_idx) + + if self.image_transforms is not None: + image_keys = self.meta.camera_keys + for cam in image_keys: + video_frames[cam] = self.image_transforms(video_frames[cam]) + + updates.append(video_frames) + + if self.delta_indices is not None: + # We always return the same number of frames. Unavailable frames are padded. + padding_mask = self._get_video_frame_padding_mask( + video_frames, query_timestamps, original_timestamps + ) + updates.append(padding_mask) + + result = item.copy() + for update in updates: + result.update(update) + + result["task"] = self.meta.tasks.iloc[item["task_index"]].name + + yield result + + def _get_query_timestamps( + self, + current_ts: float, + query_indices: dict[str, list[int]] | None = None, + episode_boundaries_ts: dict[str, tuple[float, float]] | None = None, + ) -> dict[str, list[float]]: + query_timestamps = {} + keys_to_timestamps = self._make_timestamps_from_indices(current_ts, query_indices) + for key in self.meta.video_keys: + if query_indices is not None and key in query_indices: + timestamps = keys_to_timestamps[key] + # Clamp out timesteps outside of episode boundaries + query_timestamps[key] = torch.clamp( + torch.tensor(timestamps), *episode_boundaries_ts[key] + ).tolist() + + else: + query_timestamps[key] = [current_ts] + + return query_timestamps + + def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict: + """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function + in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a + Segmentation Fault. This probably happens because a memory reference to the video loader is created in + the main process and a subprocess fails to access it. + """ + + item = {} + for video_key, query_ts in query_timestamps.items(): + root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root + video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}" + frames = decode_video_frames_torchcodec( + video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache + ) + + item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames + + return item + + def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict): + # TODO(fracapuano): Modularize this function, refactor the code + """Get frames with delta offsets using the backtrackable iterator. + + Args: + current_item (dict): Current item from the iterator. + ep_idx (int): Episode index. + + Returns: + tuple: (query_result, padding) - frames at delta offsets and padding info. + """ + current_episode_idx = current_item["episode_index"] + + # Prepare results + query_result = {} + padding = {} + + for key, delta_indices in self.delta_indices.items(): + if key in self.meta.video_keys: + continue # visual frames are decoded separately + + target_frames = [] + is_pad = [] + + # Create a results dictionary to store frames in processing order, then reconstruct original order for stacking + delta_results = {} + + # Separate and sort deltas by difficulty (easier operations first) + negative_deltas = sorted([d for d in delta_indices if d < 0], reverse=True) # [-1, -2, -3, ...] + positive_deltas = sorted([d for d in delta_indices if d > 0]) # [1, 2, 3, ...] + zero_deltas = [d for d in delta_indices if d == 0] + + # Process zero deltas (current frame) + for delta in zero_deltas: + delta_results[delta] = ( + current_item[key], + False, + ) + + # Process negative deltas in order of increasing difficulty + lookback_failed = False + + last_successful_frame = current_item[key] + + for delta in negative_deltas: + if lookback_failed: + delta_results[delta] = (last_successful_frame, True) + continue + + try: + steps_back = abs(delta) + if dataset_iterator.can_peek_back(steps_back): + past_item = dataset_iterator.peek_back(steps_back) + past_item = item_to_torch(past_item) + + if past_item["episode_index"] == current_episode_idx: + delta_results[delta] = (past_item[key], False) + last_successful_frame = past_item[key] + + else: + raise LookBackError("Retrieved frame is from different episode!") + else: + raise LookBackError("Cannot go back further than the history buffer!") + + except LookBackError: + delta_results[delta] = (last_successful_frame, True) + lookback_failed = True # All subsequent negative deltas will also fail + + # Process positive deltas in order of increasing difficulty + lookahead_failed = False + last_successful_frame = current_item[key] + + for delta in positive_deltas: + if lookahead_failed: + delta_results[delta] = (last_successful_frame, True) + continue + + try: + if dataset_iterator.can_peek_ahead(delta): + future_item = dataset_iterator.peek_ahead(delta) + future_item = item_to_torch(future_item) + + if future_item["episode_index"] == current_episode_idx: + delta_results[delta] = (future_item[key], False) + last_successful_frame = future_item[key] + + else: + raise LookAheadError("Retrieved frame is from different episode!") + else: + raise LookAheadError("Cannot go ahead further than the lookahead buffer!") + + except LookAheadError: + delta_results[delta] = (last_successful_frame, True) + lookahead_failed = True # All subsequent positive deltas will also fail + + # Reconstruct original order for stacking + for delta in delta_indices: + frame, is_padded = delta_results[delta] + + # add batch dimension for stacking + target_frames.append(frame) # frame.unsqueeze(0)) + is_pad.append(is_padded) + + # Stack frames and add to results + if target_frames: + query_result[key] = torch.stack(target_frames) + padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad) + + return query_result, padding + + def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None: + """ + Validate that all keys in delta_timestamps correspond to actual features in the dataset. + + Raises: + ValueError: If any delta timestamp key doesn't correspond to a dataset feature. + """ + if delta_timestamps is None: + return + + # Get all available feature keys from the dataset metadata + available_features = set(self.meta.features.keys()) + + # Get all keys from delta_timestamps + delta_keys = set(delta_timestamps.keys()) + + # Find any keys that don't correspond to features + invalid_keys = delta_keys - available_features + + if invalid_keys: + raise ValueError( + f"The following delta_timestamp keys do not correspond to dataset features: {invalid_keys}. " + f"Available features are: {sorted(available_features)}" + ) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 2b0d95e17..c840d5bc1 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -17,10 +17,11 @@ import contextlib import importlib.resources import json import logging -from collections.abc import Iterator +from collections import deque +from collections.abc import Iterable, Iterator from pathlib import Path from pprint import pformat -from typing import Any +from typing import Any, Deque, Generic, TypeVar import datasets import numpy as np @@ -86,6 +87,8 @@ DEFAULT_FEATURES = { "task_index": {"dtype": "int64", "shape": (1,), "names": None}, } +T = TypeVar("T") + def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float: metadata = pq.read_metadata(parquet_path) @@ -776,3 +779,230 @@ def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None: """ # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path) + + +def item_to_torch(item: dict) -> dict: + """Convert all items in a dictionary to PyTorch tensors where appropriate. + + This function is used to convert an item from a streaming dataset to PyTorch tensors. + + Args: + item (dict): Dictionary of items from a dataset. + + Returns: + dict: Dictionary with all tensor-like items converted to torch.Tensor. + """ + for key, val in item.items(): + if isinstance(val, (np.ndarray, list)) and key not in ["task"]: + # Convert numpy arrays and lists to torch tensors + item[key] = torch.tensor(val) + return item + + +def is_float_in_list(target, float_list, threshold=1e-6): + return any(abs(target - x) <= threshold for x in float_list) + + +def find_float_index(target, float_list, threshold=1e-6): + for i, x in enumerate(float_list): + if abs(target - x) <= threshold: + return i + return -1 + + +class LookBackError(Exception): + """ + Exception raised when trying to look back in the history of a Backtrackable object. + """ + + pass + + +class LookAheadError(Exception): + """ + Exception raised when trying to look ahead in the future of a Backtrackable object. + """ + + pass + + +class Backtrackable(Generic[T]): + """ + Wrap any iterator/iterable so you can step back up to `history` items + and look ahead up to `lookahead` items. + + This is useful for streaming datasets where you need to access previous and future items + but can't load the entire dataset into memory. + + Example: + ------- + ```python + ds = load_dataset("c4", "en", streaming=True, split="train") + rev = Backtrackable(ds, history=3, lookahead=2) + + x0 = next(rev) # forward + x1 = next(rev) + x2 = next(rev) + + # Look ahead + x3_peek = rev.peek_ahead(1) # next item without moving cursor + x4_peek = rev.peek_ahead(2) # two items ahead + + # Look back + x1_again = rev.peek_back(1) # previous item without moving cursor + x0_again = rev.peek_back(2) # two items back + + # Move backward + x1_back = rev.prev() # back one step + next(rev) # returns x2, continues forward from where we were + ``` + """ + + __slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead") + + def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0): + if history < 1: + raise ValueError("history must be >= 1") + if lookahead <= 0: + raise ValueError("lookahead must be > 0") + + self._source: Iterator[T] = iter(iterable) + self._back_buf: Deque[T] = deque(maxlen=history) + self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque() + self._cursor: int = 0 + self._history = history + self._lookahead = lookahead + + def __iter__(self) -> "Backtrackable[T]": + return self + + def __next__(self) -> T: + # If we've stepped back, consume from back buffer first + if self._cursor < 0: # -1 means "last item", etc. + self._cursor += 1 + return self._back_buf[self._cursor] + + # If we have items in the ahead buffer, use them first + item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source) + + # Add current item to back buffer and reset cursor + self._back_buf.append(item) + self._cursor = 0 + return item + + def prev(self) -> T: + """ + Step one item back in history and return it. + Raises IndexError if already at the oldest buffered item. + """ + if len(self._back_buf) + self._cursor <= 1: + raise LookBackError("At start of history") + + self._cursor -= 1 + return self._back_buf[self._cursor] + + def peek_back(self, n: int = 1) -> T: + """ + Look `n` items back (n=1 == previous item) without moving the cursor. + """ + if n < 0 or n + 1 > len(self._back_buf) + self._cursor: + raise LookBackError("peek_back distance out of range") + + return self._back_buf[self._cursor - (n + 1)] + + def peek_ahead(self, n: int = 1) -> T: + """ + Look `n` items ahead (n=1 == next item) without moving the cursor. + Fills the ahead buffer if necessary. + """ + if n < 1: + raise LookAheadError("peek_ahead distance must be 1 or more") + elif n > self._lookahead: + raise LookAheadError("peek_ahead distance exceeds lookahead limit") + + # Fill ahead buffer if we don't have enough items + while len(self._ahead_buf) < n: + try: + item = next(self._source) + self._ahead_buf.append(item) + + except StopIteration as err: + raise LookAheadError("peek_ahead: not enough items in source") from err + + return self._ahead_buf[n - 1] + + def history(self) -> list[T]: + """ + Return a copy of the buffered history (most recent last). + The list length ≤ `history` argument passed at construction. + """ + if self._cursor == 0: + return list(self._back_buf) + + # When cursor<0, slice so the order remains chronological + return list(self._back_buf)[: self._cursor or None] + + def lookahead_buffer(self) -> list[T]: + """ + Return a copy of the current lookahead buffer. + """ + return list(self._ahead_buf) + + def can_peek_back(self, steps: int = 1) -> bool: + """ + Check if we can go back `steps` items without raising an IndexError. + """ + return steps <= len(self._back_buf) + self._cursor + + def can_peek_ahead(self, steps: int = 1) -> bool: + """ + Check if we can peek ahead `steps` items. + This may involve trying to fill the ahead buffer. + """ + if self._lookahead > 0 and steps > self._lookahead: + return False + + # Try to fill ahead buffer to check if we can peek that far + try: + while len(self._ahead_buf) < steps: + if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead: + return False + item = next(self._source) + self._ahead_buf.append(item) + return True + except StopIteration: + return False + + def reset_cursor(self) -> None: + """ + Reset cursor to the most recent position (equivalent to calling next() + until you're back to the latest item). + """ + self._cursor = 0 + + def clear_ahead_buffer(self) -> None: + """ + Clear the ahead buffer, discarding any pre-fetched items. + """ + self._ahead_buf.clear() + + def switch_source_iterable(self, new_source: Iterable[T]) -> None: + """ + Switch the source of the backtrackable to a new iterable, keeping the history. + + This is useful when iterating over a sequence of datasets. The history from the + previous source is kept, but the lookahead buffer is cleared. The cursor is reset + to the present. + """ + self._source = iter(new_source) + self.clear_ahead_buffer() + self.reset_cursor() + + +def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) -> datasets.Dataset: + """ + Safe shards the dataset. + """ + shard_idx = min(dataset.num_shards, index + 1) - 1 + + return dataset.shard(num_shards, index=shard_idx) diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 9d7df8d61..9da89022b 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -21,9 +21,11 @@ import tempfile import warnings from dataclasses import dataclass, field from pathlib import Path +from threading import Lock from typing import Any, ClassVar import av +import fsspec import pyarrow as pa import torch import torchvision @@ -169,15 +171,68 @@ def decode_video_frames_torchvision( return closest_frames +class VideoDecoderCache: + """Thread-safe cache for video decoders to avoid expensive re-initialization.""" + + def __init__(self): + self._cache: dict[str, tuple[Any, Any]] = {} + self._lock = Lock() + + def get_decoder(self, video_path: str): + """Get a cached decoder or create a new one.""" + if importlib.util.find_spec("torchcodec"): + from torchcodec.decoders import VideoDecoder + else: + raise ImportError("torchcodec is required but not available.") + + video_path = str(video_path) + + with self._lock: + if video_path not in self._cache: + file_handle = fsspec.open(video_path).__enter__() + decoder = VideoDecoder(file_handle, seek_mode="approximate") + self._cache[video_path] = (decoder, file_handle) + + return self._cache[video_path][0] + + def clear(self): + """Clear the cache and close file handles.""" + with self._lock: + for _, file_handle in self._cache.values(): + file_handle.close() + self._cache.clear() + + def size(self) -> int: + """Return the number of cached decoders.""" + with self._lock: + return len(self._cache) + + +class FrameTimestampError(ValueError): + """Helper error to indicate the retrieved timestamps exceed the queried ones""" + + pass + + +_default_decoder_cache = VideoDecoderCache() + + def decode_video_frames_torchcodec( video_path: Path | str, timestamps: list[float], tolerance_s: float, - device: str = "cpu", log_loaded_timestamps: bool = False, + decoder_cache: VideoDecoderCache | None = None, ) -> torch.Tensor: """Loads frames associated with the requested timestamps of a video using torchcodec. + Args: + video_path: Path to the video file. + timestamps: List of timestamps to extract frames. + tolerance_s: Allowed deviation in seconds for frame retrieval. + log_loaded_timestamps: Whether to log loaded timestamps. + decoder_cache: Optional decoder cache instance. Uses default if None. + Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors. Note: Video benefits from inter-frame compression. Instead of storing every frame individually, @@ -186,27 +241,24 @@ def decode_video_frames_torchcodec( and all subsequent frames until reaching the requested frame. The number of key frames in a video can be adjusted during encoding to take into account decoding time and video size in bytes. """ + if decoder_cache is None: + decoder_cache = _default_decoder_cache - if importlib.util.find_spec("torchcodec"): - from torchcodec.decoders import VideoDecoder - else: - raise ImportError("torchcodec is required but not available.") + # Use cached decoder instead of creating new one each time + decoder = decoder_cache.get_decoder(str(video_path)) - # initialize video decoder - decoder = VideoDecoder(video_path, device=device, seek_mode="approximate") - loaded_frames = [] loaded_ts = [] + loaded_frames = [] + # get metadata for frame information metadata = decoder.metadata average_fps = metadata.average_fps - # convert timestamps to frame indices frame_indices = [round(ts * average_fps) for ts in timestamps] - # retrieve frames based on indices frames_batch = decoder.get_frames_at(indices=frame_indices) - for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False): + for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True): loaded_frames.append(frame) loaded_ts.append(pts.item()) if log_loaded_timestamps: @@ -237,10 +289,14 @@ def decode_video_frames_torchcodec( if log_loaded_timestamps: logging.info(f"{closest_ts=}") - # convert to float32 in [0,1] range (channel first) - closest_frames = closest_frames.type(torch.float32) / 255 + # convert to float32 in [0,1] range + closest_frames = (closest_frames / 255.0).type(torch.float32) + + if not len(timestamps) == len(closest_frames): + raise FrameTimestampError( + f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}" + ) - assert len(timestamps) == len(closest_frames) return closest_frames diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index ba3db6075..398bea90e 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -179,10 +179,11 @@ def train(cfg: TrainPipelineConfig): dataset, num_workers=cfg.num_workers, batch_size=cfg.batch_size, - shuffle=shuffle, + shuffle=shuffle and not cfg.dataset.streaming, sampler=sampler, pin_memory=device.type == "cuda", drop_last=False, + prefetch_factor=2, ) dl_iter = cycle(dataloader) @@ -208,6 +209,9 @@ def train(cfg: TrainPipelineConfig): for key in batch: if isinstance(batch[key], torch.Tensor): + if batch[key].dtype != torch.bool: + batch[key] = batch[key].type(torch.float32) if device.type == "mps" else batch[key] + batch[key] = batch[key].to(device, non_blocking=device.type == "cuda") train_tracker, output_dict = update_policy( diff --git a/tests/datasets/test_streaming.py b/tests/datasets/test_streaming.py new file mode 100644 index 000000000..506be3ecf --- /dev/null +++ b/tests/datasets/test_streaming.py @@ -0,0 +1,391 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest +import torch + +from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset +from lerobot.datasets.utils import safe_shard +from tests.fixtures.constants import DUMMY_REPO_ID + + +def get_frames_expected_order(streaming_ds: StreamingLeRobotDataset) -> list[int]: + """Replicates the shuffling logic of StreamingLeRobotDataset to get the expected order of indices.""" + rng = np.random.default_rng(streaming_ds.seed) + buffer_size = streaming_ds.buffer_size + num_shards = streaming_ds.num_shards + + shards_indices = [] + for shard_idx in range(num_shards): + shard = streaming_ds.hf_dataset.shard(num_shards, index=shard_idx) + shard_indices = [item["index"] for item in shard] + shards_indices.append(shard_indices) + + shard_iterators = {i: iter(s) for i, s in enumerate(shards_indices)} + + buffer_indices_generator = streaming_ds._iter_random_indices(rng, buffer_size) + + frames_buffer = [] + expected_indices = [] + + while shard_iterators: # While there are still available shards + available_shard_keys = list(shard_iterators.keys()) + if not available_shard_keys: + break + + # Call _infinite_generator_over_elements with current available shards (key difference!) + shard_key = next(streaming_ds._infinite_generator_over_elements(rng, available_shard_keys)) + + try: + frame_index = next(shard_iterators[shard_key]) + + if len(frames_buffer) == buffer_size: + i = next(buffer_indices_generator) + expected_indices.append(frames_buffer[i]) + frames_buffer[i] = frame_index + else: + frames_buffer.append(frame_index) + + except StopIteration: + del shard_iterators[shard_key] # Remove exhausted shard + + rng.shuffle(frames_buffer) + expected_indices.extend(frames_buffer) + + return expected_indices + + +def test_single_frame_consistency(tmp_path, lerobot_dataset_factory): + """Test if are correctly accessed""" + ds_num_frames = 400 + ds_num_episodes = 10 + buffer_size = 100 + + local_path = tmp_path / "test" + repo_id = f"{DUMMY_REPO_ID}" + + ds = lerobot_dataset_factory( + root=local_path, + repo_id=repo_id, + total_episodes=ds_num_episodes, + total_frames=ds_num_frames, + ) + + streaming_ds = iter(StreamingLeRobotDataset(repo_id=repo_id, root=local_path, buffer_size=buffer_size)) + + key_checks = [] + for _ in range(ds_num_frames): + streaming_frame = next(streaming_ds) + frame_idx = streaming_frame["index"] + target_frame = ds[frame_idx] + + for key in streaming_frame: + left = streaming_frame[key] + right = target_frame[key] + + if isinstance(left, str): + check = left == right + + elif isinstance(left, torch.Tensor): + check = torch.allclose(left, right) and left.shape == right.shape + + elif isinstance(left, float): + check = left == right.item() # right is a torch.Tensor + + key_checks.append((key, check)) + + assert all(t[1] for t in key_checks), ( + f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (frame_idx: {frame_idx})" + ) + + +@pytest.mark.parametrize( + "shuffle", + [False, True], +) +def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle): + """Test if streamed frames correspond to shuffling operations over in-memory dataset.""" + ds_num_frames = 400 + ds_num_episodes = 10 + buffer_size = 100 + seed = 42 + n_epochs = 3 + + local_path = tmp_path / "test" + repo_id = f"{DUMMY_REPO_ID}" + + lerobot_dataset_factory( + root=local_path, + repo_id=repo_id, + total_episodes=ds_num_episodes, + total_frames=ds_num_frames, + ) + + streaming_ds = StreamingLeRobotDataset( + repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=shuffle + ) + + first_epoch_indices = [frame["index"] for frame in streaming_ds] + expected_indices = get_frames_expected_order(streaming_ds) + + assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices" + + expected_indices = get_frames_expected_order(streaming_ds) + for _ in range(n_epochs): + streaming_indices = [frame["index"] for frame in streaming_ds] + frames_match = all( + s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True) + ) + + if shuffle: + assert not frames_match + else: + assert frames_match + + +@pytest.mark.parametrize( + "shuffle", + [False, True], +) +def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle): + """Test if streamed frames correspond to shuffling operations over in-memory dataset with multiple shards.""" + ds_num_frames = 100 + ds_num_episodes = 10 + buffer_size = 10 + + seed = 42 + n_epochs = 3 + data_file_size_mb = 0.001 + + chunks_size = 1 + + local_path = tmp_path / "test" + repo_id = f"{DUMMY_REPO_ID}-ciao" + + lerobot_dataset_factory( + root=local_path, + repo_id=repo_id, + total_episodes=ds_num_episodes, + total_frames=ds_num_frames, + data_files_size_in_mb=data_file_size_mb, + chunks_size=chunks_size, + ) + + streaming_ds = StreamingLeRobotDataset( + repo_id=repo_id, + root=local_path, + buffer_size=buffer_size, + seed=seed, + shuffle=shuffle, + max_num_shards=4, + ) + + first_epoch_indices = [frame["index"] for frame in streaming_ds] + expected_indices = get_frames_expected_order(streaming_ds) + + assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices" + + for _ in range(n_epochs): + streaming_indices = [ + frame["index"] for frame in streaming_ds + ] # NOTE: this is the same as first_epoch_indices + frames_match = all( + s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True) + ) + if shuffle: + assert not frames_match + else: + assert frames_match + + +@pytest.mark.parametrize( + "state_deltas, action_deltas", + [ + ([-1, -0.5, -0.20, 0], [0, 1, 2, 3]), + ([-1, -0.5, -0.20, 0], [-1.5, -1, -0.5, -0.20, -0.10, 0]), + ([-2, -1, -0.5, 0], [0, 1, 2, 3]), + ([-2, -1, -0.5, 0], [-1.5, -1, -0.5, -0.20, -0.10, 0]), + ], +) +def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_deltas, action_deltas): + ds_num_frames = 500 + ds_num_episodes = 10 + buffer_size = 100 + + seed = 42 + + local_path = tmp_path / "test" + repo_id = f"{DUMMY_REPO_ID}-ciao" + camera_key = "phone" + + delta_timestamps = { + camera_key: state_deltas, + "state": state_deltas, + "action": action_deltas, + } + + ds = lerobot_dataset_factory( + root=local_path, + repo_id=repo_id, + total_episodes=ds_num_episodes, + total_frames=ds_num_frames, + delta_timestamps=delta_timestamps, + ) + + streaming_ds = iter( + StreamingLeRobotDataset( + repo_id=repo_id, + root=local_path, + buffer_size=buffer_size, + seed=seed, + shuffle=False, + delta_timestamps=delta_timestamps, + ) + ) + + for i in range(ds_num_frames): + streaming_frame = next(streaming_ds) + frame_idx = streaming_frame["index"] + target_frame = ds[frame_idx] + + assert set(streaming_frame.keys()) == set(target_frame.keys()), ( + f"Keys differ between streaming frame and target one. Differ at: {set(streaming_frame.keys()) - set(target_frame.keys())}" + ) + + key_checks = [] + for key in streaming_frame: + left = streaming_frame[key] + right = target_frame[key] + + if isinstance(left, str): + check = left == right + + elif isinstance(left, torch.Tensor): + if ( + key not in ds.meta.camera_keys + and "is_pad" not in key + and f"{key}_is_pad" in streaming_frame + ): + # comparing frames only on non-padded regions. Padding is applied to last-valid broadcasting + left = left[~streaming_frame[f"{key}_is_pad"]] + right = right[~target_frame[f"{key}_is_pad"]] + + check = torch.allclose(left, right) and left.shape == right.shape + + key_checks.append((key, check)) + + assert all(t[1] for t in key_checks), ( + f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (i: {i}, frame_idx: {frame_idx})" + ) + + +@pytest.mark.parametrize( + "state_deltas, action_deltas", + [ + ([-1, -0.5, -0.20, 0], [0, 1, 2, 3, 10, 20]), + ([-1, -0.5, -0.20, 0], [-20, -1.5, -1, -0.5, -0.20, -0.10, 0]), + ([-2, -1, -0.5, 0], [0, 1, 2, 3, 10, 20]), + ([-2, -1, -0.5, 0], [-20, -1.5, -1, -0.5, -0.20, -0.10, 0]), + ], +) +def test_frames_with_delta_consistency_with_shards( + tmp_path, lerobot_dataset_factory, state_deltas, action_deltas +): + ds_num_frames = 100 + ds_num_episodes = 10 + buffer_size = 10 + data_file_size_mb = 0.001 + chunks_size = 1 + + seed = 42 + + local_path = tmp_path / "test" + repo_id = f"{DUMMY_REPO_ID}-ciao" + camera_key = "phone" + + delta_timestamps = { + camera_key: state_deltas, + "state": state_deltas, + "action": action_deltas, + } + + ds = lerobot_dataset_factory( + root=local_path, + repo_id=repo_id, + total_episodes=ds_num_episodes, + total_frames=ds_num_frames, + delta_timestamps=delta_timestamps, + data_files_size_in_mb=data_file_size_mb, + chunks_size=chunks_size, + ) + streaming_ds = StreamingLeRobotDataset( + repo_id=repo_id, + root=local_path, + buffer_size=buffer_size, + seed=seed, + shuffle=False, + delta_timestamps=delta_timestamps, + max_num_shards=4, + ) + + iter(streaming_ds) + + num_shards = 4 + shards_indices = [] + for shard_idx in range(num_shards): + shard = safe_shard(streaming_ds.hf_dataset, shard_idx, num_shards) + shard_indices = [item["index"] for item in shard] + shards_indices.append(shard_indices) + + streaming_ds = iter(streaming_ds) + + for i in range(ds_num_frames): + streaming_frame = next(streaming_ds) + frame_idx = streaming_frame["index"] + target_frame = ds[frame_idx] + + assert set(streaming_frame.keys()) == set(target_frame.keys()), ( + f"Keys differ between streaming frame and target one. Differ at: {set(streaming_frame.keys()) - set(target_frame.keys())}" + ) + + key_checks = [] + for key in streaming_frame: + left = streaming_frame[key] + right = target_frame[key] + + if isinstance(left, str): + check = left == right + + elif isinstance(left, torch.Tensor): + if ( + key not in ds.meta.camera_keys + and "is_pad" not in key + and f"{key}_is_pad" in streaming_frame + ): + # comparing frames only on non-padded regions. Padding is applied to last-valid broadcasting + left = left[~streaming_frame[f"{key}_is_pad"]] + right = right[~target_frame[f"{key}_is_pad"]] + + check = torch.allclose(left, right) and left.shape == right.shape + + elif isinstance(left, float): + check = left == right.item() # right is a torch.Tensor + + key_checks.append((key, check)) + + assert all(t[1] for t in key_checks), ( + f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (i: {i}, frame_idx: {frame_idx})" + ) From 847e74f62827253507dd7aca18da028596811f31 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 15 Sep 2025 18:52:30 +0200 Subject: [PATCH 087/158] Update dataset card by default (#1936) * remove condition on model card update --- src/lerobot/datasets/lerobot_dataset.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 9cd4b6bff..4ac7a841c 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -29,7 +29,6 @@ import PIL.Image import torch import torch.utils from huggingface_hub import HfApi, snapshot_download -from huggingface_hub.constants import REPOCARD_NAME from huggingface_hub.errors import RevisionNotFoundError from lerobot.constants import HF_LEROBOT_HOME @@ -675,11 +674,10 @@ class LeRobotDataset(torch.utils.data.Dataset): else: hub_api.upload_folder(**upload_kwargs) - if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch): - card = create_lerobot_dataset_card( - tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs - ) - card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch) + card = create_lerobot_dataset_card( + tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs + ) + card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch) if tag_version: with contextlib.suppress(RevisionNotFoundError): From 55e752f0c2e7fab0d989c5ff999fbe3b6d8872ab Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Tue, 16 Sep 2025 17:45:38 +0200 Subject: [PATCH 088/158] docs(dataset): add dataset v3 documentation (#1956) * add v3 doc * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * update changes * iterate on review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add changes * create dataset section * Update docs/source/lerobot-dataset-v3.mdx Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> * Update docs/source/lerobot-dataset-v3.mdx Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> * Update docs/source/lerobot-dataset-v3.mdx Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --------- Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michel Aractingi Co-authored-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --- docs/source/_toctree.yml | 6 +- docs/source/lerobot-dataset-v3.mdx | 169 +++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 docs/source/lerobot-dataset-v3.mdx diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 5f5a509c7..9f5de8230 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -19,9 +19,13 @@ title: Train RL in Simulation - local: async title: Use Async Inference + title: "Tutorials" +- sections: + - local: lerobot-dataset-v3 + title: Using LeRobotDataset - local: porting_datasets_v3 title: Porting Large Datasets - title: "Tutorials" + title: "Datasets" - sections: - local: smolvla title: Finetune SmolVLA diff --git a/docs/source/lerobot-dataset-v3.mdx b/docs/source/lerobot-dataset-v3.mdx new file mode 100644 index 000000000..4f33d9a25 --- /dev/null +++ b/docs/source/lerobot-dataset-v3.mdx @@ -0,0 +1,169 @@ +# LeRobotDataset v3.0 + +`LeRobotDataset v3.0` is a standardized format for robot learning data. It provides unified access to multi-modal time-series data, sensorimotor signals and multi‑camera video, as well as rich metadata for indexing, search, and visualization on the Hugging Face Hub. + +This docs will guide you to: + +- Understand the v3.0 design and directory layout +- Record a dataset and push it to the Hub +- Load datasets for training with `LeRobotDataset` +- Stream datasets without downloading using `StreamingLeRobotDataset` +- Migrate existing `v2.1` datasets to `v3.0` + +## What’s new in `v3` + +- **File-based storage**: Many episodes per Parquet/MP4 file (v2 used one file per episode). +- **Relational metadata**: Episode boundaries and lookups are resolved through metadata, not filenames. +- **Hub-native streaming**: Consume datasets directly from the Hub with `StreamingLeRobotDataset`. +- **Lower file-system pressure**: Fewer, larger files ⇒ faster initialization and fewer issues at scale. +- **Unified organization**: Clean directory layout with consistent path templates across data and videos. + +## Installation + +`LeRobotDataset v3.0` will be included in `lerobot >= 0.4.0`. + +Until that stable release, you can use the main branch by following the [build from source instructions](./installation#from-source). + +## Record a dataset + +Run the command below to record a dataset with the SO-101 and push to the Hub: + +```bash +lerobot-record \ + --robot.type=so101_follower \ + --robot.port=/dev/tty.usbmodem585A0076841 \ + --robot.id=my_awesome_follower_arm \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ + --teleop.type=so101_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ + --teleop.id=my_awesome_leader_arm \ + --display_data=true \ + --dataset.repo_id=${HF_USER}/record-test \ + --dataset.num_episodes=5 \ + --dataset.single_task="Grab the black cube" +``` + +See the [recording guide](./il_robots#record-a-dataset) for more details. + +## Format design + +A core v3 principle is **decoupling storage from the user API**: data is stored efficiently (few large files), while the public API exposes intuitive episode-level access. + +`v3` has three pillars: + +1. **Tabular data**: Low‑dimensional, high‑frequency signals (states, actions, timestamps) stored in **Apache Parquet**. Access is memory‑mapped or streamed via the `datasets` stack. +2. **Visual data**: Camera frames concatenated and encoded into **MP4**. Frames from the same episode are grouped; videos are sharded per camera for practical sizes. +3. **Metadata**: JSON/Parquet records describing schema (feature names, dtypes, shapes), frame rates, normalization stats, and **episode segmentation** (start/end offsets into shared Parquet/MP4 files). + +> To scale to millions of episodes, tabular rows and video frames from multiple episodes are **concatenated** into larger files. Episode‑specific views are reconstructed **via metadata**, not file boundaries. + +
+
+ LeRobotDataset v3 diagram +
+ From episode‑based to file‑based datasets +
+
+
+ +### Directory layout (simplified) + +- **`meta/info.json`**: canonical schema (features, shapes/dtypes), FPS, codebase version, and **path templates** to locate data/video shards. +- **`meta/stats.json`**: global feature statistics (mean/std/min/max) used for normalization; exposed as `dataset.meta.stats`. +- **`meta/tasks.jsonl`**: natural‑language task descriptions mapped to integer IDs for task‑conditioned policies. +- **`meta/episodes/`**: per‑episode records (lengths, tasks, offsets) stored as **chunked Parquet** for scalability. +- **`data/`**: frame‑by‑frame **Parquet** shards; each file typically contains **many episodes**. +- **`videos/`**: **MP4** shards per camera; each file typically contains **many episodes**. + +## Load a dataset for training + +`LeRobotDataset` returns Python dictionaries of PyTorch tensors and integrates with `torch.utils.data.DataLoader`. Here is a code example showing its use: + +```python +import torch +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +repo_id = "yaak-ai/L2D-v3" + +# 1) Load from the Hub (cached locally) +dataset = LeRobotDataset(repo_id) + +# 2) Random access by index +sample = dataset[100] +print(sample) +# { +# 'observation.state': tensor([...]), +# 'action': tensor([...]), +# 'observation.images.front_left': tensor([C, H, W]), +# 'timestamp': tensor(1.234), +# ... +# } + +# 3) Temporal windows via delta_timestamps (seconds relative to t) +delta_timestamps = { + "observation.images.front_left": [-0.2, -0.1, 0.0] # 0.2s and 0.1s before current frame +} + +dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps) + +# Accessing an index now returns a stack for the specified key(s) +sample = dataset[100] +print(sample["observation.images.front_left"].shape) # [T, C, H, W], where T=3 + +# 4) Wrap with a DataLoader for training +batch_size = 16 +data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size) + +device = "cuda" if torch.cuda.is_available() else "cpu" +for batch in data_loader: + observations = batch["observation.state"].to(device) + actions = batch["action"].to(device) + images = batch["observation.images.front_left"].to(device) + # model.forward(batch) +``` + +## Stream a dataset (no downloads) + +Use `StreamingLeRobotDataset` to iterate directly from the Hub without local copies. This allows to stream large datasets without the need to downloading them onto disk or loading them onto memory, and is a key feature of the new dataset format. + +```python +from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset + +repo_id = "yaak-ai/L2D-v3" +dataset = StreamingLeRobotDataset(repo_id) # streams directly from the Hub +``` + +
+
+ StreamingLeRobotDataset +
+ Stream directly from the Hub for on‑the‑fly training. +
+
+
+ +## Migrate `v2.1` → `v3.0` + +A converter aggregates per‑episode files into larger shards and writes episode offsets/metadata. Convert your dataset using the instructions below. + +```bash +# Pre-release build with v3 support: +pip install "https://github.com/huggingface/lerobot/archive/33cad37054c2b594ceba57463e8f11ee374fa93c.zip" + +# Convert an existing v2.1 dataset hosted on the Hub: +python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id= +``` + +**What it does** + +- Aggregates parquet files: `episode-0000.parquet`, `episode-0001.parquet`, … → **`file-0000.parquet`**, … +- Aggregates mp4 files: `episode-0000.mp4`, `episode-0001.mp4`, … → **`file-0000.mp4`**, … +- Updates `meta/episodes/*` (chunked Parquet) with per‑episode lengths, tasks, and byte/frame offsets. From 78b866116fd73f9c0c83f9a7bc8fde8ea4b9b96d Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Thu, 18 Sep 2025 15:25:26 +0200 Subject: [PATCH 089/158] feat(processors): use pipelines across the codebase (#1452) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Refactor observation preprocessing to use a modular pipeline system - Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations. - Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline. - Added tests for the new processing components and ensured they match the original functionality. - Removed hardcoded logic in favor of a more flexible, composable architecture. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor observation processing and improve modularity - Updated `ObservationProcessor` to enhance the modular design for processing observations. - Cleaned up imports and improved code readability by removing unnecessary lines and comments. - Ensured backward compatibility while integrating new processing components. - Added tests to validate the functionality of the updated processing architecture. * Remove redundant tests for None observation and serialization methods in `test_observation_processor.py` to streamline the test suite and improve maintainability. * Refactor processing architecture to use RobotProcessor - Replaced instances of RobotPipeline with RobotProcessor across the codebase for improved modularity and clarity. - Introduced ProcessorStepRegistry for better management of processing steps. - Updated relevant documentation and tests to reflect the new processing structure. - Enhanced the save/load functionality to support the new processor design. - Added a model card template for RobotProcessor to facilitate sharing and documentation. * Add RobotProcessor tutorial to documentation - Introduced a new tutorial on using RobotProcessor for preprocessing robot data. - Added a section in the table of contents for easy navigation to the new tutorial. - The tutorial covers key concepts, real-world scenarios, and practical examples for effective use of the RobotProcessor pipeline. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add normalization processor and related components - Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization. - Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks. - Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports. - Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity. - Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. * chore (docs): add docstring for processor * fix (test): test factory * fix(test): policies * Update tests/processor/test_observation_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine * chore(test): add suggestion made by copilot regarding numpy test * fix(test): import issue * Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. * chore (docstrin):Improve docstring for NormalizerProcessor * feat (device processor): Implement device processor * chore (batch handling): Enhance processing components with batch conversion utilities * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(test): linting issue * chore (output format): improves output format * chore (type): add typing for multiprocess envs * feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. * chore(normalization): addressing comments from copilot * chore(learner): nit comment from copilot * feat(pipeline): Enhance step_through method to support both tuple and dict inputs * refactor(pipeline): Simplify observation and padding data handling in batch transitions * Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Transition from tuple to dictionary format for EnvTransition - Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase. * refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling - Introduced constants for observation keys to enhance readability. - Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly. - Updated the environment state and agent position assignments to use the new constants, improving maintainability. * feat(pipeline): Add hook unregistration functionality and enhance documentation - Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management. - Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks. - Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks. * refactor(pipeline): Clarify hook behavior and improve documentation - Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability. - Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions. - Enhanced documentation to clearly outline the purpose of hooks and their execution semantics. - Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method. * feat(pipeline): Add __repr__ method to RobotProcessor for improved readability - Implemented a __repr__ method in the RobotProcessor class to provide a clear string representation of the processor, including step names and optional parameters like name and seed. - Added comprehensive tests to validate the __repr__ output for various scenarios, including empty processors, single and multiple steps, custom names, and seed values. - Ensured that the representation handles long lists of steps with truncation for better readability. * chore(pipeline): Move _CFG_NAME along other class member * refactor(pipeline): Utilize get_safe_torch_device for device assignment - Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling. - This change enhances code readability and maintains consistency in device management across the RobotProcessor class. * refactor(pipeline): Enhance state filename generation and profiling method - Updated state filename generation to use the registry name when available, improving clarity in saved files. - Modified the profile_steps method to include a warmup_runs parameter, allowing for more controlled performance profiling. - Ensured consistent conditions during profiling by deep copying transitions for each run, enhancing accuracy in timing results. * chore(doc): address pip install commant lerobot that not exist yet * feat(pipeline): Enhance configuration filename handling and state file naming - Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default. - Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved. - Added automatic detection for configuration files when loading from a directory, with error handling for multiple files. - Updated tests to validate new features, including custom filenames and automatic config detection. * refactor(pipeline): Improve state file naming conventions for clarity and uniqueness - Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory. - Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts. - Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files. * docs(pipeline): Add clarification for repo name sanitization process * Feat/pipeline add feature contract (#1637) * Add feature contract to pipelinestep and pipeline * Add tests * Add processor tests * PR feedback * encorperate pr feedback * type in doc * oops * docs(pipeline): Clarify transition handling and hook behavior - Updated documentation to specify that hooks always receive transitions in EnvTransition format, ensuring consistent behavior across input formats. - Refactored the step_through method to yield only EnvTransition objects, regardless of the input format, and updated related tests to reflect this change. - Enhanced test assertions to verify the structure of results and the correctness of processing steps. * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * refactor(pipeline): Remove model card generation and streamline processor methods - Eliminated the _generate_model_card method from RobotProcessor, which was responsible for generating README.md files from a template. - Updated save_pretrained method to remove model card generation, focusing on serialization of processor definitions and parameters. - Added default implementations for get_config, state_dict, load_state_dict, reset, and feature_contract methods in various processor classes to enhance consistency and usability. * refactor(observation): Streamline observation preprocessing and remove unused processor methods - Updated the `preprocess_observation` function to enhance image handling and ensure proper tensor formatting. - Removed the `RobotProcessor` and associated transition handling from the `rollout` function, simplifying the observation processing flow. - Integrated direct calls to `preprocess_observation` for improved clarity and efficiency in the evaluation script. * refactor(pipeline): Rename parameters for clarity and enhance save/load functionality - Updated parameter names in the save_pretrained and from_pretrained methods for improved readability, changing destination_path to save_directory and source to pretrained_model_name_or_path. - Enhanced the save_pretrained method to ensure directory creation and file handling is consistent with the new parameter names. - Streamlined the loading process in from_pretrained to utilize loaded_config for better clarity and maintainability. * refactor(pipeline): minor improvements (#1684) * chore(pipeline): remove unused features + device torch + envtransition keys * refactor(pipeline): ImageProcessor & StateProcessor are both implemented directly in VanillaObservationPRocessor * refactor(pipeline): RenameProcessor now inherits from ObservationProcessor + remove unused code * test(pipeline): fix broken test after refactors * docs(pipeline): update docstrings VanillaObservationProcessor * chore(pipeline): move None check to base pipeline classes * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * refactor(normalization): Remove unused state dict transformation methods and streamline imports - Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process. - Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging. - Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation. - Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality. * refactor(normalization): Clean up imports in normalize_processor.py * feat(batch_processor): Add feature_contract method to ToBatchProcessor - Introduced feature_contract method that returns features without modification, maintaining the no-op behavior of the processor. - This addition enhances the flexibility of the ToBatchProcessor for future feature processing needs. * fix(dependencies): Update transformers dependency constraint to allow only versions up to 4.52.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feature(pipeline): port tokenizer pipeline for VLA (#1645) * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feture(policies): add device processor (#1659) * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * feat(policies): Add new line processors and update module exports * feat(processor): Enhance batch and device processors to handle index and task_index fields - Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors. - Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged. - Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation. * refactor(processors): Standardize processor naming conventions - Updated processor names across various files to use a consistent "robot_preprocessor" and "robot_postprocessor" format. - Modified the make_processor functions in factory, act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet to reflect the new naming scheme. - Enhanced the pipeline configuration to align with the updated processor names, improving clarity and maintainability. * refactor(factory): Update processor configuration and type hints - Changed return type of get_policy_class to type[PreTrainedPolicy] for improved type safety. - Enhanced make_processor function to utilize dataset_stats in processor creation for better flexibility. - Updated ProcessorConfigKwargs to include dataset_stats, allowing for more comprehensive processor configurations. - Streamlined processor initialization by removing unnecessary kwargs and ensuring clarity in processor type handling. * refactor(factory, pi0fast): Update processor function names and parameters - Renamed make_pi0_processor to make_pi0fast_processor for clarity and consistency. - Updated parameter names in the factory's make_processor function to use pretrained_model_name_or_path instead of source, enhancing readability and alignment with naming conventions. * fix(train.py) push postprocessor with preprocessor - Add preprocesser policy overrides for device and rename_map - Add rename_map to DatasetRecordConfig (record.py) * refactor(device_processor): Update device handling and improve type hints - Changed device attribute type from torch.device to str for better clarity. - Introduced a private _device attribute to store the actual torch.device instance. - Updated tests to conditionally check for CUDA availability, ensuring compatibility across different environments. - Refactored device-related assertions in tests to use a consistent approach for device type verification. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test(tokenizer_processor): Add require_package decorator for transformers - Introduced @require_package("transformers") decorator in multiple test functions to ensure the transformers package is available before running tests. - This change enhances test reliability by preventing failures due to missing dependencies. * refactor(migrate_policy_normalization): Enhance preprocessor and postprocessor structure - Introduced RenameProcessor in the preprocessor to handle renaming features. - Combined input and output features in a single NormalizerProcessor for improved efficiency. - Updated RobotProcessor initialization to clarify step naming for preprocessor and postprocessor. - Added DeviceProcessor to both preprocessor and postprocessor for better device management. * Integrate pipeline and add phone teleop (#1681) * Add normalization processor and related components - Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization. - Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks. - Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports. - Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity. - Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. * chore (docs): add docstring for processor * fix (test): test factory * fix(test): policies * Update tests/processor/test_observation_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine * chore(test): add suggestion made by copilot regarding numpy test * fix(test): import issue * Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. * chore (docstrin):Improve docstring for NormalizerProcessor * feat (device processor): Implement device processor * chore (batch handling): Enhance processing components with batch conversion utilities * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(test): linting issue * chore (output format): improves output format * chore (type): add typing for multiprocess envs * feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. * chore(normalization): addressing comments from copilot * chore(learner): nit comment from copilot * feat(pipeline): Enhance step_through method to support both tuple and dict inputs * refactor(pipeline): Simplify observation and padding data handling in batch transitions * Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions * fix(ci): temporary fix on dataset deps version * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * refactor(train): Update memory pinning logic for mps compatibility * feat: initial commit phone teleop * ugly delta control * use quaternion * Refactor observation preprocessing to use a modular pipeline system - Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations. - Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline. - Added tests for the new processing components and ensured they match the original functionality. - Removed hardcoded logic in favor of a more flexible, composable architecture. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactor observation processing and improve modularity - Updated `ObservationProcessor` to enhance the modular design for processing observations. - Cleaned up imports and improved code readability by removing unnecessary lines and comments. - Ensured backward compatibility while integrating new processing components. - Added tests to validate the functionality of the updated processing architecture. * Remove redundant tests for None observation and serialization methods in `test_observation_processor.py` to streamline the test suite and improve maintainability. * Refactor processing architecture to use RobotProcessor - Replaced instances of RobotPipeline with RobotProcessor across the codebase for improved modularity and clarity. - Introduced ProcessorStepRegistry for better management of processing steps. - Updated relevant documentation and tests to reflect the new processing structure. - Enhanced the save/load functionality to support the new processor design. - Added a model card template for RobotProcessor to facilitate sharing and documentation. * Add RobotProcessor tutorial to documentation - Introduced a new tutorial on using RobotProcessor for preprocessing robot data. - Added a section in the table of contents for easy navigation to the new tutorial. - The tutorial covers key concepts, real-world scenarios, and practical examples for effective use of the RobotProcessor pipeline. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add normalization processor and related components - Introduced `NormalizationProcessor` to handle both observation normalization and action unnormalization. - Added `ObservationNormalizer` and `ActionUnnormalizer` classes for specific normalization tasks. - Updated `__init__.py` to include the new `NormalizationProcessor` in the module exports. - Enhanced `ObservationProcessor` with registration in the `ProcessorStepRegistry` for better modularity. - Created `RenameProcessor` for renaming keys in observations, improving flexibility in data processing. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Enhance processing architecture with new components - Added `RenameProcessor` to facilitate key renaming in observations, improving data handling flexibility. - Updated `__init__.py` to include `RenameProcessor` in module exports. - Refactored `NormalizationProcessor` and `ObservationNormalizer` to use `rsplit` for better key handling. - Introduced comprehensive tests for `NormalizationProcessor` and `RenameProcessor` to ensure functionality and robustness. * chore (docs): add docstring for processor * fix (test): test factory * fix(test): policies * Update tests/processor/test_observation_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine * chore(test): add suggestion made by copilot regarding numpy test * fix(test): import issue * Refactor normalization components and update tests - Renamed `ObservationNormalizer` to `NormalizerProcessor` and `ActionUnnormalizer` to `UnnormalizerProcessor` for clarity. - Consolidated normalization logic for both observations and actions into `NormalizerProcessor` and `UnnormalizerProcessor`. - Updated tests to reflect the new class names and ensure proper functionality of normalization and unnormalization processes. - Enhanced handling of missing statistics in normalization processes. * chore (docstrin):Improve docstring for NormalizerProcessor * feat (device processor): Implement device processor * chore (batch handling): Enhance processing components with batch conversion utilities * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(test): linting issue * chore (output format): improves output format * chore (type): add typing for multiprocess envs * feat (overrides): Implement support for loading processors with parameter overrides - Added the ability to provide non-serializable objects when loading processors from saved configurations using the `overrides` parameter. - Enhanced error handling for invalid override keys and instantiation errors. - Updated documentation and examples to illustrate the usage of overrides for both registered and unregistered steps. - Added comprehensive tests to validate the new functionality and ensure backward compatibility. * chore(normalization): addressing comments from copilot * chore(learner): nit comment from copilot * feat(pipeline): Enhance step_through method to support both tuple and dict inputs * refactor(pipeline): Simplify observation and padding data handling in batch transitions * Apply suggestions from code review Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Signed-off-by: Adil Zouitine * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Transition from tuple to dictionary format for EnvTransition - Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase. * refactor(observation_processor): Improve observation processing by using constants and simplifying pixel handling - Introduced constants for observation keys to enhance readability. - Streamlined the handling of the "pixels" key by copying observations first and processing images more clearly. - Updated the environment state and agent position assignments to use the new constants, improving maintainability. * feat(pipeline): Add hook unregistration functionality and enhance documentation - Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management. - Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks. - Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks. * refactor(pipeline): Clarify hook behavior and improve documentation - Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability. - Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions. - Enhanced documentation to clearly outline the purpose of hooks and their execution semantics. - Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method. * feat(pipeline): Add __repr__ method to RobotProcessor for improved readability - Implemented a __repr__ method in the RobotProcessor class to provide a clear string representation of the processor, including step names and optional parameters like name and seed. - Added comprehensive tests to validate the __repr__ output for various scenarios, including empty processors, single and multiple steps, custom names, and seed values. - Ensured that the representation handles long lists of steps with truncation for better readability. * chore(pipeline): Move _CFG_NAME along other class member * refactor(pipeline): Utilize get_safe_torch_device for device assignment - Replaced direct torch.device instantiation with get_safe_torch_device to ensure safe device handling. - This change enhances code readability and maintains consistency in device management across the RobotProcessor class. * refactor(pipeline): Enhance state filename generation and profiling method - Updated state filename generation to use the registry name when available, improving clarity in saved files. - Modified the profile_steps method to include a warmup_runs parameter, allowing for more controlled performance profiling. - Ensured consistent conditions during profiling by deep copying transitions for each run, enhancing accuracy in timing results. * chore(doc): address pip install commant lerobot that not exist yet * feat(pipeline): Enhance configuration filename handling and state file naming - Introduced support for custom configuration filenames in the `save_pretrained` method, allowing users to specify a filename instead of the default. - Improved state file naming to include step indices, preventing conflicts when multiple processors of the same type are saved. - Added automatic detection for configuration files when loading from a directory, with error handling for multiple files. - Updated tests to validate new features, including custom filenames and automatic config detection. * refactor(pipeline): Improve state file naming conventions for clarity and uniqueness - Enhanced state file naming to include the processor's sanitized name, ensuring uniqueness when multiple processors are saved in the same directory. - Updated tests to reflect changes in state file naming, verifying that filenames now include the processor name and step indices to prevent conflicts. - Added a new test to validate state file naming when using multiple processors, ensuring distinct filenames for each processor's state files. * docs(pipeline): Add clarification for repo name sanitization process * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * Add debug + calib * cleanup * Add pipeline * fix int * Add record example * nit * Add feature contract to pipelinestep and pipeline * Add tests * Add processor tests * PR feedback * encorperate pr feedback * type in doc * oops * cleaned up steps and integrated pipeline with feature_contract * refactor steps and robot to pipeline * cleanup pipeline * cleanup code further * make it run * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * refactor(normalization): Remove unused state dict transformation methods and streamline imports - Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process. - Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging. - Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation. - Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality. * refactor(normalization): Clean up imports in normalize_processor.py * feat(batch_processor): Add feature_contract method to ToBatchProcessor - Introduced feature_contract method that returns features without modification, maintaining the no-op behavior of the processor. - This addition enhances the flexibility of the ToBatchProcessor for future feature processing needs. * fix(dependencies): Update transformers dependency constraint to allow only versions up to 4.52.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * Do some todos and cleanup * change feature_contract to dataset_features * use one method for conversion pipeline output to add_frame dict and use base processors where possible * Add back in and use record_loop * update todo * rename to_dataset_frame * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix reference frame * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * update data visualization * update teleop example * fix record bugs * Add replay * Not code * feature(pipeline): port tokenizer pipeline for VLA (#1645) * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feture(policies): add device processor (#1659) * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * feat(policies): Add new line processors and update module exports * feat(processor): Enhance batch and device processors to handle index and task_index fields - Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors. - Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged. - Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation. * Add eval script * fix `q_curr` in InverseKinematicsEEToJoints to the IK solution * feat(processors): Introduce processors for various policy types - Added `make_processor` function to create processor instances for different policy types, including `tdmpc`, `diffusion`, `act`, `vqbet`, `pi0`, `pi0fast`, `sac`, and `reward_classifier`. - Implemented corresponding processor files for each policy type, encapsulating normalization and unnormalization steps. - Updated existing policies to remove direct normalization dependencies, enhancing modularity and clarity. - Enhanced test coverage to validate the integration of new processors with existing policy configurations. * refactor(learner): Remove normalization from cached image features retrieval - Simplified the retrieval of observation features by removing the normalization step from the `get_cached_image_features` method calls. - This change enhances clarity and aligns with the recent updates to policy processors. * refactor(policies): Remove unnormalization step from action predictions - Eliminated the unnormalization of actions in both `TDMPCPolicy` and `VQBeTPolicy` classes to streamline action prediction. - This change improves code clarity and aligns with recent updates to policy processors. * feat(train): Integrate preprocessor into training pipeline * refactor(train): Update preprocessor initialization to include dataset statistics * refactor(policies): Enhance processor creation and add NaN detection hook * feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. * feat(migration): Add script for migrating policy models with normalization layers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(migrate): Enhance migration script to create preprocessor and postprocessor for policy models - Updated the migration script to generate both a preprocessor and a postprocessor, improving the handling of normalization for training and inference. - Added functionality to convert features to PolicyFeature objects, ensuring compatibility with the new processor architecture. - Refined the extraction and removal of normalization statistics and layers, streamlining the migration process. - Improved error handling for missing mandatory configuration fields during model instantiation. * feat(migrate): Add model card generation and saving to migration script - Implemented functionality to generate and save a model card for the migrated model, including metadata such as dataset repository ID, license, and tags. - Enhanced the script to push the model card to the hub if requested, improving model documentation and accessibility. - Refactored the saving process to ensure the model card is saved locally and uploaded correctly when pushing to the hub. * feat(processor): Introduce ToBatchProcessor for handling observation batching - Added ToBatchProcessor to ensure observations have proper batch dimensions for model processing. - Implemented functionality to add batch dimensions to state and image observations as needed. - Created comprehensive unit tests to validate the processor's behavior with various tensor dimensions and types. - Ensured compatibility with existing transition keys and maintained the integrity of non-observation data. * feat(processors): Add ToBatchProcessor to multiple policy processors - Integrated ToBatchProcessor into various policy processors to handle observation batching. - Updated make functions for act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet processors to include the new batching functionality. - Ensured consistency across all processor implementations for improved data handling. * refactor(factory): Remove unused imports and NaN detection hook from processor creation * feat(batch_processor): Enhance ToBatchProcessor to handle action batching - Updated ToBatchProcessor to add batch dimensions to actions in addition to observations. - Implemented separate methods for processing observations and actions, improving code readability. - Added comprehensive unit tests to validate action batching functionality across various tensor dimensions and types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. * refactor(factory): Clean up imports in factory.py - Removed unused import of IdentityProcessor to streamline the code. * feat(migrate): Extend load_model_from_hub to include train configuration - Updated load_model_from_hub to return the train configuration alongside the model state_dict and config. - Modified main function to handle the additional train configuration when loading models from both the hub and local paths. - Adjusted dataset_repo_id extraction to utilize the train configuration for improved accuracy. * refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. * feat(batch_processor): Add task field processing to ToBatchProcessor - Enhanced ToBatchProcessor to wrap string tasks in a list, adding batch dimensions for compatibility with model inference. - Implemented a new method for processing complementary data, ensuring that task values are correctly handled as either strings or lists of strings. - Added comprehensive unit tests to validate task processing, including edge cases and in-place mutation of complementary data. * feat(normalization): Implement IDENTITY mode for normalization and unnormalization - Enhanced NormalizerProcessor and UnnormalizerProcessor to support IDENTITY mode, allowing features to bypass normalization when specified. - Updated processing logic to check normalization modes and handle missing statistics gracefully. - Added comprehensive unit tests to validate IDENTITY mode functionality for both observations and actions, ensuring correct behavior across various scenarios. - Improved error handling for unsupported normalization modes. * fix(rebase): remove residual normalization layer: * refactor(diffusion): remove normalization layer from input processing * refactor(normalization): Remove unused state dict transformation methods and streamline imports - Eliminated the _transform_state_dict_keys and _load_as_safetensor methods from PI0Policy, simplifying the model loading process. - Cleaned up imports in modeling_pi0.py by removing log_model_loading_keys and init_logging. - Updated TDMPCPolicy and VQBeTPolicy to handle action removal from batches during offline evaluation. - Introduced hotswap_stats function in normalize_processor.py to update normalization statistics dynamically, with corresponding tests to ensure functionality. * refactor(normalization): Clean up imports in normalize_processor.py * feat(batch_processor): Add feature_contract method to ToBatchProcessor - Introduced feature_contract method that returns features without modification, maintaining the no-op behavior of the processor. - This addition enhances the flexibility of the ToBatchProcessor for future feature processing needs. * fix(dependencies): Update transformers dependency constraint to allow only versions up to 4.52.0 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feature(pipeline): port tokenizer pipeline for VLA (#1645) * feat(tokenizer): Introduce TokenizerProcessor for text tokenization - Added TokenizerProcessor class to handle tokenization of task strings using Hugging Face's AutoTokenizer. - Supports both string and list inputs, with customizable parameters for task key, output key, and tokenization settings. - Implemented comprehensive unit tests to validate functionality, including handling of various input scenarios and integration with RobotProcessor. - Updated types.py to include LANGUAGE feature type and modified __init__.py to register the new processor. * feat(language): Enhance language processing in TokenizerProcessor - Added OBS_LANGUAGE constant to define the observation language key. - Updated TokenizerProcessor to store tokenized task data in the observation dictionary, ensuring compatibility with the new language feature. - Introduced Pi0NewLineProcessor to append newlines to tasks for proper tokenization. - Modified tests to validate the integration of language tokens and attention masks in the observation structure. * feat(tokenizer): Add padding configuration to TokenizerProcessor - Introduced `padding_side` parameter to the TokenizerProcessor for customizable padding direction. - Updated the `make_pi0_processor` function to include the new padding configuration. - Enhanced unit tests to validate the functionality of the `padding_side` parameter in various scenarios. * feat(processor): Add state management methods to Pi0NewLineProcessor * feat(normalization): Track normalization and unnormalization info in complementary data - Updated NormalizerProcessor and UnnormalizerProcessor to accept additional parameters for tracking normalization modes. - Enhanced the __call__ methods to store normalization and unnormalization information in the complementary data of transitions. - Added unit tests to verify the correct tracking of normalization info, including scenarios with missing stats and selective normalization keys. * feat(factory): Add preprocessor and postprocessor overrides to ProcessorConfigKwargs - Updated ProcessorConfigKwargs to include optional overrides for preprocessor and postprocessor configurations. - Enhanced the make_processor function to utilize the new overrides, allowing for more flexible processor initialization. * feat(processors): Integrate RenameProcessor into various processor configurations - Added RenameProcessor to the input steps of multiple processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Consolidated normalization features from input and output into a single NormalizerProcessor for improved efficiency. - Updated the input steps to ensure compatibility with the new RenameProcessor integration. * feat(smolvla): Refactor language processing and introduce new line processor (#1658) - Removed the prepare_language method and directly accessed language tokens and masks from the batch using the OBS_LANGUAGE constant. - Added SmolVLANewLineProcessor to ensure tasks end with a newline, enhancing tokenization compatibility. - Updated the make_smolvla_processor function to include the new line processor and tokenizer processor for improved input handling. * feture(policies): add device processor (#1659) * feat(processors): Integrate DeviceProcessor into multiple processor configurations - Added DeviceProcessor to the input and output steps of various processor functions, including make_act_processor, make_diffusion_processor, make_pi0_processor, make_pi0fast_processor, make_sac_processor, make_tdmpc_processor, make_vqbet_processor, and make_smolvla_processor. - Enhanced the DeviceProcessor class with state management methods and ensured compatibility with existing processor pipelines. - Introduced unit tests for DeviceProcessor to validate functionality across different scenarios, including CPU and CUDA operations. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(pipeline): Remove to() method for device management - Eliminated the to() method from RobotProcessor, which was responsible for moving tensor states to specified devices. - Removed associated unit tests that validated the functionality of the to() method across various scenarios. - Streamlined the pipeline code by focusing on other device management strategies. * feat(processor): Enhance DeviceProcessor with float dtype conversion - Added support for optional float dtype conversion in DeviceProcessor, allowing tensors to be converted to specified floating-point types while preserving non-float types. - Implemented validation for float dtype input and updated the processor's configuration methods to include float dtype. - Refactored tensor processing logic to streamline device movement and dtype conversion. - Introduced comprehensive unit tests to validate the new float dtype functionality across various scenarios. * feat(policies): Add new line processors and update module exports * feat(processor): Enhance batch and device processors to handle index and task_index fields - Added logic to ToBatchProcessor for unsqueezing 0D tensors for index and task_index fields, ensuring they are processed as 1D tensors. - Updated DeviceProcessor to process index and task_index fields in complementary data, preserving their tensor types and ensuring non-tensor fields remain unchanged. - Enhanced unit tests to validate the correct handling of index and task_index fields across various scenarios, including device compatibility and dtype preservation. * refactor(processors): Standardize processor naming conventions - Updated processor names across various files to use a consistent "robot_preprocessor" and "robot_postprocessor" format. - Modified the make_processor functions in factory, act, diffusion, pi0, pi0fast, sac, smolvla, tdmpc, and vqbet to reflect the new naming scheme. - Enhanced the pipeline configuration to align with the updated processor names, improving clarity and maintainability. * refactor(factory): Update processor configuration and type hints - Changed return type of get_policy_class to type[PreTrainedPolicy] for improved type safety. - Enhanced make_processor function to utilize dataset_stats in processor creation for better flexibility. - Updated ProcessorConfigKwargs to include dataset_stats, allowing for more comprehensive processor configurations. - Streamlined processor initialization by removing unnecessary kwargs and ensuring clarity in processor type handling. * Fix eval and android gripper * add some tests * refactor(factory, pi0fast): Update processor function names and parameters - Renamed make_pi0_processor to make_pi0fast_processor for clarity and consistency. - Updated parameter names in the factory's make_processor function to use pretrained_model_name_or_path instead of source, enhancing readability and alignment with naming conventions. * fix(train.py) push postprocessor with preprocessor - Add preprocesser policy overrides for device and rename_map - Add rename_map to DatasetRecordConfig (record.py) * Cleanup pr * fix more git diff pr issues * add path as type in save_pretrained * small nit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename test file * fix: make dataset_features/feature_contract is optional * fix tests * Encorperate pr feedback * clean up record.py * add ascii art, fix normal record * remove merge issues * fix merge * remove features * Add feedback PR * fix last 4 tests * remove features check * rename to transform_features * add transform_features * fix lekiwi eval and update eval api example --------- Signed-off-by: Adil Zouitine Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Adil Zouitine Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Michel Aractingi * refactor(TokenizerProcessor): improve dependency handling and observation management - Updated TokenizerProcessor to conditionally import AutoTokenizer based on the availability of the transformers library, enhancing flexibility. - Modified tokenizer attribute type to Any to accommodate scenarios where transformers may not be installed. - Improved observation handling by using a more concise approach to manage the transition dictionary, ensuring compatibility with existing data structures. - Added error handling for missing transformers library, providing clear guidance for users on installation requirements. * feat(dependencies): Add scipy as a required dependency - Included `scipy>=1.15.2` in the project dependencies to enhance functionality and support for scientific computing tasks. * feat(policies): convert save_policy_to_safetensors with pipeline * refactor(normalization): remove Normalize and Unnormalize classes - Deleted the Normalize and Unnormalize classes from the normalization module to streamline the codebase. - Updated tests to ensure compatibility with the removal of these classes, focusing on the new NormalizerProcessor and UnnormalizerProcessor implementations. - Enhanced the handling of normalization statistics and improved overall code clarity. * refactor(factory): streamline processor loading by removing unused comments - Removed commented-out code related to loading pretrained processors in the make_processor function. - This change enhances code clarity and maintains focus on the current implementation. * feat(DeviceProcessor): Enhance tensor processing with device detection and float dtype conversion - Improved the _process_tensor method to preserve GPU placement for tensors already on a GPU, facilitating multi-GPU training scenarios. - Introduced a new _detect_device method in TokenizerProcessor to ensure tokenized tensors match the device of existing tensors in transitions. - Added comprehensive unit tests to validate the functionality of device detection and float dtype conversion across various scenarios. * feat(tests): Add comprehensive tests for various policy processors - Introduced new test files for ACT, Classifier, Diffusion, PI0, SAC, SmolVLA, TDMPC, and VQBeT policy processors. - Each test file includes unit tests to validate functionality, including handling of batch sizes, device management, and data type conversions. - Enhanced test coverage to ensure robustness and reliability of processor implementations across different scenarios. * refactor(train): Remove unnecessary tensor device handling in training loop * Refactor`gym_manipulator.py` using the universal pipeline (#1650) * Migrate gym_manipulator to use the pipeline Added get_teleop_events function to capture relevant events from teleop devices unrelated to actions * Added the capability to record a dataset * Added the replay functionality with the pipeline * Refactored `actor.py` to use the pipeline * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * RL works at this commit - fixed actor.py and bugs in gym_manipulator * change folder structure to reduce the size of gym_manip * Refactored hilserl config * Remove dataset and mode from HilSerlEnvConfig to a GymManipulatorConfig to reduce verbose of configs during training * format docs * removed get_teleop_events from abc * Refactor environment configuration and processing pipeline for GymHIL support. Removed device attribute from HILSerlRobotEnvConfig, added DummyTeleopDevice for simulation, and updated processor creation to accommodate GymHIL environments. * Improved typing for HILRobotEnv config and GymManipulator config * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Migrated `gym_manipulator` to use a more modular structure similar to phone teleop * Refactor gripper handling and transition processing in HIL and robot kinematic processors - Updated gripper position handling to use a consistent key format across processors - Improved the EEReferenceAndDelta class to handle reference joint positions. - Added support for discrete gripper actions in the GripperVelocityToJoint processor. - Refactored the gym manipulator to improve modularity and clarity in processing steps. * Added delta_action_processor mapping wrapper * Added missing file delta_action_processor and improved imports in `gym_manipulator` * nit * Added missing file joint_observation_processor * Enhance processing architecture with new teleoperation processors - Introduced `AddTeleopActionAsComplimentaryData` and `AddTeleopEventsAsInfo` for integrating teleoperator actions and events into transitions. - Added `Torch2NumpyActionProcessor` and `Numpy2TorchActionProcessor` for seamless conversion between PyTorch tensors and NumPy arrays. - Updated `__init__.py` to include new processors in module exports, improving modularity and clarity in the processing pipeline. - GymHIL is now fully supported with HIL using the pipeline * Refactor configuration structure for gym_hil integration - Renamed sections for better readability, such as changing "Gym Wrappers Configuration" to "Processor Configuration." - Enhanced documentation with clear examples for dataset collection and policy evaluation configurations. * Enhance reset configuration and teleoperation event handling - Added `terminate_on_success` parameter to `ResetConfig` and `InterventionActionProcessor` for controlling episode termination behavior upon success detection. - Updated documentation to clarify the impact of `terminate_on_success` on data collection for reward classifier training. - Refactored teleoperation event handling to use `TeleopEvents` constants for improved readability and maintainability across various modules. * fix(keyboard teleop), delta action keys * Added transform features and feature contract * Added transform features for image crop * Enum for TeleopEvents * Update tranform_features delta action proc --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Remove HILEnvConfig references * chore(processor): Add default names for preprocessor and postprocessor in constants - Introduced `PREPROCESSOR_DEFAULT_NAME` and `POSTPROCESSOR_DEFAULT_NAME` constants for consistent naming across various processor implementations. - Updated processor creation in multiple policy files to utilize these constants, enhancing code readability and maintainability. - Modified the training script to load and save the preprocessor and postprocessor using the new constants. * feat(processor): multiple improvements to the pipeline porting (#1749) * [Port codebase pipeline] General fixes for RL and scripts (#1748) * Refactor dataset configuration in documentation and codebase - Updated dataset configuration keys from `dataset_root` to `root` and `num_episodes` to `num_episodes_to_record` for consistency. - Adjusted replay episode handling by renaming `episode` to `replay_episode`. - Enhanced documentation - added specific processor to transform from policy actions to delta actions * Added Robot action to tensor processor Added new processor script for dealing with gym specific action processing * removed RobotAction2Tensor processor; imrpoved choosing observations in actor * nit in delta action * added missing reset functions to kinematics * Adapt teleoperate and replay to pipeline similar to record * refactor(processors): move to inheritance (#1750) * fix(teleoperator): improvements phone implementation (#1752) * fix(teleoperator): protect shared state in phone implementation * refactor(teleop): separate classes in phone * fix: solve breaking changes (#1753) * refactor(policies): multiple improvements (#1754) * refactor(processor): simpler logic in device processor (#1755) * refactor(processor): euclidean distance in delta action processor (#1757) * refactor(processor): improvements to joint observations processor migration (#1758) * refactor(processor): improvements to tokenizer migration (#1759) * refactor(processor): improvements to tokenizer migration * fix(tests): tokenizer tests regression from #1750 * fix(processors): fix float comparison and config in hil processors (#1760) * chore(teleop): remove unnecessary callbacks in KeyboardEndEffectorTeleop (#1761) * refactor(processor): improvements normalize pipeline migration (#1756) * refactor(processor): several improvements normalize processor step * refactor(processor): more improvements normalize processor * refactor(processor): more changes to normalizer * refactor(processor): take a different approach to DRY * refactor(processor): final design * chore(record): revert comment and continue deleted (#1764) * refactor(examples): pipeline phone examples (#1769) * refactor(examples): phone teleop + teleop script * refactor(examples): phone replay + replay * chore(examples): rename phone example files & folders * feat(processor): fix improvements to the pipeline porting (#1796) * refactor(processor): enhance tensor device handling in normalization process (#1795) * refactor(tests): remove unsupported device detection test for complementary data (#1797) * chore(tests): update ToBatchProcessor test (#1798) * refactor(tests): remove in-place mutation tests for actions and complementary data in batch processor * test(tests): add tests for action and task processing in batch processor * add names for android and ios phone (#1799) * use _tensor_stats in normalize processor (#1800) * fix(normalize_processor): correct device reference for tensor epsilon handling (#1801) * add point 5 add missing feature contracts (#1806) * Fix PR comments 1452 (#1807) * use key to determine image * Address rest of PR comments * use PolicyFeatures in transform_features --------- Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --------- Co-authored-by: Michel Aractingi Co-authored-by: Adil Zouitine Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> * refactor(constants, processor): standardize action and observation keys across multiple files (#1808) - Added new constants for truncated and done states in constants.py. - Updated references to action and observation keys in pipeline_features.py, converters.py, hil_processor.py, tokenizer_processor.py, and robot_kinematic_processor.py to use the new constants for improved readability and maintainability. * refactor(processor): improve processor pipeline typing with generic type (#1810) * refactor(processor): introduce generic type for to_output - Always return `TOutput` - Remove `_prepare_transition`, so `__call__` now always returns `TOutput` - Update tests accordingly - This refactor paves the way for adding settings for `to_transition` and `to_output` in `make_processor` and the post-processor * refactor(processor): consolidate ProcessorKwargs usage across policies - Removed the ProcessorTypes module and integrated ProcessorKwargs directly into the processor pipeline. - Updated multiple policy files to utilize the new ProcessorKwargs structure for preprocessor and postprocessor arguments. - Simplified the handling of processor kwargs by initializing them to empty dictionaries when not provided. * refactor(converters): implement unified tensor conversion function (#1830) - Introduced `to_tensor` function using `singledispatch` to handle various input types, including scalars, arrays, and dictionaries, converting them to PyTorch tensors. - Replaced previous tensor conversion logic in `gym_action_processor`, `normalize_processor`, and `test_converters` with the new `to_tensor` function for improved readability and maintainability. - Updated tests to cover new functionality and ensure correct tensor conversion behavior. * Revert "refactor(converters): implement unified tensor conversion function (#…" (#1840) This reverts commit a837685bf870919fc07ada287a71711cebabb1ea. * refactor(converters): implement unified tensor conversion function (#1841) - Introduced `to_tensor` function using `singledispatch` to handle various input types, including scalars, arrays, and dictionaries, converting them to PyTorch tensors. - Replaced previous tensor conversion logic in `gym_action_processor`, `normalize_processor`, and `test_converters` with the new `to_tensor` function for improved readability and maintainability. - Updated tests to cover new functionality and ensure correct tensor conversion behavior. Co-authored-by: AdilZouitine * refactor(converters): gather converters and refactor the logic (#1833) * refactor(converters): move batch transition functions to converters module - Moved `_default_batch_to_transition` and `_default_transition_to_batch` functions from `pipeline.py` to `converters.py` for better organization and separation of concerns. - Updated references in `RobotProcessor` to use the new location of these functions. - Added tests to ensure correct functionality of the transition functions, including handling of index and task_index fields. - Removed redundant tests from `pipeline.py` to streamline the test suite. * refactor(processor): reorganize EnvTransition and TransitionKey definitions - Moved `EnvTransition` and `TransitionKey` classes from `pipeline.py` to a new `core.py` module for better structure and maintainability. - Updated import statements across relevant modules to reflect the new location of these definitions, ensuring consistent access throughout the codebase. * refactor(converters): rename and update dataset frame conversion functions - Replaced `to_dataset_frame` with `transition_to_dataset_frame` for clarity and consistency in naming. - Updated references in `record.py`, `pipeline.py`, and tests to use the new function name. - Introduced `merge_transitions` to streamline the merging of transitions, enhancing readability and maintainability. - Adjusted related tests to ensure correct functionality with the new naming conventions. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix(processor): solve conflict artefacts * refactor(converters): remove unused identity function and update type hints for merge_transitions * refactor(processor): remove unused identity import and clean up gym_manipulator.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Palma * refactor(processors): add transform_features method to various processors (#1843) * refactor(processors): update transition handling in RewardClassifierProcessor and InverseKinematicsEEToJoints (#1844) * refactor(processors): unify import statements by consolidating pipeline imports into the main processor module (#1845) * refactor(processors): add extended api for specialized pipelines (#1848) * refactor(processors): enhance transform_features method across multiple processors (#1849) * refactor(processors): enhance transform_features method across multiple processors - Updated the transform_features method in various processors to utilize a copy of the features dictionary, ensuring immutability of the original features. - Added handling for new feature keys and removed obsolete ones in the MapTensorToDeltaActionDict, JointVelocityProcessor, and others. - Improved readability and maintainability by following consistent patterns in feature transformation. * refactor(processors): standardize action and observation keys in delta_action_processor and joint_observations_processor - Updated action and observation keys to use constants for improved readability and maintainability. - Refactored the transform_features method in multiple processors to ensure consistent handling of feature keys. - Enhanced error handling by raising exceptions for missing required components in action and observation processing. - Removed obsolete code and improved overall structure for better clarity. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor(processors): remove unused import in joint_observations_processor * refactor(processors): simplify transform_features method in delta_action_processor * refactor(processors): streamline transform_features method in ImageCropResizeProcessor * refactor(processors): improve error handling and streamline transform_features method in phone_processor - Raised a ValueError for missing position and rotation in action to enhance error handling. * refactor(processors): enhance error handling in JointVelocityProcessor - Added a ValueError raise for missing current joint positions in the observation method to improve error handling and ensure the integrity of the transform_features method. * refactor(processors): simplify transform_features method in robot kinematic processors * refactor(processors): standardize action keys in phone_processor * fix(processor): RKP feature obs -> act --------- Signed-off-by: Adil Zouitine Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Palma * chore(processor): rename RobotProcessor -> DataProcessorPipeline (#1850) * chore(processor): rename specialized processor -> XYZProcessorStep (#1852) * chore(processor): rename converters function names (#1853) * chore(processor): rename to_transition_teleop_action -> action_to_transition * chore(processor): rename to_transition_robot_observation -> observation_to_transition * chore(processor): rename to_output_robot_action -> transition_to_robot_action * chore(processor): add Step suffix to all processors (#1854) * refactor(processor): rename MapDeltaActionToRobotAction and MapTensorToDeltaActionDict for consistency * refactor(processor): rename DeviceProcessor to DeviceProcessorStep for consistency across modules * refactor(processor): rename Torch2NumpyActionProcessor to Torch2NumpyActionProcessorStep for consistency * refactor(processor): rename Numpy2TorchActionProcessor to Numpy2TorchActionProcessorStep for consistency * refactor(processor): rename AddTeleopActionAsComplimentaryData to AddTeleopActionAsComplimentaryDataStep for consistency * refactor(processor): rename ImageCropResizeProcessor and AddTeleopEventsAsInfo for consistency * refactor(processor): rename TimeLimitProcessor to TimeLimitProcessorStep for consistency * refactor(processor): rename GripperPenaltyProcessor to GripperPenaltyProcessorStep for consistency * refactor(processor): rename InterventionActionProcessor to InterventionActionProcessorStep for consistency * refactor(processor): rename RewardClassifierProcessor to RewardClassifierProcessorStep for consistency * refactor(processor): rename JointVelocityProcessor to JointVelocityProcessorStep for consistency * refactor(processor): rename MotorCurrentProcessor to MotorCurrentProcessorStep for consistency * refactor(processor): rename NormalizerProcessor and UnnormalizerProcessor to NormalizerProcessorStep and UnnormalizerProcessorStep for consistency * refactor(processor): rename VanillaObservationProcessor to VanillaObservationProcessorStep for consistency * refactor(processor): rename RenameProcessor to RenameProcessorStep for consistency * refactor(processor): rename TokenizerProcessor to TokenizerProcessorStep for consistency * refactor(processor): rename ToBatchProcessor to AddBatchDimensionProcessorStep for consistency * refactor(processor): update config file name in test for RenameProcessorStep consistency * refactor(processor): rename internal tokenizer variable for clarity (#1855) - Changed the internal tokenizer variable name from `_tokenizer` to `input_tokenizer` for improved readability and consistency. - Updated references throughout the class to reflect the new variable name. * chore(processor): rename merge_features -> combine_feature_dicts (#1856) * refactor(processor): rename internal device variable for clarity (#1857) - Changed the internal device variable from `_device` to `tensor_device` for improved readability and consistency. - Updated references throughout the class to reflect the new variable name. * chore(processor): rename teleop_phone variable names (#1858) * chore(processor): add type alias RobotProcessorPipeline and PolicyProcessorPipeline (#1859) * feat(processor): introduce PolicyProcessorPipeline and RobotProcessorPipeline as type aliases for DataProcessorPipeline - Added PolicyProcessorPipeline and RobotProcessorPipeline type aliases to enhance clarity and maintainability in the processor module. - Updated the __all__ list to include the new pipelines for better module export consistency. * refactor(processor): replace DataProcessorPipeline with PolicyProcessorPipeline across multiple modules - Updated all instances of DataProcessorPipeline to PolicyProcessorPipeline in various processor files for consistency and clarity. - Adjusted function signatures to reflect the new pipeline type, enhancing maintainability and readability. * refactor(processor): update hotswap_stats function to use PolicyProcessorPipeline - Changed the parameter name from robot_processor to policy_processor for clarity. - Ensured consistency with recent updates to the processor module by reflecting the new pipeline type in the function signature. * refactor(processor): replace DataProcessorPipeline with PolicyProcessorPipeline in migrate_policy_normalization.py - Updated the preprocessor and postprocessor to use PolicyProcessorPipeline for consistency with recent changes in the processor module. - Enhanced clarity and maintainability by aligning with the new pipeline structure. * refactor(processor): update hotswap_stats to use PolicyProcessorPipeline - Changed the parameter type in hotswap_stats from DataProcessorPipeline to PolicyProcessorPipeline for consistency with recent updates. - Enhanced clarity by updating the function documentation to reflect the new pipeline type. * refactor(processor): replace DataProcessorPipeline with RobotProcessorPipeline across multiple files - Updated instances of DataProcessorPipeline to RobotProcessorPipeline in evaluate.py, record.py, replay.py, teleoperate.py, and other relevant files for consistency and clarity. - Adjusted function signatures and variable types to reflect the new pipeline structure, enhancing maintainability and readability. * refactor(processor): enforce config_filename requirement for HF Hub loading (#1860) - Updated the DataProcessorPipeline to require a specific config_filename when loading from Hugging Face Hub, enhancing clarity and preventing errors. - Simplified local path checks and improved error handling for invalid paths. - Adjusted tests to reflect the new requirement and ensure proper error handling for various loading scenarios. * feat(record): add transition features to dataset and handle scalar vs array formatting in converters (#1861) - Introduced new transition features (`next.reward`, `next.done`, `next.truncated`) in the dataset during recording. - Updated the `transition_to_dataset_frame` function to handle scalar values correctly, ensuring compatibility with expected array formats for reward, done, and truncated features. * refactor(pipeline): enforce ProcessorStep inheritance for pipeline steps (#1862) - Updated the DataProcessorPipeline to require that all steps inherit from ProcessorStep, enhancing type safety and clarity. - Adjusted tests to utilize a MockTokenizerProcessorStep that adheres to the ProcessorStep interface, ensuring consistent behavior across tests. - Refactored various mock step classes in tests to inherit from ProcessorStep for improved consistency and maintainability. * refactor(dependencies): remove scipy dependency and introduce custom rotation utilities (#1863) - Removed the scipy dependency from the project to streamline requirements. - Added a new `rotation.py` module containing a custom `Rotation` class that replicates essential functionalities of `scipy.spatial.transform.Rotation`, allowing for rotation vector, matrix, and quaternion conversions without external dependencies. - Updated the `robot_kinematic_processor.py` to utilize the new custom rotation utilities. * feat(teleoperation): introduce HasTeleopEvents protocol and enhance teleop event handling (#1866) - Added the HasTeleopEvents protocol to define a standard for teleoperators that provide control events. - Implemented a runtime check to ensure teleoperators implement the get_teleop_events() method. - Updated AddTeleopEventsAsInfoStep to utilize the new protocol, enhancing compatibility with custom teleoperators. - Improved documentation for clarity on teleoperation event extraction and compatibility with built-in teleoperators. * fix(deps): use in-house rotation utils over scipy throughout the codebase * refactor(constants): rename preprocessor and postprocessor constants for clarity (#1868) - Updated constant names from PREPROCESSOR_DEFAULT_NAME and POSTPROCESSOR_DEFAULT_NAME to POLICY_PREPROCESSOR_DEFAULT_NAME and POLICY_POSTPROCESSOR_DEFAULT_NAME for better context. - Adjusted references across multiple files to use the new constant names, ensuring consistency in the codebase. * refactor(tests): update processor test assertions to reflect new preprocessor and postprocessor names (#1869) - Changed assertions in multiple processor test files to verify the updated names from "robot_preprocessor" and "robot_postprocessor" to "policy_preprocessor" and "policy_postprocessor" for consistency with recent refactoring. * refactor(utils): simplify log_rerun_data function (#1864) * refactor(logging): enhance log_rerun_data to handle observation and action separately - Updated the `log_rerun_data` function to accept and log observation and action data more clearly, improving readability and maintainability. - Refactored the `record_loop` and `teleop_loop` functions to extract and pass observation and action data to `log_rerun_data`, ensuring consistent logging format. * refactor(tests): update test_log_rerun_data to align with log_rerun_data changes - Modified test cases in `test_visualization_utils.py` to extract and pass observation and action data separately to `log_rerun_data`, improving clarity and consistency with recent function updates. - Ensured that the tests reflect the new structure of `log_rerun_data` for better maintainability. * refactor(processors): simplify calls to log_rerun + replace lambda functions with identity_transition --------- Co-authored-by: Steven Palma * fix(processor): recover type inference for use of processors (#1873) * refactor(processors): Improve Normalization Processor Performance and Device/Dtype Adaptability (#1880) * refactor(processors): reorder processor steps for consistency across implementations - Updated the order of processor steps in multiple files to ensure consistency, placing AddBatchDimensionProcessorStep and DeviceProcessorStep before NormalizerProcessorStep. - Adjusted related test assertions to reflect the new order of steps in the preprocessor, enhancing clarity and maintainability. * refactor(normalization): remove dtype specification in tensor conversion for adaptation logic - Updated tensor conversion in the _NormalizationMixin class to remove explicit dtype specification, allowing for automatic adaptation of tensor types. - Adjusted related tests to ensure proper functionality with the new tensor conversion logic, verifying that normalizers adapt correctly to input types. * chore(docs): update doctrines pipeline files (#1872) * docs(processor): update docstrings batch_processor * docs(processor): update docstrings device_processor * docs(processor): update docstrings tokenizer_processor * update docstrings processor_act * update docstrings for pipeline_features * update docstrings for utils * update docstring for processor_diffusion * update docstrings factory * add docstrings to pi0 processor * add docstring to pi0fast processor * add docstring classifier processor * add docstring to sac processor * add docstring smolvla processor * add docstring to tdmpc processor * add docstring to vqbet processor * add docstrings to converters * add docstrings for delta_action_processor * add docstring to gym action processor * update hil processor * add docstring to joint obs processor * add docstring to migrate_normalize_processor * update docstrings normalize processor * update docstring normalize processor * update docstrings observation processor * update docstrings rename_processor * add docstrings robot_kinematic_processor * cleanup rl comments * add docstring to train.py * add docstring to teleoperate.py * add docstrings to phone_processor.py * add docstrings to teleop_phone.py * add docstrings to control_utils.py * add docstrings to visualization_utils.py --------- Co-authored-by: Pepijn * refactor(eval): integrate preprocessor and postprocessor into rollout and eval_policy functions (#1900) * refactor(eval): integrate preprocessor and postprocessor into rollout and eval_policy functions - Updated the `rollout` and `eval_policy` functions to accept preprocessor and postprocessor parameters, enhancing the flexibility of the evaluation pipeline. - Adjusted the implementation to apply preprocessing and postprocessing steps during policy evaluation, improving the overall data handling and processing flow. * refactor(eval): remove redundant observation device conversion in rollout function - Eliminated unnecessary device conversion for the observation dictionary within the `rollout` function, streamlining the code and enhancing readability. - This change simplifies the observation handling process, aligning with the preference for clearer solutions. * debug * refactor(utils): enhance task handling in add_envs_task function - Improved the `add_envs_task` function to validate the output of `task_description` and `task` calls, ensuring they return lists of strings. - Removed the use of `else` statement for environments without language instructions, simplifying the logic and enhancing readability. - Streamlined the observation dictionary handling by ensuring consistent data types for task attributes. * refactor(converters): rename _from_tensor to from_tensor_to_numpy for clarity (#1902) - Updated the function name from _from_tensor to from_tensor_to_numpy to better reflect its purpose of converting PyTorch tensors to numpy arrays or scalars. - Adjusted all references to the renamed function throughout the codebase to maintain consistency. - Enhanced the _NormalizationMixin class to reconstruct the stats dictionary from tensor stats using the new function, ensuring compatibility after loading state dicts. - Added tests to verify the correct reconstruction of stats and functionality of methods dependent on self.stats after loading. * refactor(pipeline): feature contract now categorizes between OBS or Action (#1867) * refactor(processor): signature of transform_features * refactor(processor): remove prefixes + processor respect new transform_features signature + update test accordingly * refactor(processor): rename now is only for visual * refactor(processor): update normalize processor * refactor(processor): update vanilla processor features * refactor(processor): feature contract now uses its own enum * chore(processor): rename renameprocessor * chore(processor): minor changes * refactor(processor): add create & change aggregate * refactor(processor): update aggregate * refactor(processor): simplify to functions, fix features contracts and rename function * test(processor): remove to converter tests as now they are very simple * chore(docs): recover docs joint observations processor * fix(processor): update RKP * fix(tests): recv diff test_pipeline * chore(tests): add docs to test * chore(processor): leave obs language constant untouched * fix(processor): correct new shape of feature in crop image processor * refactor(eval): specify type parameters for preprocessor and postprocessor in eval_policy function (#1904) * chore(processor): remove action prefixes (#1905) * test(processor): all processors use now the same create_transition (#1906) * test(processor): all processors use now the same create_transition * test(processor): use identity instead of lambda for transition in pipelines * fix(processor): specialized processors respect contract by raising if none (#1909) * fix(processor): specialized processor now raise * test(processor): fix tests for now raise specialized processors * test(processor): use identity in newly introduced pipeline * refactor(processor): clarify action types, distinguish PolicyAction, RobotAction, and EnvAction (#1908) * refactor(processor): split action from policy, robots and environment - Updated function names to robot_action_to_transition and robot_transition_to_action across multiple files to better reflect their purpose in processing robot actions. - Adjusted references in the RobotProcessorPipeline and related components to ensure compatibility with the new naming convention. - Enhanced type annotations for action parameters to improve code readability and maintainability. * refactor(converters): rename robot_transition_to_action to transition_to_robot_action - Updated function names across multiple files to improve clarity and consistency in processing robot actions. - Adjusted references in RobotProcessorPipeline and related components to align with the new naming convention. - Simplified action handling in the AddBatchDimensionProcessorStep by removing unnecessary checks for action presence. * refactor(converters): update references to transition_to_robot_action - Renamed all instances of robot_transition_to_action to transition_to_robot_action across multiple files for consistency and clarity in the processing of robot actions. - Adjusted the RobotProcessorPipeline configurations to reflect the new naming convention, enhancing code readability. * refactor(processor): update Torch2NumpyActionProcessorStep to extend ActionProcessorStep - Changed the base class of Torch2NumpyActionProcessorStep from PolicyActionProcessorStep to ActionProcessorStep, aligning it with the current architecture of action processing. - This modification enhances the clarity of the class's role in the processing pipeline. * fix(processor): main action processor can take also EnvAction --------- Co-authored-by: Steven Palma * refactor(processor): phone processor is now an RobotActionProcessorStep * fix(processor): use subprocessors in AddBatchDimensionProcessorStep only if we have the ingredients * fix(robots): remove action prefix hard-coded in teleop keyboard and gamepad * feat(processor): enhance type safety with generic DataProcessorPipeline for policy and robot pipelines (#1915) * refactor(processor): enhance type annotations for processors in record, replay, teleoperate, and control utils - Updated type annotations for preprocessor and postprocessor parameters in record_loop and predict_action functions to specify the expected dictionary types. - Adjusted robot_action_processor type in ReplayConfig and TeleoperateConfig to improve clarity and maintainability. - Ensured consistency in type definitions across multiple files, enhancing overall code readability. * refactor(processor): enhance type annotations for RobotProcessorPipeline in various files - Updated type annotations for RobotProcessorPipeline instances in evaluate.py, record.py, replay.py, teleoperate.py, and other related files to specify input and output types more clearly. - Introduced new type conversions for PolicyAction and EnvTransition to improve type safety and maintainability across the processing pipelines. - Ensured consistency in type definitions, enhancing overall code readability and reducing potential runtime errors. * refactor(processor): update transition handling in processors to use transition_to_batch - Replaced direct transition handling with transition_to_batch in various processor tests and implementations to ensure consistent batching of input data. - Updated assertions in tests to reflect changes in data structure, enhancing clarity and maintainability. - Improved overall code readability by standardizing the way transitions are processed across different processor types. * refactor(tests): standardize transition key usage in processor tests - Updated assertions in processor test files to utilize the TransitionKey for action references, enhancing consistency across tests. - Replaced direct string references with TransitionKey constants for improved readability and maintainability. - Ensured that all relevant tests reflect these changes, contributing to a more uniform approach in handling transitions. * refactor(processor): unify action imports and enhance type clarity across multiple files - Updated imports in various files to include RobotAction and PolicyAction directly from the processor module, improving clarity and consistency. - Removed redundant imports from core, streamlining the codebase and enhancing maintainability. - Adjusted type annotations and references in the RobotProcessorPipeline and related components to align with the new import structure, ensuring better type safety and readability. * refactor(processor): migrate policy normalization to use factory functions - Updated the migration script to utilize `make_pre_post_processors` and `make_policy_config` from `lerobot.policies.factory`, enhancing consistency with the current codebase. - Improved normalization statistics extraction and processor pipeline creation, ensuring compatibility with the new `PolicyProcessorPipeline` architecture. - Cleaned up configuration handling by removing unnecessary fields and adding normalization mapping directly to the config. - Enhanced type safety and readability by refining feature type and normalization mode handling. * debug(scripts): simplify record with processors (#1918) Co-authored-by: Adil Zouitine * refactor(processor): update migration script for policy normalization and hub integration - Modified the migration script to include a branch argument for pushing to the hub, enhancing flexibility in version control. - Improved error handling by ensuring the policy type is extracted from the configuration, promoting robustness. - Streamlined the process of saving and pushing model components to the hub, allowing for a single commit with optional PR creation. - Updated the commit message and description for better clarity on the migration changes and benefits, ensuring users are informed of the new architecture and usage. * fixes for processors used in phone teleop * fixes for rotation matrix * add empty obs and act in create_initial_features * use observation instead of obs * docs(processor): update docstrings pipeline (#1920) * chore(docs): Processor doc (#1685) * chore(docs): initialize doc * Added script for the second part of the processor doc * precommit style nit * improved part 2 of processor guide * Add comprehensive documentation for processors in robotics - Introduced a detailed guide on processors, covering their role in transforming raw robot data into model-ready inputs and vice versa. - Explained core concepts such as EnvTransition, ProcessorStep, and RobotProcessor, along with their functionalities. - Included examples of common processor steps like normalization, device management, batch processing, and text tokenization. - Provided insights on building complete pipelines, integrating processors into training loops, and saving/loading configurations. - Emphasized best practices and advanced features for effective usage of processors in robotics applications. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(docs): Enhance introduction to processors with additional converter functions - Updated the introduction to processors documentation to include default batch-to-transition and transition-to-batch converters. - Added detailed descriptions and examples for new specialized converter functions: `to_transition_teleop_action`, `to_transition_robot_observation`, `to_output_robot_action`, and `to_dataset_frame`. - Improved clarity on how these converters facilitate integration with existing robotics applications. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Improved doc implement_your_own_pipeline - Use normalization processor as default example - Add section on transform features - Add section on overrides. * Add phone docs and use pipeline for robots/teleop docs * Fix typo in documentation for adapters in robots/teleop section * Enhance documentation for processors with detailed explanations and examples - Updated the introduction to processors, clarifying the role of `EnvTransition` and `ProcessorStep`. - Introduced `DataProcessorPipeline` as a generic orchestrator for chaining processor steps. - Added comprehensive descriptions of new converter functions and their applications. - Improved clarity on type safety and the differences between `RobotProcessorPipeline` and `PolicyProcessorPipeline`. - Included examples for various processing scenarios, emphasizing best practices for data handling in robotics. * Enhance documentation for processor migration and debugging - Added detailed sections on the migration of models to the new `PolicyProcessorPipeline` system, including breaking changes and migration scripts. - Introduced a comprehensive guide for debugging processor pipelines, covering common issues, step-by-step inspection, and runtime monitoring techniques. - Updated examples to reflect new usage patterns and best practices for processor implementation and error handling. - Clarified the role of various processor steps and their configurations in the context of robotics applications. --------- Co-authored-by: Michel Aractingi Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Pepijn * docs: Add new section for debugging processor pipelines - Introduced a new documentation entry for debugging processor pipelines, enhancing the existing guide on processors. - This addition aims to provide users with insights and best practices for troubleshooting and optimizing their processor workflows. * fix(processor): phone examples (#1921) * fix(processor): phone examples * chore(processor): simplify gripper in phone example kinematic chain --------- Co-authored-by: Steven Palma * refactor(processors): several additions (#1926) * chore(processor): remove merge_transitions functions (#1925) * refactor(processors): move processors out of configs (#1927) * chore(processor): streamline combine_features_dict (#1928) * chore(policies): use new constants (#1929) * fix(deps): right version transformers (#1930) * fix(tests): add none + disable async tests for now (#1931) * refactor(processor): transform_features loop + EAFP (#1932) * fix(processors): make sure nested dict are also shallow copied (#1939) * refactor(processor): replace ModelHubMixin with HubMixin and enhance save_pretrained method (#1937) - Updated DataProcessorPipeline to use HubMixin instead of ModelHubMixin for improved functionality. - Refactored save_pretrained method to handle saving * refactor(docs): streamline monitoring hooks and enhance performance reporting - Removed the log_shapes and measure_performance hooks, simplifying the monitoring process to focus on NaN checks. - Updated performance reporting to include maximum processing times alongside average times for better insights. - Clarified documentation regarding the processing pipeline and feature transformations. * fix teleop, record and eval (#1940) * fix cmd record, eval * chore(processor): update input output of main 3 processors for better semantics (#1942) * chore(processor): update input output of main 3 processors for better semantics * refactor(processor): replace Any with RobotObservation for improved type safety in processors * fix(processors): no PolicyObservation * chore(processor): update with RobotObservation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: AdilZouitine Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * test(processor): fix batch expectation * feat(example): Add SO100 EE pipeline control (teleop+record) (#1943) * feat(examples): add ee so100 processors teleop & record * refactor(processor): improve FK processor for better use compatability * docs(processor): enhance tutorial on implementing custom processors - Updated the tutorial to use `NormalizerProcessorStep` as the primary example, clarifying its role in normalizing observations and actions. - Improved explanations of the need for custom processors, emphasizing data compatibility and processing requirements. - Added code snippets demonstrating the normalization process and the configuration of processor pipelines. - Enhanced the introduction to processors, detailing their function as translators between raw robot data and model inputs. - Included examples of real-world processor configurations for both training and inference scenarios. * docs(debug): enhance debugging guide for processor pipelines - Streamlined the introduction to clarify the challenges of debugging complex processor pipelines. - Expanded the section on hooks, detailing their purpose and implementation for runtime monitoring. - Introduced step-by-step debugging techniques, emphasizing the use of the `step_through()` method for inspecting intermediate states. - Added examples of feature validation to ensure data structure contracts are met. - Consolidated best practices for debugging, highlighting the synergy between hooks, step-through debugging, and feature validation. * chore(processors): tokenizers raises and remove tensor conversion (#1949) * chore(processor): remove unused transition_features dict * feat(ee): add so100_to_so100_EE replay and evaluate examples * chore(examples): homogenize style across example files (#1955) * chore(examples): homogenize style across example files * chore(examples): homogenize style across example files eval + replay * chore(examples): homogenize headers * test(async): fix feature manipulation (#1957) * test(async): fix feature manipulation * chore(processor): remove unused functions * fix(processor): Preserve stats overrides in normalizer load_state_dict and fix training resumption (#1958) * feat(processor): enhance normalization handling and state management - Added support for additional normalization modes including IDENTITY. - Introduced a new function `clean_state_dict` to remove specific substrings from state dict keys. - Implemented preservation of explicitly provided normalization statistics during state loading. - Updated training script to conditionally provide dataset statistics based on resume state. - Expanded tests to verify the correct behavior of stats override preservation and loading. * fix(train): remove redundant comment regarding state loading - Removed a comment that noted the preprocessor and postprocessor state is already loaded when resuming training, as it was deemed unnecessary for clarity. * test(processor): update tests to handle missing or invalid task keys - Modified tests to assert that the processor raises appropriate exceptions when the task key is missing or has an invalid value in the complementary data. - Ensured that the tests cover cases for None, integer, and mixed list task values, improving robustness against invalid inputs. * fix(processor): enforce signatures * chore(processor): update comments in record.py * test(processor): fix isinstance and cuda test * modify phone docs * fix(processor): reorder output steps to ensure correct processing sequence (#1961) - Moved DeviceProcessorStep to the end of the output steps in multiple processor files to maintain the intended processing order. - Updated corresponding tests to reflect the change in step order. * fix(processors): assumptions for robot_action_processor & teleop_action_processor (#1964) * fix(processors): new assumptions pipeline * fix(processors): ee jj phone teleop replay record working * chore(processors): update comments and default vars * chore(processor): remove unnecessary copy * chore(processor): added todo assumption gripper * fix(processors): eval using detected device * finish phone docs * fix correct image link * feat(processor): implement migration detection and error handling for processor configurations (#1968) * feat(processor): implement migration detection and error handling for processor configurations - Added ProcessorMigrationError to handle migration requirements for old model formats. - Enhanced DataProcessorPipeline.from_pretrained to include robust migration detection logic. - Implemented methods for resolving configuration sources, validating loaded configs, and checking for valid processor configurations. - Introduced comprehensive tests for migration detection and configuration validation to ensure correct behavior. * refactor(processor): simplify loading logic and enhance migration detection - Refactored DataProcessorPipeline to implement a simplified three-way loading strategy for configuration files. - Introduced explicit config_filename parameter to avoid ambiguity during loading. - Updated ProcessorMigrationError to provide clearer error messages for migration requirements. - Enhanced tests to cover new loading logic and ensure proper migration detection. - Removed deprecated methods related to config source resolution. * fix(processor) RL (#1953) * fix(gym_manipulator) general fixes to make it compitable * fix for dataset v3.0 * fix for gym_manipulator * add map policy action to robot action wrappers in a seperate scripts * added unittest for policy to robot bridge * fixes for gripper penalty * fix style * fix gamepad controller * fixes for sim teleop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify numpy2torch to a regular processor as a quick fix * missing imports?! * - Removed the use of `AddRobotObservationAsComplimentaryData` from `gym_manipulator` and thus the codebase - Added get_raw_joint_positions functions to RobotEnv - Pass raw_joint_positions as input to the action_pipeline in `gym_manipulator` - Add `InverseKinematicsRLStep` to be tailored towards the need of RL which requires the use of the IK solution as the main reference point of the control loop - Added the option `use_ik_solution` in `EEReferenceDelta` step to rely on the ik solution rather than the joint values * -Updated links to all the config files to place them in the new repo with configs compatible with the pipeline --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Palma * fix(tests): update test cases for loading pipelines with specific config filenames - Modified test cases to include explicit configuration filenames when loading pipelines in `test_policy_robot_bridge.py`. - Ensured that the tests reflect the correct loading behavior for both robot-to-policy and policy-to-robot transitions. * fix(examples): train mps processor (#1970) * fix(examples): train mps processor * fix(processor): add MPS compatibility for float64 tensors - Implemented a workaround to convert float64 tensors to float32 when using the MPS device, as MPS does not support float64. - Added unit tests to verify the automatic conversion of float64 tensors to float32 and ensure compatibility with various tensor types on the MPS device. --------- Co-authored-by: AdilZouitine --------- Signed-off-by: Adil Zouitine Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Steven Palma Co-authored-by: Michel Aractingi Co-authored-by: Steven Palma Co-authored-by: Pepijn --- .gitignore | 4 + docs/source/_toctree.yml | 19 +- docs/source/backwardcomp.mdx | 56 + docs/source/debug_processor_pipeline.mdx | 299 ++ docs/source/hilserl.mdx | 443 ++- docs/source/hilserl_sim.mdx | 90 +- docs/source/il_robots.mdx | 15 +- docs/source/il_sim.mdx | 60 +- docs/source/implement_your_own_processor.mdx | 273 ++ docs/source/introduction_processors.mdx | 314 ++ docs/source/phone_teleop.mdx | 192 ++ docs/source/processors_robots_teleop.mdx | 151 + examples/3_train_policy.py | 8 +- examples/5_train_with_streaming.py | 14 +- examples/lekiwi/evaluate.py | 63 +- examples/lekiwi/record.py | 51 +- examples/lekiwi/replay.py | 33 +- examples/lekiwi/teleoperate.py | 35 +- examples/phone_to_so100/evaluate.py | 197 ++ examples/phone_to_so100/record.py | 204 ++ examples/phone_to_so100/replay.py | 99 + examples/phone_to_so100/teleoperate.py | 114 + examples/so100_to_so100_EE/evaluate.py | 198 ++ examples/so100_to_so100_EE/record.py | 203 ++ examples/so100_to_so100_EE/replay.py | 100 + examples/so100_to_so100_EE/teleoperate.py | 122 + pyproject.toml | 6 +- src/lerobot/configs/policies.py | 3 +- src/lerobot/configs/types.py | 6 + src/lerobot/constants.py | 9 + src/lerobot/datasets/pipeline_features.py | 141 + src/lerobot/datasets/utils.py | 455 ++- src/lerobot/envs/configs.py | 143 +- src/lerobot/envs/factory.py | 4 +- src/lerobot/envs/utils.py | 24 +- src/lerobot/policies/__init__.py | 11 + src/lerobot/policies/act/modeling_act.py | 16 - src/lerobot/policies/act/processor_act.py | 85 + .../policies/diffusion/modeling_diffusion.py | 16 - .../policies/diffusion/processor_diffusion.py | 92 + src/lerobot/policies/factory.py | 243 +- src/lerobot/policies/normalize.py | 420 --- src/lerobot/policies/pi0/modeling_pi0.py | 146 +- src/lerobot/policies/pi0/processor_pi0.py | 166 ++ .../policies/pi0fast/modeling_pi0fast.py | 15 - .../policies/pi0fast/processor_pi0fast.py | 92 + src/lerobot/policies/sac/modeling_sac.py | 63 +- src/lerobot/policies/sac/processor_sac.py | 92 + .../sac/reward_model/modeling_classifier.py | 15 - .../sac/reward_model/processor_classifier.py | 82 + .../policies/smolvla/modeling_smolvla.py | 171 +- .../policies/smolvla/processor_smolvla.py | 141 + src/lerobot/policies/tdmpc/modeling_tdmpc.py | 24 +- src/lerobot/policies/tdmpc/processor_tdmpc.py | 90 + src/lerobot/policies/vqbet/modeling_vqbet.py | 17 +- src/lerobot/policies/vqbet/processor_vqbet.py | 91 + src/lerobot/processor/__init__.py | 133 +- src/lerobot/processor/batch_processor.py | 254 ++ src/lerobot/processor/converters.py | 412 +++ src/lerobot/processor/core.py | 56 + .../processor/delta_action_processor.py | 145 + src/lerobot/processor/device_processor.py | 188 +- src/lerobot/processor/factory.py | 62 + src/lerobot/processor/gym_action_processor.py | 97 + src/lerobot/processor/hil_processor.py | 596 ++++ .../processor/joint_observations_processor.py | 211 ++ .../processor/migrate_policy_normalization.py | 646 +++++ src/lerobot/processor/normalize_processor.py | 710 +++-- .../processor/observation_processor.py | 149 +- src/lerobot/processor/pipeline.py | 2174 ++++++++------ src/lerobot/processor/policy_robot_bridge.py | 52 + src/lerobot/processor/rename_processor.py | 62 +- src/lerobot/processor/tokenizer_processor.py | 270 ++ src/lerobot/record.py | 171 +- src/lerobot/replay.py | 19 +- src/lerobot/robots/so100_follower/__init__.py | 3 +- .../so100_follower/config_so100_follower.py | 32 - .../robot_kinematic_processor.py | 616 ++++ .../so100_follower_end_effector.py | 200 -- src/lerobot/robots/utils.py | 5 +- src/lerobot/scripts/eval.py | 46 +- src/lerobot/scripts/rl/actor.py | 121 +- src/lerobot/scripts/rl/gym_manipulator.py | 2553 ++++------------- src/lerobot/scripts/rl/learner.py | 26 +- src/lerobot/scripts/train.py | 70 +- src/lerobot/teleoperate.py | 95 +- src/lerobot/teleoperators/__init__.py | 2 +- .../teleoperators/gamepad/gamepad_utils.py | 24 +- .../teleoperators/gamepad/teleop_gamepad.py | 43 + .../teleoperators/keyboard/teleop_keyboard.py | 74 +- src/lerobot/teleoperators/phone/__init__.py | 18 + .../teleoperators/phone/config_phone.py | 36 + .../teleoperators/phone/phone_processor.py | 110 + .../teleoperators/phone/teleop_phone.py | 421 +++ src/lerobot/teleoperators/utils.py | 12 + src/lerobot/utils/control_utils.py | 96 +- src/lerobot/utils/import_utils.py | 1 + src/lerobot/utils/rotation.py | 270 ++ src/lerobot/utils/train_utils.py | 12 +- src/lerobot/utils/visualization_utils.py | 83 +- .../actions.safetensors | 2 +- .../param_stats.safetensors | 4 +- .../actions.safetensors | 2 +- .../param_stats.safetensors | 4 +- .../pusht_diffusion_/actions.safetensors | 2 +- .../pusht_diffusion_/grad_stats.safetensors | 2 +- .../pusht_diffusion_/param_stats.safetensors | 4 +- .../policies/save_policy_to_safetensors.py | 13 +- .../actions.safetensors | 2 +- .../grad_stats.safetensors | 2 +- .../output_dict.safetensors | 2 +- .../param_stats.safetensors | 4 +- .../actions.safetensors | 2 +- .../grad_stats.safetensors | 2 +- .../output_dict.safetensors | 2 +- .../param_stats.safetensors | 4 +- tests/conftest.py | 10 +- tests/datasets/test_dataset_utils.py | 132 + tests/datasets/test_utils.py | 86 - tests/policies/test_policies.py | 110 +- tests/processor/test_act_processor.py | 412 +++ tests/processor/test_batch_conversion.py | 56 +- tests/processor/test_batch_processor.py | 1184 ++++++++ tests/processor/test_classifier_processor.py | 362 +++ tests/processor/test_converters.py | 292 ++ tests/processor/test_device_processor.py | 1161 ++++++++ tests/processor/test_diffusion_processor.py | 398 +++ tests/processor/test_migration_detection.py | 341 +++ tests/processor/test_normalize_processor.py | 1415 ++++++++- tests/processor/test_observation_processor.py | 209 +- tests/processor/test_pi0_processor.py | 424 +++ tests/processor/test_pipeline.py | 753 +++-- .../test_pipeline_from_pretrained_helpers.py | 259 ++ tests/processor/test_policy_robot_bridge.py | 525 ++++ tests/processor/test_rename_processor.py | 200 +- tests/processor/test_sac_processor.py | 414 +++ tests/processor/test_smolvla_processor.py | 459 +++ tests/processor/test_tdmpc_processor.py | 467 +++ tests/processor/test_tokenizer_processor.py | 1029 +++++++ tests/processor/test_vqbet_processor.py | 462 +++ tests/utils/test_visualization_utils.py | 209 ++ 141 files changed, 23478 insertions(+), 5556 deletions(-) create mode 100644 docs/source/debug_processor_pipeline.mdx create mode 100644 docs/source/implement_your_own_processor.mdx create mode 100644 docs/source/introduction_processors.mdx create mode 100644 docs/source/phone_teleop.mdx create mode 100644 docs/source/processors_robots_teleop.mdx create mode 100644 examples/phone_to_so100/evaluate.py create mode 100644 examples/phone_to_so100/record.py create mode 100644 examples/phone_to_so100/replay.py create mode 100644 examples/phone_to_so100/teleoperate.py create mode 100644 examples/so100_to_so100_EE/evaluate.py create mode 100644 examples/so100_to_so100_EE/record.py create mode 100644 examples/so100_to_so100_EE/replay.py create mode 100644 examples/so100_to_so100_EE/teleoperate.py create mode 100644 src/lerobot/datasets/pipeline_features.py create mode 100644 src/lerobot/policies/act/processor_act.py create mode 100644 src/lerobot/policies/diffusion/processor_diffusion.py delete mode 100644 src/lerobot/policies/normalize.py create mode 100644 src/lerobot/policies/pi0/processor_pi0.py create mode 100644 src/lerobot/policies/pi0fast/processor_pi0fast.py create mode 100644 src/lerobot/policies/sac/processor_sac.py create mode 100644 src/lerobot/policies/sac/reward_model/processor_classifier.py create mode 100644 src/lerobot/policies/smolvla/processor_smolvla.py create mode 100644 src/lerobot/policies/tdmpc/processor_tdmpc.py create mode 100644 src/lerobot/policies/vqbet/processor_vqbet.py create mode 100644 src/lerobot/processor/batch_processor.py create mode 100644 src/lerobot/processor/converters.py create mode 100644 src/lerobot/processor/core.py create mode 100644 src/lerobot/processor/delta_action_processor.py create mode 100644 src/lerobot/processor/factory.py create mode 100644 src/lerobot/processor/gym_action_processor.py create mode 100644 src/lerobot/processor/hil_processor.py create mode 100644 src/lerobot/processor/joint_observations_processor.py create mode 100644 src/lerobot/processor/migrate_policy_normalization.py create mode 100644 src/lerobot/processor/policy_robot_bridge.py create mode 100644 src/lerobot/processor/tokenizer_processor.py create mode 100644 src/lerobot/robots/so100_follower/robot_kinematic_processor.py delete mode 100644 src/lerobot/robots/so100_follower/so100_follower_end_effector.py create mode 100644 src/lerobot/teleoperators/phone/__init__.py create mode 100644 src/lerobot/teleoperators/phone/config_phone.py create mode 100644 src/lerobot/teleoperators/phone/phone_processor.py create mode 100644 src/lerobot/teleoperators/phone/teleop_phone.py create mode 100644 src/lerobot/utils/rotation.py create mode 100644 tests/datasets/test_dataset_utils.py delete mode 100644 tests/datasets/test_utils.py create mode 100644 tests/processor/test_act_processor.py create mode 100644 tests/processor/test_batch_processor.py create mode 100644 tests/processor/test_classifier_processor.py create mode 100644 tests/processor/test_converters.py create mode 100644 tests/processor/test_device_processor.py create mode 100644 tests/processor/test_diffusion_processor.py create mode 100644 tests/processor/test_migration_detection.py create mode 100644 tests/processor/test_pi0_processor.py create mode 100644 tests/processor/test_pipeline_from_pretrained_helpers.py create mode 100644 tests/processor/test_policy_robot_bridge.py create mode 100644 tests/processor/test_sac_processor.py create mode 100644 tests/processor/test_smolvla_processor.py create mode 100644 tests/processor/test_tdmpc_processor.py create mode 100644 tests/processor/test_tokenizer_processor.py create mode 100644 tests/processor/test_vqbet_processor.py create mode 100644 tests/utils/test_visualization_utils.py diff --git a/.gitignore b/.gitignore index c4d1f769f..b47e22cbf 100644 --- a/.gitignore +++ b/.gitignore @@ -173,3 +173,7 @@ outputs/ # Dev folders .cache/* +*.stl +*.urdf +*.xml +*.part diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 9f5de8230..7d6b69fba 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -30,9 +30,18 @@ - local: smolvla title: Finetune SmolVLA title: "Policies" + +- sections: + - local: introduction_processors + title: Introduction to Robot Processors + - local: debug_processor_pipeline + title: Debug your processor pipeline + - local: implement_your_own_processor + title: Implement your own processor + - local: processors_robots_teleop + title: Processors for Robots and Teleoperators + title: "Robot Processors" - sections: - - local: hope_jr - title: Hope Jr - local: so101 title: SO-101 - local: so100 @@ -41,9 +50,15 @@ title: Koch v1.1 - local: lekiwi title: LeKiwi + - local: hope_jr + title: Hope Jr - local: reachy2 title: Reachy 2 title: "Robots" +- sections: + - local: phone_teleop + title: Phone + title: "Teleoperators" - sections: - local: notebooks title: Notebooks diff --git a/docs/source/backwardcomp.mdx b/docs/source/backwardcomp.mdx index 0e1d01636..3366c8ab9 100644 --- a/docs/source/backwardcomp.mdx +++ b/docs/source/backwardcomp.mdx @@ -1,5 +1,61 @@ # Backward compatibility +## Policy Normalization Migration (PR #1452) + +**Breaking Change**: LeRobot policies no longer have built-in normalization layers embedded in their weights. Normalization is now handled by external `PolicyProcessorPipeline` components. + +### What changed? + +| | Before PR #1452 | After PR #1452 | +| -------------------------- | ------------------------------------------------ | ------------------------------------------------------------ | +| **Normalization Location** | Embedded in model weights (`normalize_inputs.*`) | External `PolicyProcessorPipeline` components | +| **Model State Dict** | Contains normalization statistics | **Clean weights only** - no normalization parameters | +| **Usage** | `policy(batch)` handles everything | `preprocessor(batch)` → `policy(...)` → `postprocessor(...)` | + +### Impact on existing models + +- Models trained **before** PR #1452 have normalization embedded in their weights +- These models need migration to work with the new `PolicyProcessorPipeline` system +- The migration extracts normalization statistics and creates separate processor pipelines + +### Migrating old models + +Use the migration script to convert models with embedded normalization: + +```shell +python src/lerobot/processor/migrate_policy_normalization.py \ + --pretrained-path lerobot/act_aloha_sim_transfer_cube_human \ + --push-to-hub \ + --branch migrated +``` + +The script: + +1. **Extracts** normalization statistics from model weights +2. **Creates** external preprocessor and postprocessor pipelines +3. **Removes** normalization layers from model weights +4. **Saves** clean model + processor pipelines +5. **Pushes** to Hub with automatic PR creation + +### Using migrated models + +```python +# New usage pattern (after migration) +from lerobot.policies.factory import make_policy, make_pre_post_processors + +# Load model and processors separately +policy = make_policy(config, ds_meta=dataset.meta) +preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=config, + dataset_stats=dataset.meta.stats +) + +# Process data through pipeline +processed_batch = preprocessor(raw_batch) +action = policy.select_action(processed_batch) +final_action = postprocessor(action) +``` + ## Hardware API redesign PR [#777](https://github.com/huggingface/lerobot/pull/777) improves the LeRobot calibration but is **not backward-compatible**. Below is a overview of what changed and how you can continue to work with datasets created before this pull request. diff --git a/docs/source/debug_processor_pipeline.mdx b/docs/source/debug_processor_pipeline.mdx new file mode 100644 index 000000000..4826c947e --- /dev/null +++ b/docs/source/debug_processor_pipeline.mdx @@ -0,0 +1,299 @@ +# Debug Your Processor Pipeline + +Processor pipelines can be complex, especially when chaining multiple transformation steps. +Unlike simple function calls, pipelines lack natural observability, you can't easily see what happens +between each step or where things go wrong. +This guide provides debugging tools and techniques specifically designed to address these challenges +and help you understand data flow through your pipelines. + +We'll explore three complementary debugging approaches: **hooks** for runtime monitoring, **step-through debugging** for detailed inspection, and **feature validation** for catching structural mismatches. Each serves a different purpose and together they provide complete visibility into your pipeline's behavior. + +## Understanding Hooks + +Hooks are functions that get called at specific points during pipeline execution. +They provide a way to inspect, monitor, or modify data without changing your pipeline code. +Think of them as "event listeners" for your pipeline. + +### What is a Hook? + +A hook is a callback function that gets automatically invoked at specific moments during pipeline execution. +The concept comes from event-driven programming, imagine you could "hook into" the pipeline's execution flow to observe or react to what's happening. + +Think of hooks like inserting checkpoints into your pipeline. Every time the pipeline reaches one of these checkpoints, it pauses briefly to call your hook function, giving you a chance to inspect the current state, log information, and validate data. + +A hook is simply a function that accepts two parameters: + +- `step_idx: int` - The index of the current processing step (0, 1, 2, etc.) +- `transition: EnvTransition` - The data transition at that point in the pipeline + +The beauty of hooks is their non-invasive nature: you can add monitoring, validation, or debugging logic without changing a single line of your pipeline code. The pipeline remains clean and focused on its core logic, while hooks handle the cross-cutting concerns like logging, monitoring, and debugging. + +### Before vs After Hooks + +The pipeline supports two types of hooks: + +- **Before hooks** (`register_before_step_hook`) - Called before each step executes +- **After hooks** (`register_after_step_hook`) - Called after each step completes + +```python +def before_hook(step_idx: int, transition: EnvTransition): + """Called before step processes the transition.""" + print(f"About to execute step {step_idx}") + # Useful for: logging, validation, setup + +def after_hook(step_idx: int, transition: EnvTransition): + """Called after step has processed the transition.""" + print(f"Completed step {step_idx}") + # Useful for: monitoring results, cleanup, debugging + +processor.register_before_step_hook(before_hook) +processor.register_after_step_hook(after_hook) +``` + +### Implementing a NaN Detection Hook + +Here's a practical example of a hook that detects NaN values: + +```python +def check_nans(step_idx: int, transition: EnvTransition): + """Check for NaN values in observations.""" + obs = transition.get(TransitionKey.OBSERVATION) + if obs: + for key, value in obs.items(): + if isinstance(value, torch.Tensor) and torch.isnan(value).any(): + print(f"NaN detected in {key} at step {step_idx}") + +# Register the hook to run after each step +processor.register_after_step_hook(check_nans) + +# Process your data - the hook will be called automatically +output = processor(input_data) + +# Remove the hook when done debugging +processor.unregister_after_step_hook(check_nans) +``` + +### How Hooks Work Internally + +Understanding the internal mechanism helps you use hooks more effectively. The pipeline maintains two separate lists: one for before-step hooks and another for after-step hooks. When you register a hook, it's simply appended to the appropriate list. + +During execution, the pipeline follows a strict sequence: for each processing step, it first calls all before-hooks in registration order, then executes the actual step transformation, and finally calls all after-hooks in registration order. This creates a predictable, sandwich-like structure around each step. + +The key insight is that hooks don't change the core pipeline logic—they're purely additive. The pipeline's `_forward` method orchestrates this dance between hooks and processing steps, ensuring that your debugging or monitoring code runs at exactly the right moments without interfering with the main data flow. + +Here's a simplified view of how the pipeline executes hooks: + +```python +class DataProcessorPipeline: + def __init__(self): + self.steps = [...] + self.before_step_hooks = [] # List of before hooks + self.after_step_hooks = [] # List of after hooks + + def _forward(self, transition): + """Internal method that processes the transition through all steps.""" + for step_idx, processor_step in enumerate(self.steps): + # 1. Call all BEFORE hooks + for hook in self.before_step_hooks: + hook(step_idx, transition) + + # 2. Execute the actual processing step + transition = processor_step(transition) + + # 3. Call all AFTER hooks + for hook in self.after_step_hooks: + hook(step_idx, transition) + + return transition + + def register_before_step_hook(self, hook_fn): + self.before_step_hooks.append(hook_fn) + + def register_after_step_hook(self, hook_fn): + self.after_step_hooks.append(hook_fn) +``` + +### Execution Flow + +The execution flow looks like this: + +``` +Input → Before Hook → Step 0 → After Hook → Before Hook → Step 1 → After Hook → ... → Output +``` + +For example, with 3 steps and both hook types: + +```python +def timing_before(step_idx, transition): + print(f"⏱️ Starting step {step_idx}") + +def validation_after(step_idx, transition): + print(f"✅ Completed step {step_idx}") + +processor.register_before_step_hook(timing_before) +processor.register_after_step_hook(validation_after) + +# This will output: +# ⏱️ Starting step 0 +# ✅ Completed step 0 +# ⏱️ Starting step 1 +# ✅ Completed step 1 +# ⏱️ Starting step 2 +# ✅ Completed step 2 +``` + +### Multiple Hooks + +You can register multiple hooks of the same type - they execute in the order registered: + +```python +def log_shapes(step_idx: int, transition: EnvTransition): + obs = transition.get(TransitionKey.OBSERVATION) + if obs: + print(f"Step {step_idx} observation shapes:") + for key, value in obs.items(): + if isinstance(value, torch.Tensor): + print(f" {key}: {value.shape}") + +processor.register_after_step_hook(check_nans) # Executes first +processor.register_after_step_hook(log_shapes) # Executes second + +# Both hooks will be called after each step in registration order +output = processor(input_data) +``` + +While hooks are excellent for monitoring specific issues (like NaN detection) or gathering metrics during normal pipeline execution, sometimes you need to dive deeper. When you want to understand exactly what happens at each step or debug complex transformation logic, step-through debugging provides the detailed inspection you need. + +## Step-Through Debugging + +Step-through debugging is like having a slow-motion replay for your pipeline. Instead of watching your data get transformed in one quick blur from input to output, you can pause and examine what happens after each individual step. + +This approach is particularly valuable when you're trying to understand a complex pipeline, debug unexpected behavior, or verify that each transformation is working as expected. Unlike hooks, which are great for automated monitoring, step-through debugging gives you manual, interactive control over the inspection process. + +The `step_through()` method is a generator that yields the transition state after each processing step, allowing you to inspect intermediate results. Think of it as creating a series of snapshots of your data as it flows through the pipeline—each snapshot shows you exactly what your data looks like after one more transformation has been applied. + +### How Step-Through Works + +The `step_through()` method fundamentally changes how the pipeline executes. Instead of running all steps in sequence and only returning the final result, it transforms the pipeline into an iterator that yields intermediate results. + +Here's what happens internally: the method starts by converting your input data into the pipeline's internal transition format, then yields this initial state. Next, it applies the first processing step and yields the result. Then it applies the second step to that result and yields again, and so on. Each `yield` gives you a complete snapshot of the transition at that point. + +This generator pattern is powerful because it's lazy—the pipeline only computes the next step when you ask for it. This means you can stop at any point, inspect the current state thoroughly, and decide whether to continue. You're not forced to run the entire pipeline just to debug one problematic step. + +Instead of running the entire pipeline and only seeing the final result, `step_through()` pauses after each step and gives you the intermediate transition: + +```python +# This creates a generator that yields intermediate states +for i, intermediate_result in enumerate(processor.step_through(input_data)): + print(f"=== After step {i} ===") + + # Inspect the observation at this stage + obs = intermediate_result.get(TransitionKey.OBSERVATION) + if obs: + for key, value in obs.items(): + if isinstance(value, torch.Tensor): + print(f"{key}: shape={value.shape}, dtype={value.dtype}") +``` + +### Interactive Debugging with Breakpoints + +You can add breakpoints in the step-through loop to interactively debug: + +```python +# Step through the pipeline with debugging +for i, intermediate in enumerate(processor.step_through(data)): + print(f"Step {i}: {processor.steps[i].__class__.__name__}") + + # Set a breakpoint to inspect the current state + breakpoint() # Debugger will pause here + + # You can now inspect 'intermediate' in the debugger: + # - Check tensor shapes and values + # - Verify expected transformations + # - Look for unexpected changes +``` + +During the debugger session, you can: + +- Examine `intermediate[TransitionKey.OBSERVATION]` to see observation data +- Check `intermediate[TransitionKey.ACTION]` for action transformations +- Inspect any part of the transition to understand what each step does + +Step-through debugging is perfect for understanding the _data_ transformations, but what about the _structure_ of that data? While hooks and step-through help you debug runtime behavior, you also need to ensure your pipeline produces data in the format expected by downstream components. This is where feature contract validation comes in. + +## Validating Feature Contracts + +Feature contracts define what data structure your pipeline expects as input and produces as output. +Validating these contracts helps catch mismatches early. + +### Understanding Feature Contracts + +Each processor step has a `transform_features()` method that describes how it changes the data structure: + +```python +# Get the expected output features from your pipeline +initial_features = { + PipelineFeatureType.OBSERVATION: { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(7,)), + "observation.image": PolicyFeature(type=FeatureType.IMAGE, shape=(3, 224, 224)) + }, + PipelineFeatureType.ACTION: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)) + } +} + +# Check what your pipeline will output +output_features = processor.transform_features(initial_features) + +print("Input features:") +for feature_type, features in initial_features.items(): + print(f" {feature_type}:") + for key, feature in features.items(): + print(f" {key}: {feature.type.value}, shape={feature.shape}") + +print("\nOutput features:") +for feature_type, features in output_features.items(): + print(f" {feature_type}:") + for key, feature in features.items(): + print(f" {key}: {feature.type.value}, shape={feature.shape}") +``` + +### Verifying Expected Features + +Check that your pipeline produces the features you expect: + +```python +# Define what features you expect the pipeline to produce +expected_keys = ["observation.state", "observation.image", "action"] + +print("Validating feature contract...") +for expected_key in expected_keys: + found = False + for feature_type, features in output_features.items(): + if expected_key in features: + feature = features[expected_key] + print(f"✅ {expected_key}: {feature.type.value}, shape={feature.shape}") + found = True + break + + if not found: + print(f"❌ Missing expected feature: {expected_key}") +``` + +This validation helps ensure your pipeline will work correctly with downstream components that expect specific data structures. + +## Summary + +Now that you understand the three debugging approaches, you can tackle any pipeline issue systematically: + +1. **Hooks** - For runtime monitoring and validation without modifying pipeline code +2. **Step-through** - For inspecting intermediate states and understanding transformations +3. **Feature validation** - For ensuring data structure contracts are met + +**When to use each approach:** + +- Start with **step-through debugging** when you need to understand what your pipeline does or when something unexpected happens +- Add **hooks** for continuous monitoring during development and production to catch issues automatically +- Use **feature validation** before deployment to ensure your pipeline works with downstream components + +These three tools work together to give you the complete observability that complex pipelines naturally lack. With hooks watching for issues, step-through helping you understand behavior, and feature validation ensuring compatibility, you'll be able to debug any pipeline confidently and efficiently. diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx index f8a5c69b2..f6bac1ffa 100644 --- a/docs/source/hilserl.mdx +++ b/docs/source/hilserl.mdx @@ -4,7 +4,13 @@ In this tutorial you will go through the full Human-in-the-Loop Sample-Efficient HIL-SERL is a sample-efficient reinforcement learning algorithm that combines human demonstrations with online learning and human interventions. The approach starts from a small set of human demonstrations, uses them to train a reward classifier, and then employs an actor-learner architecture where humans can intervene during policy execution to guide exploration and correct unsafe behaviors. In this tutorial, you'll use a gamepad to provide interventions and control the robot during the learning process. -It combines three key ingredients: 1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. 2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. 3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe. +It combines three key ingredients: + +1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. + +2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. + +3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe. Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines. @@ -56,30 +62,242 @@ pip install -e ".[hilserl]" ### Understanding Configuration -The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/envs/configs.py`. Which is defined as: +The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/scripts/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs: ```python +class GymManipulatorConfig: + env: HILSerlRobotEnvConfig # Environment configuration (nested) + dataset: DatasetConfig # Dataset recording/replay configuration (nested) + mode: str | None = None # "record", "replay", or None (for training) + device: str = "cpu" # Compute device + class HILSerlRobotEnvConfig(EnvConfig): robot: RobotConfig | None = None # Main robot agent (defined in `lerobot/robots`) - teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm, (defined in `lerobot/teleoperators`) - wrapper: EnvTransformConfig | None = None # Environment wrapper settings; check `lerobot/scripts/server/gym_manipulator.py` - fps: int = 10 # Control frequency + teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm + processor: HILSerlProcessorConfig # Processing pipeline configuration (nested) name: str = "real_robot" # Environment name - mode: str = None # "record", "replay", or None (for training) - repo_id: str | None = None # LeRobot dataset repository ID - dataset_root: str | None = None # Local dataset root (optional) - task: str = "" # Task identifier - num_episodes: int = 10 # Number of episodes for recording - episode: int = 0 # episode index for replay - device: str = "cuda" # Compute device - push_to_hub: bool = True # Whether to push the recorded datasets to Hub - pretrained_policy_name_or_path: str | None = None # For policy loading - reward_classifier_pretrained_path: str | None = None # For reward model - number_of_steps_after_success: int = 0 # For reward classifier, collect more positive examples after a success to train a classifier + task: str | None = None # Task identifier + fps: int = 10 # Control frequency + +# Nested processor configuration +class HILSerlProcessorConfig: + control_mode: str = "gamepad" # Control mode + observation: ObservationConfig | None = None # Observation processing settings + image_preprocessing: ImagePreprocessingConfig | None = None # Image crop/resize settings + gripper: GripperConfig | None = None # Gripper control and penalty settings + reset: ResetConfig | None = None # Environment reset and timing settings + inverse_kinematics: InverseKinematicsConfig | None = None # IK processing settings + reward_classifier: RewardClassifierConfig | None = None # Reward classifier settings + max_gripper_pos: float | None = 100.0 # Maximum gripper position + +# Sub-configuration classes +class ObservationConfig: + add_joint_velocity_to_observation: bool = False # Add joint velocities to state + add_current_to_observation: bool = False # Add motor currents to state + add_ee_pose_to_observation: bool = False # Add end-effector pose to state + display_cameras: bool = False # Display camera feeds during execution + +class ImagePreprocessingConfig: + crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None # Image cropping parameters + resize_size: tuple[int, int] | None = None # Target image size + +class GripperConfig: + use_gripper: bool = True # Enable gripper control + gripper_penalty: float = 0.0 # Penalty for inappropriate gripper usage + gripper_penalty_in_reward: bool = False # Include gripper penalty in reward + +class ResetConfig: + fixed_reset_joint_positions: Any | None = None # Joint positions for reset + reset_time_s: float = 5.0 # Time to wait during reset + control_time_s: float = 20.0 # Maximum episode duration + terminate_on_success: bool = True # Whether to terminate episodes on success detection + +class InverseKinematicsConfig: + urdf_path: str | None = None # Path to robot URDF file + target_frame_name: str | None = None # End-effector frame name + end_effector_bounds: dict[str, list[float]] | None = None # EE workspace bounds + end_effector_step_sizes: dict[str, float] | None = None # EE step sizes per axis + +class RewardClassifierConfig: + pretrained_path: str | None = None # Path to pretrained reward classifier + success_threshold: float = 0.5 # Success detection threshold + success_reward: float = 1.0 # Reward value for successful episodes + +# Dataset configuration +class DatasetConfig: + repo_id: str # LeRobot dataset repository ID + task: str # Task identifier + root: str | None = None # Local dataset root directory + num_episodes_to_record: int = 5 # Number of episodes for recording + replay_episode: int | None = None # Episode index for replay + push_to_hub: bool = False # Whether to push datasets to Hub ``` +### Processor Pipeline Architecture + +HIL-SERL uses a modular processor pipeline architecture that processes robot observations and actions through a series of composable steps. The pipeline is divided into two main components: + +#### Environment Processor Pipeline + +The environment processor (`env_processor`) handles incoming observations and environment state: + +1. **VanillaObservationProcessorStep**: Converts raw robot observations into standardized format +2. **JointVelocityProcessorStep** (optional): Adds joint velocity information to observations +3. **MotorCurrentProcessorStep** (optional): Adds motor current readings to observations +4. **ForwardKinematicsJointsToEE** (optional): Computes end-effector pose from joint positions +5. **ImageCropResizeProcessorStep** (optional): Crops and resizes camera images +6. **TimeLimitProcessorStep** (optional): Enforces episode time limits +7. **GripperPenaltyProcessorStep** (optional): Applies penalties for inappropriate gripper usage +8. **RewardClassifierProcessorStep** (optional): Automated reward detection using vision models +9. **AddBatchDimensionProcessorStep**: Converts data to batch format for neural network processing +10. **DeviceProcessorStep**: Moves data to the specified compute device (CPU/GPU) + +#### Action Processor Pipeline + +The action processor (`action_processor`) handles outgoing actions and human interventions: + +1. **AddTeleopActionAsComplimentaryDataStep**: Captures teleoperator actions for logging +2. **AddTeleopEventsAsInfoStep**: Records intervention events and episode control signals +3. **InterventionActionProcessorStep**: Handles human interventions and episode termination +4. **Inverse Kinematics Pipeline** (when enabled): + - **MapDeltaActionToRobotActionStep**: Converts delta actions to robot action format + - **EEReferenceAndDelta**: Computes end-effector reference and delta movements + - **EEBoundsAndSafety**: Enforces workspace safety bounds + - **InverseKinematicsEEToJoints**: Converts end-effector actions to joint targets + - **GripperVelocityToJoint**: Handles gripper control commands + +#### Configuration Examples + +**Basic Observation Processing**: + +```json +{ + "env": { + "processor": { + "observation": { + "add_joint_velocity_to_observation": true, + "add_current_to_observation": false, + "display_cameras": false + } + } + } +} +``` + +**Image Processing**: + +```json +{ + "env": { + "processor": { + "image_preprocessing": { + "crop_params_dict": { + "observation.images.front": [180, 250, 120, 150], + "observation.images.side": [180, 207, 180, 200] + }, + "resize_size": [128, 128] + } + } + } +} +``` + +**Inverse Kinematics Setup**: + +```json +{ + "env": { + "processor": { + "inverse_kinematics": { + "urdf_path": "path/to/robot.urdf", + "target_frame_name": "end_effector", + "end_effector_bounds": { + "min": [0.16, -0.08, 0.03], + "max": [0.24, 0.2, 0.1] + }, + "end_effector_step_sizes": { + "x": 0.02, + "y": 0.02, + "z": 0.02 + } + } + } + } +} +``` + +### Advanced Observation Processing + +The HIL-SERL framework supports additional observation processing features that can improve policy learning: + +#### Joint Velocity Processing + +Enable joint velocity estimation to provide the policy with motion information: + +```json +{ + "env": { + "processor": { + "observation": { + "add_joint_velocity_to_observation": true + } + } + } +} +``` + +This processor: + +- Estimates joint velocities using finite differences between consecutive joint position readings +- Adds velocity information to the observation state vector +- Useful for policies that need motion awareness for dynamic tasks + +#### Motor Current Processing + +Monitor motor currents to detect contact forces and load conditions: + +```json +{ + "env": { + "processor": { + "observation": { + "add_current_to_observation": true + } + } + } +} +``` + +This processor: + +- Reads motor current values from the robot's control system +- Adds current measurements to the observation state vector +- Helps detect contact events, object weights, and mechanical resistance +- Useful for contact-rich manipulation tasks + +#### Combined Observation Processing + +You can enable multiple observation processing features simultaneously: + +```json +{ + "env": { + "processor": { + "observation": { + "add_joint_velocity_to_observation": true, + "add_current_to_observation": true, + "add_ee_pose_to_observation": false, + "display_cameras": false + } + } + } +} +``` + +**Note**: Enabling additional observation features increases the state space dimensionality, which may require adjusting your policy network architecture and potentially collecting more training data. + ### Finding Robot Workspace Bounds Before collecting demonstrations, you need to determine the appropriate operational bounds for your robot. @@ -128,24 +346,58 @@ With the bounds defined, you can safely collect demonstrations for training. Tra **Setting Up Record Mode** -Create a configuration file for recording demonstrations (or edit an existing one like [env_config_so100.json](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json)): +Create a configuration file for recording demonstrations (or edit an existing one like [env_config.json](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/env_config.json)): -1. Set `mode` to `"record"` -2. Specify a unique `repo_id` for your dataset (e.g., "username/task_name") -3. Set `num_episodes` to the number of demonstrations you want to collect -4. Set `crop_params_dict` to `null` initially (we'll determine crops later) -5. Configure `robot`, `cameras`, and other hardware settings +1. Set `mode` to `"record"` at the root level +2. Specify a unique `repo_id` for your dataset in the `dataset` section (e.g., "username/task_name") +3. Set `num_episodes_to_record` in the `dataset` section to the number of demonstrations you want to collect +4. Set `env.processor.image_preprocessing.crop_params_dict` to `{}` initially (we'll determine crops later) +5. Configure `env.robot`, `env.teleop`, and other hardware settings in the `env` section Example configuration section: ```json -"mode": "record", -"repo_id": "username/pick_lift_cube", -"dataset_root": null, -"task": "pick_and_lift", -"num_episodes": 15, -"episode": 0, -"push_to_hub": true +{ + "env": { + "type": "gym_manipulator", + "name": "real_robot", + "fps": 10, + "processor": { + "control_mode": "gamepad", + "observation": { + "display_cameras": false + }, + "image_preprocessing": { + "crop_params_dict": {}, + "resize_size": [128, 128] + }, + "gripper": { + "use_gripper": true, + "gripper_penalty": 0.0 + }, + "reset": { + "reset_time_s": 5.0, + "control_time_s": 20.0 + } + }, + "robot": { + // ... robot configuration ... + }, + "teleop": { + // ... teleoperator configuration ... + } + }, + "dataset": { + "repo_id": "username/pick_lift_cube", + "root": null, + "task": "pick_and_lift", + "num_episodes_to_record": 15, + "replay_episode": 0, + "push_to_hub": true + }, + "mode": "record", + "device": "cpu" +} ``` ### Using a Teleoperation Device @@ -191,10 +443,20 @@ The gamepad provides a very convenient way to control the robot and the episode To setup the gamepad, you need to set the `control_mode` to `"gamepad"` and define the `teleop` section in the configuration file. ```json +{ + "env": { "teleop": { - "type": "gamepad", - "use_gripper": true + "type": "gamepad", + "use_gripper": true }, + "processor": { + "control_mode": "gamepad", + "gripper": { + "use_gripper": true + } + } + } +} ```

@@ -216,11 +478,21 @@ The SO101 leader arm has reduced gears that allows it to move and track the foll To setup the SO101 leader, you need to set the `control_mode` to `"leader"` and define the `teleop` section in the configuration file. ```json +{ + "env": { "teleop": { - "type": "so101_leader", - "port": "/dev/tty.usbmodem585A0077921", # check your port number - "use_degrees": true + "type": "so101_leader", + "port": "/dev/tty.usbmodem585A0077921", + "use_degrees": true }, + "processor": { + "control_mode": "leader", + "gripper": { + "use_gripper": true + } + } + } +} ``` In order to annotate the success/failure of the episode, **you will need** to use a keyboard to press `s` for success, `esc` for failure. @@ -251,7 +523,7 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/e During recording: -1. The robot will reset to the initial position defined in the configuration file `fixed_reset_joint_positions` +1. The robot will reset to the initial position defined in the configuration file `env.processor.reset.fixed_reset_joint_positions` 2. Complete the task successfully 3. The episode ends with a reward of 1 when you press the "success" button 4. If the time limit is reached, or the fail button is pressed, the episode ends with a reward of 0 @@ -310,11 +582,19 @@ observation.images.front: [180, 250, 120, 150] Add these crop parameters to your training configuration: ```json -"crop_params_dict": { - "observation.images.side": [180, 207, 180, 200], - "observation.images.front": [180, 250, 120, 150] -}, -"resize_size": [128, 128] +{ + "env": { + "processor": { + "image_preprocessing": { + "crop_params_dict": { + "observation.images.side": [180, 207, 180, 200], + "observation.images.front": [180, 250, 120, 150] + }, + "resize_size": [128, 128] + } + } + } +} ``` **Recommended image resolution** @@ -343,26 +623,52 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/r **Key Parameters for Data Collection** -- **mode**: set it to `"record"` to collect a dataset -- **repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub -- **num_episodes**: Number of episodes to record -- **number_of_steps_after_success**: Number of additional frames to record after a success (reward=1) is detected -- **fps**: Number of frames per second to record -- **push_to_hub**: Whether to push the dataset to the hub +- **mode**: set it to `"record"` to collect a dataset (at root level) +- **dataset.repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub +- **dataset.num_episodes_to_record**: Number of episodes to record +- **env.processor.reset.terminate_on_success**: Whether to automatically terminate episodes when success is detected (default: `true`) +- **env.fps**: Number of frames per second to record +- **dataset.push_to_hub**: Whether to push the dataset to the hub -The `number_of_steps_after_success` parameter is crucial as it allows you to collect more positive examples. When a success is detected, the system will continue recording for the specified number of steps while maintaining the reward=1 label. Otherwise, there won't be enough states in the dataset labeled to 1 to train a good classifier. +The `env.processor.reset.terminate_on_success` parameter allows you to control episode termination behavior. When set to `false`, episodes will continue even after success is detected, allowing you to collect more positive examples with the reward=1 label. This is crucial for training reward classifiers as it provides more success state examples in your dataset. When set to `true` (default), episodes terminate immediately upon success detection. + +**Important**: For reward classifier training, set `terminate_on_success: false` to collect sufficient positive examples. For regular HIL-SERL training, keep it as `true` to enable automatic episode termination when the task is completed successfully. Example configuration section for data collection: ```json { + "env": { + "type": "gym_manipulator", + "name": "real_robot", + "fps": 10, + "processor": { + "reset": { + "reset_time_s": 5.0, + "control_time_s": 20.0, + "terminate_on_success": false + }, + "gripper": { + "use_gripper": true + } + }, + "robot": { + // ... robot configuration ... + }, + "teleop": { + // ... teleoperator configuration ... + } + }, + "dataset": { + "repo_id": "hf_username/dataset_name", + "dataset_root": "data/your_dataset", + "task": "reward_classifier_task", + "num_episodes_to_record": 20, + "replay_episode": null, + "push_to_hub": true + }, "mode": "record", - "repo_id": "hf_username/dataset_name", - "dataset_root": "data/your_dataset", - "num_episodes": 20, - "push_to_hub": true, - "fps": 10, - "number_of_steps_after_success": 15 + "device": "cpu" } ``` @@ -421,9 +727,17 @@ To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to ```python -env_config = HILSerlRobotEnvConfig( - reward_classifier_pretrained_path="path_to_your_pretrained_trained_model", - # Other environment parameters +config = GymManipulatorConfig( + env=HILSerlRobotEnvConfig( + processor=HILSerlProcessorConfig( + reward_classifier=RewardClassifierConfig( + pretrained_path="path_to_your_pretrained_trained_model" + ) + ), + # Other environment parameters + ), + dataset=DatasetConfig(...), + mode=None # For training ) ``` @@ -432,7 +746,18 @@ or set the argument in the json config file. ```json { - "reward_classifier_pretrained_path": "path_to_your_pretrained_model" + "env": { + "processor": { + "reward_classifier": { + "pretrained_path": "path_to_your_pretrained_model", + "success_threshold": 0.7, + "success_reward": 1.0 + }, + "reset": { + "terminate_on_success": true + } + } + } } ``` @@ -447,7 +772,7 @@ The reward classifier will automatically provide rewards based on the visual inp **Example Workflow for training the reward classifier** 1. **Create the configuration files**: - Create the necessary json configuration files for the reward classifier and the environment. Check the examples [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/tree/main). + Create the necessary json configuration files for the reward classifier and the environment. Check the examples [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/reward_classifier/config.json). 2. **Collect a dataset**: @@ -472,7 +797,7 @@ The LeRobot system uses a distributed actor-learner architecture for training. T **Configuration Setup** -Create a training configuration file (example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_config_hilserl_so100.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`. +Create a training configuration file (example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/train_config.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`. 1. Configure the policy settings (`type="sac"`, `device`, etc.) 2. Set `dataset` to your cropped dataset diff --git a/docs/source/hilserl_sim.mdx b/docs/source/hilserl_sim.mdx index c739be835..77191fde3 100644 --- a/docs/source/hilserl_sim.mdx +++ b/docs/source/hilserl_sim.mdx @@ -26,15 +26,18 @@ pip install -e ".[hilserl]" ## Configuration -To use `gym_hil` with LeRobot, you need to create a configuration file. An example is provided [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/gym_hil_env.json). Key configuration sections include: +To use `gym_hil` with LeRobot, you need to create a configuration file. An example is provided [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/gym_hil/env_config.json). Key configuration sections include: ### Environment Type and Task ```json { - "type": "hil", - "name": "franka_sim", - "task": "PandaPickCubeGamepad-v0", + "env": { + "type": "gym_manipulator", + "name": "gym_hil", + "task": "PandaPickCubeGamepad-v0", + "fps": 10 + }, "device": "cuda" } ``` @@ -45,28 +48,40 @@ Available tasks: - `PandaPickCubeGamepad-v0`: With gamepad control - `PandaPickCubeKeyboard-v0`: With keyboard control -### Gym Wrappers Configuration +### Processor Configuration ```json -"wrapper": { - "gripper_penalty": -0.02, - "control_time_s": 15.0, - "use_gripper": true, - "fixed_reset_joint_positions": [0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785], - "end_effector_step_sizes": { - "x": 0.025, - "y": 0.025, - "z": 0.025 - }, - "control_mode": "gamepad" +{ + "env": { + "processor": { + "control_mode": "gamepad", + "gripper": { + "use_gripper": true, + "gripper_penalty": -0.02 + }, + "reset": { + "control_time_s": 15.0, + "fixed_reset_joint_positions": [ + 0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785 + ] + }, + "inverse_kinematics": { + "end_effector_step_sizes": { + "x": 0.025, + "y": 0.025, + "z": 0.025 + } + } } + } +} ``` Important parameters: -- `gripper_penalty`: Penalty for excessive gripper movement -- `use_gripper`: Whether to enable gripper control -- `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector +- `gripper.gripper_penalty`: Penalty for excessive gripper movement +- `gripper.use_gripper`: Whether to enable gripper control +- `inverse_kinematics.end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector - `control_mode`: Set to `"gamepad"` to use a gamepad controller ## Running with HIL RL of LeRobot @@ -75,39 +90,50 @@ Important parameters: To run the environment, set mode to null: - -```python +```bash python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json ``` - ### Recording a Dataset To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record: - -```python +```json +{ + "env": { + "type": "gym_manipulator", + "name": "gym_hil", + "task": "PandaPickCubeGamepad-v0" + }, + "dataset": { + "repo_id": "username/sim_dataset", + "root": null, + "task": "pick_cube", + "num_episodes_to_record": 10, + "replay_episode": null, + "push_to_hub": true + }, + "mode": "record" +} +``` + +```bash python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json ``` - ### Training a Policy -To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_gym_hil_env.json) and run the actor and learner servers: +To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/gym_hil/train_config.json) and run the actor and learner servers: - -```python +```bash python -m lerobot.scripts.rl.actor --config_path path/to/train_gym_hil_env.json ``` - In a different terminal, run the learner server: - -```python +```bash python -m lerobot.scripts.rl.learner --config_path path/to/train_gym_hil_env.json ``` - The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots. diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 905046bef..19b62167e 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -519,11 +519,14 @@ from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import _init_rerun from lerobot.record import record_loop +from lerobot.policies.factory import make_processor NUM_EPISODES = 5 FPS = 30 EPISODE_TIME_SEC = 60 TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" +HF_DATASET_ID = "/" # Create the robot configuration camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} @@ -535,7 +538,7 @@ robot_config = SO100FollowerConfig( robot = SO100Follower(robot_config) # Initialize the policy -policy = ACTPolicy.from_pretrained("/") +policy = ACTPolicy.from_pretrained(HF_MODEL_ID) # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, "action") @@ -544,7 +547,7 @@ dataset_features = {**action_features, **obs_features} # Create the dataset dataset = LeRobotDataset.create( - repo_id="/eval_", + repo_id=HF_DATASET_ID, fps=FPS, features=dataset_features, robot_type=robot.name, @@ -559,6 +562,12 @@ _init_rerun(session_name="recording") # Connect the robot robot.connect() +preprocessor, postprocessor = make_processor( + policy_cfg=policy, + pretrained_path=HF_MODEL_ID, + dataset_stats=dataset.meta.stats, +) + for episode_idx in range(NUM_EPISODES): log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") @@ -568,6 +577,8 @@ for episode_idx in range(NUM_EPISODES): events=events, fps=FPS, policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, diff --git a/docs/source/il_sim.mdx b/docs/source/il_sim.mdx index 3dd80dc4b..6a615620b 100644 --- a/docs/source/il_sim.mdx +++ b/docs/source/il_sim.mdx @@ -22,13 +22,38 @@ pip install -e ".[hilserl]" ## Teleoperate and Record a Dataset -To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_gym_hil_il.json). +To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/sim_il/env_config.json). -To teleoperate and collect a dataset, we need to modify this config file and you should add your `repo_id` here: `"repo_id": "il_gym",` and `"num_episodes": 30,` and make sure you set `mode` to `record`, "mode": "record". +To teleoperate and collect a dataset, we need to modify this config file. Here's an example configuration for imitation learning data collection: -If you do not have a Nvidia GPU also change `"device": "cuda"` parameter in the config file (for example to `mps` for MacOS). +```json +{ + "env": { + "type": "gym_manipulator", + "name": "gym_hil", + "task": "PandaPickCubeGamepad-v0", + "fps": 10 + }, + "dataset": { + "repo_id": "your_username/il_gym", + "root": null, + "task": "pick_cube", + "num_episodes_to_record": 30, + "replay_episode": null, + "push_to_hub": true + }, + "mode": "record", + "device": "cuda" +} +``` -By default the config file assumes you use a controller. To use your keyboard please change the envoirment specified at `"task"` in the config file and set it to `"PandaPickCubeKeyboard-v0"`. +Key configuration points: + +- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"` +- Set `num_episodes_to_record: 30` to collect 30 demonstration episodes +- Ensure `mode` is set to `"record"` +- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"` +- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"` Then we can run this command to start: @@ -140,9 +165,32 @@ huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \ ## Evaluate your policy in Sim -To evaluate your policy we have to use the config file that can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json). +To evaluate your policy we have to use a configuration file. An example can be found [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/sim_il/eval_config.json). -Make sure to replace the `repo_id` with the dataset you trained on, for example `pepijn223/il_sim_dataset` and replace the `pretrained_policy_name_or_path` with your model id, for example `pepijn223/il_sim_model` +Here's an example evaluation configuration: + +```json +{ + "env": { + "type": "gym_manipulator", + "name": "gym_hil", + "task": "PandaPickCubeGamepad-v0", + "fps": 10 + }, + "dataset": { + "repo_id": "your_username/il_sim_dataset", + "dataset_root": null, + "task": "pick_cube" + }, + "pretrained_policy_name_or_path": "your_username/il_sim_model", + "device": "cuda" +} +``` + +Make sure to replace: + +- `repo_id` with the dataset you trained on (e.g., `your_username/il_sim_dataset`) +- `pretrained_policy_name_or_path` with your model ID (e.g., `your_username/il_sim_model`) Then you can run this command to visualize your trained policy diff --git a/docs/source/implement_your_own_processor.mdx b/docs/source/implement_your_own_processor.mdx new file mode 100644 index 000000000..5b7d4f78a --- /dev/null +++ b/docs/source/implement_your_own_processor.mdx @@ -0,0 +1,273 @@ +# Implement your own Robot Processor + +In this tutorial, you'll learn how to implement your own Robot Processor. +It begins by exploring the need for a custom processor, then uses the `NormalizerProcessorStep` as the running example to explain how to implement, configure, and serialize a processor. Finally, it lists all helper processors that ship with LeRobot. + +## Why would you need a custom processor? + +In most cases, when reading raw data from sensors or when models output actions, you need to process this data to make it compatible with your target system. For example, a common need is normalizing data ranges to make them suitable for neural networks. + +LeRobot's `NormalizerProcessorStep` handles this crucial task: + +```python +# Input: raw joint positions in [0, 180] degrees +raw_action = torch.tensor([90.0, 45.0, 135.0]) + +# After processing: normalized to [-1, 1] range for model training +normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=dataset_stats) +normalized_result = normalizer(transition) +# ... +``` + +Other common processing needs include: + +- **Device placement**: Moving tensors between CPU/GPU and converting data types +- **Format conversion**: Transforming between different data structures +- **Batching**: Adding/removing batch dimensions for model compatibility +- **Safety constraints**: Applying limits to robot commands + +```python +# Example pipeline combining multiple processors +pipeline = PolicyProcessorPipeline([ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + NormalizerProcessorStep(features=features, stats=stats), + DeviceProcessorStep(device="cuda"), + # ... +]) +``` + +LeRobot provides a pipeline mechanism to implement sequences of processing steps for both input data and output actions, making it easy to compose these transformations in the right order for optimal performance. + +## How to implement your own processor? + +We'll use the `NormalizerProcessorStep` as our main example because it demonstrates essential processor patterns including state management, configuration serialization, and tensor handling that you'll commonly need. + +Prepare the sequence of processing steps necessary for your problem. A processor step is a class that implements the following methods: + +- `__call__`: implements the processing step for the input transition. +- `get_config`: gets the configuration of the processor step. +- `state_dict`: gets the state of the processor step. +- `load_state_dict`: loads the state of the processor step. +- `reset`: resets the state of the processor step. +- `feature_contract`: displays the modification to the feature space during the processor step. + +### Implement the `__call__` method + +The `__call__` method is the core of your processor step. It takes an `EnvTransition` and returns a modified `EnvTransition`. Here's how the `NormalizerProcessorStep` works: + +```python +@dataclass +@ProcessorStepRegistry.register("normalizer_processor") +class NormalizerProcessorStep(ProcessorStep): + """Normalize observations/actions using dataset statistics.""" + + features: dict[str, PolicyFeature] + norm_map: dict[FeatureType, NormalizationMode] + stats: dict[str, dict[str, Any]] | None = None + eps: float = 1e-8 + _tensor_stats: dict = field(default_factory=dict, init=False, repr=False) + + def __post_init__(self): + """Convert stats to tensors for efficient computation.""" + self.stats = self.stats or {} + self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=torch.float32) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + new_transition = transition.copy() + # Normalize observations + # ... + # Normalize action + # ... + return new_transition + +``` + +See the full implementation in `src/lerobot/processor/normalize_processor.py` for complete details. + +**Key principles:** + +- **Always use `transition.copy()`** to avoid side effects +- **Handle both observations and actions** consistently +- **Separate config from state**: `get_config()` returns JSON-serializable params, `state_dict()` returns tensors +- **Convert stats to tensors** in `__post_init__()` for efficient computation + +### Configuration and State Management + +Processors support serialization through three methods that separate configuration from tensor state. The `NormalizerProcessorStep` demonstrates this perfectly - it carries dataset statistics (tensors) in its state, and hyperparameters in its config: + +```python +# Continuing the NormalizerProcessorStep example... + +def get_config(self) -> dict[str, Any]: + """JSON-serializable configuration (no tensors).""" + return { + "eps": self.eps, + "features": {k: {"type": v.type.value, "shape": v.shape} for k, v in self.features.items()}, + "norm_map": {ft.value: nm.value for ft, nm in self.norm_map.items()}, + # ... + } + +def state_dict(self) -> dict[str, torch.Tensor]: + """Tensor state only (e.g., dataset statistics).""" + flat: dict[str, torch.Tensor] = {} + for key, sub in self._tensor_stats.items(): + for stat_name, tensor in sub.items(): + flat[f"{key}.{stat_name}"] = tensor.cpu() # Always save to CPU + return flat + +def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Restore tensor state at runtime.""" + self._tensor_stats.clear() + for flat_key, tensor in state.items(): + key, stat_name = flat_key.rsplit(".", 1) + # Load to processor's configured device + self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to( + dtype=torch.float32, device=self.device + ) + # ... +``` + +**Usage:** + +```python +# Save (e.g., inside a policy) +config = normalizer.get_config() +tensors = normalizer.state_dict() + +# Restore (e.g., loading a pretrained policy) +new_normalizer = NormalizerProcessorStep(**config) +new_normalizer.load_state_dict(tensors) +# Now new_normalizer has the same stats and configuration +``` + +### Transform features + +The `transform_features` method defines how your processor transforms feature names and shapes. This is crucial for policy configuration and debugging. + +For `NormalizerProcessorStep`, features are typically preserved unchanged since normalization doesn't alter keys or shapes: + +```python +def transform_features(self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Normalization preserves all feature definitions.""" + return features # No changes to feature structure + # ... +``` + +When your processor renames or reshapes data, implement this method to reflect the mapping for downstream components. For example, a simple rename processor: + +```python +def transform_features(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # Simple renaming + if "pixels" in features: + features["observation.image"] = features.pop("pixels") + + # Pattern-based renaming + for key in list(features.keys()): + if key.startswith("env_state."): + suffix = key[len("env_state."):] + features[f"observation.{suffix}"] = features.pop(key) + # ... + + return features +``` + +**Key principles:** + +- Use `features.pop(old_key)` to remove and get the old feature +- Use `features[new_key] = old_feature` to add the renamed feature +- Always return the modified features dictionary +- Document transformations clearly in the docstring + +### Using overrides + +You can override step parameters at load-time using `overrides`. This is handy for non-serializable objects or site-specific settings. It works both in policy factories and with `DataProcessorPipeline.from_pretrained(...)`. + +**Foundational model adaptation**: This is particularly useful when working with foundational pretrained policies where you rarely have access to the original training statistics. You can inject your own dataset statistics to adapt the normalizer to your specific robot or environment data. + +Example: during policy evaluation on the robot, override the device and rename map. +Use this to run a policy trained on CUDA on a CPU-only robot, or to remap camera keys when the robot uses different names than the dataset. + +Direct usage with `from_pretrained`: + +```python +from lerobot.processor import RobotProcessorPipeline + +# Load a foundational policy trained on diverse robot data +# but adapt normalization to your specific robot/environment +new_stats = LeRobotDataset(repo_id="username/my-dataset").meta.stats +processor = RobotProcessorPipeline.from_pretrained( + "huggingface/foundational-robot-policy", # Pretrained foundation model + overrides={ + "normalizer_processor": {"stats": new_stats}, # Inject your robot's statistics + "device_processor": {"device": "cuda:0"}, # registry name for registered steps + "rename_processor": {"rename_map": robot_key_map}, # Map your robot's observation keys + # ... + }, +) +``` + +## Best Practices + +Based on analysis of all LeRobot processor implementations, here are the key patterns and practices: + +### 1. **Safe Data Handling** + +Always create copies of input data to avoid unintended side effects. Use `transition.copy()` and `observation.copy()` rather than modifying data in-place. This prevents your processor from accidentally affecting other components in the pipeline. + +Check for required data before processing and handle missing data gracefully. If your processor expects certain keys (like `"pixels"` for image processing), validate their presence first. For optional data, use safe access patterns like `transition.get()` and handle `None` values appropriately. + +When data validation fails, provide clear, actionable error messages that help users understand what went wrong and how to fix it. + +### 2. **Choose Appropriate Base Classes** + +LeRobot provides specialized base classes that reduce boilerplate code and ensure consistency. Use `ObservationProcessorStep` when you only need to modify observations, `ActionProcessorStep` for action-only processing, and `RobotActionProcessorStep` specifically for dictionary-based robot actions. + +Only inherit directly from `ProcessorStep` when you need full control over the entire transition or when processing multiple transition components simultaneously. The specialized base classes handle the transition management for you and provide type safety. + +### 3. **Registration and Naming** + +Register your processors with descriptive, namespaced names using `@ProcessorStepRegistry.register()`. Use organization prefixes like `"robotics_lab/safety_clipper"` or `"acme_corp/vision_enhancer"` to avoid naming conflicts. Avoid generic names like `"processor"` or `"step"` that could clash with other implementations. + +Good registration makes your processors discoverable and enables clean serialization/deserialization when saving and loading pipelines. + +### 4. **State Management Patterns** + +Distinguish between configuration parameters (JSON-serializable values) and internal state (tensors, buffers). Use dataclass fields with `init=False, repr=False` for internal state that shouldn't appear in the constructor or string representation. + +Implement the `reset()` method to clear internal state between episodes. This is crucial for stateful processors that accumulate data over time, like moving averages or temporal filters. + +Remember that `get_config()` should only return JSON-serializable configuration, while `state_dict()` handles tensor state separately. + +### 5. **Input Validation and Error Handling** + +Validate input types and shapes before processing. Check tensor properties like `dtype` and dimensions to ensure compatibility with your algorithms. For robot actions, verify that required pose components or joint values are present and within expected ranges. + +Use early returns for edge cases where no processing is needed. Provide clear, descriptive error messages that include the expected vs. actual data types or shapes. This makes debugging much easier for users. + +### 6. **Device and Dtype Awareness** + +Design your processors to automatically adapt to the device and dtype of input tensors. Internal tensors (like normalization statistics) should match the input tensor's device and dtype to ensure compatibility with multi-GPU training, mixed precision, and distributed setups. + +Implement a `to()` method that moves your processor's internal state to the specified device. Check device/dtype compatibility at runtime and automatically migrate internal state when needed. This pattern enables seamless operation across different hardware configurations without manual intervention. + +## Conclusion + +You now have all the tools to implement custom processors in LeRobot! The key steps are: + +1. **Define your processor** as a dataclass with the required methods (`__call__`, `get_config`, `state_dict`, `load_state_dict`, `reset`, `transform_features`) +2. **Register it** using `@ProcessorStepRegistry.register("name")` for discoverability +3. **Integrate it** into a `DataProcessorPipeline` with other processing steps +4. **Use base classes** like `ObservationProcessorStep` when possible to reduce boilerplate +5. **Implement device/dtype awareness** to support multi-GPU and mixed precision setups + +The processor system is designed to be modular and composable, allowing you to build complex data processing pipelines from simple, focused components. Whether you're preprocessing sensor data for training or post-processing model outputs for robot execution, custom processors give you the flexibility to handle any data transformation your robotics application requires. + +Key principles for robust processors: + +- **Device/dtype adaptation**: Internal tensors should match input tensors +- **Clear error messages**: Help users understand what went wrong +- **Base class usage**: Leverage specialized base classes to reduce boilerplate +- **Feature contracts**: Declare data structure changes with `transform_features()` + +Start simple, test thoroughly, and ensure your processors work seamlessly across different hardware configurations! diff --git a/docs/source/introduction_processors.mdx b/docs/source/introduction_processors.mdx new file mode 100644 index 000000000..308edbb3b --- /dev/null +++ b/docs/source/introduction_processors.mdx @@ -0,0 +1,314 @@ +# Introduction to Processors + +In robotics, there's a fundamental mismatch between the data that robots and humans produce and what machine learning models expect. +Robots output raw sensor data like camera images and joint positions that need normalization, batching, and device placement before models can process them. +Language instructions from humans must be tokenized into numerical representations, and different robots use different coordinate systems that need standardization. + +The challenge extends to model outputs as well. +Models might output end-effector positions while robots need joint-space commands, or teleoperators produce relative movements while robots expect absolute commands. +Model predictions are often normalized and need conversion back to real-world scales. + +Cross-domain translation adds another layer of complexity. +Training data from one robot setup needs adaptation for deployment on different hardware, models trained with specific camera configurations must work with new arrangements, and datasets with different naming conventions need harmonization. + +**That's where processors come in.** They serve as universal translators that bridge these gaps, ensuring seamless data flow from sensors to models to actuators. +Processors handle all the preprocessing and postprocessing steps needed to convert raw environment data into model-ready inputs and vice versa. + +This means that your favorite policy can be used like this: + +```python +import torch + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.policies.factory import make_pre_post_processors +from lerobot.policies.your_policy import YourPolicy +from lerobot.processor.pipeline import RobotProcessorPipeline, PolicyProcessorPipeline +dataset = LeRobotDataset("hf_user/dataset", episodes=[0]) +sample = dataset[10] + +model = YourPolicy.from_pretrained( + "hf_user/model", +) +model.eval() +model.to("cuda") +preprocessor, postprocessor = make_pre_post_processors(model.config, pretrained_path="hf_user/model", dataset_stats=dataset.meta.stats) + +preprocessed_sample = preprocessor(sample) +action = model.select_action(preprocessed_sample) +postprocessed_action = postprocessor(action) +``` + +## What are Processors? + +In robotics, data comes in many forms: images from cameras, joint positions from sensors, text instructions from users, and more. Each type of data requires specific transformations before a model can use it effectively. Models need this data to be: + +- **Normalized**: Scaled to appropriate ranges for neural network processing +- **Batched**: Organized with proper dimensions for batch processing +- **Tokenized**: Text converted to numerical representations +- **Device-placed**: Moved to the right hardware (CPU/GPU) +- **Type-converted**: Cast to appropriate data types + +Processors handle these transformations through composable, reusable steps that can be chained together into pipelines. Think of them as a modular assembly line where each station performs a specific transformation on your data. + +## Core Concepts + +### EnvTransition: The Universal Data Container + +The `EnvTransition` is the fundamental data structure that flows through all processors. +It's a typed dictionary that represents a complete robot-environment interaction: + +- **OBSERVATION**: All sensor data (images, states, proprioception) +- **ACTION**: The action to execute or that was executed +- **REWARD**: Reinforcement learning signal +- **DONE/TRUNCATED**: Episode boundary indicators +- **INFO**: Arbitrary metadata +- **COMPLEMENTARY_DATA**: Task descriptions, indices, padding flags, inter-step data + +### ProcessorStep: The Building Block + +A `ProcessorStep` is a single transformation unit that processes transitions. It's an abstract base class with two required methods: + +```python +from lerobot.processor import ProcessorStep, EnvTransition + +class MyProcessorStep(ProcessorStep): + """Example processor step - inherit and implement abstract methods.""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Transform the transition - REQUIRED abstract method.""" + # Your processing logic here + return transition + + def transform_features(self, features): + """Declare how this step transforms feature shapes/types - REQUIRED abstract method.""" + return features # Most processors return features unchanged +``` + +`__call__` is the core of your processor step. It takes an `EnvTransition` and returns a modified `EnvTransition`. + +`transform_features` is used to declare how this step transforms feature shapes/types. + +### DataProcessorPipeline: The Generic Orchestrator + +The `DataProcessorPipeline[TInput, TOutput]` chains multiple `ProcessorStep` instances together: + +```python +from lerobot.processor import RobotProcessorPipeline, PolicyProcessorPipeline + +# For robot hardware (unbatched data) +robot_processor = RobotProcessorPipeline[RobotAction, RobotAction]( + steps=[step1, step2, step3], + name="robot_pipeline" +) + +# For model training/inference (batched data) +policy_processor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=[step1, step2, step3], + name="policy_pipeline" +) +``` + +## RobotProcessorPipeline vs PolicyProcessorPipeline + +The key distinction is in the data structures they handle: + +| Aspect | RobotProcessorPipeline | PolicyProcessorPipeline | +| --------------- | -------------------------------------------- | ---------------------------------------- | +| **Input** | `dict[str, Any]` - Individual robot values | `dict[str, Any]` - Batched tensors | +| **Output** | `dict[str, Any]` - Individual robot commands | `torch.Tensor` - Policy predictions | +| **Use Case** | Real-time robot control | Model training/inference | +| **Data Format** | Unbatched, heterogeneous | Batched, homogeneous | +| **Examples** | `{"joint_1": 0.5}` | `{"observation.state": tensor([[0.5]])}` | + +**Use `RobotProcessorPipeline`** for robot hardware interfaces: + +```python +# Robot data structures: dict[str, Any] for observations and actions +robot_obs: dict[str, Any] = { + "joint_1": 0.5, # Individual joint values + "joint_2": -0.3, + "camera_0": image_array # Raw camera data +} + +robot_action: dict[str, Any] = { + "joint_1": 0.2, # Target joint positions + "joint_2": 0.1, + "gripper": 0.8 +} +``` + +**Use `PolicyProcessorPipeline`** for model training and batch processing: + +```python +# Policy data structures: batch dicts and tensors +policy_batch: dict[str, Any] = { + "observation.state": torch.tensor([[0.5, -0.3]]), # Batched states + "observation.images.camera0": torch.tensor(...), # Batched images + "action": torch.tensor([[0.2, 0.1, 0.8]]) # Batched actions +} + +policy_action: torch.Tensor = torch.tensor([[0.2, 0.1, 0.8]]) # Model output tensor +``` + +## Converter Functions + +LeRobot provides converter functions to bridge different data formats in `lerobot.processor.converters`. These functions handle the crucial translations between robot hardware data structures, policy model formats, and the internal `EnvTransition` representation that flows through processor pipelines. + +| Category | Function | Description | +| ------------------------------ | ----------------------------- | ------------------------------- | +| **Robot Hardware Converters** | `robot_action_to_transition` | Robot dict → EnvTransition | +| | `observation_to_transition` | Robot obs → EnvTransition | +| | `transition_to_robot_action` | EnvTransition → Robot dict | +| **Policy/Training Converters** | `batch_to_transition` | Batch dict → EnvTransition | +| | `transition_to_batch` | EnvTransition → Batch dict | +| | `policy_action_to_transition` | Policy tensor → EnvTransition | +| | `transition_to_policy_action` | EnvTransition → Policy tensor | +| **Utilities** | `create_transition` | Build transitions with defaults | +| | `identity_transition` | Pass-through converter | + +The key insight is that **robot hardware converters** work with individual values and dictionaries, while **policy/training converters** work with batched tensors and model outputs. The converter functions automatically handle the structural differences, so your processor steps can focus on the core transformations without worrying about data format compatibility. + +## Processor Examples + +The following examples demonstrate real-world processor configurations for policy training and inference. + +Here is an example processor for policy training and inference: + +```python +# Training data preprocessing (optimized order for GPU performance) +training_preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=[ + RenameObservationsProcessorStep(rename_map={}), # Standardize keys + AddBatchDimensionProcessorStep(), # Add batch dims + TokenizerProcessorStep(tokenizer_name="...", ...), # Tokenize language + DeviceProcessorStep(device="cuda"), # Move to GPU first + NormalizerProcessorStep(features=..., stats=...), # Normalize on GPU + ] +) + +# Model output postprocessing +training_postprocessor = PolicyProcessorPipeline[torch.Tensor, torch.Tensor]( + steps=[ + DeviceProcessorStep(device="cpu"), # Move to CPU + UnnormalizerProcessorStep(features=..., stats=...), # Denormalize + ] + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, +) +``` + +### An interaction between a robot and a policy with processors + +The most common real-world scenario combines both pipeline types robot hardware generates observations that need policy processing, and policy outputs need robot-compatible postprocessing: + +```python +# Real deployment: Robot sensors → Model → Robot commands +with torch.no_grad(): + while not done: + raw_obs = robot.get_observation() # dict[str, Any] + + # Add your robot observation to policy observation processor + + policy_input = policy_preprocessor(raw_obs) # Batched dict + + policy_output = policy.select_action(policy_input) # Policy tensor + + policy_action = policy_postprocessor(policy_output) + + # Add your robot action to policy action processor + + robot.send_action(policy_action) +``` + +## Feature Contracts: Shape and Type Transformation + +Processors don't just transform data - they can also **change the data structure itself**. The `transform_features()` method declares these changes, which is crucial for dataset recording and policy creation. + +### Why Feature Contracts Matter + +When building datasets or policies, LeRobot needs to know: + +- **What data fields will exist** after processing +- **What shapes and types** each field will have +- **How to configure models** for the expected data structure + +```python +# Example: A processor that adds velocity to observations +class VelocityProcessor(ObservationProcessorStep): + def observation(self, obs): + new_obs = obs.copy() + if "observation.state" in obs: + # concatenate computed velocity field to the state + new_obs["observation.state"] = self._compute_velocity(obs["observation.state"]) + return new_obs + + def transform_features(self, features): + """Declare the new velocity field we're adding.""" + state_feature = features[PipelineFeatureType.OBSERVATION].get("observation.state") + if state_feature: + double_shape = (state_feature.shape[0] * 2,) if state_feature.shape else (2,) + features[PipelineFeatureType.OBSERVATION]["observation.state"] = PolicyFeature( + type=FeatureType.STATE, shape=double_shape + ) + return features +``` + +### Feature Specification Functions + +`create_initial_features()` and `aggregate_pipeline_dataset_features()` solve a critical dataset creation problem: determining the exact final data structure before any data is processed. +Since processor pipelines can add new features (like velocity fields), change tensor shapes (like cropping images), or rename keys, datasets need to know the complete output specification upfront to allocate proper storage and define schemas. +These functions work together by starting with robot hardware specifications (`create_initial_features()`) then simulating the entire pipeline transformation (`aggregate_pipeline_dataset_features()`) to compute the final feature dictionary that gets passed to `LeRobotDataset.create()`, ensuring perfect alignment between what processors output and what datasets expect to store. + +```python +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features + +# Start with robot's raw features +initial_features = create_initial_features( + observation=robot.observation_features, # {"joint_1.pos": float, "camera_0": (480,640,3)} + action=robot.action_features # {"joint_1.pos": float, "gripper.pos": float} +) + +# Apply processor pipeline to compute final features +final_features = aggregate_pipeline_dataset_features( + pipeline=my_processor_pipeline, + initial_features=initial_features, + use_videos=True +) + +# Use for dataset creation +dataset = LeRobotDataset.create( + repo_id="my_dataset", + features=final_features, # Knows exactly what data to expect + ... +) +``` + +## Common Processor Steps + +LeRobot provides many registered processor steps. Here are the most commonly used core processors: + +### Essential Processors + +- **`normalizer_processor`**: Normalize observations/actions using dataset statistics (mean/std or min/max) +- **`device_processor`**: Move tensors to CPU/GPU with optional dtype conversion +- **`to_batch_processor`**: Add batch dimensions to transitions for model compatibility +- **`rename_observations_processor`**: Rename observation keys using mapping dictionaries +- **`tokenizer_processor`**: Tokenize natural language task descriptions into tokens and attention masks + +### Next Steps + +- **[Implement Your Own Processor](implement_your_own_processor.mdx)** - Create custom processor steps +- **[Debug Your Pipeline](debug_processor_pipeline.mdx)** - Troubleshoot and optimize pipelines +- **[Processors for Robots and Teleoperators](processors_robots_teleop.mdx)** - Real-world integration patterns + +## Summary + +Processors solve the data translation problem in robotics by providing: + +- **Modular transformations**: Composable, reusable processing steps +- **Type safety**: Generic pipelines with compile-time checking +- **Performance optimization**: GPU-accelerated operations +- **Robot/Policy distinction**: Separate pipelines for different data structures +- **Comprehensive ecosystem**: 30+ registered processors for common tasks + +The key insight: `RobotProcessorPipeline` handles unbatched robot hardware data, while `PolicyProcessorPipeline` handles batched model data. Choose the right tool for your data structure! diff --git a/docs/source/phone_teleop.mdx b/docs/source/phone_teleop.mdx new file mode 100644 index 000000000..71d5457fb --- /dev/null +++ b/docs/source/phone_teleop.mdx @@ -0,0 +1,192 @@ +# Phone + +Use your phone (iOS or Android) to control your robot. + +**In this guide you'll learn:** + +- How to connect an iOS/Android phone +- How phone pose is mapped to robot end‑effector (EE) targets +- How to tweak safety limits, gripper control, and IK settings + +To use phone to control your robot, install the relevant dependencies with: + +```bash +pip install lerobot[phone] +``` + +## Get started + +### Supported platforms + +- iOS: Uses the HEBI Mobile I/O app (ARKit pose + buttons). Download the app first, open it and the examples will discover it on your network and stream the phone pose and inputs. +- Android: Uses the `teleop` package (WebXR). When you start the Python process, it prints a local URL. Open the link on your phone, tap Start, then use Move to stream pose. + +Links: + +- Android WebXR library: [`teleop` on PyPI](https://pypi.org/project/teleop/) +- iOS app: [HEBI Mobile I/O](https://docs.hebi.us/tools.html#mobile-io) + +### Phone orientation and controls + +- Orientation: hold the phone with the screen facing up and the top edge pointing in the same direction as the robot gripper. This ensures calibration aligns the phone’s frame with the robot frame so motion feels natural, see the image below for reference. +- Enable/disable: + - iOS: Hold `B1` to enable teleoperation, release to stop. The first press captures a reference pose. + - Android: Press and hold the `Move` button, release to stop. The first press captures a reference pose. +- Gripper control: + - iOS: Analog input `A3` controls the gripper as velocity input. + - Android: Buttons `A` and `B` act like increment/decrement (A opens, B closes). You can tune velocity in the `GripperVelocityToJoint` step. + +Phone teleop orientation + +### Step 1: Choose the platform + +Modify the examples to use `PhoneOS.IOS` or `PhoneOS.ANDROID` in `PhoneConfig`. The API is identical across platforms, only the input source differs. All examples are under `examples/` and have `phone_so100_*.py` variants. + +Teleoperation example: + +```36:43:examples/phone_so100_teleop.py +from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS + +teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID +teleop_device = Phone(teleop_config) +``` + +### Step 2: Connect and calibrate + +When `Phone(teleop_config)` is created and `connect()` is called, calibration is prompted automatically. Hold the phone in the orientation described above, then: + +- iOS: press and hold `B1` to capture the reference pose. +- Android: press `Move` button on the WebXR page to capture the reference pose. + +Why calibrate? We capture the current pose so subsequent poses are expressed in a robot aligned frame. When you again press the button to enable control, the position is recaptured to avoid drift when your phone is repositioned while it was disabled. + +### Step 3: Run an example + +Run on of the examples scripts to teleoperate, record a dataset, replay a dataset or evaluate a policy. + +All scripts assume you configured your robot (e.g., SO-100 follower) and set the correct serial port. + +Additionally you need to **copy the urdf of the robot to the examples folder**. For the examples in this tutorial (Using SO100/SO101) it is highly recommended to use the urdf in the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf) + +- Run this example to teleoperate: + + ```bash + python examples/phone_to_so100/teleoperate.py + ``` + +After running the example: + +- Android: after starting the script, open the printed local URL on your phone, tap Start, then press and hold Move. +- iOS: open HEBI Mobile I/O first; B1 enables motion. A3 controls the gripper. + +Additionally you can customize mapping or safety limits by editing the processor steps shown in the examples. You can also remap inputs (e.g., use a different analog input) or adapt the pipeline to other robots (e.g., LeKiwi) by modifying the input and kinematics steps. More about this in the [Processors for Robots and Teleoperators](./processors_robots_teleop.mdx) guide. + +- Run this example to record a dataset, which saves absolute end effector observations and actions: + + ```bash + python examples/phone_to_so100/record.py + ``` + +- Run this example to replay recorded episodes: + + ```bash + python examples/phone_to_so100/replay.py + ``` + +- Run this example to evaluate a pretrained policy: + + ```bash + python examples/phone_to_so100/evaluate.py + ``` + +### Important pipeline steps and options + +- Kinematics are used in multiple steps. We use [Placo](https://github.com/Rhoban/placo) which is a wrapper around Pinocchio for handling our kinematics. We construct the kinematics object by passing the robot's URDF and target frame. We set `target_frame_name` to the gripper frame. + + ```examples/phone_to_so100/teleoperate.py + kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), + ) + + ``` + +- The `MapPhoneActionToRobotAction` step converts the calibrated phone pose and inputs into target deltas and gripper commands, below is shown what the step outputs. + + ```src/lerobot/teleoperators/phone/phone_processor.py + action["enabled"] = enabled + action["target_x"] = -pos[1] if enabled else 0.0 + action["target_y"] = pos[0] if enabled else 0.0 + action["target_z"] = pos[2] if enabled else 0.0 + action["target_wx"] = rotvec[1] if enabled else 0.0 + action["target_wy"] = rotvec[0] if enabled else 0.0 + action["target_wz"] = -rotvec[2] if enabled else 0.0 + action["gripper_vel"] = gripper_vel # Still send gripper action when disabled + ``` + +- The `EEReferenceAndDelta` step converts target deltas to an absolute desired EE pose, storing a reference on enable, the `end_effector_step_sizes` are the step sizes for the EE pose and can be modified to change the motion speed. + + ```examples/phone_to_so100/teleoperate.py + EEReferenceAndDelta( + kinematics=kinematics_solver, + end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, + motor_names=list(robot.bus.motors.keys()), + use_latched_reference=True, + ), + ``` + +- The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` and `max_ee_twist_step_rad` are the step limits for the EE pose and can be modified to change the safety limits. + + ```examples/phone_to_so100/teleoperate.py + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, + max_ee_step_m=0.10, + max_ee_twist_step_rad=0.50, + ) + ``` + +- The `GripperVelocityToJoint` step turns a velocity‑like gripper input into absolute gripper position using the current measured state. The `speed_factor` is the factor by which the velocity is multiplied. + + ```examples/phone_to_so100/teleoperate.py + GripperVelocityToJoint(speed_factor=20.0) + ``` + +#### Different IK initial guesses + +We use different IK initial guesses in the kinematic steps. As initial guess either the current measured joints or the previous IK solution is used. + +- Closed loop (used in record/eval): sets `initial_guess_current_joints=True` so IK starts from the measured joints each frame. + + ```examples/phone_to_so100/record.py + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=True, # closed loop + ) + ``` + +- Open loop (used in replay): sets `initial_guess_current_joints=False` so IK continues from the previous IK solution rather than the measured state. This preserves action stability when we replay without feedback. + + ```examples/phone_to_so100/replay.py + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=False, # open loop + ) + ``` + +### Pipeline steps explained + +- MapPhoneActionToRobotAction: converts calibrated phone pose and inputs into target deltas and a gripper command. Motion is gated by an enable signal (B1 on iOS, Move on Android). +- EEReferenceAndDelta: latches a reference EE pose on enable and combines it with target deltas to produce an absolute desired EE pose each frame. When disabled, it keeps sending the last commanded pose. +- EEBoundsAndSafety: clamps the EE pose to a workspace and rate‑limits jumps for safety. Also declares `action.ee.*` features. +- InverseKinematicsEEToJoints: turns an EE pose into joint positions with IK. `initial_guess_current_joints=True` is recommended for closed‑loop control; set `False` for open‑loop replay for stability. +- GripperVelocityToJoint: integrates a velocity‑like gripper input into an absolute gripper position using the current measured state. +- ForwardKinematicsJointsToEE: computes `observation.state.ee.*` from observed joints for logging and training on EE state. + +### Troubleshooting + +- iOS not discovered: ensure HEBI Mobile I/O is open and your laptop/phone are on the same network. +- Android URL not reachable: check local you used `https` instead of `http`, use the exact IP printed by the script and allow your browser to enter and ignore the certificate issue. +- Motion feels inverted: adjust the sign flips in `MapPhoneActionToRobotAction` or swap axes to match your setup. diff --git a/docs/source/processors_robots_teleop.mdx b/docs/source/processors_robots_teleop.mdx new file mode 100644 index 000000000..c4fcbe03d --- /dev/null +++ b/docs/source/processors_robots_teleop.mdx @@ -0,0 +1,151 @@ +# Processors for Robots and Teleoperators + +This guide shows how to build and modify processing pipelines that connect teleoperators (e.g., phone) to robots and datasets. Pipelines standardize conversions between different action/observation spaces so you can swap teleops and robots without rewriting glue code. + +We use the Phone to SO‑100 follower examples for concreteness, but the same patterns apply to other robots. + +**What you'll learn** + +- Absolute vs. relative EE control: What each means, trade‑offs, and how to choose for your task. +- Three-pipeline pattern: How to map teleop actions → dataset actions → robot commands, and robot observations → dataset observations. +- Adapters (`to_transition` / `to_output`): How these convert raw dicts to `EnvTransition` and back to reduce boilerplate. +- Dataset feature contracts: How steps declare features via `transform_features(...)`, and how to aggregate/merge them for recording. +- Choosing a representation: When to store joints, absolute EE poses, or relative EE deltas—and how that affects training. +- Pipeline customization guidance: How to swap robots/URDFs safely and tune bounds, step sizes, and options like IK initialization. + +### Absolute vs relative EE control + +The examples in this guide use absolute end effector (EE) poses because they are easy to reason about. In practice, relative EE deltas or joint position are often preferred as learning features. + +With processors, you choose the learning features you want to use for your policy. This could be joints positions/velocities, absolute EE, or relative EE positions. You can also choose to store other features, such as joint torques, motor currents, etc. + +## Three pipelines + +We often compose three pipelines. Depending on your setup, some can be empty if action and observation spaces already match. +Each of these pipelines handle different conversions between different action and observation spaces. Below is a quick explanation of each pipeline. + +1. Pipeline 1: Teleop action space → dataset action space (phone pose → EE targets) +2. Pipeline 2: Dataset action space → robot command space (EE targets → joints) +3. Pipeline 3: Robot observation space → dataset observation space (joints → EE pose) + +Below is an example of the three pipelines that we use in the phone to SO-100 follower examples: + +```69:90:examples/phone_so100_record.py +phone_to_robot_ee_pose_processor = RobotProcessorPipeline[RobotAction, RobotAction]( # teleop -> dataset action + steps=[ + MapPhoneActionToRobotAction(platform=teleop_config.phone_os), + EEReferenceAndDelta( + kinematics=kinematics_solver, end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, motor_names=list(robot.bus.motors.keys()), + ), + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.20, max_ee_twist_step_rad=0.50, + ), + GripperVelocityToJoint(), + ], + to_transition=robot_action_to_transition, + to_output=transition_to_robot_action, +) + +robot_ee_to_joints_processor = RobotProcessorPipeline[RobotAction, RobotAction]( # dataset action -> robot + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()), initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_to_transition, + to_output=transition_to_robot_action, +) + +robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation]( # robot obs -> dataset obs + steps=[ + ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())) + ], + to_transition=observation_to_transition, + to_output=transition_to_observation, +) +``` + +## Why to_transition / to_output + +To convert from robot/teleoperator to pipeline and back, we use the `to_transition` and `to_output` pipeline adapters. +They standardize conversions to reduce boilerplate code, and form the bridge between the robot and teleoperators raw dictionaries and the pipeline’s `EnvTransition` format. +In the phone to SO-100 follower examples we use the following adapters: + +- `robot_action_to_transition`: transforms the teleop action dict to a pipeline transition. +- `transition_to_robot_action`: transforms the pipeline transition to a robot action dict. +- `observation_to_transition`: transforms the robot observation dict to a pipeline transition. +- `transition_to_observation`: transforms the pipeline transition to a observation dict. + +Checkout [src/lerobot/processor/converters.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/processor/converters.py) for more details. + +## Dataset feature contracts + +Dataset features are determined by the keys saved in the dataset. Each step can declare what features it modifies in a contract called `transform_features(...)`. Once you build a processor, the processor can then aggregate all of these features with `aggregate_pipeline_dataset_features()` and merge multiple feature dicts with `combine_feature_dicts(...)`. + +Below is and example of how we declare features with the `transform_features` method in the phone to SO-100 follower examples: + +```src/lerobot/robots/so100_follower/robot_kinematic_processor.py + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We only use the ee pose in the dataset, so we don't need the joint positions + for n in self.motor_names: + features[PipelineFeatureType.ACTION].pop(f"{n}.pos", None) + # We specify the dataset features of this step that we want to be stored in the dataset + for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]: + features[PipelineFeatureType.ACTION][f"ee.{k}"] = PolicyFeature( + type=FeatureType.STATE, shape=(1,) + ) + return features +``` + +Here we declare what PolicyFeatures we modify in this step, so we know what features we can expect when we run the processor. These features can then be aggregated and used to create the dataset features. + +Below is an example of how we aggregate and merge features in the phone to SO-100 record example: + +```121:145:examples/phone_so100_record.py +features=combine_feature_dicts( + # Run the feature contract of the pipelines + # This tells you how the features would look like after the pipeline steps + aggregate_pipeline_dataset_features( + pipeline=phone_to_robot_ee_pose_processor, + initial_features=create_initial_features(action=phone.action_features), # <- Action features we can expect, these come from our teleop device (phone) and action processor + use_videos=True, + ), + aggregate_pipeline_dataset_features( + pipeline=robot_joints_to_ee_pose, + initial_features=create_initial_features(observation=robot.observation_features), # <- Observation features we can expect, these come from our robot and observation processor + use_videos=True, + patterns=["observation.state.ee"], # <- Here you could optionally filter the features we want to store in the dataset, with a specific pattern + + ), + ), +``` + +How it works: + +- `aggregate_pipeline_dataset_features(...)`: applies `transform_features` across the pipeline and filters by patterns (images included when `use_videos=True`, and state features included when `patterns` is specified). +- `combine_feature_dicts(...)`: combine multiple feature dicts. +- Recording with `record_loop(...)` uses `build_dataset_frame(...)` to build frames consistent with `dataset.features` before we call `add_frame(...)` to add the frame to the dataset. + +## Guidance when customizing robot pipelines + +You can store any of the following features as your action/observation space: + +- Joint positions +- Absolute EE poses +- Relative EE deltas +- Other features: joint velocity, torques, etc. + +Pick what you want to use for your policy action and observation space and configure/modify the pipelines and steps accordingly. + +### Different robots + +- You can easily reuse pipelines, for example to use another robot with phone teleop, modify the examples and swap the robot `RobotKinematics` (URDF) and `motor_names` to use your own robot with Phone teleop. Additionally you should ensure `target_frame_name` points to your gripper/wrist. + +### Safety first + +- When changing pipelines, start with tight bounds, implement safety steps when working with real robots. +- Its advised to start with simulation first and then move to real robots. + +Thats it! We hope this guide helps you get started with customizing your robot pipelines, If you run into any issues at any point, jump into our [Discord community](https://discord.com/invite/s3KuuzsPFb) for support. diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index f2de79db8..7f3fad36c 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -27,6 +27,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetad from lerobot.datasets.utils import dataset_to_policy_features from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy +from lerobot.policies.factory import make_pre_post_processors def main(): @@ -56,9 +57,10 @@ def main(): cfg = DiffusionConfig(input_features=input_features, output_features=output_features) # We can now instantiate our policy with this config and the dataset stats. - policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats) + policy = DiffusionPolicy(cfg) policy.train() policy.to(device) + preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats) # Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames # which can differ for inputs, outputs and rewards (if there are some). @@ -99,7 +101,7 @@ def main(): done = False while not done: for batch in dataloader: - batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + batch = preprocessor(batch) loss, _ = policy.forward(batch) loss.backward() optimizer.step() @@ -114,6 +116,8 @@ def main(): # Save a policy checkpoint. policy.save_pretrained(output_directory) + preprocessor.save_pretrained(output_directory) + postprocessor.save_pretrained(output_directory) if __name__ == "__main__": diff --git a/examples/5_train_with_streaming.py b/examples/5_train_with_streaming.py index 17818410d..93d13535f 100644 --- a/examples/5_train_with_streaming.py +++ b/examples/5_train_with_streaming.py @@ -30,6 +30,7 @@ from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.utils import dataset_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.act.modeling_act import ACTPolicy +from lerobot.policies.factory import make_pre_post_processors def main(): @@ -60,9 +61,10 @@ def main(): # We can now instantiate our policy with this config and the dataset stats. cfg = ACTConfig(input_features=input_features, output_features=output_features) - policy = ACTPolicy(cfg, dataset_stats=dataset_metadata.stats) + policy = ACTPolicy(cfg) policy.train() policy.to(device) + preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats) # Delta timestamps are used to (1) augment frames used during training and (2) supervise the policy. # Here, we use delta-timestamps to only provide ground truth actions for supervision @@ -89,13 +91,7 @@ def main(): done = False while not done: for batch in dataloader: - batch = { - k: (v.type(torch.float32) if isinstance(v, torch.Tensor) and v.dtype != torch.bool else v) - for k, v in batch.items() - } - batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} - - # batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + batch = preprocessor(batch) loss, _ = policy.forward(batch) loss.backward() optimizer.step() @@ -110,6 +106,8 @@ def main(): # Save a policy checkpoint. policy.save_pretrained(output_directory) + preprocessor.save_pretrained(output_directory) + postprocessor.save_pretrained(output_directory) if __name__ == "__main__": diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 57fb62e10..3dbb10f56 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -1,6 +1,24 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import hw_to_dataset_features from lerobot.policies.act.modeling_act import ACTPolicy +from lerobot.policies.factory import make_pre_post_processors +from lerobot.processor import make_default_processors from lerobot.record import record_loop from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.utils.control_utils import init_keyboard_listener @@ -11,12 +29,16 @@ NUM_EPISODES = 2 FPS = 30 EPISODE_TIME_SEC = 60 TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" +HF_DATASET_ID = "/" -# Create the robot and teleoperator configurations +# Create the robot configuration & robot robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi") + robot = LeKiwiClient(robot_config) -policy = ACTPolicy.from_pretrained("/") +# Create policy +policy = ACTPolicy.from_pretrained(HF_MODEL_ID) # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, "action") @@ -25,7 +47,7 @@ dataset_features = {**action_features, **obs_features} # Create the dataset dataset = LeRobotDataset.create( - repo_id="/", + repo_id=HF_DATASET_ID, fps=FPS, features=dataset_features, robot_type=robot.name, @@ -33,33 +55,52 @@ dataset = LeRobotDataset.create( image_writer_threads=4, ) +# Build Policy Processors +preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=policy, + pretrained_path=HF_MODEL_ID, + dataset_stats=dataset.meta.stats, + # The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility. + preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}}, +) + +# Connect the robot # To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi` robot.connect() -_init_rerun(session_name="recording") +# TODO(Steven): Update this example to use pipelines +teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() +# Initialize the keyboard listener and rerun visualization listener, events = init_keyboard_listener() +_init_rerun(session_name="lekiwi_evaluate") if not robot.is_connected: raise ValueError("Robot is not connected!") +print("Starting evaluate loop...") recorded_episodes = 0 while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}") - # Run the policy inference loop + # Main record loop record_loop( robot=robot, events=events, fps=FPS, policy=policy, + preprocessor=preprocessor, # Pass the pre and post policy processors + postprocessor=postprocessor, dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, ) - # Logic for reset env + # Reset the environment if not stopping or re-recording if not events["stop_recording"] and ( (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] ): @@ -71,6 +112,9 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, ) if events["rerecord_episode"]: @@ -80,11 +124,12 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: dataset.clear_episode_buffer() continue + # Save episode dataset.save_episode() recorded_episodes += 1 -# Upload to hub and clean up -dataset.push_to_hub() - +# Clean up +log_say("Stop recording") robot.disconnect() listener.stop() +dataset.push_to_hub() diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 11a716761..f5d109d5d 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -1,5 +1,22 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import hw_to_dataset_features +from lerobot.processor import make_default_processors from lerobot.record import record_loop from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient @@ -9,21 +26,26 @@ from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import _init_rerun -NUM_EPISODES = 3 +NUM_EPISODES = 2 FPS = 30 EPISODE_TIME_SEC = 30 RESET_TIME_SEC = 10 TASK_DESCRIPTION = "My task description" +HF_REPO_ID = "/" # Create the robot and teleoperator configurations robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi") leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm") keyboard_config = KeyboardTeleopConfig() +# Initialize the robot and teleoperator robot = LeKiwiClient(robot_config) leader_arm = SO100Leader(leader_arm_config) keyboard = KeyboardTeleop(keyboard_config) +# TODO(Steven): Update this example to use pipelines +teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() + # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, "action") obs_features = hw_to_dataset_features(robot.observation_features, "observation") @@ -31,7 +53,7 @@ dataset_features = {**action_features, **obs_features} # Create the dataset dataset = LeRobotDataset.create( - repo_id="/", + repo_id=HF_REPO_ID, fps=FPS, features=dataset_features, robot_type=robot.name, @@ -39,23 +61,25 @@ dataset = LeRobotDataset.create( image_writer_threads=4, ) +# Connect the robot and teleoperator # To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi` robot.connect() leader_arm.connect() keyboard.connect() +# Initialize the keyboard listener and rerun visualization +listener, events = init_keyboard_listener() _init_rerun(session_name="lekiwi_record") -listener, events = init_keyboard_listener() - if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected: - raise ValueError("Robot, leader arm of keyboard is not connected!") + raise ValueError("Robot or teleop is not connected!") +print("Starting record loop...") recorded_episodes = 0 while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: log_say(f"Recording episode {recorded_episodes}") - # Run the record loop + # Main record loop record_loop( robot=robot, events=events, @@ -65,9 +89,12 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, ) - # Logic for reset env + # Reset the environment if not stopping or re-recording if not events["stop_recording"] and ( (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] ): @@ -80,6 +107,9 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: control_time_s=RESET_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, ) if events["rerecord_episode"]: @@ -89,13 +119,14 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: dataset.clear_episode_buffer() continue + # Save episode dataset.save_episode() recorded_episodes += 1 -# Upload to hub and clean up -dataset.push_to_hub() - +# Clean up +log_say("Stop recording") robot.disconnect() leader_arm.disconnect() keyboard.disconnect() listener.stop() +dataset.push_to_hub() diff --git a/examples/lekiwi/replay.py b/examples/lekiwi/replay.py index 248354df9..0f8eabdff 100644 --- a/examples/lekiwi/replay.py +++ b/examples/lekiwi/replay.py @@ -1,3 +1,19 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import time from lerobot.datasets.lerobot_dataset import LeRobotDataset @@ -8,25 +24,36 @@ from lerobot.utils.utils import log_say EPISODE_IDX = 0 +# Initialize the robot config robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi") + +# Initialize the robot robot = LeKiwiClient(robot_config) +# Fetch the dataset to replay dataset = LeRobotDataset("/", episodes=[EPISODE_IDX]) -actions = dataset.hf_dataset.select_columns("action") +# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 +episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) +actions = episode_frames.select_columns("action") +# Connect to the robot robot.connect() if not robot.is_connected: raise ValueError("Robot is not connected!") +print("Starting replay loop...") log_say(f"Replaying episode {EPISODE_IDX}") -for idx in range(dataset.num_frames): +for idx in range(len(episode_frames)): t0 = time.perf_counter() + # Get recorded action from dataset action = { name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) } - robot.send_action(action) + + # Send action to robot + _ = robot.send_action(action) busy_wait(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) diff --git a/examples/lekiwi/teleoperate.py b/examples/lekiwi/teleoperate.py index 8358a2b93..cde4000df 100644 --- a/examples/lekiwi/teleoperate.py +++ b/examples/lekiwi/teleoperate.py @@ -1,3 +1,19 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import time from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig @@ -13,35 +29,44 @@ robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="my_lekiwi") teleop_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm") keyboard_config = KeyboardTeleopConfig(id="my_laptop_keyboard") +# Initialize the robot and teleoperator robot = LeKiwiClient(robot_config) leader_arm = SO100Leader(teleop_arm_config) keyboard = KeyboardTeleop(keyboard_config) +# Connect to the robot and teleoperator # To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi` robot.connect() leader_arm.connect() keyboard.connect() +# Init rerun viewer _init_rerun(session_name="lekiwi_teleop") if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected: - raise ValueError("Robot, leader arm of keyboard is not connected!") + raise ValueError("Robot or teleop is not connected!") +print("Starting teleop loop...") while True: t0 = time.perf_counter() + # Get robot observation observation = robot.get_observation() + # Get teleop action + # Arm arm_action = leader_arm.get_action() arm_action = {f"arm_{k}": v for k, v in arm_action.items()} - + # Keyboard keyboard_keys = keyboard.get_action() base_action = robot._from_keyboard_to_base_action(keyboard_keys) - log_rerun_data(observation, {**arm_action, **base_action}) - action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action - robot.send_action(action) + # Send action to robot + _ = robot.send_action(action) + + # Visualize + log_rerun_data(observation=observation, action=action) busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0)) diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py new file mode 100644 index 000000000..e76b11350 --- /dev/null +++ b/examples/phone_to_so100/evaluate.py @@ -0,0 +1,197 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features +from lerobot.datasets.utils import combine_feature_dicts +from lerobot.model.kinematics import RobotKinematics +from lerobot.policies.act.modeling_act import ACTPolicy +from lerobot.policies.factory import make_pre_post_processors +from lerobot.processor import ( + RobotAction, + RobotObservation, + RobotProcessorPipeline, + make_default_teleop_action_processor, +) +from lerobot.processor.converters import ( + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from lerobot.record import record_loop +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + ForwardKinematicsJointsToEE, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.utils.utils import log_say +from lerobot.utils.visualization_utils import _init_rerun + +NUM_EPISODES = 5 +FPS = 30 +EPISODE_TIME_SEC = 60 +TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" +HF_DATASET_ID = "/" + +# Create the robot configuration & robot +camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem58760434471", + id="my_awesome_follower_arm", + cameras=camera_config, + use_degrees=True, +) + +robot = SO100Follower(robot_config) + +# Create policy +policy = ACTPolicy.from_pretrained(HF_MODEL_ID) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert EE action to joints action +robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Build pipeline to convert joints observation to EE observation +robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[ + ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())) + ], + to_transition=observation_to_transition, + to_output=transition_to_observation, +) + +# Create the dataset +dataset = LeRobotDataset.create( + repo_id=HF_DATASET_ID, + fps=FPS, + features=combine_feature_dicts( + aggregate_pipeline_dataset_features( + pipeline=robot_joints_to_ee_pose_processor, + initial_features=create_initial_features(observation=robot.observation_features), + use_videos=True, + ), + # User for now should be explicit on the feature keys that were used for record + # Alternatively, the user can pass the processor step that has the right features + aggregate_pipeline_dataset_features( + pipeline=make_default_teleop_action_processor(), + initial_features=create_initial_features( + action={ + f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,)) + for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"] + } + ), + use_videos=True, + ), + ), + robot_type=robot.name, + use_videos=True, + image_writer_threads=4, +) + +# Build Policy Processors +preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=policy, + pretrained_path=HF_MODEL_ID, + dataset_stats=dataset.meta.stats, + # The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility. + preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}}, +) + +# Connect the robot +robot.connect() + +# Initialize the keyboard listener and rerun visualization +listener, events = init_keyboard_listener() +_init_rerun(session_name="phone_so100_evaluate") + +if not robot.is_connected: + raise ValueError("Robot is not connected!") + +print("Starting evaluate loop...") +episode_idx = 0 +for episode_idx in range(NUM_EPISODES): + log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") + + # Main record loop + record_loop( + robot=robot, + events=events, + fps=FPS, + policy=policy, + preprocessor=preprocessor, # Pass the pre and post policy processors + postprocessor=postprocessor, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=make_default_teleop_action_processor(), + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) + + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=make_default_teleop_action_processor(), + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) + + if events["rerecord_episode"]: + log_say("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + # Save episode + dataset.save_episode() + episode_idx += 1 + +# Clean up +log_say("Stop recording") +robot.disconnect() +listener.stop() +dataset.push_to_hub() diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py new file mode 100644 index 000000000..768041d63 --- /dev/null +++ b/examples/phone_to_so100/record.py @@ -0,0 +1,204 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features +from lerobot.datasets.utils import combine_feature_dicts +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor.converters import ( + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from lerobot.record import record_loop +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + EEBoundsAndSafety, + EEReferenceAndDelta, + ForwardKinematicsJointsToEE, + GripperVelocityToJoint, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS +from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction +from lerobot.teleoperators.phone.teleop_phone import Phone +from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.utils.utils import log_say +from lerobot.utils.visualization_utils import _init_rerun + +NUM_EPISODES = 2 +FPS = 30 +EPISODE_TIME_SEC = 60 +RESET_TIME_SEC = 30 +TASK_DESCRIPTION = "My task description" +HF_REPO_ID = "/" + +# Create the robot and teleoperator configurations +camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem5A460814411", + id="my_awesome_follower_arm", + cameras=camera_config, + use_degrees=True, +) +teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID + +# Initialize the robot and teleoperator +robot = SO100Follower(robot_config) +phone = Phone(teleop_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert phone action to EE action +phone_to_robot_ee_pose_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + MapPhoneActionToRobotAction(platform=teleop_config.phone_os), + EEReferenceAndDelta( + kinematics=kinematics_solver, + end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, + motor_names=list(robot.bus.motors.keys()), + use_latched_reference=True, + ), + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, + max_ee_step_m=0.20, + max_ee_twist_step_rad=0.50, + ), + GripperVelocityToJoint(speed_factor=20.0), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Build pipeline to convert EE action to joints action +robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Build pipeline to convert joint observation to EE observation +robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[ + ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())) + ], + to_transition=observation_to_transition, + to_output=transition_to_observation, +) + +# Create the dataset +dataset = LeRobotDataset.create( + repo_id=HF_REPO_ID, + fps=FPS, + features=combine_feature_dicts( + # Run the feature contract of the pipelines + # This tells you how the features would look like after the pipeline steps + aggregate_pipeline_dataset_features( + pipeline=phone_to_robot_ee_pose_processor, + initial_features=create_initial_features(action=phone.action_features), + use_videos=True, + ), + aggregate_pipeline_dataset_features( + pipeline=robot_joints_to_ee_pose, + initial_features=create_initial_features(observation=robot.observation_features), + use_videos=True, + ), + ), + robot_type=robot.name, + use_videos=True, + image_writer_threads=4, +) + +# Connect the robot and teleoperator +robot.connect() +phone.connect() + +# Initialize the keyboard listener and rerun visualization +listener, events = init_keyboard_listener() +_init_rerun(session_name="phone_so100_record") + +if not robot.is_connected or not phone.is_connected: + raise ValueError("Robot or teleop is not connected!") + + +print("Starting record loop. Move your phone to teleoperate the robot...") +episode_idx = 0 +while episode_idx < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") + + # Main record loop + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop=phone, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=phone_to_robot_ee_pose_processor, + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose, + ) + + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop=phone, + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=phone_to_robot_ee_pose_processor, + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose, + ) + + if events["rerecord_episode"]: + log_say("Re-recording episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + # Save episode + dataset.save_episode() + episode_idx += 1 + +# Clean up +log_say("Stop recording") +robot.disconnect() +phone.disconnect() +listener.stop() +dataset.push_to_hub() diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py new file mode 100644 index 000000000..80c65a4c2 --- /dev/null +++ b/examples/phone_to_so100/replay.py @@ -0,0 +1,99 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor.converters import ( + robot_action_observation_to_transition, + transition_to_robot_action, +) +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.utils.robot_utils import busy_wait +from lerobot.utils.utils import log_say + +EPISODE_IDX = 0 +HF_REPO_ID = "/" + +# Initialize the robot config +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True +) + +# Initialize the robot +robot = SO100Follower(robot_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert EE action to joints action +robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=False, # Because replay is open loop + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Fetch the dataset to replay +dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX]) +# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 +episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) +actions = episode_frames.select_columns("action") + +# Connect to the robot +robot.connect() + +if not robot.is_connected: + raise ValueError("Robot is not connected!") + +print("Starting replay loop...") +log_say(f"Replaying episode {EPISODE_IDX}") +for idx in range(len(episode_frames)): + t0 = time.perf_counter() + + # Get recorded action from dataset + ee_action = { + name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) + } + + # Get robot observation + robot_obs = robot.get_observation() + + # Dataset EE -> robot joints + joint_action = robot_ee_to_joints_processor((ee_action, robot_obs)) + + # Send action to robot + _ = robot.send_action(joint_action) + + busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0)) + +# Clean up +robot.disconnect() diff --git a/examples/phone_to_so100/teleoperate.py b/examples/phone_to_so100/teleoperate.py new file mode 100644 index 000000000..eb5ed3526 --- /dev/null +++ b/examples/phone_to_so100/teleoperate.py @@ -0,0 +1,114 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specif + +import time + +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor.converters import ( + robot_action_observation_to_transition, + transition_to_robot_action, +) +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + EEBoundsAndSafety, + EEReferenceAndDelta, + GripperVelocityToJoint, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS +from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction +from lerobot.teleoperators.phone.teleop_phone import Phone +from lerobot.utils.robot_utils import busy_wait +from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data + +FPS = 30 + +# Initialize the robot and teleoperator +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True +) +teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID + +# Initialize the robot and teleoperator +robot = SO100Follower(robot_config) +teleop_device = Phone(teleop_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert phone action to ee pose action to joint action +phone_to_robot_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + MapPhoneActionToRobotAction(platform=teleop_config.phone_os), + EEReferenceAndDelta( + kinematics=kinematics_solver, + end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, + motor_names=list(robot.bus.motors.keys()), + use_latched_reference=True, + ), + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, + max_ee_step_m=0.10, + max_ee_twist_step_rad=0.50, + ), + GripperVelocityToJoint( + speed_factor=20.0, + ), + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Connect to the robot and teleoperator +robot.connect() +teleop_device.connect() + +# Init rerun viewer +_init_rerun(session_name="phone_so100_teleop") + +if not robot.is_connected or not teleop_device.is_connected: + raise ValueError("Robot or teleop is not connected!") + +print("Starting teleop loop. Move your phone to teleoperate the robot...") +while True: + t0 = time.perf_counter() + + # Get robot observation + robot_obs = robot.get_observation() + + # Get teleop action + phone_obs = teleop_device.get_action() + + # Phone -> EE pose -> Joints transition + joint_action = phone_to_robot_joints_processor((phone_obs, robot_obs)) + + # Send action to robot + _ = robot.send_action(joint_action) + + # Visualize + log_rerun_data(observation=phone_obs, action=joint_action) + + busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0)) diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py new file mode 100644 index 000000000..fd10bf865 --- /dev/null +++ b/examples/so100_to_so100_EE/evaluate.py @@ -0,0 +1,198 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features +from lerobot.datasets.utils import combine_feature_dicts +from lerobot.model.kinematics import RobotKinematics +from lerobot.policies.act.modeling_act import ACTPolicy +from lerobot.policies.factory import make_pre_post_processors +from lerobot.processor import ( + RobotAction, + RobotObservation, + RobotProcessorPipeline, + make_default_teleop_action_processor, +) +from lerobot.processor.converters import ( + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from lerobot.record import record_loop +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + ForwardKinematicsJointsToEE, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.utils.utils import log_say +from lerobot.utils.visualization_utils import _init_rerun + +NUM_EPISODES = 5 +FPS = 30 +EPISODE_TIME_SEC = 60 +TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" +HF_DATASET_ID = "/" + +# Create the robot configuration & robot +camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem5A460814411", + id="my_awesome_follower_arm", + cameras=camera_config, + use_degrees=True, +) + +robot = SO100Follower(robot_config) + +# Create policy +policy = ACTPolicy.from_pretrained(HF_MODEL_ID) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert EE action to joints action +robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Build pipeline to convert joints observation to EE observation +robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[ + ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())) + ], + to_transition=observation_to_transition, + to_output=transition_to_observation, +) + + +# Create the dataset +dataset = LeRobotDataset.create( + repo_id=HF_DATASET_ID, + fps=FPS, + features=combine_feature_dicts( + aggregate_pipeline_dataset_features( + pipeline=robot_joints_to_ee_pose_processor, + initial_features=create_initial_features(observation=robot.observation_features), + use_videos=True, + ), + # User for now should be explicit on the feature keys that were used for record + # Alternatively, the user can pass the processor step that has the right features + aggregate_pipeline_dataset_features( + pipeline=make_default_teleop_action_processor(), + initial_features=create_initial_features( + action={ + f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,)) + for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"] + } + ), + use_videos=True, + ), + ), + robot_type=robot.name, + use_videos=True, + image_writer_threads=4, +) + +# Build Policy Processors +preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=policy, + pretrained_path=HF_MODEL_ID, + dataset_stats=dataset.meta.stats, + # The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility. + preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}}, +) + +# Connect the robot and teleoperator +robot.connect() + +# Initialize the keyboard listener and rerun visualization +listener, events = init_keyboard_listener() +_init_rerun(session_name="so100_so100_evaluate") + +if not robot.is_connected: + raise ValueError("Robot is not connected!") + +print("Starting evaluate loop...") +episode_idx = 0 +for episode_idx in range(NUM_EPISODES): + log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") + + # Main record loop + record_loop( + robot=robot, + events=events, + fps=FPS, + policy=policy, + preprocessor=preprocessor, # Pass the pre and post policy processors + postprocessor=postprocessor, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=make_default_teleop_action_processor(), + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) + + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=make_default_teleop_action_processor(), + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) + + if events["rerecord_episode"]: + log_say("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + # Save episode + dataset.save_episode() + episode_idx += 1 + +# Clean up +log_say("Stop recording") +robot.disconnect() +listener.stop() +dataset.push_to_hub() diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py new file mode 100644 index 000000000..abb8fb99d --- /dev/null +++ b/examples/so100_to_so100_EE/record.py @@ -0,0 +1,203 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features +from lerobot.datasets.utils import combine_feature_dicts +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor.converters import ( + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from lerobot.record import record_loop +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + EEBoundsAndSafety, + ForwardKinematicsJointsToEE, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig +from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader +from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.utils.utils import log_say +from lerobot.utils.visualization_utils import _init_rerun + +NUM_EPISODES = 2 +FPS = 30 +EPISODE_TIME_SEC = 60 +RESET_TIME_SEC = 30 +TASK_DESCRIPTION = "My task description" +HF_REPO_ID = "/" + +# Create the robot and teleoperator configurations +camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} +follower_config = SO100FollowerConfig( + port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", cameras=camera_config, use_degrees=True +) +leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm") + +# Initialize the robot and teleoperator +follower = SO100Follower(follower_config) +leader = SO100Leader(leader_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +follower_kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(follower.bus.motors.keys()), +) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +leader_kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(leader.bus.motors.keys()), +) + +# Build pipeline to convert follower joints to EE observation +follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[ + ForwardKinematicsJointsToEE( + kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys()) + ), + ], + to_transition=observation_to_transition, + to_output=transition_to_observation, +) + +# Build pipeline to convert leader joints to EE action +leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + ForwardKinematicsJointsToEE( + kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys()) + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Build pipeline to convert EE action to follower joints +ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + [ + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, + max_ee_step_m=0.10, + max_ee_twist_step_rad=0.50, + ), + InverseKinematicsEEToJoints( + kinematics=follower_kinematics_solver, + motor_names=list(follower.bus.motors.keys()), + initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Create the dataset +dataset = LeRobotDataset.create( + repo_id=HF_REPO_ID, + fps=FPS, + features=combine_feature_dicts( + # Run the feature contract of the pipelines + # This tells you how the features would look like after the pipeline steps + aggregate_pipeline_dataset_features( + pipeline=leader_joints_to_ee, + initial_features=create_initial_features(action=leader.action_features), + use_videos=True, + ), + aggregate_pipeline_dataset_features( + pipeline=follower_joints_to_ee, + initial_features=create_initial_features(observation=follower.observation_features), + use_videos=True, + ), + ), + robot_type=follower.name, + use_videos=True, + image_writer_threads=4, +) + + +# Connect the robot and teleoperator +leader.connect() +follower.connect() + +# Initialize the keyboard listener and rerun visualization +listener, events = init_keyboard_listener() +_init_rerun(session_name="recording_phone") + +if not leader.is_connected or not follower.is_connected: + raise ValueError("Robot or teleop is not connected!") + +print("Starting record loop...") +episode_idx = 0 +while episode_idx < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") + + # Main record loop + record_loop( + robot=follower, + events=events, + fps=FPS, + teleop=leader, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=leader_joints_to_ee, + robot_action_processor=ee_to_follower_joints, + robot_observation_processor=follower_joints_to_ee, + ) + + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]): + log_say("Reset the environment") + record_loop( + robot=follower, + events=events, + fps=FPS, + teleop=leader, + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=leader_joints_to_ee, + robot_action_processor=ee_to_follower_joints, + robot_observation_processor=follower_joints_to_ee, + ) + + if events["rerecord_episode"]: + log_say("Re-recording episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + # Save episode + dataset.save_episode() + episode_idx += 1 + +# Clean up +log_say("Stop recording") +leader.disconnect() +follower.disconnect() +listener.stop() +dataset.push_to_hub() diff --git a/examples/so100_to_so100_EE/replay.py b/examples/so100_to_so100_EE/replay.py new file mode 100644 index 000000000..6987f4839 --- /dev/null +++ b/examples/so100_to_so100_EE/replay.py @@ -0,0 +1,100 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor.converters import ( + robot_action_observation_to_transition, + transition_to_robot_action, +) +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.utils.robot_utils import busy_wait +from lerobot.utils.utils import log_say + +EPISODE_IDX = 0 +HF_REPO_ID = "/" + +# Initialize the robot config +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True +) + +# Initialize the robot +robot = SO100Follower(robot_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), +) + +# Build pipeline to convert EE action to joints action +robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=False, # Because replay is open loop + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Fetch the dataset to replay +dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX]) +# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 +episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) +actions = episode_frames.select_columns("action") + +# Connect to the robot +robot.connect() + +if not robot.is_connected: + raise ValueError("Robot is not connected!") + +print("Starting replay loop...") +log_say(f"Replaying episode {EPISODE_IDX}") +for idx in range(len(episode_frames)): + t0 = time.perf_counter() + + # Get recorded action from dataset + ee_action = { + name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) + } + + # Get robot observation + robot_obs = robot.get_observation() + + # Dataset EE -> robot joints + joint_action = robot_ee_to_joints_processor((ee_action, robot_obs)) + + # Send action to robot + _ = robot.send_action(joint_action) + + busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0)) + +# Clean up +robot.disconnect() diff --git a/examples/so100_to_so100_EE/teleoperate.py b/examples/so100_to_so100_EE/teleoperate.py new file mode 100644 index 000000000..ab54e7236 --- /dev/null +++ b/examples/so100_to_so100_EE/teleoperate.py @@ -0,0 +1,122 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline +from lerobot.processor.converters import ( + robot_action_observation_to_transition, + robot_action_to_transition, + transition_to_robot_action, +) +from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + EEBoundsAndSafety, + ForwardKinematicsJointsToEE, + InverseKinematicsEEToJoints, +) +from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig +from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader +from lerobot.utils.robot_utils import busy_wait +from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data + +FPS = 30 + +# Initialize the robot and teleoperator config +follower_config = SO100FollowerConfig( + port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True +) +leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm") + +# Initialize the robot and teleoperator +follower = SO100Follower(follower_config) +leader = SO100Leader(leader_config) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +follower_kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(follower.bus.motors.keys()), +) + +# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf +leader_kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(leader.bus.motors.keys()), +) + +# Build pipeline to convert teleop joints to EE action +leader_to_ee = RobotProcessorPipeline[RobotAction, RobotAction]( + steps=[ + ForwardKinematicsJointsToEE( + kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys()) + ), + ], + to_transition=robot_action_to_transition, + to_output=transition_to_robot_action, +) + +# build pipeline to convert EE action to robot joints +ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + [ + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, + max_ee_step_m=0.10, + max_ee_twist_step_rad=0.50, + ), + InverseKinematicsEEToJoints( + kinematics=follower_kinematics_solver, + motor_names=list(follower.bus.motors.keys()), + initial_guess_current_joints=False, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, +) + +# Connect to the robot and teleoperator +follower.connect() +leader.connect() + +# Init rerun viewer +_init_rerun(session_name="so100_so100_EE_teleop") + +print("Starting teleop loop...") +while True: + t0 = time.perf_counter() + + # Get robot observation + robot_obs = follower.get_observation() + + # Get teleop observation + leader_joints_obs = leader.get_action() + + # teleop joints -> teleop EE action + leader_ee_act = leader_to_ee(leader_joints_obs) + + # teleop EE -> robot joints + follower_joints_act = ee_to_follower_joints((leader_ee_act, robot_obs)) + + # Send action to robot + _ = follower.send_action(follower_joints_act) + + # Visualize + log_rerun_data(observation=leader_ee_act, action=follower_joints_act) + + busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0)) diff --git a/pyproject.toml b/pyproject.toml index 7241a78f9..70755cf9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,7 @@ dependencies = [ # Common pygame-dep = ["pygame>=2.5.1"] placo-dep = ["placo>=0.9.6"] -transformers-dep = ["transformers>=4.50.3,<4.52.0"] # TODO: Bumb dependency +transformers-dep = ["transformers>=4.52.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # Motors @@ -111,6 +111,7 @@ intelrealsense = [ "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'", "pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'", ] +phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"] # stretch = [ # "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'", # "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'", @@ -153,7 +154,8 @@ all = [ "lerobot[video_benchmark]", "lerobot[aloha]", "lerobot[pusht]", - "lerobot[xarm]" + "lerobot[xarm]", + "lerobot[phone]", ] [project.scripts] diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index f5fa727cf..7532f0612 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -26,7 +26,7 @@ from huggingface_hub import hf_hub_download from huggingface_hub.constants import CONFIG_NAME from huggingface_hub.errors import HfHubHTTPError -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.constants import ACTION, OBS_STATE from lerobot.optim.optimizers import OptimizerConfig from lerobot.optim.schedulers import LRSchedulerConfig @@ -53,7 +53,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): """ n_obs_steps: int = 1 - normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict) input_features: dict[str, PolicyFeature] = field(default_factory=dict) output_features: dict[str, PolicyFeature] = field(default_factory=dict) diff --git a/src/lerobot/configs/types.py b/src/lerobot/configs/types.py index 6040ff70b..e02527840 100644 --- a/src/lerobot/configs/types.py +++ b/src/lerobot/configs/types.py @@ -24,6 +24,12 @@ class FeatureType(str, Enum): ENV = "ENV" ACTION = "ACTION" REWARD = "REWARD" + LANGUAGE = "LANGUAGE" + + +class PipelineFeatureType(str, Enum): + ACTION = "ACTION" + OBSERVATION = "OBSERVATION" class NormalizationMode(str, Enum): diff --git a/src/lerobot/constants.py b/src/lerobot/constants.py index 382435a9f..464969c72 100644 --- a/src/lerobot/constants.py +++ b/src/lerobot/constants.py @@ -21,8 +21,14 @@ OBS_ENV_STATE = "observation.environment_state" OBS_STATE = "observation.state" OBS_IMAGE = "observation.image" OBS_IMAGES = "observation.images" +OBS_LANGUAGE = "observation.language" ACTION = "action" REWARD = "next.reward" +TRUNCATED = "next.truncated" +DONE = "next.done" + +OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" +OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" ROBOTS = "robots" ROBOT_TYPE = "robot_type" @@ -39,6 +45,9 @@ OPTIMIZER_STATE = "optimizer_state.safetensors" OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json" SCHEDULER_STATE = "scheduler_state.json" +POLICY_PREPROCESSOR_DEFAULT_NAME = "policy_preprocessor" +POLICY_POSTPROCESSOR_DEFAULT_NAME = "policy_postprocessor" + if "LEROBOT_HOME" in os.environ: raise ValueError( f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n" diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py new file mode 100644 index 000000000..b55ccf8a9 --- /dev/null +++ b/src/lerobot/datasets/pipeline_features.py @@ -0,0 +1,141 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from collections.abc import Sequence +from typing import Any + +from lerobot.configs.types import PipelineFeatureType +from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE +from lerobot.datasets.utils import hw_to_dataset_features +from lerobot.processor import DataProcessorPipeline + + +def create_initial_features( + action: dict[str, Any] | None = None, observation: dict[str, Any] | None = None +) -> dict[PipelineFeatureType, dict[str, Any]]: + """ + Creates the initial features dict for the dataset from action and observation specs. + + Args: + action: A dictionary of action feature names to their types/shapes. + observation: A dictionary of observation feature names to their types/shapes. + + Returns: + The initial features dictionary structured by PipelineFeatureType. + """ + features = {PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: {}} + if action: + features[PipelineFeatureType.ACTION] = action + if observation: + features[PipelineFeatureType.OBSERVATION] = observation + return features + + +# Helper to filter state/action keys based on regex patterns. +def should_keep(key: str, patterns: tuple[str]) -> bool: + if patterns is None: + return True + return any(re.search(pat, key) for pat in patterns) + + +def strip_prefix(key: str, prefixes_to_strip: tuple[str]) -> str: + for prefix in prefixes_to_strip: + if key.startswith(prefix): + return key[len(prefix) :] + return key + + +# Define prefixes to strip from feature keys for clean names. +# Handles both fully qualified (e.g., "action.state") and short (e.g., "state") forms. +PREFIXES_TO_STRIP = tuple( + f"{token}." for const in (ACTION, OBS_STATE, OBS_IMAGES) for token in (const, const.split(".")[-1]) +) + + +def aggregate_pipeline_dataset_features( + pipeline: DataProcessorPipeline, + initial_features: dict[PipelineFeatureType, dict[str, Any]], + *, + use_videos: bool = True, + patterns: Sequence[str] | None = None, +) -> dict[str, dict]: + """ + Aggregates and filters pipeline features to create a dataset-ready features dictionary. + + This function transforms initial features using the pipeline, categorizes them as action or observations + (image or state), filters them based on `use_videos` and `patterns`, and finally + formats them for use with a Hugging Face LeRobot Dataset. + + Args: + pipeline: The DataProcessorPipeline to apply. + initial_features: A dictionary of raw feature specs for actions and observations. + use_videos: If False, image features are excluded. + patterns: A sequence of regex patterns to filter action and state features. + Image features are not affected by this filter. + + Returns: + A dictionary of features formatted for a Hugging Face LeRobot Dataset. + """ + all_features = pipeline.transform_features(initial_features) + + # Intermediate storage for categorized and filtered features. + processed_features: dict[str, dict[str, Any]] = { + "action": {}, + "observation": {}, + } + images_token = OBS_IMAGES.split(".")[-1] + + # Iterate through all features transformed by the pipeline. + for ptype, feats in all_features.items(): + if ptype not in [PipelineFeatureType.ACTION, PipelineFeatureType.OBSERVATION]: + continue + + for key, value in feats.items(): + # 1. Categorize the feature. + is_action = ptype == PipelineFeatureType.ACTION + # Observations are classified as images if their key matches image-related tokens or if the shape of the feature is 3. + # All other observations are treated as state. + is_image = not is_action and ( + (isinstance(value, tuple) and len(value) == 3) + or ( + key.startswith(f"{OBS_IMAGES}.") + or key.startswith(f"{images_token}.") + or f".{images_token}." in key + ) + ) + + # 2. Apply filtering rules. + if is_image and not use_videos: + continue + if not is_image and not should_keep(key, patterns): + continue + + # 3. Add the feature to the appropriate group with a clean name. + name = strip_prefix(key, PREFIXES_TO_STRIP) + if is_action: + processed_features["action"][name] = value + else: + processed_features["observation"][name] = value + + # Convert the processed features into the final dataset format. + dataset_features = {} + if processed_features["action"]: + dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos)) + if processed_features["observation"]: + dataset_features.update( + hw_to_dataset_features(processed_features["observation"], "observation", use_videos) + ) + + return dataset_features diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index c840d5bc1..922fc4e3f 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -150,14 +150,20 @@ def get_video_size_in_mb(mp4_path: Path) -> float: def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: - """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. + """Flatten a nested dictionary by joining keys with a separator. - For example: - ``` - >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}` - >>> print(flatten_dict(dct)) - {"a/b": 1, "a/c/d": 2, "e": 3} - ``` + Example: + >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3} + >>> print(flatten_dict(dct)) + {'a/b': 1, 'a/c/d': 2, 'e': 3} + + Args: + d (dict): The dictionary to flatten. + parent_key (str): The base key to prepend to the keys in this level. + sep (str): The separator to use between keys. + + Returns: + dict: A flattened dictionary. """ items = [] for k, v in d.items(): @@ -170,6 +176,20 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: def unflatten_dict(d: dict, sep: str = "/") -> dict: + """Unflatten a dictionary with delimited keys into a nested dictionary. + + Example: + >>> flat_dct = {"a/b": 1, "a/c/d": 2, "e": 3} + >>> print(unflatten_dict(flat_dct)) + {'a': {'b': 1, 'c': {'d': 2}}, 'e': 3} + + Args: + d (dict): A dictionary with flattened keys. + sep (str): The separator used in the keys. + + Returns: + dict: A nested dictionary. + """ outdict = {} for key, value in d.items(): parts = key.split(sep) @@ -183,6 +203,19 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict: def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: + """Serialize a dictionary containing tensors or numpy arrays to be JSON-compatible. + + Converts torch.Tensor, np.ndarray, and np.generic types to lists or native Python types. + + Args: + stats (dict): A dictionary that may contain non-serializable numeric types. + + Returns: + dict: A dictionary with all values converted to JSON-serializable types. + + Raises: + NotImplementedError: If a value has an unsupported type. + """ serialized_dict = {} for key, value in flatten_dict(stats).items(): if isinstance(value, (torch.Tensor, np.ndarray)): @@ -199,6 +232,17 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: + """Embed image bytes into the dataset table before saving to Parquet. + + This function prepares a Hugging Face dataset for serialization by converting + image objects into an embedded format that can be stored in Arrow/Parquet. + + Args: + dataset (datasets.Dataset): The input dataset, possibly containing image features. + + Returns: + datasets.Dataset: The dataset with images embedded in the table storage. + """ # Embed image bytes into the table before saving to parquet format = dataset.format dataset = dataset.with_format("arrow") @@ -208,11 +252,27 @@ def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: def load_json(fpath: Path) -> Any: + """Load data from a JSON file. + + Args: + fpath (Path): Path to the JSON file. + + Returns: + Any: The data loaded from the JSON file. + """ with open(fpath) as f: return json.load(f) def write_json(data: dict, fpath: Path) -> None: + """Write data to a JSON file. + + Creates parent directories if they don't exist. + + Args: + data (dict): The dictionary to write. + fpath (Path): The path to the output JSON file. + """ fpath.parent.mkdir(exist_ok=True, parents=True) with open(fpath, "w") as f: json.dump(data, f, indent=4, ensure_ascii=False) @@ -223,6 +283,16 @@ def write_info(info: dict, local_dir: Path) -> None: def load_info(local_dir: Path) -> dict: + """Load dataset info metadata from its standard file path. + + Also converts shape lists to tuples for consistency. + + Args: + local_dir (Path): The root directory of the dataset. + + Returns: + dict: The dataset information dictionary. + """ info = load_json(local_dir / INFO_PATH) for ft in info["features"].values(): ft["shape"] = tuple(ft["shape"]) @@ -230,16 +300,40 @@ def load_info(local_dir: Path) -> dict: def write_stats(stats: dict, local_dir: Path) -> None: + """Serialize and write dataset statistics to their standard file path. + + Args: + stats (dict): The statistics dictionary (can contain tensors/numpy arrays). + local_dir (Path): The root directory of the dataset. + """ serialized_stats = serialize_dict(stats) write_json(serialized_stats, local_dir / STATS_PATH) def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]: + """Recursively cast numerical values in a stats dictionary to numpy arrays. + + Args: + stats (dict): The statistics dictionary. + + Returns: + dict: The statistics dictionary with values cast to numpy arrays. + """ stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} return unflatten_dict(stats) def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]] | None: + """Load dataset statistics and cast numerical values to numpy arrays. + + Returns None if the stats file doesn't exist. + + Args: + local_dir (Path): The root directory of the dataset. + + Returns: + A dictionary of statistics or None if the file is not found. + """ if not (local_dir / STATS_PATH).exists(): return None stats = load_json(local_dir / STATS_PATH) @@ -297,6 +391,18 @@ def backward_compatible_episodes_stats( def load_image_as_numpy( fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True ) -> np.ndarray: + """Load an image from a file into a numpy array. + + Args: + fpath (str | Path): Path to the image file. + dtype (np.dtype): The desired data type of the output array. If floating, + pixels are scaled to [0, 1]. + channel_first (bool): If True, converts the image to (C, H, W) format. + Otherwise, it remains in (H, W, C) format. + + Returns: + np.ndarray: The image as a numpy array. + """ img = PILImage.open(fpath).convert("RGB") img_array = np.array(img, dtype=dtype) if channel_first: # (H, W, C) -> (C, H, W) @@ -307,10 +413,19 @@ def load_image_as_numpy( def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]: - """Get a transform function that convert items from Hugging Face dataset (pyarrow) - to torch tensors. Importantly, images are converted from PIL, which corresponds to - a channel last representation (h w c) of uint8 type, to a torch image representation - with channel first (c h w) of float32 type in range [0,1]. + """Convert a batch from a Hugging Face dataset to torch tensors. + + This transform function converts items from Hugging Face dataset format (pyarrow) + to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8) + to a torch image representation (C, H, W, float32) in the range [0, 1]. Other + types are converted to torch.tensor. + + Args: + items_dict (dict): A dictionary representing a batch of data from a + Hugging Face dataset. + + Returns: + dict: The batch with items converted to torch tensors. """ for key in items_dict: first_item = items_dict[key][0] @@ -325,6 +440,14 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to def is_valid_version(version: str) -> bool: + """Check if a string is a valid PEP 440 version. + + Args: + version (str): The version string to check. + + Returns: + bool: True if the version string is valid, False otherwise. + """ try: packaging.version.parse(version) return True @@ -338,6 +461,18 @@ def check_version_compatibility( current_version: str | packaging.version.Version, enforce_breaking_major: bool = True, ) -> None: + """Check for version compatibility between a dataset and the current codebase. + + Args: + repo_id (str): The repository ID for logging purposes. + version_to_check (str | packaging.version.Version): The version of the dataset. + current_version (str | packaging.version.Version): The current version of the codebase. + enforce_breaking_major (bool): If True, raise an error on major version mismatch. + + Raises: + BackwardCompatibilityError: If the dataset version is from a newer, incompatible + major version of the codebase. + """ v_check = ( packaging.version.parse(version_to_check) if not isinstance(version_to_check, packaging.version.Version) @@ -355,7 +490,14 @@ def check_version_compatibility( def get_repo_versions(repo_id: str) -> list[packaging.version.Version]: - """Returns available valid versions (branches and tags) on given repo.""" + """Return available valid versions (branches and tags) on a given Hub repo. + + Args: + repo_id (str): The repository ID on the Hugging Face Hub. + + Returns: + list[packaging.version.Version]: A list of valid versions found. + """ api = HfApi() repo_refs = api.list_repo_refs(repo_id, repo_type="dataset") repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags] @@ -368,9 +510,22 @@ def get_repo_versions(repo_id: str) -> list[packaging.version.Version]: def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str: - """ - Returns the version if available on repo or the latest compatible one. - Otherwise, will throw a `CompatibilityError`. + """Return the specified version if available on repo, or the latest compatible one. + + If the exact version is not found, it looks for the latest version with the + same major version number that is less than or equal to the target minor version. + + Args: + repo_id (str): The repository ID on the Hugging Face Hub. + version (str | packaging.version.Version): The target version. + + Returns: + str: The safe version string (e.g., "v1.2.3") to use as a revision. + + Raises: + RevisionNotFoundError: If the repo has no version tags. + BackwardCompatibilityError: If only older major versions are available. + ForwardCompatibilityError: If only newer major versions are available. """ target_version = ( packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version @@ -412,6 +567,17 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> def get_hf_features_from_features(features: dict) -> datasets.Features: + """Convert a LeRobot features dictionary to a `datasets.Features` object. + + Args: + features (dict): A LeRobot-style feature dictionary. + + Returns: + datasets.Features: The corresponding Hugging Face `datasets.Features` object. + + Raises: + ValueError: If a feature has an unsupported shape. + """ hf_features = {} for key, ft in features.items(): if ft["dtype"] == "video": @@ -439,6 +605,14 @@ def get_hf_features_from_features(features: dict) -> datasets.Features: def _validate_feature_names(features: dict[str, dict]) -> None: + """Validate that feature names do not contain invalid characters. + + Args: + features (dict): The LeRobot features dictionary. + + Raises: + ValueError: If any feature name contains '/'. + """ invalid_features = {name: ft for name, ft in features.items() if "/" in name} if invalid_features: raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.") @@ -447,8 +621,28 @@ def _validate_feature_names(features: dict[str, dict]) -> None: def hw_to_dataset_features( hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True ) -> dict[str, dict]: + """Convert hardware-specific features to a LeRobot dataset feature dictionary. + + This function takes a dictionary describing hardware outputs (like joint states + or camera image shapes) and formats it into the standard LeRobot feature + specification. + + Args: + hw_features (dict): Dictionary mapping feature names to their type (float for + joints) or shape (tuple for images). + prefix (str): The prefix to add to the feature keys (e.g., "observation" + or "action"). + use_video (bool): If True, image features are marked as "video", otherwise "image". + + Returns: + dict: A LeRobot features dictionary. + """ features = {} - joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float} + joint_fts = { + key: ftype + for key, ftype in hw_features.items() + if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL) + } cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} if joint_fts and prefix == "action": @@ -479,6 +673,20 @@ def hw_to_dataset_features( def build_dataset_frame( ds_features: dict[str, dict], values: dict[str, Any], prefix: str ) -> dict[str, np.ndarray]: + """Construct a single data frame from raw values based on dataset features. + + A "frame" is a dictionary containing all the data for a single timestep, + formatted as numpy arrays according to the feature specification. + + Args: + ds_features (dict): The LeRobot dataset features dictionary. + values (dict): A dictionary of raw values from the hardware/environment. + prefix (str): The prefix to filter features by (e.g., "observation" + or "action"). + + Returns: + dict: A dictionary representing a single frame of data. + """ frame = {} for key, ft in ds_features.items(): if key in DEFAULT_FEATURES or not key.startswith(prefix): @@ -492,6 +700,21 @@ def build_dataset_frame( def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: + """Convert dataset features to policy features. + + This function transforms the dataset's feature specification into a format + that a policy can use, classifying features by type (e.g., visual, state, + action) and ensuring correct shapes (e.g., channel-first for images). + + Args: + features (dict): The LeRobot dataset features dictionary. + + Returns: + dict: A dictionary mapping feature keys to `PolicyFeature` objects. + + Raises: + ValueError: If an image feature does not have a 3D shape. + """ # TODO(aliberts): Implement "type" in dataset features and simplify this policy_features = {} for key, ft in features.items(): @@ -522,6 +745,58 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea return policy_features +def combine_feature_dicts(*dicts: dict) -> dict: + """Merge LeRobot grouped feature dicts. + + - For 1D numeric specs (dtype not image/video/string) with "names": we merge the names and recompute the shape. + - For others (e.g. `observation.images.*`), the last one wins (if they are identical). + + Args: + *dicts: A variable number of LeRobot feature dictionaries to merge. + + Returns: + dict: A single merged feature dictionary. + + Raises: + ValueError: If there's a dtype mismatch for a feature being merged. + """ + out: dict = {} + for d in dicts: + for key, value in d.items(): + if not isinstance(value, dict): + out[key] = value + continue + + dtype = value.get("dtype") + shape = value.get("shape") + is_vector = ( + dtype not in ("image", "video", "string") + and isinstance(shape, tuple) + and len(shape) == 1 + and "names" in value + ) + + if is_vector: + # Initialize or retrieve the accumulating dict for this feature key + target = out.setdefault(key, {"dtype": dtype, "names": [], "shape": (0,)}) + # Ensure consistent data types across merged entries + if "dtype" in target and dtype != target["dtype"]: + raise ValueError(f"dtype mismatch for '{key}': {target['dtype']} vs {dtype}") + + # Merge feature names: append only new ones to preserve order without duplicates + seen = set(target["names"]) + for n in value["names"]: + if n not in seen: + target["names"].append(n) + seen.add(n) + # Recompute the shape to reflect the updated number of features + target["shape"] = (len(target["names"]),) + else: + # For images/videos and non-1D entries: override with the latest definition + out[key] = value + return out + + def create_empty_dataset_info( codebase_version: str, fps: int, @@ -532,6 +807,18 @@ def create_empty_dataset_info( data_files_size_in_mb: int | None = None, video_files_size_in_mb: int | None = None, ) -> dict: + """Create a template dictionary for a new dataset's `info.json`. + + Args: + codebase_version (str): The version of the LeRobot codebase. + fps (int): The frames per second of the data. + features (dict): The LeRobot features dictionary for the dataset. + use_videos (bool): Whether the dataset will store videos. + robot_type (str | None): The type of robot used, if any. + + Returns: + dict: A dictionary with the initial dataset metadata. + """ return { "codebase_version": codebase_version, "robot_type": robot_type, @@ -552,9 +839,23 @@ def create_empty_dataset_info( def check_delta_timestamps( delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True ) -> bool: - """This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance. - This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be - actual timestamps from the dataset. + """Check if delta timestamps are multiples of 1/fps +/- tolerance. + + This ensures that adding these delta timestamps to any existing timestamp in + the dataset will result in a value that aligns with the dataset's frame rate. + + Args: + delta_timestamps (dict): A dictionary where values are lists of time + deltas in seconds. + fps (int): The frames per second of the dataset. + tolerance_s (float): The allowed tolerance in seconds. + raise_value_error (bool): If True, raises an error on failure. + + Returns: + bool: True if all deltas are valid, False otherwise. + + Raises: + ValueError: If any delta is outside the tolerance and `raise_value_error` is True. """ outside_tolerance = {} for key, delta_ts in delta_timestamps.items(): @@ -580,6 +881,15 @@ def check_delta_timestamps( def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: + """Convert delta timestamps in seconds to delta indices in frames. + + Args: + delta_timestamps (dict): A dictionary of time deltas in seconds. + fps (int): The frames per second of the dataset. + + Returns: + dict: A dictionary of frame delta indices. + """ delta_indices = {} for key, delta_ts in delta_timestamps.items(): delta_indices[key] = [round(d * fps) for d in delta_ts] @@ -588,9 +898,17 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic def cycle(iterable: Any) -> Iterator[Any]: - """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. + """Create a dataloader-safe cyclical iterator. - See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe. + This is an equivalent of `itertools.cycle` but is safe for use with + PyTorch DataLoaders with multiple workers. + See https://github.com/pytorch/pytorch/issues/23900 for details. + + Args: + iterable: The iterable to cycle over. + + Yields: + Items from the iterable, restarting from the beginning when exhausted. """ iterator = iter(iterable) while True: @@ -601,8 +919,14 @@ def cycle(iterable: Any) -> Iterator[Any]: def create_branch(repo_id: str, *, branch: str, repo_type: str | None = None) -> None: - """Create a branch on a existing Hugging Face repo. Delete the branch if it already - exists before creating it. + """Create a branch on an existing Hugging Face repo. + + Deletes the branch if it already exists before creating it. + + Args: + repo_id (str): The ID of the repository. + branch (str): The name of the branch to create. + repo_type (str | None): The type of the repository (e.g., "dataset"). """ api = HfApi() @@ -620,9 +944,20 @@ def create_lerobot_dataset_card( dataset_info: dict | None = None, **kwargs, ) -> DatasetCard: - """ - Keyword arguments will be used to replace values in src/lerobot/datasets/card_template.md. - Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses. + """Create a `DatasetCard` for a LeRobot dataset. + + Keyword arguments are used to replace values in the card template. + Note: If specified, `license` must be a valid license identifier from + https://huggingface.co/docs/hub/repositories-licenses. + + Args: + tags (list | None): A list of tags to add to the dataset card. + dataset_info (dict | None): The dataset's info dictionary, which will + be displayed on the card. + **kwargs: Additional keyword arguments to populate the card template. + + Returns: + DatasetCard: The generated dataset card object. """ card_tags = ["LeRobot"] @@ -675,6 +1010,15 @@ def validate_frame(frame: dict, features: dict) -> None: def validate_features_presence(actual_features: set[str], expected_features: set[str]) -> str: + """Check for missing or extra features in a frame. + + Args: + actual_features (set[str]): The set of feature names present in the frame. + expected_features (set[str]): The set of feature names expected in the frame. + + Returns: + str: An error message string if there's a mismatch, otherwise an empty string. + """ error_message = "" missing_features = expected_features - actual_features extra_features = actual_features - expected_features @@ -692,6 +1036,19 @@ def validate_features_presence(actual_features: set[str], expected_features: set def validate_feature_dtype_and_shape( name: str, feature: dict, value: np.ndarray | PILImage.Image | str ) -> str: + """Validate the dtype and shape of a single feature's value. + + Args: + name (str): The name of the feature. + feature (dict): The feature specification from the LeRobot features dictionary. + value: The value of the feature to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + + Raises: + NotImplementedError: If the feature dtype is not supported for validation. + """ expected_dtype = feature["dtype"] expected_shape = feature["shape"] if is_valid_numpy_dtype_string(expected_dtype): @@ -707,6 +1064,17 @@ def validate_feature_dtype_and_shape( def validate_feature_numpy_array( name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray ) -> str: + """Validate a feature that is expected to be a numpy array. + + Args: + name (str): The name of the feature. + expected_dtype (str): The expected numpy dtype as a string. + expected_shape (list[int]): The expected shape. + value (np.ndarray): The numpy array to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ error_message = "" if isinstance(value, np.ndarray): actual_dtype = value.dtype @@ -726,6 +1094,18 @@ def validate_feature_numpy_array( def validate_feature_image_or_video( name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image ) -> str: + """Validate a feature that is expected to be an image or video frame. + + Accepts `np.ndarray` (channel-first or channel-last) or `PIL.Image.Image`. + + Args: + name (str): The name of the feature. + expected_shape (list[str]): The expected shape (C, H, W). + value: The image data to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. error_message = "" if isinstance(value, np.ndarray): @@ -742,12 +1122,35 @@ def validate_feature_image_or_video( def validate_feature_string(name: str, value: str) -> str: + """Validate a feature that is expected to be a string. + + Args: + name (str): The name of the feature. + value (str): The value to validate. + + Returns: + str: An error message if validation fails, otherwise an empty string. + """ if not isinstance(value, str): return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" return "" def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict) -> None: + """Validate the episode buffer before it's written to disk. + + Ensures the buffer has the required keys, contains at least one frame, and + has features consistent with the dataset's specification. + + Args: + episode_buffer (dict): The buffer containing data for a single episode. + total_episodes (int): The current total number of episodes in the dataset. + features (dict): The LeRobot features dictionary for the dataset. + + Raises: + ValueError: If the buffer is invalid. + NotImplementedError: If the episode index is manually set and doesn't match. + """ if "size" not in episode_buffer: raise ValueError("size key not found in episode_buffer") diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 35797c6ed..f71aca70d 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -161,35 +161,73 @@ class XarmEnv(EnvConfig): @dataclass -class VideoRecordConfig: - """Configuration for video recording in ManiSkill environments.""" - - enabled: bool = False - record_dir: str = "videos" - trajectory_name: str = "trajectory" +class ImagePreprocessingConfig: + crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None + resize_size: tuple[int, int] | None = None @dataclass -class EnvTransformConfig: - """Configuration for environment wrappers.""" +class RewardClassifierConfig: + """Configuration for reward classification.""" + + pretrained_path: str | None = None + success_threshold: float = 0.5 + success_reward: float = 1.0 + + +@dataclass +class InverseKinematicsConfig: + """Configuration for inverse kinematics processing.""" + + urdf_path: str | None = None + target_frame_name: str | None = None + end_effector_bounds: dict[str, list[float]] | None = None + end_effector_step_sizes: dict[str, float] | None = None + + +@dataclass +class ObservationConfig: + """Configuration for observation processing.""" - # ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig) - control_mode: str = "gamepad" - display_cameras: bool = False add_joint_velocity_to_observation: bool = False add_current_to_observation: bool = False add_ee_pose_to_observation: bool = False - crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None - resize_size: tuple[int, int] | None = None - control_time_s: float = 20.0 - fixed_reset_joint_positions: Any | None = None - reset_time_s: float = 5.0 + display_cameras: bool = False + + +@dataclass +class GripperConfig: + """Configuration for gripper control and penalties.""" + use_gripper: bool = True - gripper_quantization_threshold: float | None = 0.8 gripper_penalty: float = 0.0 gripper_penalty_in_reward: bool = False +@dataclass +class ResetConfig: + """Configuration for environment reset behavior.""" + + fixed_reset_joint_positions: Any | None = None + reset_time_s: float = 5.0 + control_time_s: float = 20.0 + terminate_on_success: bool = True + + +@dataclass +class HILSerlProcessorConfig: + """Configuration for environment processing pipeline.""" + + control_mode: str = "gamepad" + observation: ObservationConfig | None = None + image_preprocessing: ImagePreprocessingConfig | None = None + gripper: GripperConfig | None = None + reset: ResetConfig | None = None + inverse_kinematics: InverseKinematicsConfig | None = None + reward_classifier: RewardClassifierConfig | None = None + max_gripper_pos: float | None = 100.0 + + @EnvConfig.register_subclass(name="gym_manipulator") @dataclass class HILSerlRobotEnvConfig(EnvConfig): @@ -197,77 +235,10 @@ class HILSerlRobotEnvConfig(EnvConfig): robot: RobotConfig | None = None teleop: TeleoperatorConfig | None = None - wrapper: EnvTransformConfig | None = None - fps: int = 10 + processor: HILSerlProcessorConfig = field(default_factory=HILSerlProcessorConfig) + name: str = "real_robot" - mode: str | None = None # Either "record", "replay", None - repo_id: str | None = None - dataset_root: str | None = None - task: str | None = "" - num_episodes: int = 10 # only for record mode - episode: int = 0 - device: str = "cuda" - push_to_hub: bool = True - pretrained_policy_name_or_path: str | None = None - reward_classifier_pretrained_path: str | None = None - # For the reward classifier, to record more positive examples after a success - number_of_steps_after_success: int = 0 @property def gym_kwargs(self) -> dict: return {} - - -@EnvConfig.register_subclass("hil") -@dataclass -class HILEnvConfig(EnvConfig): - """Configuration for the HIL environment.""" - - name: str = "PandaPickCube" - task: str | None = "PandaPickCubeKeyboard-v0" - use_viewer: bool = True - gripper_penalty: float = 0.0 - use_gamepad: bool = True - state_dim: int = 18 - action_dim: int = 4 - fps: int = 100 - episode_length: int = 100 - video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig) - features: dict[str, PolicyFeature] = field( - default_factory=lambda: { - "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)), - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), - "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)), - } - ) - features_map: dict[str, str] = field( - default_factory=lambda: { - "action": ACTION, - "observation.image": OBS_IMAGE, - "observation.state": OBS_STATE, - } - ) - ################# args from hilserlrobotenv - reward_classifier_pretrained_path: str | None = None - robot_config: RobotConfig | None = None - teleop_config: TeleoperatorConfig | None = None - wrapper: EnvTransformConfig | None = None - mode: str | None = None # Either "record", "replay", None - repo_id: str | None = None - dataset_root: str | None = None - num_episodes: int = 10 # only for record mode - episode: int = 0 - device: str = "cuda" - push_to_hub: bool = True - pretrained_policy_name_or_path: str | None = None - # For the reward classifier, to record more positive examples after a success - number_of_steps_after_success: int = 0 - ############################ - - @property - def gym_kwargs(self) -> dict: - return { - "use_viewer": self.use_viewer, - "use_gamepad": self.use_gamepad, - "gripper_penalty": self.gripper_penalty, - } diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index dc6d96d61..af8f5eaf5 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -17,7 +17,7 @@ import importlib import gymnasium as gym -from lerobot.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv +from lerobot.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv def make_env_config(env_type: str, **kwargs) -> EnvConfig: @@ -27,8 +27,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: return PushtEnv(**kwargs) elif env_type == "xarm": return XarmEnv(**kwargs) - elif env_type == "hil": - return HILEnvConfig(**kwargs) else: raise ValueError(f"Policy type '{env_type}' is not available.") diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index 00676a011..b4f65ee9c 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -127,9 +127,29 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None: def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]: """Adds task feature to the observation dict with respect to the first environment attribute.""" if hasattr(env.envs[0], "task_description"): - observation["task"] = env.call("task_description") + task_result = env.call("task_description") + + if isinstance(task_result, tuple): + task_result = list(task_result) + + if not isinstance(task_result, list): + raise TypeError(f"Expected task_description to return a list, got {type(task_result)}") + if not all(isinstance(item, str) for item in task_result): + raise TypeError("All items in task_description result must be strings") + + observation["task"] = task_result elif hasattr(env.envs[0], "task"): - observation["task"] = env.call("task") + task_result = env.call("task") + + if isinstance(task_result, tuple): + task_result = list(task_result) + + if not isinstance(task_result, list): + raise TypeError(f"Expected task to return a list, got {type(task_result)}") + if not all(isinstance(item, str) for item in task_result): + raise TypeError("All items in task result must be strings") + + observation["task"] = task_result else: # For envs without language instructions, e.g. aloha transfer cube and etc. num_envs = observation[list(observation.keys())[0]].shape[0] observation["task"] = ["" for _ in range(num_envs)] diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 9cb0f6234..9b9de9931 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -15,6 +15,17 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .pi0.configuration_pi0 import PI0Config as PI0Config +from .pi0.processor_pi0 import Pi0NewLineProcessor from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig +from .smolvla.processor_smolvla import SmolVLANewLineProcessor from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig + +__all__ = [ + "ACTConfig", + "DiffusionConfig", + "PI0Config", + "SmolVLAConfig", + "TDMPCConfig", + "VQBeTConfig", +] diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index cfd549b25..e0f3462cc 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -35,7 +35,6 @@ from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.constants import ACTION, OBS_IMAGES from lerobot.policies.act.configuration_act import ACTConfig -from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy @@ -51,27 +50,16 @@ class ACTPolicy(PreTrainedPolicy): def __init__( self, config: ACTConfig, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ Args: config: Policy configuration class instance or None, in which case the default instantiation of the configuration class is used. - dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected - that they will be passed with a call to `load_state_dict` before the policy is used. """ super().__init__(config) config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.model = ACT(config) if config.temporal_ensemble_coeff is not None: @@ -137,23 +125,19 @@ class ACTPolicy(PreTrainedPolicy): """Predict a chunk of actions given environment observations.""" self.eval() - batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features] actions = self.model(batch)[0] - actions = self.unnormalize_outputs({ACTION: actions})[ACTION] return actions def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training or validation.""" - batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features] - batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) l1_loss = ( diff --git a/src/lerobot/policies/act/processor_act.py b/src/lerobot/policies/act/processor_act.py new file mode 100644 index 000000000..b0d2067e9 --- /dev/null +++ b/src/lerobot/policies/act/processor_act.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python + +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import torch + +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.act.configuration_act import ACTConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +def make_act_pre_post_processors( + config: ACTConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """Creates the pre- and post-processing pipelines for the ACT policy. + + The pre-processing pipeline handles normalization, batching, and device placement for the model inputs. + The post-processing pipeline handles unnormalization and moves the model outputs back to the CPU. + + Args: + config (ACTConfig): The ACT policy configuration object. + dataset_stats (dict[str, dict[str, torch.Tensor]] | None): A dictionary containing dataset + statistics (e.g., mean and std) used for normalization. Defaults to None. + + Returns: + tuple[PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[PolicyAction, PolicyAction]]: A tuple containing the + pre-processor pipeline and the post-processor pipeline. + """ + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + device=config.device, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 85d4d5981..747ead334 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -35,7 +35,6 @@ from torch import Tensor, nn from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig -from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import ( get_device_from_parameters, @@ -57,7 +56,6 @@ class DiffusionPolicy(PreTrainedPolicy): def __init__( self, config: DiffusionConfig, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ Args: @@ -70,14 +68,6 @@ class DiffusionPolicy(PreTrainedPolicy): config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - # queues are populated during rollout of the policy, they contain the n latest observations and actions self._queues = None @@ -106,9 +96,6 @@ class DiffusionPolicy(PreTrainedPolicy): batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} actions = self.diffusion.generate_actions(batch) - # TODO(rcadene): make above methods return output dictionary? - actions = self.unnormalize_outputs({ACTION: actions})[ACTION] - return actions @torch.no_grad() @@ -137,7 +124,6 @@ class DiffusionPolicy(PreTrainedPolicy): if ACTION in batch: batch.pop(ACTION) - batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) @@ -153,11 +139,9 @@ class DiffusionPolicy(PreTrainedPolicy): def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]: """Run the batch through the model and compute the loss for training or validation.""" - batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) - batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) # no output_dict so returning None return loss, None diff --git a/src/lerobot/policies/diffusion/processor_diffusion.py b/src/lerobot/policies/diffusion/processor_diffusion.py new file mode 100644 index 000000000..4383ec950 --- /dev/null +++ b/src/lerobot/policies/diffusion/processor_diffusion.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python + +# Copyright 2024 Columbia Artificial Intelligence, Robotics Lab, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import torch + +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +def make_diffusion_pre_post_processors( + config: DiffusionConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for a diffusion policy. + + The pre-processing pipeline prepares the input data for the model by: + 1. Renaming features. + 2. Normalizing the input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Moving the data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving the data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the diffusion policy, + containing feature definitions, normalization mappings, and device information. + dataset_stats: A dictionary of statistics used for normalization. + Defaults to None. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index ef56bdb61..06c0c4ba5 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -14,12 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging +from __future__ import annotations -from torch import nn +import logging +from typing import Any, TypedDict + +import torch +from typing_extensions import Unpack from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.datasets.utils import dataset_to_policy_features from lerobot.envs.configs import EnvConfig @@ -34,10 +39,32 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig +from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.processor.converters import ( + batch_to_transition, + policy_action_to_transition, + transition_to_batch, + transition_to_policy_action, +) -def get_policy_class(name: str) -> PreTrainedPolicy: - """Get the policy's class and config class given a name (matching the policy class' `name` attribute).""" +def get_policy_class(name: str) -> type[PreTrainedPolicy]: + """ + Retrieves a policy class by its registered name. + + This function uses dynamic imports to avoid loading all policy classes into memory + at once, improving startup time and reducing dependencies. + + Args: + name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", + "vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla". + + Returns: + The policy class corresponding to the given name. + + Raises: + NotImplementedError: If the policy name is not recognized. + """ if name == "tdmpc": from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy @@ -79,6 +106,24 @@ def get_policy_class(name: str) -> PreTrainedPolicy: def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: + """ + Instantiates a policy configuration object based on the policy type. + + This factory function simplifies the creation of policy configuration objects by + mapping a string identifier to the corresponding config class. + + Args: + policy_type: The type of the policy. Supported types include "tdmpc", + "diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla", + "reward_classifier". + **kwargs: Keyword arguments to be passed to the configuration class constructor. + + Returns: + An instance of a `PreTrainedConfig` subclass. + + Raises: + ValueError: If the `policy_type` is not recognized. + """ if policy_type == "tdmpc": return TDMPCConfig(**kwargs) elif policy_type == "diffusion": @@ -101,30 +146,187 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: raise ValueError(f"Policy type '{policy_type}' is not available.") +class ProcessorConfigKwargs(TypedDict, total=False): + """ + A TypedDict defining the keyword arguments for processor configuration. + + This provides type hints for the optional arguments passed to `make_pre_post_processors`, + improving code clarity and enabling static analysis. + + Attributes: + preprocessor_config_filename: The filename for the preprocessor configuration. + postprocessor_config_filename: The filename for the postprocessor configuration. + preprocessor_overrides: A dictionary of overrides for the preprocessor configuration. + postprocessor_overrides: A dictionary of overrides for the postprocessor configuration. + dataset_stats: Dataset statistics for normalization. + """ + + preprocessor_config_filename: str | None + postprocessor_config_filename: str | None + preprocessor_overrides: dict[str, Any] | None + postprocessor_overrides: dict[str, Any] | None + dataset_stats: dict[str, dict[str, torch.Tensor]] | None + + +def make_pre_post_processors( + policy_cfg: PreTrainedConfig, + pretrained_path: str | None = None, + **kwargs: Unpack[ProcessorConfigKwargs], +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Create or load pre- and post-processor pipelines for a given policy. + + This function acts as a factory. It can either load existing processor pipelines + from a pretrained path or create new ones from scratch based on the policy + configuration. Each policy type has a dedicated factory function for its + processors (e.g., `make_tdmpc_pre_post_processors`). + + Args: + policy_cfg: The configuration of the policy for which to create processors. + pretrained_path: An optional path to load pretrained processor pipelines from. + If provided, pipelines are loaded from this path. + **kwargs: Keyword arguments for processor configuration, as defined in + `ProcessorConfigKwargs`. + + Returns: + A tuple containing the input (pre-processor) and output (post-processor) pipelines. + + Raises: + NotImplementedError: If a processor factory is not implemented for the given + policy configuration type. + """ + if pretrained_path: + return ( + PolicyProcessorPipeline.from_pretrained( + pretrained_model_name_or_path=pretrained_path, + config_filename=kwargs.get( + "preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json" + ), + overrides=kwargs.get("preprocessor_overrides", {}), + to_transition=batch_to_transition, + to_output=transition_to_batch, + ), + PolicyProcessorPipeline.from_pretrained( + pretrained_model_name_or_path=pretrained_path, + config_filename=kwargs.get( + "postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json" + ), + overrides=kwargs.get("postprocessor_overrides", {}), + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) + + # Create a new processor based on policy type + if isinstance(policy_cfg, TDMPCConfig): + from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors + + processors = make_tdmpc_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, DiffusionConfig): + from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors + + processors = make_diffusion_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, ACTConfig): + from lerobot.policies.act.processor_act import make_act_pre_post_processors + + processors = make_act_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, VQBeTConfig): + from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors + + processors = make_vqbet_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, PI0Config): + from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors + + processors = make_pi0_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, PI0FASTConfig): + from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors + + processors = make_pi0fast_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, SACConfig): + from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors + + processors = make_sac_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, RewardClassifierConfig): + from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor + + processors = make_classifier_processor( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, SmolVLAConfig): + from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors + + processors = make_smolvla_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + else: + raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") + + return processors + + def make_policy( cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata | None = None, env_cfg: EnvConfig | None = None, ) -> PreTrainedPolicy: - """Make an instance of a policy class. + """ + Instantiate a policy model. - This function exists because (for now) we need to parse features from either a dataset or an environment - in order to properly dimension and instantiate a policy for that dataset or environment. + This factory function handles the logic of creating a policy, which requires + determining the input and output feature shapes. These shapes can be derived + either from a `LeRobotDatasetMetadata` object or an `EnvConfig` object. The function + can either initialize a new policy from scratch or load a pretrained one. Args: - cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will - be loaded with the weights from that path. - ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and - statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None. - env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be - provided if ds_meta is not. Defaults to None. - - Raises: - ValueError: Either ds_meta or env and env_cfg must be provided. - NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility) + cfg: The configuration for the policy to be created. If `cfg.pretrained_path` is + set, the policy will be loaded with weights from that path. + ds_meta: Dataset metadata used to infer feature shapes and types. Also provides + statistics for normalization layers. + env_cfg: Environment configuration used to infer feature shapes and types. + One of `ds_meta` or `env_cfg` must be provided. Returns: - PreTrainedPolicy: _description_ + An instantiated and device-placed policy model. + + Raises: + ValueError: If both or neither of `ds_meta` and `env_cfg` are provided. + NotImplementedError: If attempting to use an unsupported policy-backend + combination (e.g., VQBeT with 'mps'). """ if bool(ds_meta) == bool(env_cfg): raise ValueError("Either one of a dataset metadata or a sim env must be provided.") @@ -147,7 +349,6 @@ def make_policy( kwargs = {} if ds_meta is not None: features = dataset_to_policy_features(ds_meta.features) - kwargs["dataset_stats"] = ds_meta.stats else: if not cfg.pretrained_path: logging.warning( @@ -155,6 +356,8 @@ def make_policy( "rather than a dataset. Normalization modules inside the policy will have infinite values " "by default without stats from a dataset." ) + if env_cfg is None: + raise ValueError("env_cfg cannot be None when ds_meta is not provided") features = env_to_policy_features(env_cfg) cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} @@ -171,7 +374,7 @@ def make_policy( policy = policy_cls(**kwargs) policy.to(cfg.device) - assert isinstance(policy, nn.Module) + assert isinstance(policy, torch.nn.Module) # policy = torch.compile(policy, mode="reduce-overhead") diff --git a/src/lerobot/policies/normalize.py b/src/lerobot/policies/normalize.py deleted file mode 100644 index 119055873..000000000 --- a/src/lerobot/policies/normalize.py +++ /dev/null @@ -1,420 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np -import torch -from torch import Tensor, nn - -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature - - -def create_stats_buffers( - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, -) -> dict[str, dict[str, nn.ParameterDict]]: - """ - Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max - statistics. - - Args: (see Normalize and Unnormalize) - - Returns: - dict: A dictionary where keys are modalities and values are `nn.ParameterDict` containing - `nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation. - """ - stats_buffers = {} - - for key, ft in features.items(): - norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - assert isinstance(norm_mode, NormalizationMode) - - shape = tuple(ft.shape) - - if ft.type is FeatureType.VISUAL: - # sanity checks - assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}" - c, h, w = shape - assert c < h and c < w, f"{key} is not channel first ({shape=})" - # override image shape to be invariant to height and width - shape = (c, 1, 1) - - # Note: we initialize mean, std, min, max to infinity. They should be overwritten - # downstream by `stats` or `policy.load_state_dict`, as expected. During forward, - # we assert they are not infinity anymore. - - buffer = {} - if norm_mode is NormalizationMode.MEAN_STD: - mean = torch.ones(shape, dtype=torch.float32) * torch.inf - std = torch.ones(shape, dtype=torch.float32) * torch.inf - buffer = nn.ParameterDict( - { - "mean": nn.Parameter(mean, requires_grad=False), - "std": nn.Parameter(std, requires_grad=False), - } - ) - elif norm_mode is NormalizationMode.MIN_MAX: - min = torch.ones(shape, dtype=torch.float32) * torch.inf - max = torch.ones(shape, dtype=torch.float32) * torch.inf - buffer = nn.ParameterDict( - { - "min": nn.Parameter(min, requires_grad=False), - "max": nn.Parameter(max, requires_grad=False), - } - ) - - # TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch) - if stats: - if isinstance(stats[key]["mean"], np.ndarray): - if norm_mode is NormalizationMode.MEAN_STD: - buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) - buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) - elif norm_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) - buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) - elif isinstance(stats[key]["mean"], torch.Tensor): - # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated - # tensors anywhere (for example, when we use the same stats for normalization and - # unnormalization). See the logic here - # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. - if norm_mode is NormalizationMode.MEAN_STD: - buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32) - buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) - elif norm_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32) - buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) - else: - type_ = type(stats[key]["mean"]) - raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.") - - stats_buffers[key] = buffer - return stats_buffers - - -def _no_stats_error_str(name: str) -> str: - return ( - f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a " - "pretrained model." - ) - - -class Normalize(nn.Module): - """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training.""" - - def __init__( - self, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, - ): - """ - Args: - shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values - are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing - mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape - is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format. - modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values - are their normalization modes among: - - "mean_std": subtract the mean and divide by standard deviation. - - "min_max": map to [-1, 1] range. - stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") - and values are dictionaries of statistic types and their values (e.g. - `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for - training the model for the first time, these statistics will overwrite the default buffers. If - not provided, as expected for finetuning or evaluation, the default buffers should to be - overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the - dataset is not needed to get the stats, since they are already in the policy state_dict. - """ - super().__init__() - self.features = features - self.norm_map = norm_map - self.stats = stats - stats_buffers = create_stats_buffers(features, norm_map, stats) - for key, buffer in stats_buffers.items(): - setattr(self, "buffer_" + key.replace(".", "_"), buffer) - - # TODO(rcadene): should we remove torch.no_grad? - @torch.no_grad() - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - # TODO: Remove this shallow copy - batch = dict(batch) # shallow copy avoids mutating the input batch - for key, ft in self.features.items(): - if key not in batch: - # FIXME(aliberts, rcadene): This might lead to silent fail! - continue - - norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - buffer = getattr(self, "buffer_" + key.replace(".", "_")) - - if norm_mode is NormalizationMode.MEAN_STD: - mean = buffer["mean"] - std = buffer["std"] - assert not torch.isinf(mean).any(), _no_stats_error_str("mean") - assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[key] = (batch[key] - mean) / (std + 1e-8) - elif norm_mode is NormalizationMode.MIN_MAX: - min = buffer["min"] - max = buffer["max"] - assert not torch.isinf(min).any(), _no_stats_error_str("min") - assert not torch.isinf(max).any(), _no_stats_error_str("max") - # normalize to [0,1] - batch[key] = (batch[key] - min) / (max - min + 1e-8) - # normalize to [-1, 1] - batch[key] = batch[key] * 2 - 1 - else: - raise ValueError(norm_mode) - return batch - - -class Unnormalize(nn.Module): - """ - Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their - original range used by the environment. - """ - - def __init__( - self, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, - ): - """ - Args: - shapes (dict): A dictionary where keys are input modalities (e.g. "observation.image") and values - are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing - mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape - is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format. - modes (dict): A dictionary where keys are output modalities (e.g. "observation.image") and values - are their normalization modes among: - - "mean_std": subtract the mean and divide by standard deviation. - - "min_max": map to [-1, 1] range. - stats (dict, optional): A dictionary where keys are output modalities (e.g. "observation.image") - and values are dictionaries of statistic types and their values (e.g. - `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for - training the model for the first time, these statistics will overwrite the default buffers. If - not provided, as expected for finetuning or evaluation, the default buffers should to be - overwritten by a call to `policy.load_state_dict(state_dict)`. That way, initializing the - dataset is not needed to get the stats, since they are already in the policy state_dict. - """ - super().__init__() - self.features = features - self.norm_map = norm_map - self.stats = stats - # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` - stats_buffers = create_stats_buffers(features, norm_map, stats) - for key, buffer in stats_buffers.items(): - setattr(self, "buffer_" + key.replace(".", "_"), buffer) - - # TODO(rcadene): should we remove torch.no_grad? - @torch.no_grad() - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - batch = dict(batch) # shallow copy avoids mutating the input batch - for key, ft in self.features.items(): - if key not in batch: - continue - - norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - buffer = getattr(self, "buffer_" + key.replace(".", "_")) - - if norm_mode is NormalizationMode.MEAN_STD: - mean = buffer["mean"] - std = buffer["std"] - assert not torch.isinf(mean).any(), _no_stats_error_str("mean") - assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[key] = batch[key] * std + mean - elif norm_mode is NormalizationMode.MIN_MAX: - min = buffer["min"] - max = buffer["max"] - assert not torch.isinf(min).any(), _no_stats_error_str("min") - assert not torch.isinf(max).any(), _no_stats_error_str("max") - batch[key] = (batch[key] + 1) / 2 - batch[key] = batch[key] * (max - min) + min - else: - raise ValueError(norm_mode) - return batch - - -# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization -# and remove the `Normalize` and `Unnormalize` classes. -def _initialize_stats_buffers( - module: nn.Module, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, -) -> None: - """Register statistics buffers (mean/std or min/max) on the given *module*. - - The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`, - but is factored out so it can be reused by both classes and stay in sync. - """ - for key, ft in features.items(): - norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - shape: tuple[int, ...] = tuple(ft.shape) - if ft.type is FeatureType.VISUAL: - # reduce spatial dimensions, keep channel dimension only - c, *_ = shape - shape = (c, 1, 1) - - prefix = key.replace(".", "_") - - if norm_mode is NormalizationMode.MEAN_STD: - mean = torch.full(shape, torch.inf, dtype=torch.float32) - std = torch.full(shape, torch.inf, dtype=torch.float32) - - if stats and key in stats and "mean" in stats[key] and "std" in stats[key]: - mean_data = stats[key]["mean"] - std_data = stats[key]["std"] - if isinstance(mean_data, torch.Tensor): - # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated - # tensors anywhere (for example, when we use the same stats for normalization and - # unnormalization). See the logic here - # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. - mean = mean_data.clone().to(dtype=torch.float32) - std = std_data.clone().to(dtype=torch.float32) - else: - raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") - - module.register_buffer(f"{prefix}_mean", mean) - module.register_buffer(f"{prefix}_std", std) - continue - - if norm_mode is NormalizationMode.MIN_MAX: - min_val = torch.full(shape, torch.inf, dtype=torch.float32) - max_val = torch.full(shape, torch.inf, dtype=torch.float32) - - if stats and key in stats and "min" in stats[key] and "max" in stats[key]: - min_data = stats[key]["min"] - max_data = stats[key]["max"] - if isinstance(min_data, torch.Tensor): - min_val = min_data.clone().to(dtype=torch.float32) - max_val = max_data.clone().to(dtype=torch.float32) - else: - raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") - - module.register_buffer(f"{prefix}_min", min_val) - module.register_buffer(f"{prefix}_max", max_val) - continue - - raise ValueError(norm_mode) - - -class NormalizeBuffer(nn.Module): - """Same as `Normalize` but statistics are stored as registered buffers rather than parameters.""" - - def __init__( - self, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, - ): - super().__init__() - self.features = features - self.norm_map = norm_map - - _initialize_stats_buffers(self, features, norm_map, stats) - - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - batch = dict(batch) - for key, ft in self.features.items(): - if key not in batch: - continue - - norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - prefix = key.replace(".", "_") - - if norm_mode is NormalizationMode.MEAN_STD: - mean = getattr(self, f"{prefix}_mean") - std = getattr(self, f"{prefix}_std") - assert not torch.isinf(mean).any(), _no_stats_error_str("mean") - assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[key] = (batch[key] - mean) / (std + 1e-8) - continue - - if norm_mode is NormalizationMode.MIN_MAX: - min_val = getattr(self, f"{prefix}_min") - max_val = getattr(self, f"{prefix}_max") - assert not torch.isinf(min_val).any(), _no_stats_error_str("min") - assert not torch.isinf(max_val).any(), _no_stats_error_str("max") - batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8) - batch[key] = batch[key] * 2 - 1 - continue - - raise ValueError(norm_mode) - - return batch - - -class UnnormalizeBuffer(nn.Module): - """Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics.""" - - def __init__( - self, - features: dict[str, PolicyFeature], - norm_map: dict[str, NormalizationMode], - stats: dict[str, dict[str, Tensor]] | None = None, - ): - super().__init__() - self.features = features - self.norm_map = norm_map - - _initialize_stats_buffers(self, features, norm_map, stats) - - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - # batch = dict(batch) - for key, ft in self.features.items(): - if key not in batch: - continue - - norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) - if norm_mode is NormalizationMode.IDENTITY: - continue - - prefix = key.replace(".", "_") - - if norm_mode is NormalizationMode.MEAN_STD: - mean = getattr(self, f"{prefix}_mean") - std = getattr(self, f"{prefix}_std") - assert not torch.isinf(mean).any(), _no_stats_error_str("mean") - assert not torch.isinf(std).any(), _no_stats_error_str("std") - batch[key] = batch[key] * std + mean - continue - - if norm_mode is NormalizationMode.MIN_MAX: - min_val = getattr(self, f"{prefix}_min") - max_val = getattr(self, f"{prefix}_max") - assert not torch.isinf(min_val).any(), _no_stats_error_str("min") - assert not torch.isinf(max_val).any(), _no_stats_error_str("max") - batch[key] = (batch[key] + 1) / 2 - batch[key] = batch[key] * (max_val - min_val) + min_val - continue - - raise ValueError(norm_mode) - - return batch diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index de41e2bd4..66bd81e61 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -56,18 +56,15 @@ from collections import deque import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from transformers import AutoTokenizer -from lerobot.constants import ACTION, OBS_STATE -from lerobot.policies.normalize import Normalize, Unnormalize +from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi0.paligemma_with_expert import ( PaliGemmaWithExpertConfig, PaliGemmaWithExpertModel, ) from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.utils import log_model_loading_keys -from lerobot.utils.utils import get_safe_dtype, init_logging +from lerobot.utils.utils import get_safe_dtype def create_sinusoidal_pos_embedding( @@ -223,28 +220,17 @@ class PI0Policy(PreTrainedPolicy): def __init__( self, config: PI0Config, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ Args: config: Policy configuration class instance or None, in which case the default instantiation of the configuration class is used. - dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected - that they will be passed with a call to `load_state_dict` before the policy is used. """ super().__init__(config) config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") self.model = PI0FlowMatching(config) self.reset() @@ -253,99 +239,6 @@ class PI0Policy(PreTrainedPolicy): """This should be called whenever the environment is reset.""" self._action_queue = deque([], maxlen=self.config.n_action_steps) - @classmethod - def _transform_state_dict_keys(cls, state_dict: dict) -> dict: - """ - Transform state dict keys to match expected model structure. - - Transformations: - - model.paligemma_with_expert.paligemma.language_model.lm_head -> - model.paligemma_with_expert.paligemma.lm_head - - model.paligemma_with_expert.paligemma.language_model.model -> - model.paligemma_with_expert.paligemma.model.language_model - - model.paligemma_with_expert.paligemma.vision_tower -> - model.paligemma_with_expert.paligemma.model.vision_tower - - model.paligemma_with_expert.paligemma.multi_modal_projector -> - model.paligemma_with_expert.paligemma.model.multi_modal_projector - - Also handles tied weights between lm_head.weight and - embed_tokens.weight. - """ - import re - - transformed_dict = {} - - transformations = [ - ( - re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"), - ".paligemma_with_expert.paligemma.lm_head", - ), - ( - re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"), - ".paligemma_with_expert.paligemma.model.language_model", - ), - ( - re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"), - ".paligemma_with_expert.paligemma.model.vision_tower", - ), - ( - re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"), - ".paligemma_with_expert.paligemma.model.multi_modal_projector", - ), - ] - - for key, value in state_dict.items(): - new_key = key - for pattern, replacement in transformations: - new_key = pattern.sub(replacement, new_key) - transformed_dict[new_key] = value - - # Handle tied weights: lm_head.weight and embed_tokens.weight share memory - lm_head_key = None - embed_tokens_key = None - - for key in transformed_dict: - if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"): - lm_head_key = key - elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"): - embed_tokens_key = key - if lm_head_key and embed_tokens_key: - break - - if lm_head_key and not embed_tokens_key: - embed_tokens_key = lm_head_key.replace( - ".lm_head.weight", ".model.language_model.embed_tokens.weight" - ) - transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key] - elif embed_tokens_key and not lm_head_key: - lm_head_key = embed_tokens_key.replace( - ".model.language_model.embed_tokens.weight", ".lm_head.weight" - ) - transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key] - - return transformed_dict - - @classmethod - def _load_as_safetensor( - cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool - ) -> "PI0Policy": - """Override to apply key transformations before loading.""" - from safetensors.torch import load_file - - init_logging() - # Load the state dict from file safely - state_dict = load_file(model_file, device=map_location) - - # Apply key transformations - transformed_state_dict = cls._transform_state_dict_keys(state_dict) - - # Load the transformed state dict - msg = model.load_state_dict(transformed_state_dict, strict=strict) - - # Log message - log_model_loading_keys(msg.missing_keys, msg.unexpected_keys) - return model - def get_optim_params(self) -> dict: return self.parameters() @@ -377,14 +270,13 @@ class PI0Policy(PreTrainedPolicy): if self.config.adapt_to_pi_aloha: batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - batch = self.normalize_inputs(batch) - # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by # querying the policy. if len(self._action_queue) == 0: images, img_masks = self.prepare_images(batch) state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) + lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] + lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] actions = self.model.sample_actions( images, img_masks, lang_tokens, lang_masks, state, noise=noise @@ -394,8 +286,6 @@ class PI0Policy(PreTrainedPolicy): original_action_dim = self.config.action_feature.shape[0] actions = actions[:, :, :original_action_dim] - actions = self.unnormalize_outputs({"action": actions})["action"] - if self.config.adapt_to_pi_aloha: actions = self._pi_aloha_encode_actions(actions) @@ -410,12 +300,10 @@ class PI0Policy(PreTrainedPolicy): batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) - images, img_masks = self.prepare_images(batch) state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) + lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] + lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] actions = self.prepare_action(batch) actions_is_pad = batch.get("action_is_pad") @@ -482,26 +370,6 @@ class PI0Policy(PreTrainedPolicy): return images, img_masks - def prepare_language(self, batch) -> tuple[Tensor, Tensor]: - """Tokenize the text input""" - device = batch[OBS_STATE].device - tasks = batch["task"] - - # PaliGemma prompt has to end with a new line - tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] - - tokenized_prompt = self.language_tokenizer.__call__( - tasks, - padding="max_length", - padding_side="right", - max_length=self.config.tokenizer_max_length, - return_tensors="pt", - ) - lang_tokens = tokenized_prompt["input_ids"].to(device=device) - lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) - - return lang_tokens, lang_masks - def _pi_aloha_decode_state(self, state): # Flip the joints. for motor_idx in [1, 2, 8, 9]: @@ -567,7 +435,7 @@ class PI0FlowMatching(nn.Module): └──────────────────────────────┘ """ - def __init__(self, config): + def __init__(self, config: PI0Config): super().__init__() self.config = config diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py new file mode 100644 index 000000000..cd9712201 --- /dev/null +++ b/src/lerobot/policies/pi0/processor_pi0.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.pi0.configuration_pi0 import PI0Config +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + ComplementaryDataProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + TokenizerProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +@ProcessorStepRegistry.register(name="pi0_new_line_processor") +class Pi0NewLineProcessor(ComplementaryDataProcessorStep): + """ + Ensures that the task description string ends with a newline character. + + This processing step is required for compatibility with the PaliGemma tokenizer, + which expects a newline at the end of the text prompt. It handles both single + strings and lists of strings for the 'task' key in complementary data. + """ + + def complementary_data(self, complementary_data): + """ + Adds a newline to the 'task' field if it doesn't already have one. + + Args: + complementary_data: A dictionary that may contain a 'task' key with a + string or list of strings. + + Returns: + A new dictionary with the modified 'task' field. + """ + if "task" not in complementary_data: + return complementary_data + + task = complementary_data["task"] + if task is None: + return complementary_data + + new_complementary_data = dict(complementary_data) + + # Handle both string and list of strings + if isinstance(task, str): + # Single string: add newline if not present + if not task.endswith("\n"): + new_complementary_data["task"] = f"{task}\n" + elif isinstance(task, list) and all(isinstance(t, str) for t in task): + # List of strings: add newline to each if not present + new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task] + # If task is neither string nor list of strings, leave unchanged + + return new_complementary_data + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + This step does not alter the feature definitions. + + Args: + features: The input feature dictionary. + + Returns: + The unchanged feature dictionary. + """ + return features + + +def make_pi0_pre_post_processors( + config: PI0Config, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the PI0 policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Appending a newline character to the task description for tokenizer compatibility. + 5. Tokenizing the text prompt using the PaliGemma tokenizer. + 6. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the PI0 policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + # Add remaining processors + input_steps: list[ProcessorStep] = [ + RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one + AddBatchDimensionProcessorStep(), + Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma + TokenizerProcessorStep( + tokenizer_name="google/paligemma-3b-pt-224", + max_length=config.tokenizer_max_length, + padding_side="right", + padding="max_length", + ), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + + output_steps: list[ProcessorStep] = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/pi0fast/modeling_pi0fast.py b/src/lerobot/policies/pi0fast/modeling_pi0fast.py index 88727b581..682a372f4 100644 --- a/src/lerobot/policies/pi0fast/modeling_pi0fast.py +++ b/src/lerobot/policies/pi0fast/modeling_pi0fast.py @@ -58,7 +58,6 @@ from transformers.cache_utils import HybridCache, StaticCache from transformers.models.auto import CONFIG_MAPPING from lerobot.constants import ACTION, OBS_STATE -from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.policies.pretrained import PreTrainedPolicy @@ -146,14 +145,6 @@ class PI0FASTPolicy(PreTrainedPolicy): config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224") self.model = PI0FAST(config) @@ -221,8 +212,6 @@ class PI0FASTPolicy(PreTrainedPolicy): if self.config.adapt_to_pi_aloha: batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - batch = self.normalize_inputs(batch) - # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by # querying the policy. if len(self._action_queue) == 0: @@ -235,8 +224,6 @@ class PI0FASTPolicy(PreTrainedPolicy): ] # self.config.max_action_dim # self.config.action_feature.shape[0] actions = actions[:, :, :original_action_dim] - actions = self.unnormalize_outputs({"action": actions})["action"] - if self.config.adapt_to_pi_aloha: actions = self._pi_aloha_encode_actions(actions) @@ -249,8 +236,6 @@ class PI0FASTPolicy(PreTrainedPolicy): if self.config.adapt_to_pi_aloha: batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) loss_dict = self.model.forward(batch) return loss_dict["loss"], loss_dict diff --git a/src/lerobot/policies/pi0fast/processor_pi0fast.py b/src/lerobot/policies/pi0fast/processor_pi0fast.py new file mode 100644 index 000000000..81314aa37 --- /dev/null +++ b/src/lerobot/policies/pi0fast/processor_pi0fast.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +def make_pi0fast_pre_post_processors( + config: PI0FASTConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the PI0Fast policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the PI0Fast policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index 878f3cdd8..fcaf02a4b 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -28,7 +28,6 @@ import torch.nn.functional as F # noqa: N812 from torch import Tensor from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution -from lerobot.policies.normalize import NormalizeBuffer from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature from lerobot.policies.utils import get_device_from_parameters @@ -45,7 +44,6 @@ class SACPolicy( def __init__( self, config: SACConfig | None = None, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): super().__init__(config) config.validate_features() @@ -53,7 +51,6 @@ class SACPolicy( # Determine action dimension and initialize all components continuous_action_dim = config.output_features["action"].shape[0] - self._init_normalization(dataset_stats) self._init_encoders() self._init_critics(continuous_action_dim) self._init_actor(continuous_action_dim) @@ -88,8 +85,7 @@ class SACPolicy( observations_features = None if self.shared_encoder and self.actor.encoder.has_images: - # Cache and normalize image features - observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True) + observations_features = self.actor.encoder.get_cached_image_features(batch) actions, _, _ = self.actor(batch, observations_features) @@ -391,28 +387,12 @@ class SACPolicy( actor_loss = ((self.temperature * log_probs) - min_q_preds).mean() return actor_loss - def _init_normalization(self, dataset_stats): - """Initialize input/output normalization modules.""" - self.normalize_inputs = nn.Identity() - self.normalize_targets = nn.Identity() - if self.config.dataset_stats is not None: - params = _convert_normalization_params_to_tensor(self.config.dataset_stats) - self.normalize_inputs = NormalizeBuffer( - self.config.input_features, self.config.normalization_mapping, params - ) - stats = dataset_stats or params - self.normalize_targets = NormalizeBuffer( - self.config.output_features, self.config.normalization_mapping, stats - ) - def _init_encoders(self): """Initialize shared or separate encoders for actor and critic.""" self.shared_encoder = self.config.shared_encoder - self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs) + self.encoder_critic = SACObservationEncoder(self.config) self.encoder_actor = ( - self.encoder_critic - if self.shared_encoder - else SACObservationEncoder(self.config, self.normalize_inputs) + self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config) ) def _init_critics(self, continuous_action_dim): @@ -424,9 +404,7 @@ class SACPolicy( ) for _ in range(self.config.num_critics) ] - self.critic_ensemble = CriticEnsemble( - encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets - ) + self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads) target_heads = [ CriticHead( input_dim=self.encoder_critic.output_dim + continuous_action_dim, @@ -434,9 +412,7 @@ class SACPolicy( ) for _ in range(self.config.num_critics) ] - self.critic_target = CriticEnsemble( - encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets - ) + self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads) self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) if self.config.use_torch_compile: @@ -490,10 +466,9 @@ class SACPolicy( class SACObservationEncoder(nn.Module): """Encode image and/or state vector observations.""" - def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None: + def __init__(self, config: SACConfig) -> None: super().__init__() self.config = config - self.input_normalization = input_normalizer self._init_image_layers() self._init_state_layers() self._compute_output_dim() @@ -568,11 +543,10 @@ class SACObservationEncoder(nn.Module): def forward( self, obs: dict[str, Tensor], cache: dict[str, Tensor] | None = None, detach: bool = False ) -> Tensor: - obs = self.input_normalization(obs) parts = [] if self.has_images: if cache is None: - cache = self.get_cached_image_features(obs, normalize=False) + cache = self.get_cached_image_features(obs) parts.append(self._encode_images(cache, detach)) if self.has_env: parts.append(self.env_encoder(obs["observation.environment_state"])) @@ -585,7 +559,7 @@ class SACObservationEncoder(nn.Module): "No parts to concatenate, you should have at least one image or environment state or state" ) - def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]: + def get_cached_image_features(self, obs: dict[str, Tensor]) -> dict[str, Tensor]: """Extract and optionally cache image features from observations. This function processes image observations through the vision encoder once and returns @@ -597,26 +571,17 @@ class SACObservationEncoder(nn.Module): - The vision encoder forward pass is typically the main computational bottleneck during training and inference - Caching these features can provide 2-4x speedup in training and inference - Normalization behavior: - - When called from inside forward(): set normalize=False since inputs are already normalized - - When called from outside forward(): set normalize=True to ensure proper input normalization - Usage patterns: - - Called in select_action() with normalize=True + - Called in select_action() - Called in learner.py's get_observation_features() to pre-compute features for all policy components - - Called internally by forward() with normalize=False + - Called internally by forward() Args: obs: Dictionary of observation tensors containing image keys - normalize: Whether to normalize observations before encoding - Set to True when calling directly from outside the encoder's forward method - Set to False when calling from within forward() where inputs are already normalized Returns: Dictionary mapping image keys to their corresponding encoded features """ - if normalize: - obs = self.input_normalization(obs) batched = torch.cat([obs[k] for k in self.image_keys], dim=0) out = self.image_encoder(batched) chunks = torch.chunk(out, len(self.image_keys), dim=0) @@ -747,7 +712,6 @@ class CriticEnsemble(nn.Module): Args: encoder (SACObservationEncoder): encoder for observations. ensemble (List[CriticHead]): list of critic heads. - output_normalization (nn.Module): normalization layer for actions. init_final (float | None): optional initializer scale for final layers. Forward returns a tensor of shape (num_critics, batch_size) containing Q-values. @@ -757,13 +721,11 @@ class CriticEnsemble(nn.Module): self, encoder: SACObservationEncoder, ensemble: list[CriticHead], - output_normalization: nn.Module, init_final: float | None = None, ): super().__init__() self.encoder = encoder self.init_final = init_final - self.output_normalization = output_normalization self.critics = nn.ModuleList(ensemble) def forward( @@ -775,11 +737,6 @@ class CriticEnsemble(nn.Module): device = get_device_from_parameters(self) # Move each tensor in observations to device observations = {k: v.to(device) for k, v in observations.items()} - # NOTE: We normalize actions it helps for sample efficiency - actions: dict[str, torch.tensor] = {"action": actions} - # NOTE: Normalization layer took dict in input and outputs a dict that why - actions = self.output_normalization(actions)["action"] - actions = actions.to(device) obs_enc = self.encoder(observations, cache=observation_features) diff --git a/src/lerobot/policies/sac/processor_sac.py b/src/lerobot/policies/sac/processor_sac.py new file mode 100644 index 000000000..9e8013d31 --- /dev/null +++ b/src/lerobot/policies/sac/processor_sac.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.sac.configuration_sac import SACConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +def make_sac_pre_post_processors( + config: SACConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the SAC policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the SAC policy. + dataset_stats: A dictionary of statistics for normalization. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + # Add remaining processors + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/sac/reward_model/modeling_classifier.py b/src/lerobot/policies/sac/reward_model/modeling_classifier.py index cadd1c9f2..ca501c3a7 100644 --- a/src/lerobot/policies/sac/reward_model/modeling_classifier.py +++ b/src/lerobot/policies/sac/reward_model/modeling_classifier.py @@ -20,7 +20,6 @@ import torch from torch import Tensor, nn from lerobot.constants import OBS_IMAGE, REWARD -from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig @@ -108,22 +107,12 @@ class Classifier(PreTrainedPolicy): def __init__( self, config: RewardClassifierConfig, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): from transformers import AutoModel super().__init__(config) self.config = config - # Initialize normalization (standardized with the policy framework) - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - # Set up encoder encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True) # Extract vision model if we're given a multimodal model @@ -247,10 +236,6 @@ class Classifier(PreTrainedPolicy): def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: """Standard forward pass for training compatible with train.py.""" - # Normalize inputs if needed - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) - # Extract images and labels images, labels = self.extract_images_and_labels(batch) diff --git a/src/lerobot/policies/sac/reward_model/processor_classifier.py b/src/lerobot/policies/sac/reward_model/processor_classifier.py new file mode 100644 index 000000000..c2a34eab2 --- /dev/null +++ b/src/lerobot/policies/sac/reward_model/processor_classifier.py @@ -0,0 +1,82 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig +from lerobot.processor import ( + DeviceProcessorStep, + IdentityProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +def make_classifier_processor( + config: RewardClassifierConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the reward classifier. + + The pre-processing pipeline prepares input data for the classifier by: + 1. Normalizing both input and output features based on dataset statistics. + 2. Moving the data to the specified device. + + The post-processing pipeline handles the classifier's output by: + 1. Moving the data to the CPU. + 2. Applying an identity step, as no unnormalization is needed for the output logits. + + Args: + config: The configuration object for the RewardClassifier. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + input_steps = [ + NormalizerProcessorStep( + features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + NormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device=config.device), + ] + output_steps = [DeviceProcessorStep(device="cpu"), IdentityProcessorStep()] + + return ( + PolicyProcessorPipeline( + steps=input_steps, + name="classifier_preprocessor", + ), + PolicyProcessorPipeline( + steps=output_steps, + name="classifier_postprocessor", + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 18f2fc58a..48d4b2315 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -53,21 +53,13 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base") """ import math -import os -import re from collections import deque -import safetensors import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from transformers import AutoProcessor -from lerobot.constants import ACTION, OBS_STATE -from lerobot.policies.normalize import ( - Normalize, - Unnormalize, -) +from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel @@ -76,102 +68,6 @@ from lerobot.policies.utils import ( ) from lerobot.utils.utils import get_safe_dtype -# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker -_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") - - -def canonicalise(k: str) -> str: - """ - Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a - normalisation-buffer key. - """ - return _VARIANT_RE.sub(".buffer_", k) - - -def standardise_state_dict( - checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True -) -> tuple[dict[str, torch.Tensor], list[str]]: - """ - • Re-keys `checkpoint ` so that every entry matches the *reference* key set. - • If several variant keys collapse to the same canonical name we keep the - first one and log the collision. - • Returns the new dict + a list of entries that could not be matched. - """ - out, collisions, unmatched = {}, {}, [] - - for k, v in checkpoint.items(): - canon = canonicalise(k) - if canon in ref_keys: - if canon in out: # duplicate after collapsing - collisions.setdefault(canon, []).append(k) - else: - out[canon] = v - else: - unmatched.append(k) - - if verbose: - for canon, variants in collisions.items(): - print(f"[standardise_state_dict] '{canon}' ← {variants}") - if unmatched: - print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys") - - out.update({k: checkpoint[k] for k in unmatched}) - return out, unmatched - - -def rename_checkpoint_keys(checkpoint: dict, rename_str: str): - """ - Renames keys in a checkpoint dictionary based on the given rename string. - - Args: - checkpoint (dict): The checkpoint dictionary. - rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2". - - Returns: - dict: The modified checkpoint with renamed keys. - """ - - rename_dict = dict(pair.split("//") for pair in rename_str.split(",")) - - new_checkpoint = {} - for k, v in checkpoint.items(): - for old_key, new_key in rename_dict.items(): - if old_key in k: - k = k.replace(old_key, new_key) - new_checkpoint[k] = v - return new_checkpoint - - -def load_smolvla( - model: torch.nn.Module, - filename: str | os.PathLike, - *, - device: str = "cpu", - checkpoint_keys_mapping: str = "", -) -> torch.nn.Module: - state_dict = safetensors.torch.load_file(filename, device=device) - - # Optional user-supplied renames (e.g. "model._orig_mod.//model.") - if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping: - state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping) - - state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys())) - - # HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset - norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs") - state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)} - - missing, unexpected = model.load_state_dict(state_dict, strict=False) - - if not all(key.startswith(norm_keys) for key in missing) or unexpected: - raise RuntimeError( - "SmolVLA %d missing / %d unexpected keys", - len(missing), - len(unexpected), - ) - - return model - def create_sinusoidal_pos_embedding( time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" @@ -326,28 +222,17 @@ class SmolVLAPolicy(PreTrainedPolicy): def __init__( self, config: SmolVLAConfig, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ Args: config: Policy configuration class instance or None, in which case the default instantiation of the configuration class is used. - dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected - that they will be passed with a call to `load_state_dict` before the policy is used. """ super().__init__(config) config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer self.model = VLAFlowMatching(config) self.reset() @@ -357,23 +242,6 @@ class SmolVLAPolicy(PreTrainedPolicy): ACTION: deque(maxlen=self.config.n_action_steps), } - # HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues - @classmethod - def _load_as_safetensor( - cls, - model: "SmolVLAPolicy", - model_file: str, - map_location: str, - strict: bool, - ): - safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) - return load_smolvla( - model, - model_file, - device=map_location, - checkpoint_keys_mapping="model._orig_mod.//model.", - ) - def get_optim_params(self) -> dict: return self.parameters() @@ -389,7 +257,8 @@ class SmolVLAPolicy(PreTrainedPolicy): images, img_masks = self.prepare_images(batch) state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) + lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] + lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise) @@ -397,8 +266,6 @@ class SmolVLAPolicy(PreTrainedPolicy): original_action_dim = self.config.action_feature.shape[0] actions = actions[:, :, :original_action_dim] - actions = self.unnormalize_outputs({ACTION: actions})[ACTION] - if self.config.adapt_to_pi_aloha: actions = self._pi_aloha_encode_actions(actions) @@ -408,8 +275,6 @@ class SmolVLAPolicy(PreTrainedPolicy): if self.config.adapt_to_pi_aloha: batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - batch = self.normalize_inputs(batch) - return batch @torch.no_grad() @@ -450,11 +315,11 @@ class SmolVLAPolicy(PreTrainedPolicy): if self.config.adapt_to_pi_aloha: batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) + images, img_masks = self.prepare_images(batch) state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) + lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] + lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] actions = self.prepare_action(batch) actions_is_pad = batch.get("actions_id_pad") loss_dict = {} @@ -518,30 +383,6 @@ class SmolVLAPolicy(PreTrainedPolicy): img_masks.append(mask) return images, img_masks - def prepare_language(self, batch) -> tuple[Tensor, Tensor]: - """Tokenize the text input""" - device = batch[OBS_STATE].device - tasks = batch["task"] - if isinstance(tasks, str): - tasks = [tasks] - - if len(tasks) == 1: - tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])] - - tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] - - tokenized_prompt = self.language_tokenizer.__call__( - tasks, - padding=self.config.pad_language_to, - padding_side="right", - max_length=self.config.tokenizer_max_length, - return_tensors="pt", - ) - lang_tokens = tokenized_prompt["input_ids"].to(device=device) - lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) - - return lang_tokens, lang_masks - def _pi_aloha_decode_state(self, state): # Flip the joints. for motor_idx in [1, 2, 8, 9]: diff --git a/src/lerobot/policies/smolvla/processor_smolvla.py b/src/lerobot/policies/smolvla/processor_smolvla.py new file mode 100644 index 000000000..ac3cd4626 --- /dev/null +++ b/src/lerobot/policies/smolvla/processor_smolvla.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python + +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + ComplementaryDataProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + TokenizerProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +def make_smolvla_pre_post_processors( + config: SmolVLAConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the SmolVLA policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Ensuring the language task description ends with a newline character. + 5. Tokenizing the language task description. + 6. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output actions to their original scale. + + Args: + config: The configuration object for the SmolVLA policy. + dataset_stats: A dictionary of statistics for normalization. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one + AddBatchDimensionProcessorStep(), + SmolVLANewLineProcessor(), + TokenizerProcessorStep( + tokenizer_name=config.vlm_model_name, + padding=config.pad_language_to, + padding_side="right", + max_length=config.tokenizer_max_length, + ), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) + + +@ProcessorStepRegistry.register(name="smolvla_new_line_processor") +class SmolVLANewLineProcessor(ComplementaryDataProcessorStep): + """ + A processor step that ensures the 'task' description ends with a newline character. + + This step is necessary for certain tokenizers (e.g., PaliGemma) that expect a + newline at the end of the prompt. It handles both single string tasks and lists + of string tasks. + """ + + def complementary_data(self, complementary_data): + if "task" not in complementary_data: + return complementary_data + + task = complementary_data["task"] + if task is None: + return complementary_data + + new_complementary_data = dict(complementary_data) + + # Handle both string and list of strings + if isinstance(task, str): + # Single string: add newline if not present + if not task.endswith("\n"): + new_complementary_data["task"] = f"{task}\n" + elif isinstance(task, list) and all(isinstance(t, str) for t in task): + # List of strings: add newline to each if not present + new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task] + # If task is neither string nor list of strings, leave unchanged + + return new_complementary_data + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index 7ba88e5e6..e160310b3 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -36,7 +36,6 @@ import torch.nn.functional as F # noqa: N812 from torch import Tensor from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD -from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues @@ -63,26 +62,19 @@ class TDMPCPolicy(PreTrainedPolicy): config_class = TDMPCConfig name = "tdmpc" - def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None): + def __init__( + self, + config: TDMPCConfig, + ): """ Args: config: Policy configuration class instance or None, in which case the default instantiation of the configuration class is used. - dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected - that they will be passed with a call to `load_state_dict` before the policy is used. """ super().__init__(config) config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.model = TDMPCTOLD(config) self.model_target = deepcopy(self.model) for param in self.model_target.parameters(): @@ -137,7 +129,6 @@ class TDMPCPolicy(PreTrainedPolicy): actions = torch.clamp(actions, -1, +1) - actions = self.unnormalize_outputs({ACTION: actions})[ACTION] return actions @torch.no_grad() @@ -147,11 +138,12 @@ class TDMPCPolicy(PreTrainedPolicy): if ACTION in batch: batch.pop(ACTION) - batch = self.normalize_inputs(batch) - if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] + # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out + if ACTION in batch: + batch.pop(ACTION) self._queues = populate_queues(self._queues, batch) @@ -320,11 +312,9 @@ class TDMPCPolicy(PreTrainedPolicy): """ device = get_device_from_parameters(self) - batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] - batch = self.normalize_targets(batch) info = {} diff --git a/src/lerobot/policies/tdmpc/processor_tdmpc.py b/src/lerobot/policies/tdmpc/processor_tdmpc.py new file mode 100644 index 000000000..75a7d4f7e --- /dev/null +++ b/src/lerobot/policies/tdmpc/processor_tdmpc.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python + +# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import torch + +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +def make_tdmpc_pre_post_processors( + config: TDMPCConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the TDMPC policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the TDMPC policy. + dataset_stats: A dictionary of statistics for normalization. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index feb65bb4c..bb6040e90 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -28,7 +28,6 @@ import torchvision from torch import Tensor, nn from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE -from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig @@ -48,7 +47,6 @@ class VQBeTPolicy(PreTrainedPolicy): def __init__( self, config: VQBeTConfig | None = None, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ Args: @@ -61,14 +59,6 @@ class VQBeTPolicy(PreTrainedPolicy): config.validate_features() self.config = config - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.vqbet = VQBeTModel(config) self.reset() @@ -128,7 +118,6 @@ class VQBeTPolicy(PreTrainedPolicy): def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size] - actions = self.unnormalize_outputs({ACTION: actions})[ACTION] return actions @torch.no_grad() @@ -142,10 +131,12 @@ class VQBeTPolicy(PreTrainedPolicy): # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out if ACTION in batch: batch.pop(ACTION) - batch = self.normalize_inputs(batch) batch = dict(batch) # shallow copy so that adding a key doesn't modify the original # NOTE: It's important that this happens after stacking the images into a single key. batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out + if ACTION in batch: + batch.pop(ACTION) self._queues = populate_queues(self._queues, batch) @@ -165,10 +156,8 @@ class VQBeTPolicy(PreTrainedPolicy): def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training or validation.""" - batch = self.normalize_inputs(batch) batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) - batch = self.normalize_targets(batch) # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181) if not self.vqbet.action_head.vqvae_model.discretized.item(): # loss: total loss of training RVQ diff --git a/src/lerobot/policies/vqbet/processor_vqbet.py b/src/lerobot/policies/vqbet/processor_vqbet.py new file mode 100644 index 000000000..1c741cd33 --- /dev/null +++ b/src/lerobot/policies/vqbet/processor_vqbet.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python + +# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru +# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import torch + +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + RenameObservationsProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action + + +def make_vqbet_pre_post_processors( + config: VQBeTConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the VQ-BeT policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features, allowing customization to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the VQ-BeT policy. + dataset_stats: A dictionary of statistics for normalization. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + input_steps = [ + RenameObservationsProcessorStep(rename_map={}), # Let the possibility to the user to rename the keys + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + ] + output_steps = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 8dd244c27..be11ac1af 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -14,41 +14,120 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .device_processor import DeviceProcessor -from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor -from .observation_processor import VanillaObservationProcessor -from .pipeline import ( - ActionProcessor, - DoneProcessor, +from .batch_processor import AddBatchDimensionProcessorStep +from .converters import ( + batch_to_transition, + create_transition, + transition_to_batch, +) +from .core import ( + EnvAction, EnvTransition, - IdentityProcessor, - InfoProcessor, - ObservationProcessor, + PolicyAction, + RobotAction, + RobotObservation, + TransitionKey, +) +from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep +from .device_processor import DeviceProcessorStep +from .factory import ( + make_default_processors, + make_default_robot_action_processor, + make_default_robot_observation_processor, + make_default_teleop_action_processor, +) +from .gym_action_processor import ( + Numpy2TorchActionProcessorStep, + Torch2NumpyActionProcessorStep, +) +from .hil_processor import ( + AddTeleopActionAsComplimentaryDataStep, + AddTeleopEventsAsInfoStep, + GripperPenaltyProcessorStep, + ImageCropResizeProcessorStep, + InterventionActionProcessorStep, + RewardClassifierProcessorStep, + TimeLimitProcessorStep, +) +from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep +from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats +from .observation_processor import VanillaObservationProcessorStep +from .pipeline import ( + ActionProcessorStep, + ComplementaryDataProcessorStep, + DataProcessorPipeline, + DoneProcessorStep, + IdentityProcessorStep, + InfoProcessorStep, + ObservationProcessorStep, + PolicyActionProcessorStep, + PolicyProcessorPipeline, + ProcessorKwargs, ProcessorStep, ProcessorStepRegistry, - RewardProcessor, - RobotProcessor, - TransitionKey, - TruncatedProcessor, + RewardProcessorStep, + RobotActionProcessorStep, + RobotProcessorPipeline, + TruncatedProcessorStep, ) -from .rename_processor import RenameProcessor +from .policy_robot_bridge import ( + PolicyActionToRobotActionProcessorStep, + RobotActionToPolicyActionProcessorStep, +) +from .rename_processor import RenameObservationsProcessorStep +from .tokenizer_processor import TokenizerProcessorStep __all__ = [ - "ActionProcessor", - "DeviceProcessor", - "DoneProcessor", + "ActionProcessorStep", + "AddTeleopActionAsComplimentaryDataStep", + "AddTeleopEventsAsInfoStep", + "ComplementaryDataProcessorStep", + "batch_to_transition", + "create_transition", + "DeviceProcessorStep", + "DoneProcessorStep", + "EnvAction", "EnvTransition", - "IdentityProcessor", - "InfoProcessor", - "NormalizerProcessor", - "UnnormalizerProcessor", - "ObservationProcessor", + "GripperPenaltyProcessorStep", + "hotswap_stats", + "IdentityProcessorStep", + "ImageCropResizeProcessorStep", + "InfoProcessorStep", + "InterventionActionProcessorStep", + "JointVelocityProcessorStep", + "make_default_processors", + "make_default_teleop_action_processor", + "make_default_robot_action_processor", + "make_default_robot_observation_processor", + "MapDeltaActionToRobotActionStep", + "MapTensorToDeltaActionDictStep", + "MotorCurrentProcessorStep", + "NormalizerProcessorStep", + "Numpy2TorchActionProcessorStep", + "ObservationProcessorStep", + "PolicyAction", + "PolicyActionProcessorStep", + "PolicyProcessorPipeline", + "ProcessorKwargs", "ProcessorStep", "ProcessorStepRegistry", - "RenameProcessor", - "RewardProcessor", - "RobotProcessor", + "RobotAction", + "RobotActionProcessorStep", + "RobotObservation", + "RenameObservationsProcessorStep", + "RewardClassifierProcessorStep", + "RewardProcessorStep", + "DataProcessorPipeline", + "TimeLimitProcessorStep", + "AddBatchDimensionProcessorStep", + "RobotProcessorPipeline", + "TokenizerProcessorStep", + "Torch2NumpyActionProcessorStep", + "RobotActionToPolicyActionProcessorStep", + "PolicyActionToRobotActionProcessorStep", + "transition_to_batch", "TransitionKey", - "TruncatedProcessor", - "VanillaObservationProcessor", + "TruncatedProcessorStep", + "UnnormalizerProcessorStep", + "VanillaObservationProcessorStep", ] diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py new file mode 100644 index 000000000..a563599cd --- /dev/null +++ b/src/lerobot/processor/batch_processor.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script defines processor steps for adding a batch dimension to various components of an environment transition. + +These steps are designed to process actions, observations, and complementary data, making them suitable for batch processing by adding a leading dimension. This is a common requirement before feeding data into a neural network model. +""" + +from dataclasses import dataclass, field + +from torch import Tensor + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE + +from .core import EnvTransition, PolicyAction +from .pipeline import ( + ComplementaryDataProcessorStep, + ObservationProcessorStep, + PolicyActionProcessorStep, + ProcessorStep, + ProcessorStepRegistry, + TransitionKey, +) + + +@dataclass +@ProcessorStepRegistry.register(name="to_batch_processor_action") +class AddBatchDimensionActionStep(PolicyActionProcessorStep): + """ + Processor step to add a batch dimension to a 1D tensor action. + + This is useful for creating a batch of size 1 from a single action sample. + """ + + def action(self, action: PolicyAction) -> PolicyAction: + """ + Adds a batch dimension to the action if it's a 1D tensor. + + Args: + action: The action tensor. + + Returns: + The action tensor with an added batch dimension. + """ + if action.dim() != 1: + return action + return action.unsqueeze(0) + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Returns the input features unchanged. + + Adding a batch dimension does not alter the feature definition. + + Args: + features: A dictionary of policy features. + + Returns: + The original dictionary of policy features. + """ + return features + + +@dataclass +@ProcessorStepRegistry.register(name="to_batch_processor_observation") +class AddBatchDimensionObservationStep(ObservationProcessorStep): + """ + Processor step to add a batch dimension to observations. + + It handles different types of observations: + - State vectors (1D tensors). + - Single images (3D tensors). + - Dictionaries of multiple images (3D tensors). + """ + + def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]: + """ + Adds a batch dimension to tensor-based observations in the observation dictionary. + + Args: + observation: The observation dictionary. + + Returns: + The observation dictionary with batch dimensions added to tensors. + """ + # Process state observations - add batch dim if 1D + for state_key in [OBS_STATE, OBS_ENV_STATE]: + if state_key in observation: + state_value = observation[state_key] + if isinstance(state_value, Tensor) and state_value.dim() == 1: + observation[state_key] = state_value.unsqueeze(0) + + # Process single image observation - add batch dim if 3D + if OBS_IMAGE in observation: + image_value = observation[OBS_IMAGE] + if isinstance(image_value, Tensor) and image_value.dim() == 3: + observation[OBS_IMAGE] = image_value.unsqueeze(0) + + # Process multiple image observations - add batch dim if 3D + for key, value in observation.items(): + if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3: + observation[key] = value.unsqueeze(0) + return observation + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Returns the input features unchanged. + + Adding a batch dimension does not alter the feature definition. + + Args: + features: A dictionary of policy features. + + Returns: + The original dictionary of policy features. + """ + return features + + +@dataclass +@ProcessorStepRegistry.register(name="to_batch_processor_complementary_data") +class AddBatchDimensionComplementaryDataStep(ComplementaryDataProcessorStep): + """ + Processor step to add a batch dimension to complementary data fields. + + Handles specific keys like 'task', 'index', and 'task_index' to make them batched. + - 'task' (str) is wrapped in a list. + - 'index' and 'task_index' (0D tensors) get a batch dimension. + """ + + def complementary_data(self, complementary_data: dict) -> dict: + """ + Adds a batch dimension to specific fields in the complementary data dictionary. + + Args: + complementary_data: The complementary data dictionary. + + Returns: + The complementary data dictionary with batch dimensions added. + """ + # Process task field - wrap string in list to add batch dimension + if "task" in complementary_data: + task_value = complementary_data["task"] + if isinstance(task_value, str): + complementary_data["task"] = [task_value] + + # Process index field - add batch dim if 0D + if "index" in complementary_data: + index_value = complementary_data["index"] + if isinstance(index_value, Tensor) and index_value.dim() == 0: + complementary_data["index"] = index_value.unsqueeze(0) + + # Process task_index field - add batch dim if 0D + if "task_index" in complementary_data: + task_index_value = complementary_data["task_index"] + if isinstance(task_index_value, Tensor) and task_index_value.dim() == 0: + complementary_data["task_index"] = task_index_value.unsqueeze(0) + return complementary_data + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Returns the input features unchanged. + + Adding a batch dimension does not alter the feature definition. + + Args: + features: A dictionary of policy features. + + Returns: + The original dictionary of policy features. + """ + return features + + +@dataclass +@ProcessorStepRegistry.register(name="to_batch_processor") +class AddBatchDimensionProcessorStep(ProcessorStep): + """ + A composite processor step that adds a batch dimension to the entire environment transition. + + This step combines individual processors for actions, observations, and complementary data + to create a batched transition (batch size 1) from a single-instance transition. + + Attributes: + to_batch_action_processor: Processor for the action component. + to_batch_observation_processor: Processor for the observation component. + to_batch_complementary_data_processor: Processor for the complementary data component. + """ + + to_batch_action_processor: AddBatchDimensionActionStep = field( + default_factory=AddBatchDimensionActionStep + ) + to_batch_observation_processor: AddBatchDimensionObservationStep = field( + default_factory=AddBatchDimensionObservationStep + ) + to_batch_complementary_data_processor: AddBatchDimensionComplementaryDataStep = field( + default_factory=AddBatchDimensionComplementaryDataStep + ) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """ + Applies the batching process to all relevant parts of an environment transition. + + Args: + transition: The environment transition to process. + + Returns: + The environment transition with a batch dimension added. + """ + if transition[TransitionKey.ACTION] is not None: + transition = self.to_batch_action_processor(transition) + if transition[TransitionKey.OBSERVATION] is not None: + transition = self.to_batch_observation_processor(transition) + if transition[TransitionKey.COMPLEMENTARY_DATA] is not None: + transition = self.to_batch_complementary_data_processor(transition) + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Returns the input features unchanged. + + Adding a batch dimension does not alter the feature definition. + + Args: + features: A dictionary of policy features. + + Returns: + The original dictionary of policy features. + """ + # NOTE: We ignore the batch dimension when transforming features + return features diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py new file mode 100644 index 000000000..440f8b1db --- /dev/null +++ b/src/lerobot/processor/converters.py @@ -0,0 +1,412 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from functools import singledispatch +from typing import Any + +import numpy as np +import torch + +from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey + + +@singledispatch +def to_tensor( + value: Any, + *, + dtype: torch.dtype | None = torch.float32, + device: torch.device | str | None = None, +) -> torch.Tensor: + """ + Convert various data types to PyTorch tensors with configurable options. + + This is a unified tensor conversion function using single dispatch to handle + different input types appropriately. + + Args: + value: Input value to convert (tensor, array, scalar, sequence, etc.). + dtype: Target tensor dtype. If None, preserves original dtype. + device: Target device for the tensor. + + Returns: + A PyTorch tensor. + + Raises: + TypeError: If the input type is not supported. + """ + raise TypeError(f"Unsupported type for tensor conversion: {type(value)}") + + +@to_tensor.register(torch.Tensor) +def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor: + """Handle conversion for existing PyTorch tensors.""" + if dtype is not None: + value = value.to(dtype=dtype) + if device is not None: + value = value.to(device=device) + return value + + +@to_tensor.register(np.ndarray) +def _( + value: np.ndarray, + *, + dtype=torch.float32, + device=None, + **kwargs, +) -> torch.Tensor: + """Handle conversion for numpy arrays.""" + # Check for numpy scalars (0-dimensional arrays) and treat them as scalars. + if value.ndim == 0: + # Numpy scalars should be converted to 0-dimensional tensors. + scalar_value = value.item() + return torch.tensor(scalar_value, dtype=dtype, device=device) + + # Create tensor from numpy array. + tensor = torch.from_numpy(value) + + # Apply dtype and device conversion if specified. + if dtype is not None: + tensor = tensor.to(dtype=dtype) + if device is not None: + tensor = tensor.to(device=device) + + return tensor + + +@to_tensor.register(int) +@to_tensor.register(float) +@to_tensor.register(np.integer) +@to_tensor.register(np.floating) +def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor: + """Handle conversion for scalar values including numpy scalars.""" + return torch.tensor(value, dtype=dtype, device=device) + + +@to_tensor.register(list) +@to_tensor.register(tuple) +def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor: + """Handle conversion for sequences (lists, tuples).""" + return torch.tensor(value, dtype=dtype, device=device) + + +@to_tensor.register(dict) +def _(value: dict, *, device=None, **kwargs) -> dict: + """Handle conversion for dictionaries by recursively converting their values to tensors.""" + if not value: + return {} + + result = {} + for key, sub_value in value.items(): + if sub_value is None: + continue + + if isinstance(sub_value, dict): + # Recursively process nested dictionaries. + result[key] = to_tensor( + sub_value, + device=device, + **kwargs, + ) + continue + + # Convert individual values to tensors. + result[key] = to_tensor( + sub_value, + device=device, + **kwargs, + ) + return result + + +def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | Any: + """ + Convert a PyTorch tensor to a numpy array or scalar if applicable. + + If the input is not a tensor, it is returned unchanged. + + Args: + x: The input, which can be a tensor or any other type. + + Returns: + A numpy array, a scalar, or the original input. + """ + if isinstance(x, torch.Tensor): + return x.item() if x.numel() == 1 else x.detach().cpu().numpy() + return x + + +def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: + """ + Extract complementary data from a batch dictionary. + + This includes padding flags, task description, and indices. + + Args: + batch: The batch dictionary. + + Returns: + A dictionary with the extracted complementary data. + """ + pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} + task_key = {"task": batch["task"]} if "task" in batch else {} + index_key = {"index": batch["index"]} if "index" in batch else {} + task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} + + return {**pad_keys, **task_key, **index_key, **task_index_key} + + +def create_transition( + observation: dict[str, Any] | None = None, + action: PolicyAction | RobotAction | None = None, + reward: float = 0.0, + done: bool = False, + truncated: bool = False, + info: dict[str, Any] | None = None, + complementary_data: dict[str, Any] | None = None, +) -> EnvTransition: + """ + Create an `EnvTransition` dictionary with sensible defaults. + + Args: + observation: Observation dictionary. + action: Action dictionary. + reward: Scalar reward value. + done: Episode termination flag. + truncated: Episode truncation flag. + info: Additional info dictionary. + complementary_data: Complementary data dictionary. + + Returns: + A complete `EnvTransition` dictionary. + """ + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info if info is not None else {}, + TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {}, + } + + +def robot_action_observation_to_transition( + action_observation: tuple[RobotAction, RobotObservation], +) -> EnvTransition: + """ + Convert a raw robot action and observation dictionary into a standardized `EnvTransition`. + + Args: + action: The raw action dictionary from a teleoperation device or controller. + observation: The raw observation dictionary from the environment. + + Returns: + An `EnvTransition` containing the formatted observation. + """ + if not isinstance(action_observation, tuple): + raise ValueError("action_observation should be a tuple type with an action and observation") + + action, observation = action_observation + + if action is not None and not isinstance(action, dict): + raise ValueError(f"Action should be a RobotAction type got {type(action)}") + + if observation is not None and not isinstance(observation, dict): + raise ValueError(f"Observation should be a RobotObservation type got {type(observation)}") + + return create_transition(action=action, observation=observation) + + +def robot_action_to_transition(action: RobotAction) -> EnvTransition: + """ + Convert a raw robot action dictionary into a standardized `EnvTransition`. + + Args: + action: The raw action dictionary from a teleoperation device or controller. + + Returns: + An `EnvTransition` containing the formatted action. + """ + if not isinstance(action, dict): + raise ValueError(f"Action should be a RobotAction type got {type(action)}") + return create_transition(action=action) + + +def observation_to_transition(observation: RobotObservation) -> EnvTransition: + """ + Convert a raw robot observation dictionary into a standardized `EnvTransition`. + + Args: + observation: The raw observation dictionary from the environment. + + Returns: + An `EnvTransition` containing the formatted observation. + """ + if not isinstance(observation, dict): + raise ValueError(f"Observation should be a RobotObservation type got {type(observation)}") + return create_transition(observation=observation) + + +def transition_to_robot_action(transition: EnvTransition) -> RobotAction: + """ + Extract a raw robot action dictionary for a robot from an `EnvTransition`. + + This function searches for keys in the format "action.*.pos" or "action.*.vel" + and converts them into a flat dictionary suitable for sending to a robot controller. + + Args: + transition: The `EnvTransition` containing the action. + + Returns: + A dictionary representing the raw robot action. + """ + if not isinstance(transition, dict): + raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}") + + action = transition.get(TransitionKey.ACTION) + if not isinstance(action, dict): + raise ValueError(f"Action should be a RobotAction type (dict) got {type(action)}") + return transition.get(TransitionKey.ACTION) + + +def transition_to_policy_action(transition: EnvTransition) -> PolicyAction: + """ + Convert an `EnvTransition` to a `PolicyAction`. + """ + if not isinstance(transition, dict): + raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}") + + action = transition.get(TransitionKey.ACTION) + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + return action + + +def transition_to_observation(transition: EnvTransition) -> RobotObservation: + """ + Convert an `EnvTransition` to a `RobotObservation`. + """ + if not isinstance(transition, dict): + raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}") + + observation = transition.get(TransitionKey.OBSERVATION) + if not isinstance(observation, dict): + raise ValueError(f"Observation should be a RobotObservation (dict) type got {type(observation)}") + return observation + + +def policy_action_to_transition(action: PolicyAction) -> EnvTransition: + """ + Convert a `PolicyAction` to an `EnvTransition`. + """ + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + return create_transition(action=action) + + +def batch_to_transition(batch: dict[str, Any]) -> EnvTransition: + """ + Convert a batch dictionary from a dataset/dataloader into an `EnvTransition`. + + This function maps recognized keys from a batch to the `EnvTransition` structure, + filling in missing keys with sensible defaults. + + Args: + batch: A batch dictionary. + + Returns: + An `EnvTransition` dictionary. + + Raises: + ValueError: If the input is not a dictionary. + """ + + # Validate input type. + if not isinstance(batch, dict): + raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}") + + action = batch.get("action") + if action is not None and not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + + # Extract observation and complementary data keys. + observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} + complementary_data = _extract_complementary_data(batch) + + return create_transition( + observation=observation_keys if observation_keys else None, + action=batch.get("action"), + reward=batch.get("next.reward", 0.0), + done=batch.get("next.done", False), + truncated=batch.get("next.truncated", False), + info=batch.get("info", {}), + complementary_data=complementary_data if complementary_data else None, + ) + + +def transition_to_batch(transition: EnvTransition) -> dict[str, Any]: + """ + Convert an `EnvTransition` back to the canonical batch format used in LeRobot. + + This is the inverse of `batch_to_transition`. + + Args: + transition: The `EnvTransition` to convert. + + Returns: + A batch dictionary with canonical LeRobot field names. + """ + if not isinstance(transition, dict): + raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}") + + batch = { + "action": transition.get(TransitionKey.ACTION), + "next.reward": transition.get(TransitionKey.REWARD, 0.0), + "next.done": transition.get(TransitionKey.DONE, False), + "next.truncated": transition.get(TransitionKey.TRUNCATED, False), + "info": transition.get(TransitionKey.INFO, {}), + } + + # Add complementary data. + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + if comp_data: + batch.update(comp_data) + + # Flatten observation dictionary. + observation = transition.get(TransitionKey.OBSERVATION) + if isinstance(observation, dict): + batch.update(observation) + + return batch + + +def identity_transition(transition: EnvTransition) -> EnvTransition: + """ + An identity function for transitions, returning the input unchanged. + + Useful as a default or placeholder in processing pipelines. + + Args: + tr: An `EnvTransition`. + + Returns: + The same `EnvTransition`. + """ + return transition diff --git a/src/lerobot/processor/core.py b/src/lerobot/processor/core.py new file mode 100644 index 000000000..679ba8c54 --- /dev/null +++ b/src/lerobot/processor/core.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from enum import Enum +from typing import Any, TypeAlias, TypedDict + +import numpy as np +import torch + + +class TransitionKey(str, Enum): + """Keys for accessing EnvTransition dictionary components.""" + + # TODO(Steven): Use consts + OBSERVATION = "observation" + ACTION = "action" + REWARD = "reward" + DONE = "done" + TRUNCATED = "truncated" + INFO = "info" + COMPLEMENTARY_DATA = "complementary_data" + + +PolicyAction: TypeAlias = torch.Tensor +RobotAction: TypeAlias = dict[str, Any] +EnvAction: TypeAlias = np.ndarray +RobotObservation: TypeAlias = dict[str, Any] + + +EnvTransition = TypedDict( + "EnvTransition", + { + TransitionKey.OBSERVATION.value: dict[str, Any] | None, + TransitionKey.ACTION.value: PolicyAction | RobotAction | EnvAction | None, + TransitionKey.REWARD.value: float | torch.Tensor | None, + TransitionKey.DONE.value: bool | torch.Tensor | None, + TransitionKey.TRUNCATED.value: bool | torch.Tensor | None, + TransitionKey.INFO.value: dict[str, Any] | None, + TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None, + }, +) diff --git a/src/lerobot/processor/delta_action_processor.py b/src/lerobot/processor/delta_action_processor.py new file mode 100644 index 000000000..949ae78d5 --- /dev/null +++ b/src/lerobot/processor/delta_action_processor.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature + +from .core import PolicyAction, RobotAction +from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep + + +@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict") +@dataclass +class MapTensorToDeltaActionDictStep(ActionProcessorStep): + """ + Maps a flat action tensor from a policy to a structured delta action dictionary. + + This step is typically used after a policy outputs a continuous action vector. + It decomposes the vector into named components for delta movements of the + end-effector (x, y, z) and optionally the gripper. + + Attributes: + use_gripper: If True, assumes the 4th element of the tensor is the + gripper action. + """ + + use_gripper: bool = True + + def action(self, action: PolicyAction) -> RobotAction: + if not isinstance(action, PolicyAction): + raise ValueError("Only PolicyAction is supported for this processor") + + if action.dim() > 1: + action = action.squeeze(0) + + # TODO (maractingi): add rotation + delta_action = { + "delta_x": action[0].item(), + "delta_y": action[1].item(), + "delta_z": action[2].item(), + } + if self.use_gripper: + delta_action["gripper"] = action[3].item() + return delta_action + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + for axis in ["x", "y", "z"]: + features[PipelineFeatureType.ACTION][f"delta_{axis}"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + if self.use_gripper: + features[PipelineFeatureType.ACTION]["gripper"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + return features + + +@ProcessorStepRegistry.register("map_delta_action_to_robot_action") +@dataclass +class MapDeltaActionToRobotActionStep(RobotActionProcessorStep): + """ + Maps delta actions from teleoperators to robot target actions for inverse kinematics. + + This step converts a dictionary of delta movements (e.g., from a gamepad) + into a target action format that includes an "enabled" flag and target + end-effector positions. It also handles scaling and noise filtering. + + Attributes: + position_scale: A factor to scale the delta position inputs. + rotation_scale: A factor to scale the delta rotation inputs (currently unused). + noise_threshold: The magnitude below which delta inputs are considered noise + and do not trigger an "enabled" state. + """ + + # Scale factors for delta movements + position_scale: float = 1.0 + rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard + noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise + + def action(self, action: RobotAction) -> RobotAction: + # NOTE (maractingi): Action can be a dict from the teleop_devices or a tensor from the policy + # TODO (maractingi): changing this target_xyz naming convention from the teleop_devices + delta_x = action.pop("delta_x") + delta_y = action.pop("delta_y") + delta_z = action.pop("delta_z") + gripper = action.pop("gripper") + + # Determine if the teleoperator is actively providing input + # Consider enabled if any significant movement delta is detected + position_magnitude = (delta_x**2 + delta_y**2 + delta_z**2) ** 0.5 # Use Euclidean norm for position + enabled = position_magnitude > self.noise_threshold # Small threshold to avoid noise + + # Scale the deltas appropriately + scaled_delta_x = delta_x * self.position_scale + scaled_delta_y = delta_y * self.position_scale + scaled_delta_z = delta_z * self.position_scale + + # For gamepad/keyboard, we don't have rotation input, so set to 0 + # These could be extended in the future for more sophisticated teleoperators + target_wx = 0.0 + target_wy = 0.0 + target_wz = 0.0 + + # Update action with robot target format + action = { + "enabled": enabled, + "target_x": scaled_delta_x, + "target_y": scaled_delta_y, + "target_z": scaled_delta_z, + "target_wx": target_wx, + "target_wy": target_wy, + "target_wz": target_wz, + "gripper_vel": float(gripper), + } + + return action + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + for axis in ["x", "y", "z", "gripper"]: + features[PipelineFeatureType.ACTION].pop(f"delta_{axis}", None) + + for feat in ["enabled", "target_x", "target_y", "target_z", "target_wx", "target_wy", "target_wz"]: + features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + return features diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index 0f00bb470..2d0dd0880 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -13,70 +13,182 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +This script defines a processor step for moving environment transition data to a specific torch device and casting +its floating-point precision. +""" + from dataclasses import dataclass from typing import Any import torch -from lerobot.configs.types import PolicyFeature -from lerobot.processor.pipeline import EnvTransition, TransitionKey +from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.utils.utils import get_safe_torch_device +from .core import EnvTransition, PolicyAction, TransitionKey +from .pipeline import ProcessorStep, ProcessorStepRegistry + +@ProcessorStepRegistry.register("device_processor") @dataclass -class DeviceProcessor: - """Processes transitions by moving tensors to the specified device. +class DeviceProcessorStep(ProcessorStep): + """ + Processor step to move all tensors within an `EnvTransition` to a specified device and optionally cast their + floating-point data type. - This processor ensures that all tensors in the transition are moved to the - specified device (CPU or GPU) before they are returned. + This is crucial for preparing data for model training or inference on hardware like GPUs. + + Attributes: + device: The target device for tensors (e.g., "cpu", "cuda", "cuda:0"). + float_dtype: The target floating-point dtype as a string (e.g., "float32", "float16", "bfloat16"). + If None, the dtype is not changed. """ - device: torch.device = "cpu" + device: str = "cpu" + float_dtype: str | None = None + + DTYPE_MAPPING = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "bfloat16": torch.bfloat16, + "half": torch.float16, + "float": torch.float32, + "double": torch.float64, + } def __post_init__(self): - self.device = get_safe_torch_device(self.device) + """ + Initializes the processor by converting string configurations to torch objects. + + This method sets up the `torch.device`, determines if transfers can be non-blocking, and validates the + `float_dtype` string, converting it to a `torch.dtype` object. + """ + self.tensor_device: torch.device = get_safe_torch_device(self.device) + # Update device string in case a specific GPU was selected (e.g., "cuda" -> "cuda:0") + self.device = self.tensor_device.type self.non_blocking = "cuda" in str(self.device) + # Validate and convert float_dtype string to torch dtype + if self.float_dtype is not None: + if self.float_dtype not in self.DTYPE_MAPPING: + raise ValueError( + f"Invalid float_dtype '{self.float_dtype}'. Available options: {list(self.DTYPE_MAPPING.keys())}" + ) + self._target_float_dtype = self.DTYPE_MAPPING[self.float_dtype] + else: + self._target_float_dtype = None + + def _process_tensor(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Moves a single tensor to the target device and casts its dtype. + + Handles multi-GPU scenarios by not moving a tensor if it's already on a different CUDA device than + the target, which is useful when using frameworks like Accelerate. + + Args: + tensor: The input torch.Tensor. + + Returns: + The processed tensor on the correct device and with the correct dtype. + """ + # Determine target device + if tensor.is_cuda and self.tensor_device.type == "cuda": + # Both tensor and target are on GPU - preserve tensor's GPU placement. + # This handles multi-GPU scenarios where Accelerate has already placed + # tensors on the correct GPU for each process. + target_device = tensor.device + else: + # Either tensor is on CPU, or we're configured for CPU. + # In both cases, use the configured device. + target_device = self.tensor_device + + # MPS workaround: Convert float64 to float32 since MPS doesn't support float64 + if target_device.type == "mps" and tensor.dtype == torch.float64: + tensor = tensor.to(dtype=torch.float32) + + # Only move if necessary + if tensor.device != target_device: + tensor = tensor.to(target_device, non_blocking=self.non_blocking) + + # Convert float dtype if specified and tensor is floating point + if self._target_float_dtype is not None and tensor.is_floating_point(): + tensor = tensor.to(dtype=self._target_float_dtype) + + return tensor + def __call__(self, transition: EnvTransition) -> EnvTransition: - # Create a copy of the transition + """ + Applies device and dtype conversion to all tensors in an environment transition. + + It iterates through the transition, finds all `torch.Tensor` objects (including those nested in + dictionaries like `observation`), and processes them. + + Args: + transition: The input `EnvTransition` object. + + Returns: + A new `EnvTransition` object with all tensors moved to the target device and dtype. + """ new_transition = transition.copy() + action = new_transition.get(TransitionKey.ACTION) - # Process observation tensors - observation = transition.get(TransitionKey.OBSERVATION) - if observation is not None: - new_observation = { - k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v - for k, v in observation.items() - } - new_transition[TransitionKey.OBSERVATION] = new_observation + if action is not None and not isinstance(action, PolicyAction): + raise ValueError(f"If action is not None should be a PolicyAction type got {type(action)}") - # Process action tensor - action = transition.get(TransitionKey.ACTION) - if action is not None and isinstance(action, torch.Tensor): - new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking) + simple_tensor_keys = [ + TransitionKey.ACTION, + TransitionKey.REWARD, + TransitionKey.DONE, + TransitionKey.TRUNCATED, + ] - # Process reward tensor - reward = transition.get(TransitionKey.REWARD) - if reward is not None and isinstance(reward, torch.Tensor): - new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking) + dict_tensor_keys = [ + TransitionKey.OBSERVATION, + TransitionKey.COMPLEMENTARY_DATA, + ] - # Process done tensor - done = transition.get(TransitionKey.DONE) - if done is not None and isinstance(done, torch.Tensor): - new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking) + # Process simple, top-level tensors + for key in simple_tensor_keys: + value = transition.get(key) + if isinstance(value, torch.Tensor): + new_transition[key] = self._process_tensor(value) - # Process truncated tensor - truncated = transition.get(TransitionKey.TRUNCATED) - if truncated is not None and isinstance(truncated, torch.Tensor): - new_transition[TransitionKey.TRUNCATED] = truncated.to( - self.device, non_blocking=self.non_blocking - ) + # Process tensors nested within dictionaries + for key in dict_tensor_keys: + data_dict = transition.get(key) + if data_dict is not None: + new_data_dict = { + k: self._process_tensor(v) if isinstance(v, torch.Tensor) else v + for k, v in data_dict.items() + } + new_transition[key] = new_data_dict return new_transition def get_config(self) -> dict[str, Any]: - """Return configuration for serialization.""" - return {"device": self.device} + """ + Returns the serializable configuration of the processor. - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + Returns: + A dictionary containing the device and float_dtype settings. + """ + return {"device": self.device, "float_dtype": self.float_dtype} + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Returns the input features unchanged. + + Device and dtype transformations do not alter the fundamental definition of the features (e.g., shape). + + Args: + features: A dictionary of policy features. + + Returns: + The original dictionary of policy features. + """ return features diff --git a/src/lerobot/processor/factory.py b/src/lerobot/processor/factory.py new file mode 100644 index 000000000..5a0c41072 --- /dev/null +++ b/src/lerobot/processor/factory.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .converters import ( + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from .core import RobotAction, RobotObservation +from .pipeline import IdentityProcessorStep, RobotProcessorPipeline + + +def make_default_teleop_action_processor() -> RobotProcessorPipeline[ + tuple[RobotAction, RobotObservation], RobotAction +]: + teleop_action_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[IdentityProcessorStep()], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + return teleop_action_processor + + +def make_default_robot_action_processor() -> RobotProcessorPipeline[ + tuple[RobotAction, RobotObservation], RobotAction +]: + robot_action_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[IdentityProcessorStep()], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + return robot_action_processor + + +def make_default_robot_observation_processor() -> RobotProcessorPipeline[RobotObservation, RobotObservation]: + robot_observation_processor = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[IdentityProcessorStep()], + to_transition=observation_to_transition, + to_output=transition_to_observation, + ) + return robot_observation_processor + + +def make_default_processors(): + teleop_action_processor = make_default_teleop_action_processor() + robot_action_processor = make_default_robot_action_processor() + robot_observation_processor = make_default_robot_observation_processor() + return (teleop_action_processor, robot_action_processor, robot_observation_processor) diff --git a/src/lerobot/processor/gym_action_processor.py b/src/lerobot/processor/gym_action_processor.py new file mode 100644 index 000000000..8fa8cfd86 --- /dev/null +++ b/src/lerobot/processor/gym_action_processor.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature + +from .converters import to_tensor +from .core import EnvAction, EnvTransition, PolicyAction +from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry + + +@ProcessorStepRegistry.register("torch2numpy_action_processor") +@dataclass +class Torch2NumpyActionProcessorStep(ActionProcessorStep): + """ + Converts a PyTorch tensor action to a NumPy array. + + This step is useful when the output of a policy (typically a torch.Tensor) + needs to be passed to an environment or component that expects a NumPy array. + + Attributes: + squeeze_batch_dim: If True, removes the first dimension of the array + if it is of size 1. This is useful for converting a + batched action of size (1, D) to a single action of size (D,). + """ + + squeeze_batch_dim: bool = True + + def action(self, action: PolicyAction) -> EnvAction: + if not isinstance(action, PolicyAction): + raise TypeError( + f"Expected PolicyAction or None, got {type(action).__name__}. " + "Use appropriate processor for non-tensor actions." + ) + + numpy_action = action.detach().cpu().numpy() + + # Remove batch dimensions but preserve action dimensions. + # Only squeeze if there's a batch dimension (first dim == 1). + if ( + self.squeeze_batch_dim + and numpy_action.shape + and len(numpy_action.shape) > 1 + and numpy_action.shape[0] == 1 + ): + numpy_action = numpy_action.squeeze(0) + + return numpy_action + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@ProcessorStepRegistry.register("numpy2torch_action_processor") +@dataclass +class Numpy2TorchActionProcessorStep(ProcessorStep): + """Converts a NumPy array action to a PyTorch tensor when action is present.""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Converts numpy action to torch tensor if action exists, otherwise passes through.""" + from .core import TransitionKey + + self._current_transition = transition.copy() + new_transition = self._current_transition + + action = new_transition.get(TransitionKey.ACTION) + if action is not None: + if not isinstance(action, EnvAction): + raise TypeError( + f"Expected np.ndarray or None, got {type(action).__name__}. " + "Use appropriate processor for non-tensor actions." + ) + torch_action = to_tensor(action, dtype=None) # Preserve original dtype + new_transition[TransitionKey.ACTION] = torch_action + + return new_transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py new file mode 100644 index 000000000..47f69a973 --- /dev/null +++ b/src/lerobot/processor/hil_processor.py @@ -0,0 +1,596 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import time +from dataclasses import dataclass +from typing import Any, Protocol, TypeVar, runtime_checkable + +import numpy as np +import torch +import torchvision.transforms.functional as F # noqa: N812 + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.teleoperators.teleoperator import Teleoperator +from lerobot.teleoperators.utils import TeleopEvents + +from .core import EnvTransition, PolicyAction, TransitionKey +from .pipeline import ( + ComplementaryDataProcessorStep, + InfoProcessorStep, + ObservationProcessorStep, + ProcessorStep, + ProcessorStepRegistry, + TruncatedProcessorStep, +) + +GRIPPER_KEY = "gripper" +DISCRETE_PENALTY_KEY = "discrete_penalty" +TELEOP_ACTION_KEY = "teleop_action" + + +@runtime_checkable +class HasTeleopEvents(Protocol): + """ + Minimal protocol for objects that provide teleoperation events. + + This protocol defines the `get_teleop_events()` method, allowing processor + steps to interact with teleoperators that support event-based controls + (like episode termination or success flagging) without needing to know the + teleoperator's specific class. + """ + + def get_teleop_events(self) -> dict[str, Any]: + """ + Get extra control events from the teleoperator. + + Returns: + A dictionary containing control events such as: + - `is_intervention`: bool - Whether the human is currently intervening. + - `terminate_episode`: bool - Whether to terminate the current episode. + - `success`: bool - Whether the episode was successful. + - `rerecord_episode`: bool - Whether to rerecord the episode. + """ + ... + + +# Type variable constrained to Teleoperator subclasses that also implement events +TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator) + + +def _check_teleop_with_events(teleop: Teleoperator) -> None: + """ + Runtime check that a teleoperator implements the `HasTeleopEvents` protocol. + + Args: + teleop: The teleoperator instance to check. + + Raises: + TypeError: If the teleoperator does not have a `get_teleop_events` method. + """ + if not isinstance(teleop, HasTeleopEvents): + raise TypeError( + f"Teleoperator {type(teleop).__name__} must implement get_teleop_events() method. " + f"Compatible teleoperators: GamepadTeleop, KeyboardEndEffectorTeleop" + ) + + +@ProcessorStepRegistry.register("add_teleop_action_as_complementary_data") +@dataclass +class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep): + """ + Adds the raw action from a teleoperator to the transition's complementary data. + + This is useful for human-in-the-loop scenarios where the human's input needs to + be available to downstream processors, for example, to override a policy's action + during an intervention. + + Attributes: + teleop_device: The teleoperator instance to get the action from. + """ + + teleop_device: Teleoperator + + def complementary_data(self, complementary_data: dict) -> dict: + """ + Retrieves the teleoperator's action and adds it to the complementary data. + + Args: + complementary_data: The incoming complementary data dictionary. + + Returns: + A new dictionary with the teleoperator action added under the + `teleop_action` key. + """ + new_complementary_data = dict(complementary_data) + new_complementary_data[TELEOP_ACTION_KEY] = self.teleop_device.get_action() + return new_complementary_data + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@ProcessorStepRegistry.register("add_teleop_action_as_info") +@dataclass +class AddTeleopEventsAsInfoStep(InfoProcessorStep): + """ + Adds teleoperator control events (e.g., terminate, success) to the transition's info. + + This step extracts control events from teleoperators that support event-based + interaction, making these signals available to other parts of the system. + + Attributes: + teleop_device: An instance of a teleoperator that implements the + `HasTeleopEvents` protocol. + """ + + teleop_device: TeleopWithEvents + + def __post_init__(self): + """Validates that the provided teleoperator supports events after initialization.""" + _check_teleop_with_events(self.teleop_device) + + def info(self, info: dict) -> dict: + """ + Retrieves teleoperator events and updates the info dictionary. + + Args: + info: The incoming info dictionary. + + Returns: + A new dictionary including the teleoperator events. + """ + new_info = dict(info) + + teleop_events = self.teleop_device.get_teleop_events() + new_info.update(teleop_events) + return new_info + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@ProcessorStepRegistry.register("image_crop_resize_processor") +@dataclass +class ImageCropResizeProcessorStep(ObservationProcessorStep): + """ + Crops and/or resizes image observations. + + This step iterates through all image keys in an observation dictionary and applies + the specified transformations. It handles device placement, moving tensors to the + CPU if necessary for operations not supported on certain accelerators like MPS. + + Attributes: + crop_params_dict: A dictionary mapping image keys to cropping parameters + (top, left, height, width). + resize_size: A tuple (height, width) to resize all images to. + """ + + crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None + resize_size: tuple[int, int] | None = None + + def observation(self, observation: dict) -> dict: + """ + Applies cropping and resizing to all images in the observation dictionary. + + Args: + observation: The observation dictionary, potentially containing image tensors. + + Returns: + A new observation dictionary with transformed images. + """ + if self.resize_size is None and not self.crop_params_dict: + return observation + + new_observation = dict(observation) + + # Process all image keys in the observation + for key in observation: + if "image" not in key: + continue + + image = observation[key] + device = image.device + # NOTE (maractingi): No mps kernel for crop and resize, so we need to move to cpu + if device.type == "mps": + image = image.cpu() + # Crop if crop params are provided for this key + if self.crop_params_dict is not None and key in self.crop_params_dict: + crop_params = self.crop_params_dict[key] + image = F.crop(image, *crop_params) + if self.resize_size is not None: + image = F.resize(image, self.resize_size) + image = image.clamp(0.0, 1.0) + new_observation[key] = image.to(device) + + return new_observation + + def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary with the crop parameters and resize dimensions. + """ + return { + "crop_params_dict": self.crop_params_dict, + "resize_size": self.resize_size, + } + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Updates the image feature shapes in the policy features dictionary if resizing is applied. + + Args: + features: The policy features dictionary. + + Returns: + The updated policy features dictionary with new image shapes. + """ + if self.resize_size is None: + return features + for key in features[PipelineFeatureType.OBSERVATION]: + if "image" in key: + nb_channel = features[PipelineFeatureType.OBSERVATION][key].shape[0] + features[PipelineFeatureType.OBSERVATION][key] = PolicyFeature( + type=features[PipelineFeatureType.OBSERVATION][key].type, + shape=(nb_channel, *self.resize_size), + ) + return features + + +@dataclass +@ProcessorStepRegistry.register("time_limit_processor") +class TimeLimitProcessorStep(TruncatedProcessorStep): + """ + Tracks episode steps and enforces a time limit by truncating the episode. + + Attributes: + max_episode_steps: The maximum number of steps allowed per episode. + current_step: The current step count for the active episode. + """ + + max_episode_steps: int + current_step: int = 0 + + def truncated(self, truncated: bool) -> bool: + """ + Increments the step counter and sets the truncated flag if the time limit is reached. + + Args: + truncated: The incoming truncated flag. + + Returns: + True if the episode step limit is reached, otherwise the incoming value. + """ + self.current_step += 1 + if self.current_step >= self.max_episode_steps: + truncated = True + # TODO (steven): missing an else truncated = False? + return truncated + + def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary containing the `max_episode_steps`. + """ + return { + "max_episode_steps": self.max_episode_steps, + } + + def reset(self) -> None: + """Resets the step counter, typically called at the start of a new episode.""" + self.current_step = 0 + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@dataclass +@ProcessorStepRegistry.register("gripper_penalty_processor") +class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep): + """ + Applies a penalty for inefficient gripper usage. + + This step penalizes actions that attempt to close an already closed gripper or + open an already open one, based on position thresholds. + + Attributes: + penalty: The negative reward value to apply. + max_gripper_pos: The maximum position value for the gripper, used for normalization. + """ + + penalty: float = -0.01 + max_gripper_pos: float = 30.0 + + def complementary_data(self, complementary_data: dict) -> dict: + """ + Calculates the gripper penalty and adds it to the complementary data. + + Args: + complementary_data: The incoming complementary data, which should contain + raw joint positions. + + Returns: + A new complementary data dictionary with the `discrete_penalty` key added. + """ + action = self.transition.get(TransitionKey.ACTION) + + raw_joint_positions = complementary_data.get("raw_joint_positions", None) + if raw_joint_positions is None: + return complementary_data + + current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None) + if current_gripper_pos is None: + return complementary_data + + # Gripper action is a PolicyAction at this stage + gripper_action = action[-1].item() + gripper_action_normalized = gripper_action / self.max_gripper_pos + + # Normalize gripper state and action + gripper_state_normalized = current_gripper_pos / self.max_gripper_pos + + # Calculate penalty boolean as in original + gripper_penalty_bool = (gripper_state_normalized < 0.5 and gripper_action_normalized > 0.5) or ( + gripper_state_normalized > 0.75 and gripper_action_normalized < 0.5 + ) + + gripper_penalty = self.penalty * int(gripper_penalty_bool) + + # Create new complementary data with penalty info + new_complementary_data = dict(complementary_data) + new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty + + return new_complementary_data + + def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary containing the penalty value and max gripper position. + """ + return { + "penalty": self.penalty, + "max_gripper_pos": self.max_gripper_pos, + } + + def reset(self) -> None: + """Resets the processor's internal state.""" + pass + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@dataclass +@ProcessorStepRegistry.register("intervention_action_processor") +class InterventionActionProcessorStep(ProcessorStep): + """ + Handles human intervention, overriding policy actions and managing episode termination. + + When an intervention is detected (via teleoperator events in the `info` dict), + this step replaces the policy's action with the human's teleoperated action. + It also processes signals to terminate the episode or flag success. + + Attributes: + use_gripper: Whether to include the gripper in the teleoperated action. + terminate_on_success: If True, automatically sets the `done` flag when a + `success` event is received. + """ + + use_gripper: bool = False + terminate_on_success: bool = True + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """ + Processes the transition to handle interventions. + + Args: + transition: The incoming environment transition. + + Returns: + The modified transition, potentially with an overridden action, updated + reward, and termination status. + """ + action = transition.get(TransitionKey.ACTION) + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + + # Get intervention signals from complementary data + info = transition.get(TransitionKey.INFO, {}) + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + teleop_action = complementary_data.get(TELEOP_ACTION_KEY, {}) + is_intervention = info.get(TeleopEvents.IS_INTERVENTION, False) + terminate_episode = info.get(TeleopEvents.TERMINATE_EPISODE, False) + success = info.get(TeleopEvents.SUCCESS, False) + rerecord_episode = info.get(TeleopEvents.RERECORD_EPISODE, False) + + new_transition = transition.copy() + + # Override action if intervention is active + if is_intervention and teleop_action is not None: + if isinstance(teleop_action, dict): + # Convert teleop_action dict to tensor format + action_list = [ + teleop_action.get("delta_x", 0.0), + teleop_action.get("delta_y", 0.0), + teleop_action.get("delta_z", 0.0), + ] + if self.use_gripper: + action_list.append(teleop_action.get(GRIPPER_KEY, 1.0)) + elif isinstance(teleop_action, np.ndarray): + action_list = teleop_action.tolist() + else: + action_list = teleop_action + + teleop_action_tensor = torch.tensor(action_list, dtype=action.dtype, device=action.device) + new_transition[TransitionKey.ACTION] = teleop_action_tensor + + # Handle episode termination + new_transition[TransitionKey.DONE] = bool(terminate_episode) or ( + self.terminate_on_success and success + ) + new_transition[TransitionKey.REWARD] = float(success) + + # Update info with intervention metadata + info = new_transition.get(TransitionKey.INFO, {}) + info[TeleopEvents.IS_INTERVENTION] = is_intervention + info[TeleopEvents.RERECORD_EPISODE] = rerecord_episode + info[TeleopEvents.SUCCESS] = success + new_transition[TransitionKey.INFO] = info + + # Update complementary data with teleop action + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + complementary_data[TELEOP_ACTION_KEY] = new_transition.get(TransitionKey.ACTION) + new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + + return new_transition + + def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary containing the step's configuration attributes. + """ + return { + "use_gripper": self.use_gripper, + "terminate_on_success": self.terminate_on_success, + } + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@dataclass +@ProcessorStepRegistry.register("reward_classifier_processor") +class RewardClassifierProcessorStep(ProcessorStep): + """ + Applies a pretrained reward classifier to image observations to predict success. + + This step uses a model to determine if the current state is successful, updating + the reward and potentially terminating the episode. + + Attributes: + pretrained_path: Path to the pretrained reward classifier model. + device: The device to run the classifier on. + success_threshold: The probability threshold to consider a prediction as successful. + success_reward: The reward value to assign on success. + terminate_on_success: If True, terminates the episode upon successful classification. + reward_classifier: The loaded classifier model instance. + """ + + pretrained_path: str | None = None + device: str = "cpu" + success_threshold: float = 0.5 + success_reward: float = 1.0 + terminate_on_success: bool = True + + reward_classifier: Any = None + + def __post_init__(self): + """Initializes the reward classifier model after the dataclass is created.""" + if self.pretrained_path is not None: + from lerobot.policies.sac.reward_model.modeling_classifier import Classifier + + self.reward_classifier = Classifier.from_pretrained(self.pretrained_path) + self.reward_classifier.to(self.device) + self.reward_classifier.eval() + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """ + Processes a transition, applying the reward classifier to its image observations. + + Args: + transition: The incoming environment transition. + + Returns: + The modified transition with an updated reward and done flag based on the + classifier's prediction. + """ + new_transition = transition.copy() + observation = new_transition.get(TransitionKey.OBSERVATION) + if observation is None or self.reward_classifier is None: + return new_transition + + # Extract images from observation + images = {key: value for key, value in observation.items() if "image" in key} + + if not images: + return new_transition + + # Run reward classifier + start_time = time.perf_counter() + with torch.inference_mode(): + success = self.reward_classifier.predict_reward(images, threshold=self.success_threshold) + + classifier_frequency = 1 / (time.perf_counter() - start_time) + + # Calculate reward and termination + reward = new_transition.get(TransitionKey.REWARD, 0.0) + terminated = new_transition.get(TransitionKey.DONE, False) + + if math.isclose(success, 1, abs_tol=1e-2): + reward = self.success_reward + if self.terminate_on_success: + terminated = True + + # Update transition + new_transition[TransitionKey.REWARD] = reward + new_transition[TransitionKey.DONE] = terminated + + # Update info with classifier frequency + info = new_transition.get(TransitionKey.INFO, {}) + info["reward_classifier_frequency"] = classifier_frequency + new_transition[TransitionKey.INFO] = info + + return new_transition + + def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary containing the step's configuration attributes. + """ + return { + "device": self.device, + "success_threshold": self.success_threshold, + "success_reward": self.success_reward, + "terminate_on_success": self.terminate_on_success, + } + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features diff --git a/src/lerobot/processor/joint_observations_processor.py b/src/lerobot/processor/joint_observations_processor.py new file mode 100644 index 000000000..ab3c6ecc1 --- /dev/null +++ b/src/lerobot/processor/joint_observations_processor.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any + +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.constants import OBS_STATE +from lerobot.processor.pipeline import ( + ObservationProcessorStep, + ProcessorStepRegistry, +) +from lerobot.robots import Robot + + +@dataclass +@ProcessorStepRegistry.register("joint_velocity_processor") +class JointVelocityProcessorStep(ObservationProcessorStep): + """ + Calculates and appends joint velocity information to the observation state. + + This step computes the velocity of each joint by calculating the finite + difference between the current and the last observed joint positions. The + resulting velocity vector is then concatenated to the original state vector. + + Attributes: + dt: The time step (delta time) in seconds between observations, used for + calculating velocity. + last_joint_positions: Stores the joint positions from the previous step + to enable velocity calculation. + """ + + dt: float = 0.1 + + last_joint_positions: torch.Tensor | None = None + + def observation(self, observation: dict) -> dict: + """ + Computes joint velocities and adds them to the observation state. + + Args: + observation: The input observation dictionary, expected to contain + an `observation.state` key with joint positions. + + Returns: + A new observation dictionary with the `observation.state` tensor + extended to include joint velocities. + + Raises: + ValueError: If `observation.state` is not found in the observation. + """ + # Get current joint positions (assuming they're in observation.state) + current_positions = observation.get(OBS_STATE) + if current_positions is None: + raise ValueError(f"{OBS_STATE} is not in observation") + + # Initialize last joint positions if not already set + if self.last_joint_positions is None: + self.last_joint_positions = current_positions.clone() + joint_velocities = torch.zeros_like(current_positions) + else: + # Compute velocities + joint_velocities = (current_positions - self.last_joint_positions) / self.dt + + self.last_joint_positions = current_positions.clone() + + # Extend observation with velocities + extended_state = torch.cat([current_positions, joint_velocities], dim=-1) + + # Create new observation dict + new_observation = dict(observation) + new_observation[OBS_STATE] = extended_state + + return new_observation + + def get_config(self) -> dict[str, Any]: + """ + Returns the configuration of the step for serialization. + + Returns: + A dictionary containing the time step `dt`. + """ + return { + "dt": self.dt, + } + + def reset(self) -> None: + """Resets the internal state, clearing the last known joint positions.""" + self.last_joint_positions = None + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Updates the `observation.state` feature to reflect the added velocities. + + This method doubles the size of the first dimension of the `observation.state` + shape to account for the concatenation of position and velocity vectors. + + Args: + features: The policy features dictionary. + + Returns: + The updated policy features dictionary. + """ + if OBS_STATE in features[PipelineFeatureType.OBSERVATION]: + original_feature = features[PipelineFeatureType.OBSERVATION][OBS_STATE] + # Double the shape to account for positions + velocities + new_shape = (original_feature.shape[0] * 2,) + original_feature.shape[1:] + + features[PipelineFeatureType.OBSERVATION][OBS_STATE] = PolicyFeature( + type=original_feature.type, shape=new_shape + ) + return features + + +@dataclass +@ProcessorStepRegistry.register("current_processor") +class MotorCurrentProcessorStep(ObservationProcessorStep): + """ + Reads motor currents from a robot and appends them to the observation state. + + This step queries the robot's hardware interface to get the present current + for each motor and concatenates this information to the existing state vector. + + Attributes: + robot: An instance of a `lerobot` Robot class that provides access to + the hardware bus. + """ + + robot: Robot | None = None + + def observation(self, observation: dict) -> dict: + """ + Fetches motor currents and adds them to the observation state. + + Args: + observation: The input observation dictionary. + + Returns: + A new observation dictionary with the `observation.state` tensor + extended to include motor currents. + + Raises: + ValueError: If the `robot` attribute has not been set. + """ + # Get current values from robot state + if self.robot is None: + raise ValueError("Robot is not set") + + present_current_dict = self.robot.bus.sync_read("Present_Current") # type: ignore[attr-defined] + motor_currents = torch.tensor( + [present_current_dict[name] for name in self.robot.bus.motors], # type: ignore[attr-defined] + dtype=torch.float32, + ).unsqueeze(0) + + current_state = observation.get(OBS_STATE) + if current_state is None: + return observation + + extended_state = torch.cat([current_state, motor_currents], dim=-1) + + # Create new observation dict + new_observation = dict(observation) + new_observation[OBS_STATE] = extended_state + + return new_observation + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Updates the `observation.state` feature to reflect the added motor currents. + + This method increases the size of the first dimension of the `observation.state` + shape by the number of motors in the robot. + + Args: + features: The policy features dictionary. + + Returns: + The updated policy features dictionary. + """ + if OBS_STATE in features[PipelineFeatureType.OBSERVATION] and self.robot is not None: + original_feature = features[PipelineFeatureType.OBSERVATION][OBS_STATE] + # Add motor current dimensions to the original state shape + num_motors = 0 + if hasattr(self.robot, "bus") and hasattr(self.robot.bus, "motors"): # type: ignore[attr-defined] + num_motors = len(self.robot.bus.motors) # type: ignore[attr-defined] + + if num_motors > 0: + new_shape = (original_feature.shape[0] + num_motors,) + original_feature.shape[1:] + features[PipelineFeatureType.OBSERVATION][OBS_STATE] = PolicyFeature( + type=original_feature.type, shape=new_shape + ) + return features diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py new file mode 100644 index 000000000..131f799d6 --- /dev/null +++ b/src/lerobot/processor/migrate_policy_normalization.py @@ -0,0 +1,646 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A generic script to migrate LeRobot policies with built-in normalization layers to the new +pipeline-based processor system. + +This script performs the following steps: +1. Loads a pretrained policy model and its configuration from a local path or the + Hugging Face Hub. +2. Scans the model's state dictionary to extract normalization statistics (e.g., mean, + std, min, max) for all features. +3. Creates two new processor pipelines: + - A preprocessor that normalizes inputs (observations) and outputs (actions). + - A postprocessor that unnormalizes outputs (actions) for inference. +4. Removes the original normalization layers from the model's state dictionary, + creating a "clean" model. +5. Saves the new clean model, the preprocessor, the postprocessor, and a generated + model card to a new directory. +6. Optionally pushes all the new artifacts to the Hugging Face Hub. + +Usage: + python src/lerobot/processor/migrate_policy_normalization.py \ + --pretrained-path lerobot/act_aloha_sim_transfer_cube_human \ + --push-to-hub \ + --branch main + +Note: This script now uses the modern `make_pre_post_processors` and `make_policy_config` +factory functions from `lerobot.policies.factory` to create processors and configurations, +ensuring consistency with the current codebase. + +The script extracts normalization statistics from the old model's state_dict, creates clean +processor pipelines using the factory functions, and saves a migrated model that is compatible +with the new PolicyProcessorPipeline architecture. +""" + +import argparse +import json +import os +from pathlib import Path +from typing import Any + +import torch +from huggingface_hub import HfApi, hf_hub_download +from safetensors.torch import load_file as load_safetensors + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors + + +def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]: + """ + Scans a model's state_dict to find and extract normalization statistics. + + This function identifies keys corresponding to normalization layers (e.g., those + for mean, std, min, max) based on a set of predefined patterns and organizes + them into a nested dictionary. + + Args: + state_dict: The state dictionary of a pretrained policy model. + + Returns: + A nested dictionary where outer keys are feature names (e.g., + 'observation.state') and inner keys are statistic types ('mean', 'std'), + mapping to their corresponding tensor values. + """ + stats = {} + + # Define patterns to match and their prefixes to remove + normalization_patterns = [ + "normalize_inputs.buffer_", + "unnormalize_outputs.buffer_", + "normalize_targets.buffer_", + "normalize.", # Must come after normalize_* patterns + "unnormalize.", # Must come after unnormalize_* patterns + "input_normalizer.", + "output_normalizer.", + "normalalize_inputs.", + "unnormalize_outputs.", + "normalize_targets.", + "unnormalize_targets.", + ] + + # Process each key in state_dict + for key, tensor in state_dict.items(): + # Try each pattern + for pattern in normalization_patterns: + if key.startswith(pattern): + # Extract the remaining part after the pattern + remaining = key[len(pattern) :] + parts = remaining.split(".") + + # Need at least feature name and stat type + if len(parts) >= 2: + # Last part is the stat type (mean, std, min, max, etc.) + stat_type = parts[-1] + # Everything else is the feature name + feature_name = ".".join(parts[:-1]).replace("_", ".") + + # Add to stats + if feature_name not in stats: + stats[feature_name] = {} + stats[feature_name][stat_type] = tensor.clone() + + # Only process the first matching pattern + break + + return stats + + +def detect_features_and_norm_modes( + config: dict[str, Any], stats: dict[str, dict[str, torch.Tensor]] +) -> tuple[dict[str, PolicyFeature], dict[FeatureType, NormalizationMode]]: + """ + Infers policy features and normalization modes from the model config and stats. + + This function first attempts to find feature definitions and normalization + mappings directly from the policy's configuration file. If this information is + not present, it infers it from the extracted normalization statistics, using + tensor shapes to determine feature shapes and the presence of specific stat + keys (e.g., 'mean'/'std' vs 'min'/'max') to determine the normalization mode. + It applies sensible defaults if inference is not possible. + + Args: + config: The policy's configuration dictionary from `config.json`. + stats: The normalization statistics extracted from the model's state_dict. + + Returns: + A tuple containing: + - A dictionary mapping feature names to `PolicyFeature` objects. + - A dictionary mapping `FeatureType` enums to `NormalizationMode` enums. + """ + features = {} + norm_modes = {} + + # First, check if there's a normalization_mapping in the config + if "normalization_mapping" in config: + print(f"Found normalization_mapping in config: {config['normalization_mapping']}") + # Extract normalization modes from config + for feature_type_str, mode_str in config["normalization_mapping"].items(): + # Convert string to FeatureType enum + try: + if feature_type_str == "VISUAL": + feature_type = FeatureType.VISUAL + elif feature_type_str == "STATE": + feature_type = FeatureType.STATE + elif feature_type_str == "ACTION": + feature_type = FeatureType.ACTION + else: + print(f"Warning: Unknown feature type '{feature_type_str}', skipping") + continue + except (AttributeError, ValueError): + print(f"Warning: Could not parse feature type '{feature_type_str}', skipping") + continue + + # Convert string to NormalizationMode enum + try: + if mode_str == "MEAN_STD": + mode = NormalizationMode.MEAN_STD + elif mode_str == "MIN_MAX": + mode = NormalizationMode.MIN_MAX + elif mode_str == "IDENTITY": + mode = NormalizationMode.IDENTITY + else: + print( + f"Warning: Unknown normalization mode '{mode_str}' for feature type '{feature_type_str}'" + ) + continue + except (AttributeError, ValueError): + print(f"Warning: Could not parse normalization mode '{mode_str}', skipping") + continue + + norm_modes[feature_type] = mode + + # Try to extract from config + if "features" in config: + for key, feature_config in config["features"].items(): + shape = feature_config.get("shape", feature_config.get("dim")) + shape = (shape,) if isinstance(shape, int) else tuple(shape) + + # Determine feature type + if "image" in key or "visual" in key: + feature_type = FeatureType.VISUAL + elif "state" in key: + feature_type = FeatureType.STATE + elif "action" in key: + feature_type = FeatureType.ACTION + else: + feature_type = FeatureType.STATE # Default + + features[key] = PolicyFeature(feature_type, shape) + + # If no features in config, infer from stats + if not features: + for key, stat_dict in stats.items(): + # Get shape from any stat tensor + tensor = next(iter(stat_dict.values())) + shape = tuple(tensor.shape) + + # Determine feature type based on key + if "image" in key or "visual" in key or "pixels" in key: + feature_type = FeatureType.VISUAL + elif "state" in key or "joint" in key or "position" in key: + feature_type = FeatureType.STATE + elif "action" in key: + feature_type = FeatureType.ACTION + else: + feature_type = FeatureType.STATE + + features[key] = PolicyFeature(feature_type, shape) + + # If normalization modes weren't in config, determine based on available stats + if not norm_modes: + for key, stat_dict in stats.items(): + if key in features: + if "mean" in stat_dict and "std" in stat_dict: + feature_type = features[key].type + if feature_type not in norm_modes: + norm_modes[feature_type] = NormalizationMode.MEAN_STD + elif "min" in stat_dict and "max" in stat_dict: + feature_type = features[key].type + if feature_type not in norm_modes: + norm_modes[feature_type] = NormalizationMode.MIN_MAX + + # Default normalization modes if not detected + if FeatureType.VISUAL not in norm_modes: + norm_modes[FeatureType.VISUAL] = NormalizationMode.MEAN_STD + if FeatureType.STATE not in norm_modes: + norm_modes[FeatureType.STATE] = NormalizationMode.MIN_MAX + if FeatureType.ACTION not in norm_modes: + norm_modes[FeatureType.ACTION] = NormalizationMode.MEAN_STD + + return features, norm_modes + + +def remove_normalization_layers(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Creates a new state_dict with all normalization-related layers removed. + + This function filters the original state dictionary, excluding any keys that + match a set of predefined patterns associated with normalization modules. + + Args: + state_dict: The original model state dictionary. + + Returns: + A new state dictionary containing only the core model weights, without + any normalization parameters. + """ + new_state_dict = {} + + # Patterns to remove + remove_patterns = [ + "normalize_inputs.", + "unnormalize_outputs.", + "normalize_targets.", # Added pattern for target normalization + "normalize.", + "unnormalize.", + "input_normalizer.", + "output_normalizer.", + "normalizer.", + ] + + for key, tensor in state_dict.items(): + should_remove = any(pattern in key for pattern in remove_patterns) + if not should_remove: + new_state_dict[key] = tensor + + return new_state_dict + + +def clean_state_dict( + state_dict: dict[str, torch.Tensor], remove_str: str = "._orig_mod" +) -> dict[str, torch.Tensor]: + """ + Remove a substring (e.g. '._orig_mod') from all keys in a state dict. + + Args: + state_dict (dict): The original state dict. + remove_str (str): The substring to remove from the keys. + + Returns: + dict: A new state dict with cleaned keys. + """ + new_state_dict = {} + for k, v in state_dict.items(): + new_k = k.replace(remove_str, "") + new_state_dict[new_k] = v + return new_state_dict + + +def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]: + """ + Converts a feature dictionary from the old config format to the new `PolicyFeature` format. + + Args: + features_dict: The feature dictionary in the old format, where values are + simple dictionaries (e.g., `{"shape": [7]}`). + + Returns: + A dictionary mapping feature names to `PolicyFeature` dataclass objects. + """ + converted_features = {} + + for key, feature_dict in features_dict.items(): + # Determine feature type based on key + if "image" in key or "visual" in key: + feature_type = FeatureType.VISUAL + elif "state" in key: + feature_type = FeatureType.STATE + elif "action" in key: + feature_type = FeatureType.ACTION + else: + feature_type = FeatureType.STATE + + # Get shape from feature dict + shape = feature_dict.get("shape", feature_dict.get("dim")) + shape = (shape,) if isinstance(shape, int) else tuple(shape) if shape is not None else () + + converted_features[key] = PolicyFeature(feature_type, shape) + + return converted_features + + +def load_model_from_hub( + repo_id: str, revision: str | None = None +) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: + """ + Downloads and loads a model's state_dict and configs from the Hugging Face Hub. + + Args: + repo_id: The repository ID on the Hub (e.g., 'lerobot/aloha'). + revision: The specific git revision (branch, tag, or commit hash) to use. + + Returns: + A tuple containing the model's state dictionary, the policy configuration, + and the training configuration. + """ + # Download files. + safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision) + + config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision) + train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision) + + # Load state_dict + state_dict = load_safetensors(safetensors_path) + + # Load config + with open(config_path) as f: + config = json.load(f) + + with open(train_config_path) as f: + train_config = json.load(f) + + return state_dict, config, train_config + + +def main(): + parser = argparse.ArgumentParser( + description="Migrate policy models with normalization layers to new pipeline system" + ) + parser.add_argument( + "--pretrained-path", + type=str, + required=True, + help="Path to pretrained model (hub repo or local directory)", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Output directory for migrated model (default: same as pretrained-path)", + ) + parser.add_argument("--push-to-hub", action="store_true", help="Push migrated model to hub") + parser.add_argument( + "--hub-repo-id", + type=str, + default=None, + help="Hub repository ID for pushing (default: same as pretrained-path)", + ) + parser.add_argument("--revision", type=str, default=None, help="Revision of the model to load") + parser.add_argument("--private", action="store_true", help="Make the hub repository private") + parser.add_argument( + "--branch", + type=str, + default=None, + help="Git branch to use when pushing to hub. If specified, a PR will be created automatically (default: push directly to main)", + ) + + args = parser.parse_args() + + # Load model and config + print(f"Loading model from {args.pretrained_path}...") + if os.path.isdir(args.pretrained_path): + # Local directory + state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors")) + with open(os.path.join(args.pretrained_path, "config.json")) as f: + config = json.load(f) + with open(os.path.join(args.pretrained_path, "train_config.json")) as f: + train_config = json.load(f) + else: + # Hub repository + state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision) + + # Extract normalization statistics + print("Extracting normalization statistics...") + stats = extract_normalization_stats(state_dict) + + print(f"Found normalization statistics for: {list(stats.keys())}") + + # Detect input features and normalization modes + print("Detecting features and normalization modes...") + features, norm_map = detect_features_and_norm_modes(config, stats) + + print(f"Detected features: {list(features.keys())}") + print(f"Normalization modes: {norm_map}") + + # Remove normalization layers from state_dict + print("Removing normalization layers from model...") + new_state_dict = remove_normalization_layers(state_dict) + new_state_dict = clean_state_dict(new_state_dict, remove_str="._orig_mod") + + removed_keys = set(state_dict.keys()) - set(new_state_dict.keys()) + if removed_keys: + print(f"Removed {len(removed_keys)} normalization layer keys") + + # Determine output path + if args.output_dir: + output_dir = Path(args.output_dir) + else: + if os.path.isdir(args.pretrained_path): + output_dir = Path(args.pretrained_path).parent / f"{Path(args.pretrained_path).name}_migrated" + else: + output_dir = Path(f"./{args.pretrained_path.replace('/', '_')}_migrated") + + output_dir.mkdir(parents=True, exist_ok=True) + + # Extract policy type from config + if "type" not in config: + raise ValueError("Policy type not found in config.json. The config must contain a 'type' field.") + + policy_type = config["type"] + print(f"Detected policy type: {policy_type}") + + # Clean up config - remove fields that shouldn't be passed to config constructor + cleaned_config = dict(config) + + # Remove fields that are not part of the config class constructors + fields_to_remove = ["normalization_mapping", "type"] + for field in fields_to_remove: + if field in cleaned_config: + print(f"Removing '{field}' field from config") + del cleaned_config[field] + + # Convert input_features and output_features to PolicyFeature objects if they exist + if "input_features" in cleaned_config: + cleaned_config["input_features"] = convert_features_to_policy_features( + cleaned_config["input_features"] + ) + if "output_features" in cleaned_config: + cleaned_config["output_features"] = convert_features_to_policy_features( + cleaned_config["output_features"] + ) + + # Add normalization mapping to config + cleaned_config["normalization_mapping"] = norm_map + + # Create policy configuration using the factory + print(f"Creating {policy_type} policy configuration...") + policy_config = make_policy_config(policy_type, **cleaned_config) + + # Create policy instance using the factory + print(f"Instantiating {policy_type} policy...") + policy_class = get_policy_class(policy_type) + policy = policy_class(policy_config) + + # Load the cleaned state dict + policy.load_state_dict(new_state_dict, strict=True) + print("Successfully loaded cleaned state dict into policy model") + + # Create preprocessor and postprocessor using the factory + print("Creating preprocessor and postprocessor using make_pre_post_processors...") + preprocessor, postprocessor = make_pre_post_processors(policy_cfg=policy_config, dataset_stats=stats) + + # Determine hub repo ID if pushing to hub + hub_repo_id = None + if args.push_to_hub: + if args.hub_repo_id: + hub_repo_id = args.hub_repo_id + else: + if not os.path.isdir(args.pretrained_path): + # Use same repo with "_migrated" suffix + hub_repo_id = f"{args.pretrained_path}_migrated" + else: + raise ValueError("--hub-repo-id must be specified when pushing local model to hub") + + # Save all components to local directory first + print(f"Saving preprocessor to {output_dir}...") + preprocessor.save_pretrained(output_dir) + + print(f"Saving postprocessor to {output_dir}...") + postprocessor.save_pretrained(output_dir) + + print(f"Saving model to {output_dir}...") + policy.save_pretrained(output_dir) + + # Generate and save model card + print("Generating model card...") + # Get metadata from original config + dataset_repo_id = train_config.get("repo_id", "unknown") + license = config.get("license", "apache-2.0") + + tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type] + tags = set(tags).union({"robotics", "lerobot", policy_type}) + tags = list(tags) + + # Generate model card + card = policy.generate_model_card( + dataset_repo_id=dataset_repo_id, model_type=policy_type, license=license, tags=tags + ) + + # Save model card locally + card.save(str(output_dir / "README.md")) + print(f"Model card saved to {output_dir / 'README.md'}") + # Push all files to hub in a single operation if requested + if args.push_to_hub and hub_repo_id: + api = HfApi() + + # Determine if we should create a PR (automatically if branch is specified) + create_pr = args.branch is not None + target_location = f"branch '{args.branch}'" if args.branch else "main branch" + + print(f"Pushing all migrated files to {hub_repo_id} on {target_location}...") + + # Upload all files in a single commit with automatic PR creation if branch specified + commit_message = "Migrate policy to PolicyProcessorPipeline system" + commit_description = None + + if create_pr: + # Separate commit description for PR body + commit_description = """🤖 **Automated Policy Migration to PolicyProcessorPipeline** + +This PR migrates your model to the new LeRobot policy format using the modern PolicyProcessorPipeline architecture. + +## What Changed + +### ✨ **New Architecture - PolicyProcessorPipeline** +Your model now uses external PolicyProcessorPipeline components for data processing instead of built-in normalization layers. This provides: +- **Modularity**: Separate preprocessing and postprocessing pipelines +- **Flexibility**: Easy to swap, configure, and debug processing steps +- **Compatibility**: Works with the latest LeRobot ecosystem + +### 🔧 **Normalization Extraction** +We've extracted normalization statistics from your model's state_dict and removed the built-in normalization layers: +- **Extracted patterns**: `normalize_inputs.*`, `unnormalize_outputs.*`, `normalize.*`, `unnormalize.*`, `input_normalizer.*`, `output_normalizer.*` +- **Statistics preserved**: Mean, std, min, max values for all features +- **Clean model**: State dict now contains only core model weights + +### 📦 **Files Added** +- **preprocessor_config.json**: Configuration for input preprocessing pipeline +- **postprocessor_config.json**: Configuration for output postprocessing pipeline +- **model.safetensors**: Clean model weights without normalization layers +- **config.json**: Updated model configuration +- **train_config.json**: Training configuration +- **README.md**: Updated model card with migration information + +### 🚀 **Benefits** +- **Backward Compatible**: Your model behavior remains identical +- **Future Ready**: Compatible with latest LeRobot features and updates +- **Debuggable**: Easy to inspect and modify processing steps +- **Portable**: Processors can be shared and reused across models + +### 💻 **Usage** +```python +# Load your migrated model +from lerobot.policies import get_policy_class +from lerobot.processor import PolicyProcessorPipeline + +# The preprocessor and postprocessor are now external +preprocessor = PolicyProcessorPipeline.from_pretrained("your-model-repo", config_filename="preprocessor_config.json") +postprocessor = PolicyProcessorPipeline.from_pretrained("your-model-repo", config_filename="postprocessor_config.json") +policy = get_policy_class("your-policy-type").from_pretrained("your-model-repo") + +# Process data through the pipeline +processed_batch = preprocessor(raw_batch) +action = policy(processed_batch) +final_action = postprocessor(action) +``` + +*Generated automatically by the LeRobot policy migration script*""" + + upload_kwargs = { + "repo_id": hub_repo_id, + "folder_path": output_dir, + "repo_type": "model", + "commit_message": commit_message, + "revision": args.branch, + "create_pr": create_pr, + "allow_patterns": ["*.json", "*.safetensors", "*.md"], + "ignore_patterns": ["*.tmp", "*.log"], + } + + # Add commit_description for PR body if creating PR + if create_pr and commit_description: + upload_kwargs["commit_description"] = commit_description + + api.upload_folder(**upload_kwargs) + + if create_pr: + print("All files pushed and pull request created successfully!") + else: + print("All files pushed to main branch successfully!") + + print("\nMigration complete!") + print(f"Migrated model saved to: {output_dir}") + if args.push_to_hub and hub_repo_id: + if args.branch: + print( + f"Successfully pushed all files to branch '{args.branch}' and created PR on https://huggingface.co/{hub_repo_id}" + ) + else: + print(f"Successfully pushed to https://huggingface.co/{hub_repo_id}") + if args.branch: + print(f"\nView the branch at: https://huggingface.co/{hub_repo_id}/tree/{args.branch}") + print( + f"View the PR at: https://huggingface.co/{hub_repo_id}/discussions (look for the most recent PR)" + ) + else: + print(f"\nView the changes at: https://huggingface.co/{hub_repo_id}") + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 14628727f..bece54f0b 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -1,67 +1,353 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations -from collections.abc import Mapping +from copy import deepcopy from dataclasses import dataclass, field from typing import Any -import numpy as np import torch from torch import Tensor -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey + +from .converters import from_tensor_to_numpy, to_tensor +from .core import EnvTransition, PolicyAction, TransitionKey +from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry -def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]: - """Convert numpy arrays and other types to torch tensors.""" - tensor_stats: dict[str, dict[str, Tensor]] = {} - for key, sub in stats.items(): - tensor_stats[key] = {} - for stat_name, value in sub.items(): - if isinstance(value, np.ndarray): - tensor_val = torch.from_numpy(value.astype(np.float32)) - elif isinstance(value, torch.Tensor): - tensor_val = value.to(dtype=torch.float32) - elif isinstance(value, (int, float, list, tuple)): - tensor_val = torch.tensor(value, dtype=torch.float32) - else: - raise TypeError(f"Unsupported type for stats['{key}']['{stat_name}']: {type(value)}") - tensor_stats[key][stat_name] = tensor_val - return tensor_stats +@dataclass +class _NormalizationMixin: + """ + A mixin class providing core functionality for normalization and unnormalization. + + This class manages normalization statistics (`stats`), converts them to tensors for + efficient computation, handles device placement, and implements the logic for + applying normalization transformations (mean/std and min/max). It is designed to + be inherited by concrete `ProcessorStep` implementations and should not be used + directly. + + **Stats Override Preservation:** + When stats are explicitly provided during construction (e.g., via overrides in + `DataProcessorPipeline.from_pretrained()`), they are preserved even when + `load_state_dict()` is called. This allows users to override normalization + statistics from saved models while keeping the rest of the model state intact. + + Examples: + ```python + # Common use case: Override with dataset stats + from lerobot.datasets import LeRobotDataset + + dataset = LeRobotDataset("my_dataset") + pipeline = DataProcessorPipeline.from_pretrained( + "model_path", overrides={"normalizer_processor": {"stats": dataset.meta.stats}} + ) + # dataset.meta.stats will be used, not the stats from the saved model + + # Custom stats override + custom_stats = {"action": {"mean": [0.0], "std": [1.0]}} + pipeline = DataProcessorPipeline.from_pretrained( + "model_path", overrides={"normalizer_processor": {"stats": custom_stats}} + ) + ``` + + Attributes: + features: A dictionary mapping feature names to `PolicyFeature` objects, defining + the data structure to be processed. + norm_map: A dictionary mapping `FeatureType` to `NormalizationMode`, specifying + which normalization method to use for each type of feature. + stats: A dictionary containing the normalization statistics (e.g., mean, std, + min, max) for each feature. + device: The PyTorch device on which to store and perform tensor operations. + eps: A small epsilon value to prevent division by zero in normalization + calculations. + normalize_observation_keys: An optional set of keys to selectively apply + normalization to specific observation features. + _tensor_stats: An internal dictionary holding the normalization statistics as + PyTorch tensors. + _stats_explicitly_provided: Internal flag tracking whether stats were explicitly + provided during construction (used for override preservation). + """ + + features: dict[str, PolicyFeature] + norm_map: dict[FeatureType, NormalizationMode] + stats: dict[str, dict[str, Any]] | None = None + device: torch.device | str | None = None + dtype: torch.dtype | None = None + eps: float = 1e-8 + normalize_observation_keys: set[str] | None = None + + _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) + _stats_explicitly_provided: bool = field(default=False, init=False, repr=False) + + def __post_init__(self): + """ + Initializes the mixin after dataclass construction. + + This method handles the robust deserialization of `features` and `norm_map` + from JSON-compatible formats (where enums become strings and tuples become + lists) and converts the provided `stats` dictionary into a dictionary of + tensors (`_tensor_stats`) on the specified device. + """ + # Track if stats were explicitly provided (not None and not empty) + self._stats_explicitly_provided = self.stats is not None and bool(self.stats) + # Robust JSON deserialization handling (guard empty maps). + if self.features: + first_val = next(iter(self.features.values())) + if isinstance(first_val, dict): + reconstructed = {} + for key, ft_dict in self.features.items(): + reconstructed[key] = PolicyFeature( + type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) + ) + self.features = reconstructed + + if self.norm_map: + # if keys are strings (JSON), rebuild enum map + if all(isinstance(k, str) for k in self.norm_map.keys()): + reconstructed = {} + for ft_type_str, norm_mode_str in self.norm_map.items(): + reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) + self.norm_map = reconstructed + + # Convert stats to tensors and move to the target device once during initialization. + self.stats = self.stats or {} + if self.dtype is None: + self.dtype = torch.float32 + self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) + + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ) -> _NormalizationMixin: + """ + Moves the processor's normalization stats to the specified device. + + Args: + device: The target PyTorch device. + + Returns: + The instance of the class, allowing for method chaining. + """ + if device is not None: + self.device = device + if dtype is not None: + self.dtype = dtype + self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) + return self + + def state_dict(self) -> dict[str, Tensor]: + """ + Returns the normalization statistics as a flat state dictionary. + + All tensors are moved to the CPU before being returned, which is standard practice + for saving state dictionaries. + + Returns: + A flat dictionary mapping from `'feature_name.stat_name'` to the + corresponding statistics tensor on the CPU. + """ + flat: dict[str, Tensor] = {} + for key, sub in self._tensor_stats.items(): + for stat_name, tensor in sub.items(): + flat[f"{key}.{stat_name}"] = tensor.cpu() # Always save to CPU + return flat + + def load_state_dict(self, state: dict[str, Tensor]) -> None: + """ + Loads normalization statistics from a state dictionary. + + The loaded tensors are moved to the processor's configured device. + + **Stats Override Preservation:** + If stats were explicitly provided during construction (e.g., via overrides in + `DataProcessorPipeline.from_pretrained()`), they are preserved and the state + dictionary is ignored. This allows users to override normalization statistics + while still loading the rest of the model state. + + This behavior is crucial for scenarios where users want to adapt a pretrained + model to a new dataset with different statistics without retraining the entire + model. + + Args: + state: A flat state dictionary with keys in the format + `'feature_name.stat_name'`. + + Note: + When stats are preserved due to explicit provision, only the tensor + representation is updated to ensure consistency with the current device + and dtype settings. + """ + # If stats were explicitly provided during construction, preserve them + if self._stats_explicitly_provided and self.stats is not None: + # Don't load from state_dict, keep the explicitly provided stats + # But ensure _tensor_stats is properly initialized + self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment] + return + + # Normal behavior: load stats from state_dict + self._tensor_stats.clear() + for flat_key, tensor in state.items(): + key, stat_name = flat_key.rsplit(".", 1) + # Load to the processor's configured device. + self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to( + dtype=torch.float32, device=self.device + ) + + # Reconstruct the original stats dict from tensor stats for compatibility with to() method + # and other functions that rely on self.stats + self.stats = {} + for key, tensor_dict in self._tensor_stats.items(): + self.stats[key] = {} + for stat_name, tensor in tensor_dict.items(): + # Convert tensor back to python/numpy format + self.stats[key][stat_name] = from_tensor_to_numpy(tensor) + + def get_config(self) -> dict[str, Any]: + """ + Returns a serializable dictionary of the processor's configuration. + + This method is used when saving the processor to disk, ensuring that its + configuration can be reconstructed later. + + Returns: + A JSON-serializable dictionary containing the configuration. + """ + config = { + "eps": self.eps, + "features": { + key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() + }, + "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, + } + if self.normalize_observation_keys is not None: + config["normalize_observation_keys"] = sorted(self.normalize_observation_keys) + return config + + def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]: + """ + Applies (un)normalization to all relevant features in an observation dictionary. + + Args: + observation: The observation dictionary to process. + inverse: If `True`, applies unnormalization; otherwise, applies normalization. + + Returns: + A new observation dictionary with the transformed tensor values. + """ + new_observation = dict(observation) + for key, feature in self.features.items(): + if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys: + continue + if feature.type != FeatureType.ACTION and key in new_observation: + # Convert to tensor but preserve original dtype for adaptation logic + tensor = torch.as_tensor(new_observation[key]) + new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse) + return new_observation + + def _normalize_action(self, action: Tensor, inverse: bool) -> Tensor: + # Convert to tensor but preserve original dtype for adaptation logic + """ + Applies (un)normalization to an action tensor. + + Args: + action: The action tensor to process. + inverse: If `True`, applies unnormalization; otherwise, applies normalization. + + Returns: + The transformed action tensor. + """ + processed_action = self._apply_transform(action, "action", FeatureType.ACTION, inverse=inverse) + return processed_action + + def _apply_transform( + self, tensor: Tensor, key: str, feature_type: FeatureType, *, inverse: bool = False + ) -> Tensor: + """ + Core logic to apply a normalization or unnormalization transformation to a tensor. + + This method selects the appropriate normalization mode (e.g., mean/std, min/max) + based on the feature type and applies the corresponding mathematical operation. + + Args: + tensor: The input tensor to transform. + key: The feature key corresponding to the tensor. + feature_type: The `FeatureType` of the tensor. + inverse: If `True`, applies the inverse transformation (unnormalization). + + Returns: + The transformed tensor. + + Raises: + ValueError: If an unsupported normalization mode is encountered. + """ + norm_mode = self.norm_map.get(feature_type, NormalizationMode.IDENTITY) + if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats: + return tensor + + if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX): + raise ValueError(f"Unsupported normalization mode: {norm_mode}") + + # For Accelerate compatibility: Ensure stats are on the same device and dtype as the input tensor + if self._tensor_stats and key in self._tensor_stats: + first_stat = next(iter(self._tensor_stats[key].values())) + if first_stat.device != tensor.device or first_stat.dtype != tensor.dtype: + self.to(device=tensor.device, dtype=tensor.dtype) + + stats = self._tensor_stats[key] + + if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + # Avoid division by zero by adding a small epsilon. + denom = std + self.eps + if inverse: + return tensor * std + mean + return (tensor - mean) / denom + + if norm_mode == NormalizationMode.MIN_MAX and "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + denom = max_val - min_val + # When min_val == max_val, substitute the denominator with a small epsilon + # to prevent division by zero. This consistently maps an input equal to + # min_val to -1, ensuring a stable transformation. + denom = torch.where( + denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom + ) + if inverse: + # Map from [-1, 1] back to [min, max] + return (tensor + 1) / 2 * denom + min_val + # Map from [min, max] to [-1, 1] + return 2 * (tensor - min_val) / denom - 1 + + # If necessary stats are missing, return input unchanged. + return tensor @dataclass @ProcessorStepRegistry.register(name="normalizer_processor") -class NormalizerProcessor: - """Normalizes observations and actions in a single processor step. - - This processor handles normalization of both observation and action tensors - using either mean/std normalization or min/max scaling to a [-1, 1] range. - - For each tensor key in the stats dictionary, the processor will: - - Use mean/std normalization if those statistics are provided: (x - mean) / std - - Use min/max scaling if those statistics are provided: 2 * (x - min) / (max - min) - 1 - - The processor can be configured to normalize only specific keys by setting - the normalize_keys parameter. +class NormalizerProcessorStep(_NormalizationMixin, ProcessorStep): """ + A processor step that applies normalization to observations and actions in a transition. - # Features and normalisation map are mandatory to match the design of normalize.py - features: dict[str, PolicyFeature] - norm_map: dict[FeatureType, NormalizationMode] - - # Pre-computed statistics coming from dataset.meta.stats for instance. - stats: dict[str, dict[str, Any]] | None = None - - # Explicit subset of keys to normalise. If ``None`` every key (except - # "action") found in ``stats`` will be normalised. Using a ``set`` makes - # membership checks O(1). - normalize_keys: set[str] | None = None - - eps: float = 1e-8 - - _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) + This class uses the logic from `_NormalizationMixin` to perform forward normalization + (e.g., scaling data to have zero mean and unit variance, or to the range [-1, 1]). + It is typically used in the pre-processing pipeline before feeding data to a policy. + """ @classmethod def from_lerobot_dataset( @@ -70,158 +356,73 @@ class NormalizerProcessor: features: dict[str, PolicyFeature], norm_map: dict[FeatureType, NormalizationMode], *, - normalize_keys: set[str] | None = None, + normalize_observation_keys: set[str] | None = None, eps: float = 1e-8, - ) -> NormalizerProcessor: - """Factory helper that pulls statistics from a :class:`LeRobotDataset`. - - The features and norm_map parameters are mandatory to match the design - pattern used in normalize.py. + device: torch.device | str | None = None, + ) -> NormalizerProcessorStep: """ + Creates a `NormalizerProcessorStep` instance using statistics from a `LeRobotDataset`. + Args: + dataset: The dataset from which to extract normalization statistics. + features: The feature definition for the processor. + norm_map: The mapping from feature types to normalization modes. + normalize_observation_keys: An optional set of observation keys to normalize. + eps: A small epsilon value for numerical stability. + device: The target device for the processor. + + Returns: + A new instance of `NormalizerProcessorStep`. + """ return cls( features=features, norm_map=norm_map, stats=dataset.meta.stats, - normalize_keys=normalize_keys, + normalize_observation_keys=normalize_observation_keys, eps=eps, + device=device, ) - def __post_init__(self): - # Handle deserialization from JSON config - if self.features and isinstance(list(self.features.values())[0], dict): - # Features came from JSON - need to reconstruct PolicyFeature objects - reconstructed_features = {} - for key, ft_dict in self.features.items(): - reconstructed_features[key] = PolicyFeature( - type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) - ) - self.features = reconstructed_features - - if self.norm_map and isinstance(list(self.norm_map.keys())[0], str): - # norm_map came from JSON - need to reconstruct enum keys and values - reconstructed_norm_map = {} - for ft_type_str, norm_mode_str in self.norm_map.items(): - reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) - self.norm_map = reconstructed_norm_map - - # Convert statistics once so we avoid repeated numpy→Tensor conversions - # during runtime. - self.stats = self.stats or {} - self._tensor_stats = _convert_stats_to_tensors(self.stats) - - # Ensure *normalize_keys* is a set for fast look-ups and compare by - # value later when returning the configuration. - if self.normalize_keys is not None and not isinstance(self.normalize_keys, set): - self.normalize_keys = set(self.normalize_keys) - - def _normalize_obs(self, observation): - if observation is None: - return None - - # Decide which keys should be normalised for this call. - if self.normalize_keys is not None: - keys_to_norm = self.normalize_keys - else: - # Use feature map to skip action keys. - keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION} - - processed = dict(observation) - for key in keys_to_norm: - if key not in processed or key not in self._tensor_stats: - continue - - orig_val = processed[key] - tensor = ( - orig_val.to(dtype=torch.float32) - if isinstance(orig_val, torch.Tensor) - else torch.as_tensor(orig_val, dtype=torch.float32) - ) - stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} - - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - processed[key] = (tensor - mean) / (std + self.eps) - elif "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 - return processed - - def _normalize_action(self, action): - if action is None or "action" not in self._tensor_stats: - return action - - tensor = ( - action.to(dtype=torch.float32) - if isinstance(action, torch.Tensor) - else torch.as_tensor(action, dtype=torch.float32) - ) - stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - return (tensor - mean) / (std + self.eps) - if "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 - raise ValueError("Action stats must contain either ('mean','std') or ('min','max')") - def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION)) - action = self._normalize_action(transition.get(TransitionKey.ACTION)) - - # Create a new transition with normalized values new_transition = transition.copy() - new_transition[TransitionKey.OBSERVATION] = observation - new_transition[TransitionKey.ACTION] = action + + # Handle observation normalization. + observation = new_transition.get(TransitionKey.OBSERVATION) + if observation is not None: + new_transition[TransitionKey.OBSERVATION] = self._normalize_observation( + observation, inverse=False + ) + + # Handle action normalization. + action = new_transition.get(TransitionKey.ACTION) + + if action is None: + return new_transition + + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + + new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=False) + return new_transition - def get_config(self) -> dict[str, Any]: - config = { - "eps": self.eps, - "features": { - key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() - }, - "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, - } - if self.normalize_keys is not None: - # Serialise as a list for YAML / JSON friendliness - config["normalize_keys"] = sorted(self.normalize_keys) - return config - - def state_dict(self) -> dict[str, Tensor]: - flat = {} - for key, sub in self._tensor_stats.items(): - for stat_name, tensor in sub.items(): - flat[f"{key}.{stat_name}"] = tensor - return flat - - def load_state_dict(self, state: Mapping[str, Tensor]) -> None: - self._tensor_stats.clear() - for flat_key, tensor in state.items(): - key, stat_name = flat_key.rsplit(".", 1) - self._tensor_stats.setdefault(key, {})[stat_name] = tensor - - def reset(self): - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: return features @dataclass @ProcessorStepRegistry.register(name="unnormalizer_processor") -class UnnormalizerProcessor: - """Inverse normalisation for observations and actions. - - Exactly mirrors :class:`NormalizerProcessor` but applies the inverse - transform. +class UnnormalizerProcessorStep(_NormalizationMixin, ProcessorStep): """ + A processor step that applies unnormalization to observations and actions. - features: dict[str, PolicyFeature] - norm_map: dict[FeatureType, NormalizationMode] - stats: dict[str, dict[str, Any]] | None = None - - _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) + This class inverts the normalization process, scaling data back to its original + range. It is typically used in the post-processing pipeline to convert a policy's + normalized action output into a format that can be executed by a robot or + environment. + """ @classmethod def from_lerobot_dataset( @@ -229,103 +430,72 @@ class UnnormalizerProcessor: dataset: LeRobotDataset, features: dict[str, PolicyFeature], norm_map: dict[FeatureType, NormalizationMode], - ) -> UnnormalizerProcessor: - return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats) + *, + device: torch.device | str | None = None, + ) -> UnnormalizerProcessorStep: + """ + Creates an `UnnormalizerProcessorStep` using statistics from a `LeRobotDataset`. - def __post_init__(self): - # Handle deserialization from JSON config - if self.features and isinstance(list(self.features.values())[0], dict): - # Features came from JSON - need to reconstruct PolicyFeature objects - reconstructed_features = {} - for key, ft_dict in self.features.items(): - reconstructed_features[key] = PolicyFeature( - type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) - ) - self.features = reconstructed_features + Args: + dataset: The dataset from which to extract normalization statistics. + features: The feature definition for the processor. + norm_map: The mapping from feature types to normalization modes. + device: The target device for the processor. - if self.norm_map and isinstance(list(self.norm_map.keys())[0], str): - # norm_map came from JSON - need to reconstruct enum keys and values - reconstructed_norm_map = {} - for ft_type_str, norm_mode_str in self.norm_map.items(): - reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) - self.norm_map = reconstructed_norm_map - - self.stats = self.stats or {} - self._tensor_stats = _convert_stats_to_tensors(self.stats) - - def _unnormalize_obs(self, observation): - if observation is None: - return None - keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION] - processed = dict(observation) - for key in keys: - if key not in processed or key not in self._tensor_stats: - continue - orig_val = processed[key] - tensor = ( - orig_val.to(dtype=torch.float32) - if isinstance(orig_val, torch.Tensor) - else torch.as_tensor(orig_val, dtype=torch.float32) - ) - stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - processed[key] = tensor * std + mean - elif "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val - return processed - - def _unnormalize_action(self, action): - if action is None or "action" not in self._tensor_stats: - return action - tensor = ( - action.to(dtype=torch.float32) - if isinstance(action, torch.Tensor) - else torch.as_tensor(action, dtype=torch.float32) - ) - stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} - if "mean" in stats and "std" in stats: - mean, std = stats["mean"], stats["std"] - return tensor * std + mean - if "min" in stats and "max" in stats: - min_val, max_val = stats["min"], stats["max"] - return (tensor + 1) / 2 * (max_val - min_val) + min_val - raise ValueError("Action stats must contain either ('mean','std') or ('min','max')") + Returns: + A new instance of `UnnormalizerProcessorStep`. + """ + return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats, device=device) def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION)) - action = self._unnormalize_action(transition.get(TransitionKey.ACTION)) - - # Create a new transition with unnormalized values new_transition = transition.copy() - new_transition[TransitionKey.OBSERVATION] = observation - new_transition[TransitionKey.ACTION] = action + + # Handle observation unnormalization. + observation = new_transition.get(TransitionKey.OBSERVATION) + if observation is not None: + new_transition[TransitionKey.OBSERVATION] = self._normalize_observation(observation, inverse=True) + + # Handle action unnormalization. + action = new_transition.get(TransitionKey.ACTION) + + if action is None: + return new_transition + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type got {type(action)}") + + new_transition[TransitionKey.ACTION] = self._normalize_action(action, inverse=True) + return new_transition - def get_config(self) -> dict[str, Any]: - return { - "features": { - key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() - }, - "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, - } - - def state_dict(self) -> dict[str, Tensor]: - flat = {} - for key, sub in self._tensor_stats.items(): - for stat_name, tensor in sub.items(): - flat[f"{key}.{stat_name}"] = tensor - return flat - - def load_state_dict(self, state: Mapping[str, Tensor]) -> None: - self._tensor_stats.clear() - for flat_key, tensor in state.items(): - key, stat_name = flat_key.rsplit(".", 1) - self._tensor_stats.setdefault(key, {})[stat_name] = tensor - - def reset(self): - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: return features + + +def hotswap_stats( + policy_processor: PolicyProcessorPipeline, stats: dict[str, dict[str, Any]] +) -> PolicyProcessorPipeline: + """ + Replaces normalization statistics in an existing `PolicyProcessorPipeline` instance. + + This function creates a deep copy of the provided pipeline and updates the + statistics of any `NormalizerProcessorStep` or `UnnormalizerProcessorStep` it + contains. This is useful for adapting a trained policy to a new environment or + dataset with different data distributions without having to reconstruct the entire + pipeline. + + Args: + policy_processor: The policy processor pipeline to modify. + stats: The new dictionary of normalization statistics to apply. + + Returns: + A new `PolicyProcessorPipeline` instance with the updated statistics. + """ + rp = deepcopy(policy_processor) + for step in rp.steps: + if isinstance(step, _NormalizationMixin): + step.stats = stats + # Re-initialize tensor_stats on the correct device. + step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype) # type: ignore[assignment] + return rp diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index 7d63db238..71fdbbf0d 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -20,32 +20,54 @@ import numpy as np import torch from torch import Tensor -from lerobot.configs.types import PolicyFeature +from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE -from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry + +from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @dataclass @ProcessorStepRegistry.register(name="observation_processor") -class VanillaObservationProcessor(ObservationProcessor): +class VanillaObservationProcessorStep(ObservationProcessorStep): """ - Processes environment observations into the LeRobot format by handling both images and states. + Processes standard Gymnasium observations into the LeRobot format. - Image processing: - - Converts channel-last (H, W, C) images to channel-first (C, H, W) - - Normalizes uint8 images ([0, 255]) to float32 ([0, 1]) - - Adds a batch dimension if missing - - Supports single images and image dictionaries + This step handles both image and state data from a typical observation dictionary, + preparing it for use in a LeRobot policy. - State processing: - - Maps 'environment_state' to observation.environment_state - - Maps 'agent_pos' to observation.state - - Converts numpy arrays to tensors - - Adds a batch dimension if missing + **Image Processing:** + - Converts channel-last (H, W, C), `uint8` images to channel-first (C, H, W), + `float32` tensors. + - Normalizes pixel values from the [0, 255] range to [0, 1]. + - Adds a batch dimension if one is not already present. + - Recognizes a single image under the key `"pixels"` and maps it to + `"observation.image"`. + - Recognizes a dictionary of images under the key `"pixels"` and maps them + to `"observation.images.{camera_name}"`. + + **State Processing:** + - Maps the `"environment_state"` key to `"observation.environment_state"`. + - Maps the `"agent_pos"` key to `"observation.state"`. + - Converts NumPy arrays to PyTorch tensors. + - Adds a batch dimension if one is not already present. """ def _process_single_image(self, img: np.ndarray) -> Tensor: - """Process a single image array.""" + """ + Processes a single NumPy image array into a channel-first, normalized tensor. + + Args: + img: A NumPy array representing the image, expected to be in channel-last + (H, W, C) format with a `uint8` dtype. + + Returns: + A `float32` PyTorch tensor in channel-first (B, C, H, W) format, with + pixel values normalized to the [0, 1] range. + + Raises: + ValueError: If the input image does not appear to be in channel-last + format or is not of `uint8` dtype. + """ # Convert to tensor img_tensor = torch.from_numpy(img) @@ -106,19 +128,32 @@ class VanillaObservationProcessor(ObservationProcessor): def observation(self, observation): return self._process_observation(observation) - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - """Transforms feature keys to a standardized contract. - - This method handles several renaming patterns: - - Exact matches (e.g., 'pixels' -> 'OBS_IMAGE'). - - Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE'). - - Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1'). - - Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1'). - - environment_state -> OBS_ENV_STATE, - - agent_pos -> OBS_STATE, - - observation.environment_state -> OBS_ENV_STATE, - - observation.agent_pos -> OBS_STATE + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: """ + Transforms feature keys from the Gym standard to the LeRobot standard. + + This method standardizes the feature dictionary by renaming keys according + to LeRobot's conventions, ensuring that policies can be constructed correctly. + It handles various raw key formats, including those with an "observation." prefix. + + **Renaming Rules:** + - `pixels` or `observation.pixels` -> `observation.image` + - `pixels.{cam}` or `observation.pixels.{cam}` -> `observation.images.{cam}` + - `environment_state` or `observation.environment_state` -> `observation.environment_state` + - `agent_pos` or `observation.agent_pos` -> `observation.state` + + Args: + features: The policy features dictionary with Gym-style keys. + + Returns: + The policy features dictionary with standardized LeRobot keys. + """ + # Build a new features mapping keyed by the same FeatureType buckets + # We assume callers already placed features in the correct FeatureType. + new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features.keys()} + exact_pairs = { "pixels": OBS_IMAGE, "environment_state": OBS_ENV_STATE, @@ -129,29 +164,43 @@ class VanillaObservationProcessor(ObservationProcessor): "pixels.": f"{OBS_IMAGES}.", } - for key in list(features.keys()): - matched_prefix = False - for old_prefix, new_prefix in prefix_pairs.items(): - prefixed_old = f"observation.{old_prefix}" - if key.startswith(prefixed_old): - suffix = key[len(prefixed_old) :] - features[f"{new_prefix}{suffix}"] = features.pop(key) - matched_prefix = True - break + # Iterate over all incoming feature buckets and normalize/move each entry + for src_ft, bucket in features.items(): + for key, feat in list(bucket.items()): + handled = False - if key.startswith(old_prefix): - suffix = key[len(old_prefix) :] - features[f"{new_prefix}{suffix}"] = features.pop(key) - matched_prefix = True - break - - if matched_prefix: - continue - - for old, new in exact_pairs.items(): - if key == old or key == f"observation.{old}": - if key in features: - features[new] = features.pop(key) + # Prefix-based rules (e.g. pixels.cam1 -> OBS_IMAGES.cam1) + for old_prefix, new_prefix in prefix_pairs.items(): + prefixed_old = f"observation.{old_prefix}" + if key.startswith(prefixed_old): + suffix = key[len(prefixed_old) :] + new_key = f"{new_prefix}{suffix}" + new_features[src_ft][new_key] = feat + handled = True break - return features + if key.startswith(old_prefix): + suffix = key[len(old_prefix) :] + new_key = f"{new_prefix}{suffix}" + new_features[src_ft][new_key] = feat + handled = True + break + + if handled: + continue + + # Exact-name rules (pixels, environment_state, agent_pos) + for old, new in exact_pairs.items(): + if key == old or key == f"observation.{old}": + new_key = new + new_features[src_ft][new_key] = feat + handled = True + break + + if handled: + continue + + # Default: keep key in the same source FeatureType bucket + new_features[src_ft][key] = feat + + return new_features diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 6e1b2a2cb..1c88cd741 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -13,72 +13,76 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +This module defines a generic, sequential data processing pipeline framework, primarily designed for +transforming robotics data (observations, actions, rewards, etc.). + +The core components are: +- ProcessorStep: An abstract base class for a single data transformation operation. +- ProcessorStepRegistry: A mechanism to register and retrieve ProcessorStep classes by name. +- DataProcessorPipeline: A class that chains multiple ProcessorStep instances together to form a complete + data processing workflow. It integrates with the Hugging Face Hub for easy sharing and versioning of + pipelines, including their configuration and state. +- Specialized abstract ProcessorStep subclasses (e.g., ObservationProcessorStep, ActionProcessorStep) + to simplify the creation of steps that target specific parts of a data transition. +""" + from __future__ import annotations import importlib import json import os +import re +from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence from copy import deepcopy from dataclasses import dataclass, field -from enum import Enum from pathlib import Path -from typing import Any, Protocol, TypedDict +from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, cast import torch -from huggingface_hub import ModelHubMixin, hf_hub_download -from huggingface_hub.errors import HfHubHTTPError +from huggingface_hub import hf_hub_download from safetensors.torch import load_file, save_file -from lerobot.configs.types import PolicyFeature +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.utils.hub import HubMixin +from .converters import batch_to_transition, create_transition, transition_to_batch +from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey -class TransitionKey(str, Enum): - """Keys for accessing EnvTransition dictionary components.""" - - # TODO(Steven): Use consts - OBSERVATION = "observation" - ACTION = "action" - REWARD = "reward" - DONE = "done" - TRUNCATED = "truncated" - INFO = "info" - COMPLEMENTARY_DATA = "complementary_data" - - -EnvTransition = TypedDict( - "EnvTransition", - { - TransitionKey.OBSERVATION.value: dict[str, Any] | None, - TransitionKey.ACTION.value: Any | torch.Tensor | None, - TransitionKey.REWARD.value: float | torch.Tensor | None, - TransitionKey.DONE.value: bool | torch.Tensor | None, - TransitionKey.TRUNCATED.value: bool | torch.Tensor | None, - TransitionKey.INFO.value: dict[str, Any] | None, - TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None, - }, -) +# Generic type variables for pipeline input and output. +TInput = TypeVar("TInput") +TOutput = TypeVar("TOutput") class ProcessorStepRegistry: - """Registry for processor steps that enables saving/loading by name instead of module path.""" + """A registry for ProcessorStep classes to allow instantiation from a string name. + + This class provides a way to map string identifiers to `ProcessorStep` classes, + which is useful for deserializing pipelines from configuration files without + + hardcoding class imports. + """ _registry: dict[str, type] = {} @classmethod - def register(cls, name: str = None): - """Decorator to register a processor step class. + def register(cls, name: str | None = None): + """A class decorator to register a ProcessorStep. Args: - name: Optional registration name. If not provided, uses class name. + name: The name to register the class under. If None, the class's `__name__` is used. - Example: - @ProcessorStepRegistry.register("adaptive_normalizer") - class AdaptiveObservationNormalizer: - ... + Returns: + A decorator function that registers the class and returns it. + + Raises: + ValueError: If a step with the same name is already registered. """ def decorator(step_class: type) -> type: + """The actual decorator that performs the registration.""" registration_name = name if name is not None else step_class.__name__ if registration_name in cls._registry: @@ -88,7 +92,7 @@ class ProcessorStepRegistry: ) cls._registry[registration_name] = step_class - # Store the registration name on the class for later reference + # Store the registration name on the class for easy lookup during serialization. step_class._registry_name = registration_name return step_class @@ -96,16 +100,16 @@ class ProcessorStepRegistry: @classmethod def get(cls, name: str) -> type: - """Get a registered processor step class by name. + """Retrieves a processor step class from the registry by its name. Args: - name: The registration name of the step. + name: The name of the step to retrieve. Returns: - The registered step class. + The processor step class corresponding to the given name. Raises: - KeyError: If the step is not registered. + KeyError: If the name is not found in the registry. """ if name not in cls._registry: available = list(cls._registry.keys()) @@ -118,310 +122,231 @@ class ProcessorStepRegistry: @classmethod def unregister(cls, name: str) -> None: - """Remove a step from the registry.""" + """Removes a processor step from the registry. + + Args: + name: The name of the step to unregister. + """ cls._registry.pop(name, None) @classmethod def list(cls) -> list[str]: - """List all registered step names.""" + """Returns a list of all registered processor step names.""" return list(cls._registry.keys()) @classmethod def clear(cls) -> None: - """Clear all registrations.""" + """Clears all processor steps from the registry.""" cls._registry.clear() -class ProcessorStep(Protocol): - """Structural typing interface for a single processor step. +class ProcessorStep(ABC): + """Abstract base class for a single step in a data processing pipeline. - A step is any callable accepting a full `EnvTransition` dict and - returning a (possibly modified) dict of the same structure. Implementers - are encouraged—but not required—to expose the optional helper methods - listed below. When present, these hooks let `RobotProcessor` - automatically serialise the step's configuration and learnable state using - a safe-to-share JSON + SafeTensors format. + Each step must implement the `__call__` method to perform its transformation + on a data transition and the `transform_features` method to describe how it + alters the shape or type of data features. - - **Required**: - - ``__call__(transition: EnvTransition) -> EnvTransition`` - - ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]`` - - Optional helper protocol: - * ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable - configuration and state. YOU decide what to save here. This is where all - non-tensor state goes (e.g., name, counter, threshold, window_size). - The config dict will be passed to your class constructor when loading. - * ``state_dict() -> dict[str, torch.Tensor]`` – PyTorch tensor state ONLY. - This is exclusively for torch.Tensor objects (e.g., learned weights, - running statistics as tensors). Never put simple Python types here. - * ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict - containing torch tensors only. - * ``reset()`` – Clear internal buffers at episode boundaries. - - Example separation: - - get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10} - - state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)} + Subclasses can optionally be stateful by implementing `state_dict` and `load_state_dict`. """ - def __call__(self, transition: EnvTransition) -> EnvTransition: ... + _current_transition: EnvTransition | None = None - def get_config(self) -> dict[str, Any]: ... + @property + def transition(self) -> EnvTransition: + """Provides access to the most recent transition being processed. - def state_dict(self) -> dict[str, torch.Tensor]: ... + This is useful for steps that need to access other parts of the transition + data beyond their primary target (e.g., an action processing step that + needs to look at the observation). - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ... + Raises: + ValueError: If accessed before the step has been called with a transition. + """ + if self._current_transition is None: + raise ValueError("Transition is not set. Make sure to call the step with a transition first.") + return self._current_transition - def reset(self) -> None: ... + @abstractmethod + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Processes an environment transition. - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ... + This method should contain the core logic of the processing step. + + Args: + transition: The input data transition to be processed. + + Returns: + The processed transition. + """ + return transition + + def get_config(self) -> dict[str, Any]: + """Returns the configuration of the step for serialization. + + Returns: + A JSON-serializable dictionary of configuration parameters. + """ + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + """Returns the state of the step (e.g., learned parameters, running means). + + Returns: + A dictionary mapping state names to tensors. + """ + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Loads the step's state from a state dictionary. + + Args: + state: A dictionary of state tensors. + """ + return None + + def reset(self) -> None: + """Resets the internal state of the processor step, if any.""" + return None + + @abstractmethod + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Defines how this step modifies the description of pipeline features. + + This method is used to track changes in data shapes, dtypes, or modalities + as data flows through the pipeline, without needing to process actual data. + + Args: + features: A dictionary describing the input features for observations, actions, etc. + + Returns: + A dictionary describing the output features after this step's transformation. + """ + return features -def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401 - """Convert a *batch* dict coming from Learobot replay/dataset code into an - ``EnvTransition`` dictionary. +class ProcessorKwargs(TypedDict, total=False): + """A TypedDict for optional keyword arguments used in pipeline construction.""" - The function maps well known keys to the EnvTransition structure. Missing keys are - filled with sane defaults (``None`` or ``0.0``/``False``). - - Keys recognised (case-sensitive): - - * "observation.*" (keys starting with "observation." are grouped into observation dict) - * "action" - * "next.reward" - * "next.done" - * "next.truncated" - * "info" - - Additional keys are ignored so that existing dataloaders can carry extra - metadata without breaking the processor. - """ - - # Extract observation keys - observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} - observation = observation_keys if observation_keys else None - - # Extract padding and task keys for complementary data - pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} - task_key = {"task": batch["task"]} if "task" in batch else {} - complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {} - - transition: EnvTransition = { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: batch.get("action"), - TransitionKey.REWARD: batch.get("next.reward", 0.0), - TransitionKey.DONE: batch.get("next.done", False), - TransitionKey.TRUNCATED: batch.get("next.truncated", False), - TransitionKey.INFO: batch.get("info", {}), - TransitionKey.COMPLEMENTARY_DATA: complementary_data, - } - return transition + to_transition: Callable[[dict[str, Any]], EnvTransition] | None + to_output: Callable[[EnvTransition], Any] | None + name: str | None + before_step_hooks: list[Callable[[int, EnvTransition], None]] | None + after_step_hooks: list[Callable[[int, EnvTransition], None]] | None -def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401 - """Inverse of :pyfunc:`_default_batch_to_transition`. Returns a dict with - the canonical field names used throughout *LeRobot*. - """ +class ProcessorMigrationError(Exception): + """Raised when a model needs migration to the processor format""" - batch = { - "action": transition.get(TransitionKey.ACTION), - "next.reward": transition.get(TransitionKey.REWARD, 0.0), - "next.done": transition.get(TransitionKey.DONE, False), - "next.truncated": transition.get(TransitionKey.TRUNCATED, False), - "info": transition.get(TransitionKey.INFO, {}), - } - - # Add padding and task data from complementary_data - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) - if complementary_data: - pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k} - batch.update(pad_data) - - if "task" in complementary_data: - batch["task"] = complementary_data["task"] - - # Handle observation - flatten dict to observation.* keys if it's a dict - observation = transition.get(TransitionKey.OBSERVATION) - if isinstance(observation, dict): - batch.update(observation) - - return batch + def __init__(self, model_path: str | Path, migration_command: str, original_error: str): + self.model_path = model_path + self.migration_command = migration_command + self.original_error = original_error + super().__init__( + f"Model '{model_path}' requires migration to processor format. " + f"Run: {migration_command}\n\nOriginal error: {original_error}" + ) @dataclass -class RobotProcessor(ModelHubMixin): - """ - Composable, debuggable post-processing processor for robot transitions. +class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]): + """A sequential pipeline for processing data, integrated with the Hugging Face Hub. - The class orchestrates an ordered collection of small, functional transforms—steps—executed - left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` dicts - and batch dictionaries, automatically converting between formats as needed. + This class chains together multiple `ProcessorStep` instances to form a complete + data processing workflow. It's generic, allowing for custom input and output types, + which are handled by the `to_transition` and `to_output` converters. - Args: - steps: Ordered list of processing steps executed on every call. Defaults to empty list. - name: Human-readable identifier that is persisted inside the JSON config. - Defaults to "RobotProcessor". - to_transition: Function to convert batch dict to EnvTransition dict. - Defaults to _default_batch_to_transition. - to_output: Function to convert EnvTransition dict to the desired output format. - Usually it is a batch dict or EnvTransition dict. - Defaults to _default_transition_to_batch. - before_step_hooks: List of hooks called before each step. Each hook receives the step - index and transition, and can optionally return a modified transition. - after_step_hooks: List of hooks called after each step. Each hook receives the step - index and transition, and can optionally return a modified transition. - - Hook Semantics: - - Hooks are executed sequentially in the order they were registered. There is no way to - reorder hooks after registration without creating a new pipeline. - - Hooks are for observation/monitoring only and DO NOT modify transitions. They are called - with the step index and current transition for logging, debugging, or monitoring purposes. - - All hooks for a given type (before/after) are executed for every step, or none at all if - an error occurs. There is no partial execution of hooks. - - Hooks should generally be stateless to maintain predictable behavior. If you need stateful - processing, consider implementing a proper ProcessorStep instead. - - To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline. - - Hooks ALWAYS receive transitions in EnvTransition format, regardless of the input format - passed to __call__. This ensures consistent hook behavior whether processing batch dicts - or EnvTransition objects. + Attributes: + steps: A sequence of `ProcessorStep` objects that make up the pipeline. + name: A descriptive name for the pipeline. + to_transition: A function to convert raw input data into the standardized `EnvTransition` format. + to_output: A function to convert the final `EnvTransition` into the desired output format. + before_step_hooks: A list of functions to be called before each step is executed. + after_step_hooks: A list of functions to be called after each step is executed. """ steps: Sequence[ProcessorStep] = field(default_factory=list) - name: str = "RobotProcessor" + name: str = "DataProcessorPipeline" - to_transition: Callable[[dict[str, Any]], EnvTransition] = field( - default_factory=lambda: _default_batch_to_transition, repr=False + to_transition: Callable[[TInput], EnvTransition] = field( + default_factory=lambda: cast(Callable[[TInput], EnvTransition], batch_to_transition), repr=False ) - to_output: Callable[[EnvTransition], dict[str, Any] | EnvTransition] = field( - default_factory=lambda: _default_transition_to_batch, repr=False + to_output: Callable[[EnvTransition], TOutput] = field( + default_factory=lambda: cast(Callable[[EnvTransition], TOutput], transition_to_batch), + repr=False, ) - # Processor-level hooks for observation/monitoring - # Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) - def __call__(self, data: EnvTransition | dict[str, Any]): - """Process data through all steps. - - The method accepts either the classic EnvTransition dict or a batch dictionary - (like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied - it is first converted to the internal dict format using to_transition; after all - steps are executed the dict is transformed back into a batch dict with to_batch and the - result is returned – thereby preserving the caller's original data type. + def __call__(self, data: TInput) -> TOutput: + """Processes input data through the full pipeline. Args: - data: Either an EnvTransition dict or a batch dictionary to process. + data: The input data to process. Returns: - The processed data in the same format as the input (EnvTransition or batch dict). - - Raises: - ValueError: If the transition is not a valid EnvTransition format. + The processed data in the specified output format. """ - # Check if we need to convert back to batch format at the end - _, called_with_batch = self._prepare_transition(data) + transition = self.to_transition(data) + transformed_transition = self._forward(transition) + return self.to_output(transformed_transition) - # Use step_through to get the iterator - step_iterator = self.step_through(data) + def _forward(self, transition: EnvTransition) -> EnvTransition: + """Executes all processing steps and hooks in sequence. - # Get initial state (before any steps) - current_transition = next(step_iterator) + Args: + transition: The initial `EnvTransition` object. - # Process each step with hooks - for idx, next_transition in enumerate(step_iterator): - # Apply before hooks with current state (before step execution) + Returns: + The final `EnvTransition` after all steps have been applied. + """ + for idx, processor_step in enumerate(self.steps): + # Execute pre-hooks for hook in self.before_step_hooks: - hook(idx, current_transition) + hook(idx, transition) - # Move to next state (after step execution) - current_transition = next_transition + transition = processor_step(transition) - # Apply after hooks with updated state + # Execute post-hooks for hook in self.after_step_hooks: - hook(idx, current_transition) + hook(idx, transition) + return transition - # Convert back to original format if needed - return self.to_output(current_transition) if called_with_batch else current_transition + def step_through(self, data: TInput) -> Iterable[EnvTransition]: + """Processes data step-by-step, yielding the transition at each stage. - def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]: - """Prepare and validate transition data for processing. + This is a generator method useful for debugging and inspecting the intermediate + state of the data as it passes through the pipeline. Args: - data: Either an EnvTransition dict or a batch dictionary to process. - - Returns: - A tuple of (prepared_transition, called_with_batch_flag) - - Raises: - ValueError: If the transition is not a valid EnvTransition format. - """ - # Check if data is already an EnvTransition or needs conversion - if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()): - # It's a batch dict, convert it - called_with_batch = True - transition = self.to_transition(data) - else: - # It's already an EnvTransition - called_with_batch = False - transition = data - - # Basic validation - if not isinstance(transition, dict): - raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}") - - return transition, called_with_batch - - def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition]: - """Yield the intermediate results after each processor step. - - This is a low-level method that does NOT apply hooks. It simply executes each step - and yields the intermediate results. This allows users to debug the pipeline or - apply custom logic between steps if needed. - - Note: This method always yields EnvTransition objects regardless of input format. - If you need the results in the original input format, you'll need to convert them - using `to_output()`. - - Args: - data: Either an EnvTransition dict or a batch dictionary to process. + data: The input data. Yields: - The intermediate EnvTransition results after each step. + The `EnvTransition` object, starting with the initial state and then after + each processing step. """ - transition, _ = self._prepare_transition(data) + transition = self.to_transition(data) - # Yield initial state + # Yield the initial state before any processing. yield transition - # Process each step WITHOUT hooks (low-level method) for processor_step in self.steps: transition = processor_step(transition) yield transition def _save_pretrained(self, save_directory: Path, **kwargs): - """Internal save method for ModelHubMixin compatibility.""" - # Extract config_filename from kwargs if provided - config_filename = kwargs.pop("config_filename", None) - self.save_pretrained(save_directory, config_filename=config_filename) + """Internal method to comply with `HubMixin`'s saving mechanism. - def save_pretrained(self, save_directory: str | Path, config_filename: str | None = None, **kwargs): - """Serialize the processor definition and parameters to *save_directory*. - - Args: - save_directory: Directory where the processor will be saved. - config_filename: Optional custom config filename. If not provided, defaults to - "{self.name}.json" where self.name is sanitized for filesystem compatibility. + This method does the actual saving work and is called by HubMixin.save_pretrained. """ - os.makedirs(str(save_directory), exist_ok=True) + config_filename = kwargs.pop("config_filename", None) - # Sanitize processor name for use in filenames - import re - - # The huggingface hub does not allow special characters in the repo name, so we sanitize the name + # Sanitize the pipeline name to create a valid filename prefix. sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) - # Use sanitized name for config if not provided if config_filename is None: config_filename = f"{sanitized_name}.json" @@ -430,40 +355,31 @@ class RobotProcessor(ModelHubMixin): "steps": [], } + # Iterate through each step to build its configuration entry. for step_index, processor_step in enumerate(self.steps): - # Check if step was registered registry_name = getattr(processor_step.__class__, "_registry_name", None) step_entry: dict[str, Any] = {} + # Prefer registry name for portability, otherwise fall back to full class path. if registry_name: - # Use registry name for registered steps step_entry["registry_name"] = registry_name else: - # Fall back to full module path for unregistered steps step_entry["class"] = ( f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}" ) + # Save step configuration if `get_config` is implemented. if hasattr(processor_step, "get_config"): step_entry["config"] = processor_step.get_config() + # Save step state if `state_dict` is implemented and returns a non-empty dict. if hasattr(processor_step, "state_dict"): state = processor_step.state_dict() if state: - # Clone tensors to avoid shared memory issues - # This ensures each tensor has its own memory allocation - # The reason is to avoid the following error: - # RuntimeError: Some tensors share memory, this will lead to duplicate memory on disk - # and potential differences when loading them again - # ------------------------------------------------------------------------------ - # Since the state_dict of processor will be light, we can just clone the tensors - # and save them to the disk. - cloned_state = {} - for key, tensor in state.items(): - cloned_state[key] = tensor.clone() + # Clone tensors to avoid modifying the original state. + cloned_state = {key: tensor.clone() for key, tensor in state.items()} - # Include pipeline name and step index to ensure unique filenames - # This prevents conflicts when multiple processors are saved in the same directory + # Create a unique filename for the state file. if registry_name: state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors" else: @@ -474,13 +390,69 @@ class RobotProcessor(ModelHubMixin): config["steps"].append(step_entry) + # Write the main configuration JSON file. with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer: json.dump(config, file_pointer, indent=2) + def save_pretrained( + self, + save_directory: str | Path | None = None, + *, + repo_id: str | None = None, + push_to_hub: bool = False, + card_kwargs: dict[str, Any] | None = None, + config_filename: str | None = None, + **push_to_hub_kwargs, + ): + """Saves the pipeline's configuration and state to a directory. + + This method creates a JSON configuration file that defines the pipeline's structure + (name and steps). For each stateful step, it also saves a `.safetensors` file + containing its state dictionary. + + Args: + save_directory: The directory where the pipeline will be saved. If None, saves to + HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}. + repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`. + push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it. + card_kwargs: Additional arguments passed to the card template to customize the card. + config_filename: The name of the JSON configuration file. If None, a name is + generated from the pipeline's `name` attribute. + **push_to_hub_kwargs: Additional key word arguments passed along to the push_to_hub method. + """ + if save_directory is None: + # Use default directory in HF_LEROBOT_HOME + from lerobot.constants import HF_LEROBOT_HOME + + sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) + save_directory = HF_LEROBOT_HOME / "processors" / sanitized_name + + # For direct saves (not through hub), handle config_filename + if not push_to_hub and config_filename is not None: + # Call _save_pretrained directly with config_filename + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + self._save_pretrained(save_directory, config_filename=config_filename) + return None + + # Pass config_filename through kwargs for _save_pretrained when using hub + if config_filename is not None: + push_to_hub_kwargs["config_filename"] = config_filename + + # Call parent's save_pretrained which will call our _save_pretrained + return super().save_pretrained( + save_directory=save_directory, + repo_id=repo_id, + push_to_hub=push_to_hub, + card_kwargs=card_kwargs, + **push_to_hub_kwargs, + ) + @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str | Path, + config_filename: str, *, force_download: bool = False, resume_download: bool | None = None, @@ -489,267 +461,798 @@ class RobotProcessor(ModelHubMixin): cache_dir: str | Path | None = None, local_files_only: bool = False, revision: str | None = None, - config_filename: str | None = None, overrides: dict[str, Any] | None = None, + to_transition: Callable[[TInput], EnvTransition] | None = None, + to_output: Callable[[EnvTransition], TOutput] | None = None, **kwargs, - ) -> RobotProcessor: - """Load a serialized processor from source (local path or Hugging Face Hub identifier). + ) -> DataProcessorPipeline[TInput, TOutput]: + """Loads a pipeline from a local directory, single file, or Hugging Face Hub repository. + + This method implements a simplified loading pipeline with intelligent migration detection: + + **Simplified Loading Strategy**: + 1. **Config Loading** (_load_config): + - **Directory**: Load specified config_filename from directory + - **Single file**: Load file directly (config_filename ignored) + - **Hub repository**: Download specified config_filename from Hub + + 2. **Config Validation** (_validate_loaded_config): + - Format validation: Ensure config is valid processor format + - Migration detection: Guide users to migrate old LeRobot models + - Clear errors: Provide actionable error messages + + 3. **Step Construction** (_build_steps_with_overrides): + - Class resolution: Registry lookup or dynamic imports + - Override merging: User parameters override saved config + - State loading: Load .safetensors files for stateful steps + + 4. **Override Validation** (_validate_overrides_used): + - Ensure all user overrides were applied (catch typos) + - Provide helpful error messages with available keys + + **Migration Detection**: + - **Smart detection**: Analyzes JSON files to detect old LeRobot models + - **Precise targeting**: Avoids false positives on other HuggingFace models + - **Clear guidance**: Provides exact migration command to run + - **Error mode**: Always raises ProcessorMigrationError for clear user action + + **Loading Examples**: + ```python + # Directory loading + pipeline = DataProcessorPipeline.from_pretrained("/models/my_model", config_filename="processor.json") + + # Single file loading + pipeline = DataProcessorPipeline.from_pretrained( + "/models/my_model/processor.json", config_filename="processor.json" + ) + + # Hub loading + pipeline = DataProcessorPipeline.from_pretrained("user/repo", config_filename="processor.json") + + # Multiple configs (preprocessor/postprocessor) + preprocessor = DataProcessorPipeline.from_pretrained( + "model", config_filename="policy_preprocessor.json" + ) + postprocessor = DataProcessorPipeline.from_pretrained( + "model", config_filename="policy_postprocessor.json" + ) + ``` + + **Override System**: + - **Key matching**: Use registry names or class names as override keys + - **Config merging**: User overrides take precedence over saved config + - **Validation**: Ensure all override keys match actual steps (catch typos) + - **Example**: overrides={"NormalizeStep": {"device": "cuda"}} Args: - pretrained_model_name_or_path: Local path to a saved processor directory or Hugging Face Hub identifier - (e.g., "username/processor-name"). - config_filename: Optional specific config filename to load. If not provided, will: - - For local paths: look for any .json file in the directory (error if multiple found) - - For HF Hub: try common names ("processor.json", "preprocessor.json", "postprocessor.json") - overrides: Optional dictionary mapping step names to configuration overrides. - Keys must match exact step class names (for unregistered steps) or registry names - (for registered steps). Values are dictionaries containing parameter overrides - that will be merged with the saved configuration. This is useful for providing - non-serializable objects like environment instances. + pretrained_model_name_or_path: The identifier of the repository on the Hugging Face Hub, + a path to a local directory, or a path to a single config file. + config_filename: The name of the pipeline's JSON configuration file. Always required + to prevent ambiguity when multiple configs exist (e.g., preprocessor vs postprocessor). + force_download: Whether to force (re)downloading the files. + resume_download: Whether to resume a previously interrupted download. + proxies: A dictionary of proxy servers to use. + token: The token to use as HTTP bearer authorization for private Hub repositories. + cache_dir: The path to a specific cache folder to store downloaded files. + local_files_only: If True, avoid downloading files from the Hub. + revision: The specific model version to use (e.g., a branch name, tag name, or commit id). + overrides: A dictionary to override the configuration of specific steps. Keys should + match the step's class name or registry name. + to_transition: A custom function to convert input data to `EnvTransition`. + to_output: A custom function to convert the final `EnvTransition` to the output format. + **kwargs: Additional arguments (not used). Returns: - A RobotProcessor instance loaded from the saved configuration. + An instance of `DataProcessorPipeline` loaded with the specified configuration and state. Raises: - ImportError: If a processor step class cannot be loaded or imported. - ValueError: If a step cannot be instantiated with the provided configuration. - KeyError: If an override key doesn't match any step in the saved configuration. - - Examples: - Basic loading: - ```python - processor = RobotProcessor.from_pretrained("path/to/processor") - ``` - - Loading specific config file: - ```python - processor = RobotProcessor.from_pretrained( - "username/multi-processor-repo", config_filename="preprocessor.json" - ) - ``` - - Loading with overrides for non-serializable objects: - ```python - import gym - - env = gym.make("CartPole-v1") - processor = RobotProcessor.from_pretrained( - "username/cartpole-processor", overrides={"ActionRepeatStep": {"env": env}} - ) - ``` - - Multiple overrides: - ```python - processor = RobotProcessor.from_pretrained( - "path/to/processor", - overrides={ - "CustomStep": {"param1": "new_value"}, - "device_processor": {"device": "cuda:1"}, # For registered steps - }, - ) - ``` + FileNotFoundError: If the config file cannot be found. + ValueError: If configuration is ambiguous or instantiation fails. + ImportError: If a step's class cannot be imported. + KeyError: If an override key doesn't match any step in the pipeline. + ProcessorMigrationError: If the model requires migration to processor format. """ - # Use the local variable name 'source' for clarity - source = str(pretrained_model_name_or_path) + model_id = str(pretrained_model_name_or_path) + hub_download_kwargs = { + "force_download": force_download, + "resume_download": resume_download, + "proxies": proxies, + "token": token, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + "revision": revision, + } - if Path(source).is_dir(): - # Local path - use it directly - base_path = Path(source) + # 1. Load configuration using simplified 3-way logic + loaded_config, base_path = cls._load_config(model_id, config_filename, hub_download_kwargs) - if config_filename is None: - # Look for any .json file in the directory - json_files = list(base_path.glob("*.json")) - if len(json_files) == 0: - raise FileNotFoundError(f"No .json configuration files found in {source}") - elif len(json_files) > 1: - raise ValueError( - f"Multiple .json files found in {source}: {[f.name for f in json_files]}. " - f"Please specify which one to load using the config_filename parameter." - ) - config_filename = json_files[0].name + # 2. Validate configuration and handle migration + cls._validate_loaded_config(model_id, loaded_config, config_filename) - with open(base_path / config_filename) as file_pointer: - loaded_config: dict[str, Any] = json.load(file_pointer) - else: - # Hugging Face Hub - download all required files - if config_filename is None: - # Try common config names - common_names = [ - "processor.json", - "preprocessor.json", - "postprocessor.json", - "robotprocessor.json", - ] - config_path = None - for name in common_names: - try: - config_path = hf_hub_download( - source, - name, - repo_type="model", - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, - ) - config_filename = name - break - except (FileNotFoundError, OSError, HfHubHTTPError): - # FileNotFoundError: local file issues - # OSError: network/system errors - # HfHubHTTPError: file not found on Hub (404) or other HTTP errors - continue + # 3. Build steps with overrides + steps, validated_overrides = cls._build_steps_with_overrides( + loaded_config, overrides or {}, model_id, base_path, hub_download_kwargs + ) - if config_path is None: - raise FileNotFoundError( - f"No processor configuration file found in {source}. " - f"Tried: {common_names}. Please specify the config_filename parameter." - ) - else: - # Download specific config file - config_path = hf_hub_download( - source, - config_filename, - repo_type="model", - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, + # 4. Validate that all overrides were used + cls._validate_overrides_used(validated_overrides, loaded_config) + + # 5. Construct and return the final pipeline instance + return cls( + steps=steps, + name=loaded_config.get("name", "DataProcessorPipeline"), + to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition), + to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch), + ) + + @classmethod + def _load_config( + cls, + model_id: str, + config_filename: str, + hub_download_kwargs: dict[str, Any], + ) -> tuple[dict[str, Any], Path]: + """Load configuration from local file or Hugging Face Hub. + + This method implements a super-simplified 3-way loading strategy: + + 1. **Local directory**: Load config_filename from directory + - Example: model_id="/models/my_model", config_filename="processor.json" + - Loads: "/models/my_model/processor.json" + + 2. **Single file**: Load file directly (ignore config_filename) + - Example: model_id="/models/my_model/processor.json" + - Loads: "/models/my_model/processor.json" (config_filename ignored) + + 3. **Hub repository**: Download config_filename from Hub + - Example: model_id="user/repo", config_filename="processor.json" + - Downloads and loads: config_filename from Hub repo + + **Benefits of Explicit config_filename**: + - No auto-detection complexity or edge cases + - No risk of loading wrong config (preprocessor vs postprocessor) + - Consistent behavior across local and Hub usage + - Clear, predictable errors + + Args: + model_id: The model identifier (Hub repo ID, local directory, or file path) + config_filename: The explicit config filename to load (always required) + hub_download_kwargs: Parameters for hf_hub_download (tokens, cache, etc.) + + Returns: + Tuple of (loaded_config, base_path) + - loaded_config: Parsed JSON config dict (always loaded, never None) + - base_path: Directory containing config file (for state file resolution) + + Raises: + FileNotFoundError: If config file cannot be found locally or on Hub + """ + model_path = Path(model_id) + + if model_path.is_dir(): + # Directory: load specified config from directory + config_path = model_path / config_filename + if not config_path.exists(): + # Check for migration before giving clear error + if cls._should_suggest_migration(model_path): + cls._suggest_processor_migration(model_id, f"Config file '{config_filename}' not found") + raise FileNotFoundError( + f"Config file '{config_filename}' not found in directory '{model_id}'" ) - with open(config_path) as file_pointer: - loaded_config = json.load(file_pointer) + with open(config_path) as f: + return json.load(f), model_path - # Store downloaded files in the same directory as the config - base_path = Path(config_path).parent + elif model_path.is_file(): + # File: load file directly (config_filename is ignored for single files) + with open(model_path) as f: + return json.load(f), model_path.parent - # Handle None overrides - if overrides is None: - overrides = {} - - # Validate that all override keys will be matched - override_keys = set(overrides.keys()) - - steps: list[ProcessorStep] = [] - for step_entry in loaded_config["steps"]: - # Check if step uses registry name or module path - if "registry_name" in step_entry: - # Load from registry - try: - step_class = ProcessorStepRegistry.get(step_entry["registry_name"]) - step_key = step_entry["registry_name"] - except KeyError as e: - raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e - else: - # Fall back to module path loading for backward compatibility - full_class_path = step_entry["class"] - module_path, class_name = full_class_path.rsplit(".", 1) - - # Import the module containing the step class - try: - module = importlib.import_module(module_path) - step_class = getattr(module, class_name) - step_key = class_name - except (ImportError, AttributeError) as e: - raise ImportError( - f"Failed to load processor step '{full_class_path}'. " - f"Make sure the module '{module_path}' is installed and contains class '{class_name}'. " - f"Consider registering the step using @ProcessorStepRegistry.register() for better portability. " - f"Error: {str(e)}" - ) from e - - # Instantiate the step with its config + else: + # Hub: download specified config try: - saved_cfg = step_entry.get("config", {}) - step_overrides = overrides.get(step_key, {}) - merged_cfg = {**saved_cfg, **step_overrides} - step_instance: ProcessorStep = step_class(**merged_cfg) + config_path = hf_hub_download( + repo_id=model_id, + filename=config_filename, + repo_type="model", + **hub_download_kwargs, + ) - # Track which override keys were used - if step_key in override_keys: - override_keys.discard(step_key) + with open(config_path) as f: + return json.load(f), Path(config_path).parent except Exception as e: - step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown")) - raise ValueError( - f"Failed to instantiate processor step '{step_name}' with config: {step_entry.get('config', {})}. " - f"Error: {str(e)}" + raise FileNotFoundError( + f"Could not find '{config_filename}' on the HuggingFace Hub at '{model_id}'" ) from e - # Load state if available - if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"): - if Path(source).is_dir(): - # Local path - read directly - state_path = str(base_path / step_entry["state_file"]) - else: - # Hugging Face Hub - download the state file - state_path = hf_hub_download( - source, - step_entry["state_file"], - repo_type="model", - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, - ) + @classmethod + def _validate_loaded_config( + cls, model_id: str, loaded_config: dict[str, Any], config_filename: str + ) -> None: + """Validate that a config was loaded and is a valid processor config. - step_instance.load_state_dict(load_file(state_path)) + This method validates processor config format with intelligent migration detection: + + **Config Format Validation**: + - Use _is_processor_config() to validate structure + - Must have "steps" field with list of step configurations + - Each step needs "class" or "registry_name" + - If validation fails AND local directory: Check for migration need + - If migration needed: Raise ProcessorMigrationError with command + - If no migration: Raise ValueError with helpful error message + + **Migration Detection Logic**: + - Only triggered for local directories (not Hub repos) + - Analyzes all JSON files in directory to detect old LeRobot models + - Provides exact migration command with model path + + Args: + model_id: The model identifier (used for migration detection) + loaded_config: The loaded config dictionary (guaranteed non-None) + config_filename: The config filename that was loaded (for error messages) + + Raises: + ValueError: If config format is invalid + ProcessorMigrationError: If model needs migration to processor format + """ + # Validate that this is actually a processor config + if not cls._is_processor_config(loaded_config): + if Path(model_id).is_dir() and cls._should_suggest_migration(Path(model_id)): + cls._suggest_processor_migration( + model_id, + f"Config file '{config_filename}' is not a valid processor configuration", + ) + raise ValueError( + f"Config file '{config_filename}' is not a valid processor configuration. " + f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}" + ) + + @classmethod + def _build_steps_with_overrides( + cls, + loaded_config: dict[str, Any], + overrides: dict[str, Any], + model_id: str, + base_path: Path | None, + hub_download_kwargs: dict[str, Any], + ) -> tuple[list[ProcessorStep], set[str]]: + """Build all processor steps with overrides and state loading. + + This method orchestrates the complete step construction pipeline: + + **For each step in loaded_config["steps"]**: + + 1. **Class Resolution** (via _resolve_step_class): + - **If "registry_name" exists**: Look up in ProcessorStepRegistry + Example: {"registry_name": "normalize_step"} -> Get registered class + - **Else use "class" field**: Dynamic import from full module path + Example: {"class": "lerobot.processor.normalize.NormalizeStep"} + - **Result**: (step_class, step_key) where step_key is used for overrides + + 2. **Step Instantiation** (via _instantiate_step): + - **Merge configs**: saved_config + user_overrides + - **Override priority**: User overrides take precedence over saved config + - **Example**: saved={"mean": 0.0}, override={"mean": 1.0} -> final={"mean": 1.0} + - **Result**: Instantiated ProcessorStep object + + 3. **State Loading** (via _load_step_state): + - **If step has "state_file"**: Load tensor state from .safetensors + - **Local first**: Check base_path/state_file.safetensors + - **Hub fallback**: Download state file if not found locally + - **Optional**: Only load if step has load_state_dict method + + 4. **Override Tracking**: + - **Track used overrides**: Remove step_key from remaining set + - **Purpose**: Validate all user overrides were applied (detect typos) + + **Error Handling**: + - Class resolution errors -> ImportError with helpful message + - Instantiation errors -> ValueError with config details + - State loading errors -> Propagated from load_state_dict + + Args: + loaded_config: The loaded processor configuration (must have "steps" field) + overrides: User-provided parameter overrides (keyed by class/registry name) + model_id: The model identifier (needed for Hub state file downloads) + base_path: Local directory path for finding state files + hub_download_kwargs: Parameters for hf_hub_download (tokens, cache, etc.) + + Returns: + Tuple of (instantiated_steps_list, unused_override_keys) + - instantiated_steps_list: List of ready-to-use ProcessorStep instances + - unused_override_keys: Override keys that didn't match any step (for validation) + + Raises: + ImportError: If a step class cannot be imported or found in registry + ValueError: If a step cannot be instantiated with its configuration + """ + steps: list[ProcessorStep] = [] + override_keys = set(overrides.keys()) + + for step_entry in loaded_config["steps"]: + # 1. Get step class and key + step_class, step_key = cls._resolve_step_class(step_entry) + + # 2. Instantiate step with overrides + step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides) + + # 3. Load step state if available + cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs) + + # 4. Track used overrides + if step_key in override_keys: + override_keys.discard(step_key) steps.append(step_instance) - # Check for unused override keys - if override_keys: - available_keys = [] - for step_entry in loaded_config["steps"]: - if "registry_name" in step_entry: - available_keys.append(step_entry["registry_name"]) - else: - full_class_path = step_entry["class"] - class_name = full_class_path.rsplit(".", 1)[1] - available_keys.append(class_name) + return steps, override_keys - raise KeyError( - f"Override keys {list(override_keys)} do not match any step in the saved configuration. " - f"Available step keys: {available_keys}. " - f"Make sure override keys match exact step class names or registry names." + @classmethod + def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]: + """Resolve step class from registry or import path. + + This method implements a two-tier resolution strategy: + + **Tier 1: Registry-based resolution** (preferred): + - **If "registry_name" in step_entry**: Look up in ProcessorStepRegistry + - **Advantage**: Faster, no imports needed, guaranteed compatibility + - **Example**: {"registry_name": "normalize_step"} -> Get pre-registered class + - **Error**: KeyError if registry_name not found -> Convert to ImportError + + **Tier 2: Dynamic import fallback**: + - **Else use "class" field**: Full module.ClassName import path + - **Process**: Split "module.path.ClassName" into module + class parts + - **Import**: Use importlib.import_module() + getattr() + - **Example**: "lerobot.processor.normalize.NormalizeStep" + a. Import module: "lerobot.processor.normalize" + b. Get class: getattr(module, "NormalizeStep") + - **step_key**: Use class_name ("NormalizeStep") for overrides + + **Override Key Strategy**: + - Registry steps: Use registry_name ("normalize_step") + - Import steps: Use class_name ("NormalizeStep") + - This allows users to override with: {"normalize_step": {...}} or {"NormalizeStep": {...}} + + **Error Handling**: + - Registry KeyError -> ImportError with registry context + - Import/Attribute errors -> ImportError with helpful suggestions + - All errors include troubleshooting guidance + + Args: + step_entry: The step configuration dictionary (must have "registry_name" or "class") + + Returns: + Tuple of (step_class, step_key) + - step_class: The resolved ProcessorStep class (ready for instantiation) + - step_key: The key used for user overrides (registry_name or class_name) + + Raises: + ImportError: If step class cannot be loaded from registry or import path + """ + if "registry_name" in step_entry: + try: + step_class = ProcessorStepRegistry.get(step_entry["registry_name"]) + return step_class, step_entry["registry_name"] + except KeyError as e: + raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e + else: + # Fallback to dynamic import using the full class path + full_class_path = step_entry["class"] + module_path, class_name = full_class_path.rsplit(".", 1) + + try: + module = importlib.import_module(module_path) + step_class = getattr(module, class_name) + return step_class, class_name + except (ImportError, AttributeError) as e: + raise ImportError( + f"Failed to load processor step '{full_class_path}'. " + f"Make sure the module '{module_path}' is installed and contains class '{class_name}'. " + f"Consider registering the step using @ProcessorStepRegistry.register() for better portability. " + f"Error: {str(e)}" + ) from e + + @classmethod + def _instantiate_step( + cls, + step_entry: dict[str, Any], + step_class: type[ProcessorStep], + step_key: str, + overrides: dict[str, Any], + ) -> ProcessorStep: + """Instantiate a single processor step with config overrides. + + This method handles the configuration merging and instantiation logic: + + **Configuration Merging Strategy**: + 1. **Extract saved config**: Get step_entry.get("config", {}) from saved pipeline + - Example: {"config": {"mean": 0.0, "std": 1.0}} + 2. **Extract user overrides**: Get overrides.get(step_key, {}) for this step + - Example: overrides = {"NormalizeStep": {"mean": 2.0, "device": "cuda"}} + 3. **Merge with priority**: {**saved_cfg, **step_overrides} + - **Override priority**: User values override saved values + - **Result**: {"mean": 2.0, "std": 1.0, "device": "cuda"} + + **Instantiation Process**: + - **Call constructor**: step_class(**merged_cfg) + - **Example**: NormalizeStep(mean=2.0, std=1.0, device="cuda") + + **Error Handling**: + - **Any exception during instantiation**: Convert to ValueError + - **Include context**: step name, attempted config, original error + - **Purpose**: Help users debug configuration issues + - **Common causes**: + a. Invalid parameter types (str instead of float) + b. Missing required parameters + c. Incompatible parameter combinations + + Args: + step_entry: The step configuration from saved config (contains "config" dict) + step_class: The step class to instantiate (already resolved) + step_key: The key used for overrides ("registry_name" or class name) + overrides: User-provided parameter overrides (keyed by step_key) + + Returns: + The instantiated processor step (ready for use) + + Raises: + ValueError: If step cannot be instantiated, with detailed error context + """ + try: + saved_cfg = step_entry.get("config", {}) + step_overrides = overrides.get(step_key, {}) + merged_cfg = {**saved_cfg, **step_overrides} + return step_class(**merged_cfg) + except Exception as e: + step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown")) + raise ValueError( + f"Failed to instantiate processor step '{step_name}' with config: {step_entry.get('config', {})}. " + f"Error: {str(e)}" + ) from e + + @classmethod + def _load_step_state( + cls, + step_instance: ProcessorStep, + step_entry: dict[str, Any], + model_id: str, + base_path: Path | None, + hub_download_kwargs: dict[str, Any], + ) -> None: + """Load state dictionary for a processor step if available. + + This method implements conditional state loading with local/Hub fallback: + + **Precondition Checks** (early return if not met): + 1. **"state_file" in step_entry**: Step config specifies a state file + - **If missing**: Step has no saved state (e.g., stateless transforms) + 2. **hasattr(step_instance, "load_state_dict")**: Step supports state loading + - **If missing**: Step doesn't implement state loading (rare) + + **State File Resolution Strategy**: + 1. **Local file priority**: Check base_path/state_filename exists + - **Advantage**: Faster, no network calls + - **Example**: "/models/my_model/normalize_step_0.safetensors" + - **Use case**: Loading from local saved model directory + + 2. **Hub download fallback**: Download state file from repository + - **When triggered**: Local file not found or base_path is None + - **Process**: Use hf_hub_download with same parameters as config + - **Example**: Download "normalize_step_0.safetensors" from "user/repo" + - **Result**: Downloaded to local cache, path returned + + **State Loading Process**: + - **Load tensors**: Use safetensors.torch.load_file() + - **Apply to step**: Call step_instance.load_state_dict(tensor_dict) + - **In-place modification**: Updates step's internal tensor state + + **Common state file examples**: + - "normalize_step_0.safetensors" - normalization statistics + - "custom_step_1.safetensors" - learned parameters + - "tokenizer_step_2.safetensors" - vocabulary embeddings + + Args: + step_instance: The step instance to load state into (must have load_state_dict) + step_entry: The step configuration dictionary (may contain "state_file") + model_id: The model identifier (used for Hub downloads if needed) + base_path: Local directory path for finding state files (None for Hub-only) + hub_download_kwargs: Parameters for hf_hub_download (tokens, cache, etc.) + + Note: + This method modifies step_instance in-place and returns None. + If state loading fails, exceptions from load_state_dict propagate. + """ + if "state_file" not in step_entry or not hasattr(step_instance, "load_state_dict"): + return + + state_filename = step_entry["state_file"] + + # Try local file first + if base_path and (base_path / state_filename).exists(): + state_path = str(base_path / state_filename) + else: + # Download from Hub + state_path = hf_hub_download( + repo_id=model_id, + filename=state_filename, + repo_type="model", + **hub_download_kwargs, ) - return cls(steps, loaded_config.get("name", "RobotProcessor")) + step_instance.load_state_dict(load_file(state_path)) + + @classmethod + def _validate_overrides_used( + cls, remaining_override_keys: set[str], loaded_config: dict[str, Any] + ) -> None: + """Validate that all provided overrides were used. + + This method ensures user overrides are valid to catch typos and configuration errors: + + **Validation Logic**: + 1. **If remaining_override_keys is empty**: All overrides were used -> Success + - **Early return**: No validation needed + - **Normal case**: User provided correct override keys + + 2. **If remaining_override_keys has entries**: Some overrides unused -> Error + - **Root cause**: User provided keys that don't match any step + - **Common issues**: + a. Typos in step names ("NormalizStep" vs "NormalizeStep") + b. Using wrong key type (class name vs registry name) + c. Step doesn't exist in saved pipeline + + **Helpful Error Generation**: + - **Extract available keys**: Build list of valid override keys from config + a. **Registry steps**: Use "registry_name" directly + b. **Import steps**: Extract class name from "class" field + - Example: "lerobot.processor.normalize.NormalizeStep" -> "NormalizeStep" + - **Error message includes**: + a. Invalid keys provided by user + b. List of valid keys they can use + c. Guidance about registry vs class names + + **Override Key Resolution Rules**: + - Steps with "registry_name": Use registry_name for overrides + - Steps with "class": Use final class name for overrides + - Users must match these exact keys in their overrides dict + + Args: + remaining_override_keys: Override keys that weren't matched to any step + loaded_config: The loaded processor configuration (contains "steps" list) + + Raises: + KeyError: If any override keys were not used, with helpful error message + """ + if not remaining_override_keys: + return + + available_keys = [ + step.get("registry_name") or step["class"].rsplit(".", 1)[1] for step in loaded_config["steps"] + ] + + raise KeyError( + f"Override keys {list(remaining_override_keys)} do not match any step in the saved configuration. " + f"Available step keys: {available_keys}. " + f"Make sure override keys match exact step class names or registry names." + ) + + @classmethod + def _should_suggest_migration(cls, model_path: Path) -> bool: + """Check if directory has JSON files but no processor configs. + + This method implements smart migration detection to avoid false positives: + + **Decision Logic**: + 1. **No JSON files found**: Return False + - **Reason**: Empty directory or only non-config files + - **Example**: Directory with only .safetensors, .md files + - **Action**: No migration needed + + 2. **JSON files exist**: Analyze each file + - **Goal**: Determine if ANY file is a valid processor config + - **Process**: + a. Try to parse each .json file + b. Skip files with JSON parse errors (malformed) + c. Check if parsed config passes _is_processor_config() + - **If ANY valid processor found**: Return False (no migration) + - **If NO valid processors found**: Return True (migration needed) + + **Examples**: + - **No migration**: ["processor.json", "config.json"] where processor.json is valid + - **Migration needed**: ["config.json", "train.json"] where both are model configs + - **No migration**: [] (empty directory) + - **Migration needed**: ["old_model_config.json"] with old LeRobot format + + **Why this works**: + - **Precise detection**: Only suggests migration for actual old LeRobot models + - **Avoids false positives**: Won't trigger on other HuggingFace model types + - **Graceful handling**: Ignores malformed JSON files + + Args: + model_path: Path to local directory to analyze + + Returns: + True if directory has JSON configs but none are processor configs (migration needed) + False if no JSON files or at least one valid processor config exists + """ + json_files = list(model_path.glob("*.json")) + if len(json_files) == 0: + return False + + # Check if any JSON file is a processor config + for json_file in json_files: + try: + with open(json_file) as f: + config = json.load(f) + + if cls._is_processor_config(config): + return False # Found at least one processor config, no migration needed + + except (json.JSONDecodeError, OSError): + # Skip files that can't be parsed as JSON + continue + + # Have JSON files but no processor configs - suggest migration + return True + + @classmethod + def _is_processor_config(cls, config: dict) -> bool: + """Check if config follows DataProcessorPipeline format. + + This method validates the processor configuration structure: + + **Required Structure Validation**: + 1. **"steps" field existence**: Must have top-level "steps" key + - **If missing**: Not a processor config (e.g., model config, train config) + - **Example invalid**: {"type": "act", "hidden_dim": 256} + + 2. **"steps" field type**: Must be a list, not other types + - **If not list**: Invalid format + - **Example invalid**: {"steps": "some_string"} or {"steps": {"key": "value"}} + + 3. **Empty steps validation**: Empty list is valid + - **If len(steps) == 0**: Return True immediately + - **Use case**: Empty processor pipeline (no-op) + - **Example valid**: {"name": "EmptyProcessor", "steps": []} + + **Individual Step Validation** (for non-empty steps): + For each step in the steps list: + 1. **Step type**: Must be a dictionary + - **If not dict**: Invalid step format + - **Example invalid**: ["string_step", 123, true] + + 2. **Step identifier**: Must have either "class" OR "registry_name" + - **"registry_name"**: Registered step (preferred) + Example: {"registry_name": "normalize_step", "config": {...}} + - **"class"**: Full import path + Example: {"class": "lerobot.processor.normalize.NormalizeStep"} + - **If neither**: Invalid step (can't resolve class) + - **If both**: Also valid (registry_name takes precedence) + + **Valid Processor Config Examples**: + - {"steps": []} - Empty processor + - {"steps": [{"registry_name": "normalize"}]} - Registry step + - {"steps": [{"class": "my.module.Step"}]} - Import step + - {"name": "MyProcessor", "steps": [...]} - With name + + **Invalid Config Examples**: + - {"type": "act"} - Missing "steps" + - {"steps": "normalize"} - Steps not a list + - {"steps": [{}]} - Step missing class/registry_name + - {"steps": ["string"]} - Step not a dict + + Args: + config: The configuration dictionary to validate + + Returns: + True if config follows valid DataProcessorPipeline format, False otherwise + """ + # Must have a "steps" field with a list of step configurations + if not isinstance(config.get("steps"), list): + return False + + steps = config["steps"] + if len(steps) == 0: + return True # Empty processor is valid + + # Each step must be a dict with either "class" or "registry_name" + for step in steps: + if not isinstance(step, dict): + return False + if not ("class" in step or "registry_name" in step): + return False + + return True + + @classmethod + def _suggest_processor_migration(cls, model_path: str | Path, original_error: str) -> None: + """Raise migration error when we detect JSON files but no processor configs. + + This method is called when migration detection determines that a model + directory contains configuration files but none are valid processor configs. + This typically indicates an old LeRobot model that needs migration. + + **When this is called**: + - User tries to load DataProcessorPipeline from local directory + - Directory contains JSON configuration files + - None of the JSON files follow processor config format + - _should_suggest_migration() returned True + + **Migration Command Generation**: + - Constructs exact command user needs to run + - Uses the migration script: migrate_policy_normalization.py + - Includes the model path automatically + - Example: "python src/lerobot/processor/migrate_policy_normalization.py --pretrained-path /models/old_model" + + **Error Structure**: + - **Always raises**: ProcessorMigrationError (never returns) + - **Includes**: model_path, migration_command, original_error + - **Purpose**: Force user attention to migration need + - **User experience**: Clear actionable error with exact command to run + + **Migration Process**: + The suggested command will: + 1. Extract normalization stats from old model + 2. Create new processor configs (preprocessor + postprocessor) + 3. Remove normalization layers from model + 4. Save migrated model with processor pipeline + + Args: + model_path: Path to the model directory needing migration + original_error: The error that triggered migration detection (for context) + + Raises: + ProcessorMigrationError: Always raised (this method never returns normally) + """ + migration_command = ( + f"python src/lerobot/processor/migrate_policy_normalization.py --pretrained-path {model_path}" + ) + + raise ProcessorMigrationError(model_path, migration_command, original_error) def __len__(self) -> int: - """Return the number of steps in the processor.""" + """Returns the number of steps in the pipeline.""" return len(self.steps) - def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor: - """Indexing helper exposing underlying steps. - * ``int`` – returns the idx-th ProcessorStep. - * ``slice`` – returns a new RobotProcessor with the sliced steps. + def __getitem__(self, idx: int | slice) -> ProcessorStep | DataProcessorPipeline[TInput, TOutput]: + """Retrieves a step or a sub-pipeline by index or slice. + + Args: + idx: An integer index or a slice object. + + Returns: + A `ProcessorStep` if `idx` is an integer, or a new `DataProcessorPipeline` + containing the sliced steps. """ if isinstance(idx, slice): - return RobotProcessor(self.steps[idx], self.name) + # Return a new pipeline instance with the sliced steps. + return DataProcessorPipeline( + steps=self.steps[idx], + name=self.name, + to_transition=self.to_transition, + to_output=self.to_output, + before_step_hooks=self.before_step_hooks.copy(), + after_step_hooks=self.after_step_hooks.copy(), + ) return self.steps[idx] def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]): - """Attach fn to be executed before every processor step.""" + """Registers a function to be called before each step. + + Args: + fn: A callable that accepts the step index and the current transition. + """ self.before_step_hooks.append(fn) def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], None]): - """Remove a previously registered before_step hook. + """Unregisters a 'before_step' hook. Args: - fn: The exact function reference that was registered. Must be the same object. + fn: The exact function object that was previously registered. Raises: - ValueError: If the hook is not found in the registered hooks. + ValueError: If the hook is not found in the list. """ try: self.before_step_hooks.remove(fn) @@ -759,17 +1262,21 @@ class RobotProcessor(ModelHubMixin): ) from None def register_after_step_hook(self, fn: Callable[[int, EnvTransition], None]): - """Attach fn to be executed after every processor step.""" + """Registers a function to be called after each step. + + Args: + fn: A callable that accepts the step index and the current transition. + """ self.after_step_hooks.append(fn) def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], None]): - """Remove a previously registered after_step hook. + """Unregisters an 'after_step' hook. Args: - fn: The exact function reference that was registered. Must be the same object. + fn: The exact function object that was previously registered. Raises: - ValueError: If the hook is not found in the registered hooks. + ValueError: If the hook is not found in the list. """ try: self.after_step_hooks.remove(fn) @@ -779,13 +1286,13 @@ class RobotProcessor(ModelHubMixin): ) from None def reset(self): - """Clear state in every step that implements ``reset()`` and fire registered hooks.""" + """Resets the state of all stateful steps in the pipeline.""" for step in self.steps: if hasattr(step, "reset"): - step.reset() # type: ignore[attr-defined] + step.reset() def __repr__(self) -> str: - """Return a readable string representation of the processor.""" + """Provides a concise string representation of the pipeline.""" step_names = [step.__class__.__name__ for step in self.steps] if not step_names: @@ -793,472 +1300,417 @@ class RobotProcessor(ModelHubMixin): elif len(step_names) <= 3: steps_repr = f"steps={len(step_names)}: [{', '.join(step_names)}]" else: - # Show first 2 and last 1 with ellipsis for long lists + # For long pipelines, show the first, second, and last steps. displayed = f"{step_names[0]}, {step_names[1]}, ..., {step_names[-1]}" steps_repr = f"steps={len(step_names)}: [{displayed}]" parts = [f"name='{self.name}'", steps_repr] - return f"RobotProcessor({', '.join(parts)})" + return f"DataProcessorPipeline({', '.join(parts)})" def __post_init__(self): + """Validates that all provided steps are instances of `ProcessorStep`.""" for i, step in enumerate(self.steps): - if not callable(step): - raise TypeError( - f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition" - ) + if not isinstance(step, ProcessorStep): + raise TypeError(f"Step {i} ({type(step).__name__}) must inherit from ProcessorStep") - fc = getattr(step, "feature_contract", None) - if not callable(fc): - raise TypeError( - f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]" - ) + def transform_features( + self, initial_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Applies feature transformations from all steps sequentially. - def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + This method propagates a feature description dictionary through each step's + `transform_features` method, allowing the pipeline to statically determine + the output feature specification without processing any real data. + + Args: + initial_features: A dictionary describing the initial features. + + Returns: + The final feature description after all transformations. """ - Apply ALL steps in order. Each step must implement - feature_contract(features) and return a dict (full or incremental schema). - """ - features: dict[str, PolicyFeature] = deepcopy(initial_features) + features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = deepcopy(initial_features) for _, step in enumerate(self.steps): - out = step.feature_contract(features) - if not isinstance(out, dict): - raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]") + out = step.transform_features(features) features = out return features - -class ObservationProcessor: - """Base class for processors that modify only the observation component of a transition. - - Subclasses should override the `observation` method to implement custom observation processing. - This class handles the boilerplate of extracting and reinserting the processed observation - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class MyObservationScaler(ObservationProcessor): - def __init__(self, scale_factor): - self.scale_factor = scale_factor - - def observation(self, observation): - return observation * self.scale_factor - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific observation processing logic. - """ - - def observation(self, observation): - """Process the observation component. + # Convenience methods for processing individual parts of a transition. + def process_observation(self, observation: dict[str, Any]) -> dict[str, Any]: + """Processes only the observation part of a transition through the pipeline. Args: - observation: The observation to process + observation: The observation dictionary. Returns: - The processed observation + The processed observation dictionary. """ - return observation + transition: EnvTransition = create_transition(observation=observation) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.OBSERVATION] + + def process_action( + self, action: PolicyAction | RobotAction | EnvAction + ) -> PolicyAction | RobotAction | EnvAction: + """Processes only the action part of a transition through the pipeline. + + Args: + action: The action data. + + Returns: + The processed action. + """ + transition: EnvTransition = create_transition(action=action) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.ACTION] + + def process_reward(self, reward: float | torch.Tensor) -> float | torch.Tensor: + """Processes only the reward part of a transition through the pipeline. + + Args: + reward: The reward value. + + Returns: + The processed reward. + """ + transition: EnvTransition = create_transition(reward=reward) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.REWARD] + + def process_done(self, done: bool | torch.Tensor) -> bool | torch.Tensor: + """Processes only the done flag of a transition through the pipeline. + + Args: + done: The done flag. + + Returns: + The processed done flag. + """ + transition: EnvTransition = create_transition(done=done) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.DONE] + + def process_truncated(self, truncated: bool | torch.Tensor) -> bool | torch.Tensor: + """Processes only the truncated flag of a transition through the pipeline. + + Args: + truncated: The truncated flag. + + Returns: + The processed truncated flag. + """ + transition: EnvTransition = create_transition(truncated=truncated) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.TRUNCATED] + + def process_info(self, info: dict[str, Any]) -> dict[str, Any]: + """Processes only the info dictionary of a transition through the pipeline. + + Args: + info: The info dictionary. + + Returns: + The processed info dictionary. + """ + transition: EnvTransition = create_transition(info=info) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.INFO] + + def process_complementary_data(self, complementary_data: dict[str, Any]) -> dict[str, Any]: + """Processes only the complementary data part of a transition through the pipeline. + + Args: + complementary_data: The complementary data dictionary. + + Returns: + The processed complementary data dictionary. + """ + transition: EnvTransition = create_transition(complementary_data=complementary_data) + transformed_transition = self._forward(transition) + return transformed_transition[TransitionKey.COMPLEMENTARY_DATA] + + +# Type aliases for semantic clarity. +RobotProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput] +PolicyProcessorPipeline: TypeAlias = DataProcessorPipeline[TInput, TOutput] + + +class ObservationProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` that specifically targets the observation in a transition.""" + + @abstractmethod + def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + """Processes an observation dictionary. Subclasses must implement this method. + + Args: + observation: The input observation dictionary from the transition. + + Returns: + The processed observation dictionary. + """ + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - observation = transition.get(TransitionKey.OBSERVATION) - if observation is None: - return transition + """Applies the `observation` method to the transition's observation.""" + self._current_transition = transition.copy() + new_transition = self._current_transition - processed_observation = self.observation(observation) - # Create a new transition dict with the processed observation - new_transition = transition.copy() + observation = new_transition.get(TransitionKey.OBSERVATION) + if observation is None or not isinstance(observation, dict): + raise ValueError("ObservationProcessorStep requires an observation in the transition.") + + processed_observation = self.observation(observation.copy()) new_transition[TransitionKey.OBSERVATION] = processed_observation return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} +class ActionProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` that specifically targets the action in a transition.""" - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class ActionProcessor: - """Base class for processors that modify only the action component of a transition. - - Subclasses should override the `action` method to implement custom action processing. - This class handles the boilerplate of extracting and reinserting the processed action - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class ActionClipping(ActionProcessor): - def __init__(self, min_val, max_val): - self.min_val = min_val - self.max_val = max_val - - def action(self, action): - return np.clip(action, self.min_val, self.max_val) - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific action processing logic. - """ - - def action(self, action): - """Process the action component. + @abstractmethod + def action( + self, action: PolicyAction | RobotAction | EnvAction + ) -> PolicyAction | RobotAction | EnvAction: + """Processes an action. Subclasses must implement this method. Args: - action: The action to process + action: The input action from the transition. Returns: - The processed action + The processed action. """ - return action + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - action = transition.get(TransitionKey.ACTION) + """Applies the `action` method to the transition's action.""" + self._current_transition = transition.copy() + new_transition = self._current_transition + + action = new_transition.get(TransitionKey.ACTION) if action is None: - return transition + raise ValueError("ActionProcessorStep requires an action in the transition.") processed_action = self.action(action) - # Create a new transition dict with the processed action - new_transition = transition.copy() new_transition[TransitionKey.ACTION] = processed_action return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} +class RobotActionProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` for processing a `RobotAction` (a dictionary).""" - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class RewardProcessor: - """Base class for processors that modify only the reward component of a transition. - - Subclasses should override the `reward` method to implement custom reward processing. - This class handles the boilerplate of extracting and reinserting the processed reward - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class RewardScaler(RewardProcessor): - def __init__(self, scale_factor): - self.scale_factor = scale_factor - - def reward(self, reward): - return reward * self.scale_factor - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific reward processing logic. - """ - - def reward(self, reward): - """Process the reward component. + @abstractmethod + def action(self, action: RobotAction) -> RobotAction: + """Processes a `RobotAction`. Subclasses must implement this method. Args: - reward: The reward to process + action: The input `RobotAction` dictionary. Returns: - The processed reward + The processed `RobotAction`. """ - return reward + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - reward = transition.get(TransitionKey.REWARD) + """Applies the `action` method to the transition's action, ensuring it's a `RobotAction`.""" + self._current_transition = transition.copy() + new_transition = self._current_transition + + action = new_transition.get(TransitionKey.ACTION) + if action is None or not isinstance(action, dict): + raise ValueError(f"Action should be a RobotAction type (dict), but got {type(action)}") + + processed_action = self.action(action.copy()) + new_transition[TransitionKey.ACTION] = processed_action + return new_transition + + +class PolicyActionProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` for processing a `PolicyAction` (a tensor or dict of tensors).""" + + @abstractmethod + def action(self, action: PolicyAction) -> PolicyAction: + """Processes a `PolicyAction`. Subclasses must implement this method. + + Args: + action: The input `PolicyAction`. + + Returns: + The processed `PolicyAction`. + """ + ... + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Applies the `action` method to the transition's action, ensuring it's a `PolicyAction`.""" + self._current_transition = transition.copy() + new_transition = self._current_transition + + action = new_transition.get(TransitionKey.ACTION) + if not isinstance(action, PolicyAction): + raise ValueError(f"Action should be a PolicyAction type (tensor), but got {type(action)}") + + processed_action = self.action(action) + new_transition[TransitionKey.ACTION] = processed_action + return new_transition + + +class RewardProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` that specifically targets the reward in a transition.""" + + @abstractmethod + def reward(self, reward) -> float | torch.Tensor: + """Processes a reward. Subclasses must implement this method. + + Args: + reward: The input reward from the transition. + + Returns: + The processed reward. + """ + ... + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Applies the `reward` method to the transition's reward.""" + self._current_transition = transition.copy() + new_transition = self._current_transition + + reward = new_transition.get(TransitionKey.REWARD) if reward is None: - return transition + raise ValueError("RewardProcessorStep requires a reward in the transition.") processed_reward = self.reward(reward) - # Create a new transition dict with the processed reward - new_transition = transition.copy() new_transition[TransitionKey.REWARD] = processed_reward return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} +class DoneProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` that specifically targets the 'done' flag in a transition.""" - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class DoneProcessor: - """Base class for processors that modify only the done flag of a transition. - - Subclasses should override the `done` method to implement custom done flag processing. - This class handles the boilerplate of extracting and reinserting the processed done flag - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class TimeoutDone(DoneProcessor): - def __init__(self, max_steps): - self.steps = 0 - self.max_steps = max_steps - - def done(self, done): - self.steps += 1 - return done or self.steps >= self.max_steps - - def reset(self): - self.steps = 0 - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific done flag processing logic. - """ - - def done(self, done): - """Process the done flag. + @abstractmethod + def done(self, done) -> bool | torch.Tensor: + """Processes a 'done' flag. Subclasses must implement this method. Args: - done: The done flag to process + done: The input 'done' flag from the transition. Returns: - The processed done flag + The processed 'done' flag. """ - return done + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - done = transition.get(TransitionKey.DONE) + """Applies the `done` method to the transition's 'done' flag.""" + self._current_transition = transition.copy() + new_transition = self._current_transition + + done = new_transition.get(TransitionKey.DONE) if done is None: - return transition + raise ValueError("DoneProcessorStep requires a done flag in the transition.") processed_done = self.done(done) - # Create a new transition dict with the processed done flag - new_transition = transition.copy() new_transition[TransitionKey.DONE] = processed_done return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} +class TruncatedProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` that specifically targets the 'truncated' flag in a transition.""" - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class TruncatedProcessor: - """Base class for processors that modify only the truncated flag of a transition. - - Subclasses should override the `truncated` method to implement custom truncated flag processing. - This class handles the boilerplate of extracting and reinserting the processed truncated flag - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class EarlyTruncation(TruncatedProcessor): - def __init__(self, threshold): - self.threshold = threshold - - def truncated(self, truncated): - # Additional truncation condition - return truncated or some_condition > self.threshold - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific truncated flag processing logic. - """ - - def truncated(self, truncated): - """Process the truncated flag. + @abstractmethod + def truncated(self, truncated) -> bool | torch.Tensor: + """Processes a 'truncated' flag. Subclasses must implement this method. Args: - truncated: The truncated flag to process + truncated: The input 'truncated' flag from the transition. Returns: - The processed truncated flag + The processed 'truncated' flag. """ - return truncated + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - truncated = transition.get(TransitionKey.TRUNCATED) + """Applies the `truncated` method to the transition's 'truncated' flag.""" + self._current_transition = transition.copy() + new_transition = self._current_transition + + truncated = new_transition.get(TransitionKey.TRUNCATED) if truncated is None: - return transition + raise ValueError("TruncatedProcessorStep requires a truncated flag in the transition.") processed_truncated = self.truncated(truncated) - # Create a new transition dict with the processed truncated flag - new_transition = transition.copy() new_transition[TransitionKey.TRUNCATED] = processed_truncated return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} +class InfoProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` that specifically targets the 'info' dictionary in a transition.""" - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class InfoProcessor: - """Base class for processors that modify only the info dictionary of a transition. - - Subclasses should override the `info` method to implement custom info processing. - This class handles the boilerplate of extracting and reinserting the processed info - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - - Example: - ```python - class InfoAugmenter(InfoProcessor): - def __init__(self): - self.step_count = 0 - - def info(self, info): - info = info.copy() # Create a copy to avoid modifying the original - info["steps"] = self.step_count - self.step_count += 1 - return info - - def reset(self): - self.step_count = 0 - ``` - - By inheriting from this class, you avoid writing repetitive code to handle transition dict - manipulation, focusing only on the specific info dictionary processing logic. - """ - - def info(self, info): - """Process the info dictionary. + @abstractmethod + def info(self, info) -> dict[str, Any]: + """Processes an 'info' dictionary. Subclasses must implement this method. Args: - info: The info dictionary to process + info: The input 'info' dictionary from the transition. Returns: - The processed info dictionary + The processed 'info' dictionary. """ - return info + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - info = transition.get(TransitionKey.INFO) - if info is None: - return transition + """Applies the `info` method to the transition's 'info' dictionary.""" + self._current_transition = transition.copy() + new_transition = self._current_transition - processed_info = self.info(info) - # Create a new transition dict with the processed info - new_transition = transition.copy() + info = new_transition.get(TransitionKey.INFO) + if info is None or not isinstance(info, dict): + raise ValueError("InfoProcessorStep requires an info dictionary in the transition.") + + processed_info = self.info(info.copy()) new_transition[TransitionKey.INFO] = processed_info return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} +class ComplementaryDataProcessorStep(ProcessorStep, ABC): + """An abstract `ProcessorStep` that targets the 'complementary_data' in a transition.""" - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class ComplementaryDataProcessor: - """Base class for processors that modify only the complementary data of a transition. - - Subclasses should override the `complementary_data` method to implement custom complementary data processing. - This class handles the boilerplate of extracting and reinserting the processed complementary data - into the transition dict, eliminating the need to implement the `__call__` method in subclasses. - """ - - def complementary_data(self, complementary_data): - """Process the complementary data. + @abstractmethod + def complementary_data(self, complementary_data) -> dict[str, Any]: + """Processes a 'complementary_data' dictionary. Subclasses must implement this method. Args: - complementary_data: The complementary data to process + complementary_data: The input 'complementary_data' from the transition. Returns: - The processed complementary data + The processed 'complementary_data' dictionary. """ - return complementary_data + ... def __call__(self, transition: EnvTransition) -> EnvTransition: - complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) - if complementary_data is None: - return transition + """Applies the `complementary_data` method to the transition's data.""" + self._current_transition = transition.copy() + new_transition = self._current_transition - processed_complementary_data = self.complementary_data(complementary_data) - # Create a new transition dict with the processed complementary data - new_transition = transition.copy() + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None or not isinstance(complementary_data, dict): + raise ValueError("ComplementaryDataProcessorStep requires complementary data in the transition.") + + processed_complementary_data = self.complementary_data(complementary_data.copy()) new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data return new_transition - def get_config(self) -> dict[str, Any]: - return {} - def state_dict(self) -> dict[str, torch.Tensor]: - return {} +class IdentityProcessorStep(ProcessorStep): + """A no-op processor step that returns the input transition and features unchanged. - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - return features - - -class IdentityProcessor: - """Identity processor that does nothing.""" + This can be useful as a placeholder or for debugging purposes. + """ def __call__(self, transition: EnvTransition) -> EnvTransition: + """Returns the transition without modification.""" return transition - def get_config(self) -> dict[str, Any]: - return {} - - def state_dict(self) -> dict[str, torch.Tensor]: - return {} - - def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: - pass - - def reset(self) -> None: - pass - - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Returns the features without modification.""" return features diff --git a/src/lerobot/processor/policy_robot_bridge.py b/src/lerobot/processor/policy_robot_bridge.py new file mode 100644 index 000000000..74c534998 --- /dev/null +++ b/src/lerobot/processor/policy_robot_bridge.py @@ -0,0 +1,52 @@ +from dataclasses import asdict, dataclass +from typing import Any + +import torch + +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.processor import ActionProcessorStep, PolicyAction, ProcessorStepRegistry, RobotAction + + +@dataclass +@ProcessorStepRegistry.register("robot_action_to_policy_action_processor") +class RobotActionToPolicyActionProcessorStep(ActionProcessorStep): + """Processor step to map a dictionary to a tensor action.""" + + motor_names: list[str] + + def action(self, action: RobotAction) -> PolicyAction: + if len(self.motor_names) != len(action): + raise ValueError(f"Action must have {len(self.motor_names)} elements, got {len(action)}") + return torch.tensor([action[f"{name}.pos"] for name in self.motor_names]) + + def get_config(self) -> dict[str, Any]: + return asdict(self) + + def transform_features(self, features): + features[PipelineFeatureType.ACTION]["action"] = PolicyFeature( + type=FeatureType.ACTION, shape=(len(self.motor_names),) + ) + return features + + +@dataclass +@ProcessorStepRegistry.register("policy_action_to_robot_action_processor") +class PolicyActionToRobotActionProcessorStep(ActionProcessorStep): + """Processor step to map a policy action to a robot action.""" + + motor_names: list[str] + + def action(self, action: PolicyAction) -> RobotAction: + if len(self.motor_names) != len(action): + raise ValueError(f"Action must have {len(self.motor_names)} elements, got {len(action)}") + return {f"{name}.pos": action[i] for i, name in enumerate(self.motor_names)} + + def get_config(self) -> dict[str, Any]: + return asdict(self) + + def transform_features(self, features): + for name in self.motor_names: + features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + return features diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py index 4fe4105a5..6cae5921f 100644 --- a/src/lerobot/processor/rename_processor.py +++ b/src/lerobot/processor/rename_processor.py @@ -13,20 +13,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy from dataclasses import dataclass, field from typing import Any -from lerobot.configs.types import PolicyFeature -from lerobot.processor.pipeline import ( - ObservationProcessor, - ProcessorStepRegistry, -) +from lerobot.configs.types import PipelineFeatureType, PolicyFeature + +from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @dataclass -@ProcessorStepRegistry.register(name="rename_processor") -class RenameProcessor(ObservationProcessor): - """Rename processor that renames keys in the observation.""" +@ProcessorStepRegistry.register(name="rename_observations_processor") +class RenameObservationsProcessorStep(ObservationProcessorStep): + """ + A processor step that renames keys in an observation dictionary. + + This step is useful for creating a standardized data interface by mapping keys + from an environment's format to the format expected by a LeRobot policy or + other downstream components. + + Attributes: + rename_map: A dictionary mapping from old key names to new key names. + Keys present in an observation that are not in this map will + be kept with their original names. + """ rename_map: dict[str, str] = field(default_factory=dict) @@ -43,9 +53,41 @@ class RenameProcessor(ObservationProcessor): def get_config(self) -> dict[str, Any]: return {"rename_map": self.rename_map} - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: """Transforms: - Each key in the observation that appears in `rename_map` is renamed to its value. - Keys not in `rename_map` remain unchanged. """ - return {self.rename_map.get(k, k): v for k, v in features.items()} + new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = features.copy() + new_features[PipelineFeatureType.OBSERVATION] = { + self.rename_map.get(k, k): v for k, v in features[PipelineFeatureType.OBSERVATION].items() + } + return new_features + + +def rename_stats(stats: dict[str, dict[str, Any]], rename_map: dict[str, str]) -> dict[str, dict[str, Any]]: + """ + Renames the top-level keys in a statistics dictionary using a provided mapping. + + This is a helper function typically used to keep normalization statistics + consistent with renamed observation or action features. It performs a defensive + deep copy to avoid modifying the original `stats` dictionary. + + Args: + stats: A nested dictionary of statistics, where top-level keys are + feature names (e.g., `{"observation.state": {"mean": 0.5}}`). + rename_map: A dictionary mapping old feature names to new feature names. + + Returns: + A new statistics dictionary with its top-level keys renamed. Returns an + empty dictionary if the input `stats` is empty. + """ + if not stats: + return {} + renamed: dict[str, dict[str, Any]] = {} + for old_key, sub_stats in stats.items(): + new_key = rename_map.get(old_key, old_key) + renamed[new_key] = deepcopy(sub_stats) if sub_stats is not None else {} + return renamed diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py new file mode 100644 index 000000000..23db7b5e3 --- /dev/null +++ b/src/lerobot/processor/tokenizer_processor.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script defines a processor for tokenizing natural language instructions from an environment transition. + +It uses a tokenizer from the Hugging Face `transformers` library to convert task descriptions (text) into +token IDs and attention masks, which are then added to the observation dictionary. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import torch + +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS +from lerobot.utils.import_utils import _transformers_available + +from .core import EnvTransition, TransitionKey +from .pipeline import ObservationProcessorStep, ProcessorStepRegistry + +# Conditional import for type checking and lazy loading +if TYPE_CHECKING or _transformers_available: + from transformers import AutoTokenizer +else: + AutoTokenizer = None + + +@dataclass +@ProcessorStepRegistry.register(name="tokenizer_processor") +class TokenizerProcessorStep(ObservationProcessorStep): + """ + Processor step to tokenize a natural language task description. + + This step extracts a task string from the `complementary_data` of an `EnvTransition`, + tokenizes it using a Hugging Face `transformers` tokenizer, and adds the resulting + token IDs and attention mask to the `observation` dictionary. + + Requires the `transformers` library to be installed. + + Attributes: + tokenizer_name: The name of a pretrained tokenizer from the Hugging Face Hub (e.g., "bert-base-uncased"). + tokenizer: A pre-initialized tokenizer object. If provided, `tokenizer_name` is ignored. + max_length: The maximum length to pad or truncate sequences to. + task_key: The key in `complementary_data` where the task string is stored. + padding_side: The side to pad on ('left' or 'right'). + padding: The padding strategy ('max_length', 'longest', etc.). + truncation: Whether to truncate sequences longer than `max_length`. + input_tokenizer: The internal tokenizer instance, loaded during initialization. + """ + + tokenizer_name: str | None = None + tokenizer: Any | None = None # Use `Any` for compatibility without a hard dependency + max_length: int = 512 + task_key: str = "task" + padding_side: str = "right" + padding: str = "max_length" + truncation: bool = True + + # Internal tokenizer instance (not part of the config) + input_tokenizer: Any = field(default=None, init=False, repr=False) + + def __post_init__(self): + """ + Initializes the tokenizer after the dataclass is created. + + It checks for the availability of the `transformers` library and loads the tokenizer + either from a provided object or by name from the Hugging Face Hub. + + Raises: + ImportError: If the `transformers` library is not installed. + ValueError: If neither `tokenizer` nor `tokenizer_name` is provided. + """ + if not _transformers_available: + raise ImportError( + "The 'transformers' library is not installed. " + "Please install it with `pip install 'lerobot[transformers-dep]'` to use TokenizerProcessorStep." + ) + + if self.tokenizer is not None: + # Use provided tokenizer object directly + self.input_tokenizer = self.tokenizer + elif self.tokenizer_name is not None: + if AutoTokenizer is None: + raise ImportError("AutoTokenizer is not available") + self.input_tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) + else: + raise ValueError( + "Either 'tokenizer' or 'tokenizer_name' must be provided. " + "Pass a tokenizer object directly or a tokenizer name to auto-load." + ) + + def get_task(self, transition: EnvTransition) -> list[str] | None: + """ + Extracts the task description(s) from the transition's complementary data. + + Args: + transition: The environment transition. + + Returns: + A list of task strings, or None if the task key is not found or the value is None. + """ + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None: + raise ValueError("Complementary data is None so no task can be extracted from it") + + task = complementary_data[self.task_key] + if task is None: + raise ValueError("Task extracted from Complementary data is None") + + # Standardize to a list of strings for the tokenizer + if isinstance(task, str): + return [task] + elif isinstance(task, list) and all(isinstance(t, str) for t in task): + return task + + return None + + def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + """ + Tokenizes the task description and adds it to the observation dictionary. + + This method retrieves the task, tokenizes it, moves the resulting tensors to the + same device as other data in the transition, and updates the observation. + + Args: + observation: The original observation dictionary. + + Returns: + The updated observation dictionary including token IDs and an attention mask. + """ + task = self.get_task(self.transition) + if task is None: + raise ValueError("Task cannot be None") + + # Tokenize the task (this will create CPU tensors) + tokenized_prompt = self._tokenize_text(task) + + # Detect the device from existing tensors in the transition to ensure consistency + target_device = self._detect_device(self.transition) + + # Move new tokenized tensors to the detected device + if target_device is not None: + tokenized_prompt = { + k: v.to(target_device) if isinstance(v, torch.Tensor) else v + for k, v in tokenized_prompt.items() + } + + # Create a new observation dict to avoid modifying the original in place + new_observation = dict(observation) + + # Add tokenized data to the observation + new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"] + new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool) + + return new_observation + + def _detect_device(self, transition: EnvTransition) -> torch.device | None: + """ + Detects the torch.device from existing tensors in the transition. + + It checks tensors in the observation dictionary first, then the action tensor. + + Args: + transition: The environment transition. + + Returns: + The detected `torch.device`, or None if no tensors are found. + """ + # Check observation tensors first (most likely place to find tensors) + observation = transition.get(TransitionKey.OBSERVATION) + if observation: + for value in observation.values(): + if isinstance(value, torch.Tensor): + return value.device + + # Fallback to checking the action tensor + action = transition.get(TransitionKey.ACTION) + if isinstance(action, torch.Tensor): + return action.device + + return None # No tensors found, default will be CPU + + def _tokenize_text(self, text: str | list[str]) -> dict[str, torch.Tensor]: + """ + A wrapper around the tokenizer call. + + Args: + text: A string or list of strings to tokenize. + + Returns: + A dictionary containing tokenized 'input_ids' and 'attention_mask' as PyTorch tensors. + """ + return self.input_tokenizer( + text, + max_length=self.max_length, + truncation=self.truncation, + padding=self.padding, + padding_side=self.padding_side, + return_tensors="pt", + ) + + def get_config(self) -> dict[str, Any]: + """ + Returns the serializable configuration of the processor. + + Note: The tokenizer object itself is not serialized. If the processor was initialized + with a tokenizer name, that name will be included in the config. + + Returns: + A dictionary with the processor's configuration parameters. + """ + config = { + "max_length": self.max_length, + "task_key": self.task_key, + "padding_side": self.padding_side, + "padding": self.padding, + "truncation": self.truncation, + } + + # Only save tokenizer_name if it was used to create the tokenizer + if self.tokenizer_name is not None and self.tokenizer is None: + config["tokenizer_name"] = self.tokenizer_name + + return config + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Adds feature definitions for the language tokens and attention mask. + + This updates the policy features dictionary to include the new data added to the + observation, ensuring downstream components are aware of their shape and type. + + Args: + features: The dictionary of existing policy features. + + Returns: + The updated dictionary of policy features. + """ + # Add a feature for the token IDs if it doesn't already exist + if OBS_LANGUAGE_TOKENS not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_TOKENS] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_length,) + ) + + # Add a feature for the attention mask if it doesn't already exist + if OBS_LANGUAGE_ATTENTION_MASK not in features[PipelineFeatureType.OBSERVATION]: + features[PipelineFeatureType.OBSERVATION][OBS_LANGUAGE_ATTENTION_MASK] = PolicyFeature( + type=FeatureType.LANGUAGE, shape=(self.max_length,) + ) + + return features diff --git a/src/lerobot/record.py b/src/lerobot/record.py index f39a05fb5..d09b017e4 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -21,11 +21,12 @@ Example: lerobot-record \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ - --robot.cameras="{laptop: {type: opencv, camera_index: 0, width: 640, height: 480}}" \ + --robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ --robot.id=black \ - --dataset.repo_id=aliberts/record-test \ + --dataset.repo_id=/ \ --dataset.num_episodes=2 \ --dataset.single_task="Grab the cube" \ + --display_data=true # <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \ # --teleop.type=so100_leader \ # --teleop.port=/dev/tty.usbmodem58760431551 \ @@ -59,9 +60,10 @@ lerobot-record \ import logging import time -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from pathlib import Path from pprint import pformat +from typing import Any from lerobot.cameras import ( # noqa: F401 CameraConfig, # noqa: F401 @@ -72,10 +74,20 @@ from lerobot.configs import parser from lerobot.configs.policies import PreTrainedConfig from lerobot.datasets.image_writer import safe_stop_image_writer from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features +from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts from lerobot.datasets.video_utils import VideoEncodingManager -from lerobot.policies.factory import make_policy +from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.processor import ( + PolicyAction, + PolicyProcessorPipeline, + RobotAction, + RobotObservation, + RobotProcessorPipeline, + make_default_processors, +) +from lerobot.processor.rename_processor import rename_stats from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -149,6 +161,8 @@ class DatasetRecordConfig: # Number of episodes to record before batch encoding videos # Set to 1 for immediate encoding (default behavior), or higher for batched encoding video_encoding_batch_size: int = 1 + # Rename map for the observation to override the image and state keys + rename_map: dict[str, str] = field(default_factory=dict) def __post_init__(self): if self.single_task is None: @@ -187,14 +201,55 @@ class RecordConfig: return ["policy"] +""" --------------- record_loop() data flow -------------------------- + [ Robot ] + V + [ robot.get_observation() ] ---> raw_obs + V + [ robot_observation_processor ] ---> processed_obs + V + .-----( ACTION LOGIC )------------------. + V V + [ From Teleoperator ] [ From Policy ] + | | + | [teleop.get_action] -> raw_action | [predict_action] + | | | | + | V | V + | [teleop_action_processor] | | + | | | | + '---> processed_teleop_action '---> processed_policy_action + | | + '-------------------------.-------------' + V + [ robot_action_processor ] --> robot_action_to_send + V + [ robot.send_action() ] -- (Robot Executes) + V + ( Save to Dataset ) + V + ( Rerun Log / Loop Wait ) +""" + + @safe_stop_image_writer def record_loop( robot: Robot, events: dict, fps: int, + teleop_action_processor: RobotProcessorPipeline[ + tuple[RobotAction, RobotObservation], RobotAction + ], # runs after teleop + robot_action_processor: RobotProcessorPipeline[ + tuple[RobotAction, RobotObservation], RobotAction + ], # runs before robot + robot_observation_processor: RobotProcessorPipeline[ + RobotObservation, RobotObservation + ], # runs after robot dataset: LeRobotDataset | None = None, teleop: Teleoperator | list[Teleoperator] | None = None, policy: PreTrainedPolicy | None = None, + preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None, + postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None, control_time_s: int | None = None, single_task: str | None = None, display_data: bool = False, @@ -226,9 +281,11 @@ def record_loop( "For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot." ) - # if policy is given it needs cleaning up - if policy is not None: + # Reset policy and processor if they are provided + if policy is not None and preprocessor is not None and postprocessor is not None: policy.reset() + preprocessor.reset() + postprocessor.reset() timestamp = 0 start_episode_t = time.perf_counter() @@ -239,32 +296,46 @@ def record_loop( events["exit_early"] = False break - observation = robot.get_observation() + # Get robot observation + obs = robot.get_observation() + + # Applies a pipeline to the raw robot observation, default is IdentityProcessor + obs_processed = robot_observation_processor(obs) if policy is not None or dataset is not None: - observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") + observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix="observation") - if policy is not None: + # Get action from either policy or teleop + if policy is not None and preprocessor is not None and postprocessor is not None: action_values = predict_action( - observation_frame, - policy, - get_safe_torch_device(policy.config.device), - policy.config.use_amp, + observation=observation_frame, + policy=policy, + device=get_safe_torch_device(policy.config.device), + preprocessor=preprocessor, + postprocessor=postprocessor, + use_amp=policy.config.use_amp, task=single_task, robot_type=robot.robot_type, ) - action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)} + + action_names = dataset.features["action"]["names"] + act_processed_policy: RobotAction = { + f"{name}": float(action_values[i]) for i, name in enumerate(action_names) + } + elif policy is None and isinstance(teleop, Teleoperator): - action = teleop.get_action() + act = teleop.get_action() + + # Applies a pipeline to the raw teleop action, default is IdentityProcessor + act_processed_teleop = teleop_action_processor((act, obs)) + elif policy is None and isinstance(teleop, list): - # TODO(pepijn, steven): clean the record loop for use of multiple robots (possibly with pipeline) arm_action = teleop_arm.get_action() arm_action = {f"arm_{k}": v for k, v in arm_action.items()} - keyboard_action = teleop_keyboard.get_action() base_action = robot._from_keyboard_to_base_action(keyboard_action) - - action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action + act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action + act_processed_teleop = teleop_action_processor((act, obs)) else: logging.info( "No policy or teleoperator provided, skipping action generation." @@ -273,17 +344,28 @@ def record_loop( ) continue - # Action can eventually be clipped using `max_relative_target`, - # so action actually sent is saved in the dataset. - sent_action = robot.send_action(action) + # Applies a pipeline to the action, default is IdentityProcessor + if policy is not None and act_processed_policy is not None: + action_values = act_processed_policy + robot_action_to_send = robot_action_processor((act_processed_policy, obs)) + else: + action_values = act_processed_teleop + robot_action_to_send = robot_action_processor((act_processed_teleop, obs)) + # Send action to robot + # Action can eventually be clipped using `max_relative_target`, + # so action actually sent is saved in the dataset. action = postprocessor.process(action) + # TODO(steven, pepijn, adil): we should use a pipeline step to clip the action, so the sent action is the action that we input to the robot. + _sent_action = robot.send_action(robot_action_to_send) + + # Write to dataset if dataset is not None: - action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action") + action_frame = build_dataset_frame(dataset.features, action_values, prefix="action") frame = {**observation_frame, **action_frame, "task": single_task} dataset.add_frame(frame) if display_data: - log_rerun_data(observation, action) + log_rerun_data(observation=obs_processed, action=action_values) dt_s = time.perf_counter() - start_loop_t busy_wait(1 / fps - dt_s) @@ -301,9 +383,22 @@ def record(cfg: RecordConfig) -> LeRobotDataset: robot = make_robot_from_config(cfg.robot) teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None - action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video) - obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video) - dataset_features = {**action_features, **obs_features} + teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() + + dataset_features = combine_feature_dicts( + aggregate_pipeline_dataset_features( + pipeline=teleop_action_processor, + initial_features=create_initial_features( + action=robot.action_features + ), # TODO(steven, pepijn): in future this should be come from teleop or policy + use_videos=cfg.dataset.video, + ), + aggregate_pipeline_dataset_features( + pipeline=robot_observation_processor, + initial_features=create_initial_features(observation=robot.observation_features), + use_videos=cfg.dataset.video, + ), + ) if cfg.resume: dataset = LeRobotDataset( @@ -335,6 +430,18 @@ def record(cfg: RecordConfig) -> LeRobotDataset: # Load pretrained policy policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + preprocessor = None + postprocessor = None + if cfg.policy is not None: + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), + preprocessor_overrides={ + "device_processor": {"device": cfg.policy.device}, + "rename_observations_processor": {"rename_map": cfg.dataset.rename_map}, + }, + ) robot.connect() if teleop is not None: @@ -350,8 +457,13 @@ def record(cfg: RecordConfig) -> LeRobotDataset: robot=robot, events=events, fps=cfg.dataset.fps, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, teleop=teleop, policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, dataset=dataset, control_time_s=cfg.dataset.episode_time_s, single_task=cfg.dataset.single_task, @@ -368,6 +480,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset: robot=robot, events=events, fps=cfg.dataset.fps, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, teleop=teleop, control_time_s=cfg.dataset.reset_time_s, single_task=cfg.dataset.single_task, diff --git a/src/lerobot/replay.py b/src/lerobot/replay.py index cd76d114e..6761e3f4f 100644 --- a/src/lerobot/replay.py +++ b/src/lerobot/replay.py @@ -23,7 +23,7 @@ lerobot-replay \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=black \ --dataset.repo_id=aliberts/record-test \ - --dataset.episode=2 + --dataset.episode=0 ``` Example replay with bimanual so100: @@ -45,9 +45,11 @@ from dataclasses import asdict, dataclass from pathlib import Path from pprint import pformat -import draccus - +from lerobot.configs import parser from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.processor import ( + make_default_robot_action_processor, +) from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -55,7 +57,6 @@ from lerobot.robots import ( # noqa: F401 hope_jr, koch_follower, make_robot_from_config, - reachy2, so100_follower, so101_follower, ) @@ -86,11 +87,13 @@ class ReplayConfig: play_sounds: bool = True -@draccus.wrap() +@parser.wrap() def replay(cfg: ReplayConfig): init_logging() logging.info(pformat(asdict(cfg))) + robot_action_processor = make_default_robot_action_processor() + robot = make_robot_from_config(cfg.robot) dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode]) @@ -109,7 +112,11 @@ def replay(cfg: ReplayConfig): for i, name in enumerate(dataset.features["action"]["names"]): action[name] = action_array[i] - robot.send_action(action) + robot_obs = robot.get_observation() + + processed_action = robot_action_processor((action, robot_obs)) + + _ = robot.send_action(processed_action) dt_s = time.perf_counter() - start_episode_t busy_wait(1 / dataset.fps - dt_s) diff --git a/src/lerobot/robots/so100_follower/__init__.py b/src/lerobot/robots/so100_follower/__init__.py index b995aab13..5dc43ac3b 100644 --- a/src/lerobot/robots/so100_follower/__init__.py +++ b/src/lerobot/robots/so100_follower/__init__.py @@ -14,6 +14,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig +from .config_so100_follower import SO100FollowerConfig from .so100_follower import SO100Follower -from .so100_follower_end_effector import SO100FollowerEndEffector diff --git a/src/lerobot/robots/so100_follower/config_so100_follower.py b/src/lerobot/robots/so100_follower/config_so100_follower.py index 561790e77..272b8c43f 100644 --- a/src/lerobot/robots/so100_follower/config_so100_follower.py +++ b/src/lerobot/robots/so100_follower/config_so100_follower.py @@ -39,35 +39,3 @@ class SO100FollowerConfig(RobotConfig): # Set to `True` for backward compatibility with previous policies/dataset use_degrees: bool = False - - -@RobotConfig.register_subclass("so100_follower_end_effector") -@dataclass -class SO100FollowerEndEffectorConfig(SO100FollowerConfig): - """Configuration for the SO100FollowerEndEffector robot.""" - - # Path to URDF file for kinematics - # NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: - # https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf - urdf_path: str | None = None - - # End-effector frame name in the URDF - target_frame_name: str = "gripper_frame_link" - - # Default bounds for the end-effector position (in meters) - end_effector_bounds: dict[str, list[float]] = field( - default_factory=lambda: { - "min": [-1.0, -1.0, -1.0], # min x, y, z - "max": [1.0, 1.0, 1.0], # max x, y, z - } - ) - - max_gripper_pos: float = 50 - - end_effector_step_sizes: dict[str, float] = field( - default_factory=lambda: { - "x": 0.02, - "y": 0.02, - "z": 0.02, - } - ) diff --git a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py new file mode 100644 index 000000000..56686d447 --- /dev/null +++ b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py @@ -0,0 +1,616 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +import numpy as np + +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import ( + EnvTransition, + ObservationProcessorStep, + ProcessorStep, + ProcessorStepRegistry, + RobotAction, + RobotActionProcessorStep, + TransitionKey, +) +from lerobot.utils.rotation import Rotation + + +@ProcessorStepRegistry.register("ee_reference_and_delta") +@dataclass +class EEReferenceAndDelta(RobotActionProcessorStep): + """ + Computes a target end-effector pose from a relative delta command. + + This step takes a desired change in position and orientation (`target_*`) and applies it to a + reference end-effector pose to calculate an absolute target pose. The reference pose is derived + from the current robot joint positions using forward kinematics. + + The processor can operate in two modes: + 1. `use_latched_reference=True`: The reference pose is "latched" or saved at the moment the action + is first enabled. Subsequent commands are relative to this fixed reference. + 2. `use_latched_reference=False`: The reference pose is updated to the robot's current pose at + every step. + + Attributes: + kinematics: The robot's kinematic model for forward kinematics. + end_effector_step_sizes: A dictionary scaling the input delta commands. + motor_names: A list of motor names required for forward kinematics. + use_latched_reference: If True, latch the reference pose on enable; otherwise, always use the + current pose as the reference. + reference_ee_pose: Internal state storing the latched reference pose. + _prev_enabled: Internal state to detect the rising edge of the enable signal. + _command_when_disabled: Internal state to hold the last command while disabled. + """ + + kinematics: RobotKinematics + end_effector_step_sizes: dict + motor_names: list[str] + use_latched_reference: bool = ( + True # If True, latch reference on enable; if False, always use current pose + ) + use_ik_solution: bool = False + + reference_ee_pose: np.ndarray | None = field(default=None, init=False, repr=False) + _prev_enabled: bool = field(default=False, init=False, repr=False) + _command_when_disabled: np.ndarray | None = field(default=None, init=False, repr=False) + + def action(self, action: RobotAction) -> RobotAction: + observation = self.transition.get(TransitionKey.OBSERVATION).copy() + + if observation is None: + raise ValueError("Joints observation is require for computing robot kinematics") + + if self.use_ik_solution and "IK_solution" in self.transition.get(TransitionKey.COMPLEMENTARY_DATA): + q_raw = self.transition.get(TransitionKey.COMPLEMENTARY_DATA)["IK_solution"] + else: + q_raw = np.array( + [ + float(v) + for k, v in observation.items() + if isinstance(k, str) + and k.endswith(".pos") + and k.removesuffix(".pos") in self.motor_names + ], + dtype=float, + ) + + if q_raw is None: + raise ValueError("Joints observation is require for computing robot kinematics") + + # Current pose from FK on measured joints + t_curr = self.kinematics.forward_kinematics(q_raw) + + enabled = bool(action.pop("enabled")) + tx = float(action.pop("target_x")) + ty = float(action.pop("target_y")) + tz = float(action.pop("target_z")) + wx = float(action.pop("target_wx")) + wy = float(action.pop("target_wy")) + wz = float(action.pop("target_wz")) + gripper_vel = float(action.pop("gripper_vel")) + + desired = None + + if enabled: + ref = t_curr + if self.use_latched_reference: + # Latched reference mode: latch reference at the rising edge + if not self._prev_enabled or self.reference_ee_pose is None: + self.reference_ee_pose = t_curr.copy() + ref = self.reference_ee_pose if self.reference_ee_pose is not None else t_curr + + delta_p = np.array( + [ + tx * self.end_effector_step_sizes["x"], + ty * self.end_effector_step_sizes["y"], + tz * self.end_effector_step_sizes["z"], + ], + dtype=float, + ) + r_abs = Rotation.from_rotvec([wx, wy, wz]).as_matrix() + desired = np.eye(4, dtype=float) + desired[:3, :3] = ref[:3, :3] @ r_abs + desired[:3, 3] = ref[:3, 3] + delta_p + + self._command_when_disabled = desired.copy() + else: + # While disabled, keep sending the same command to avoid drift. + if self._command_when_disabled is None: + # If we've never had an enabled command yet, freeze current FK pose once. + self._command_when_disabled = t_curr.copy() + desired = self._command_when_disabled.copy() + + # Write action fields + pos = desired[:3, 3] + tw = Rotation.from_matrix(desired[:3, :3]).as_rotvec() + action["ee.x"] = float(pos[0]) + action["ee.y"] = float(pos[1]) + action["ee.z"] = float(pos[2]) + action["ee.wx"] = float(tw[0]) + action["ee.wy"] = float(tw[1]) + action["ee.wz"] = float(tw[2]) + action["ee.gripper_vel"] = gripper_vel + + self._prev_enabled = enabled + return action + + def reset(self): + """Resets the internal state of the processor.""" + self._prev_enabled = False + self.reference_ee_pose = None + self._command_when_disabled = None + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + for feat in [ + "enabled", + "target_x", + "target_y", + "target_z", + "target_wx", + "target_wy", + "target_wz", + "gripper_vel", + ]: + features[PipelineFeatureType.ACTION].pop(f"{feat}", None) + + for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_vel"]: + features[PipelineFeatureType.ACTION][f"ee.{feat}"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + return features + + +@ProcessorStepRegistry.register("ee_bounds_and_safety") +@dataclass +class EEBoundsAndSafety(RobotActionProcessorStep): + """ + Clips the end-effector pose to predefined bounds and checks for unsafe jumps. + + This step ensures that the target end-effector pose remains within a safe operational workspace. + It also moderates the command to prevent large, sudden movements between consecutive steps. + + Attributes: + end_effector_bounds: A dictionary with "min" and "max" keys for position clipping. + max_ee_step_m: The maximum allowed change in position (in meters) between steps. + max_ee_twist_step_rad: The maximum allowed change in orientation (in radians) between steps. + _last_pos: Internal state storing the last commanded position. + _last_twist: Internal state storing the last commanded orientation. + """ + + end_effector_bounds: dict + max_ee_step_m: float = 0.05 + max_ee_twist_step_rad: float = 0.20 + _last_pos: np.ndarray | None = field(default=None, init=False, repr=False) + _last_twist: np.ndarray | None = field(default=None, init=False, repr=False) + + def action(self, action: RobotAction) -> RobotAction: + x = action["ee.x"] + y = action["ee.y"] + z = action["ee.z"] + wx = action["ee.wx"] + wy = action["ee.wy"] + wz = action["ee.wz"] + # TODO(Steven): ee.gripper_vel does not need to be bounded + + if None in (x, y, z, wx, wy, wz): + raise ValueError( + "Missing required end-effector pose components: x, y, z, wx, wy, wz must all be present in action" + ) + + pos = np.array([x, y, z], dtype=float) + twist = np.array([wx, wy, wz], dtype=float) + + # Clip position + pos = np.clip(pos, self.end_effector_bounds["min"], self.end_effector_bounds["max"]) + + # Check for jumps in position + if self._last_pos is not None: + dpos = pos - self._last_pos + n = float(np.linalg.norm(dpos)) + if n > self.max_ee_step_m and n > 0: + pos = self._last_pos + dpos * (self.max_ee_step_m / n) + raise ValueError(f"EE jump {n:.3f}m > {self.max_ee_step_m}m") + + self._last_pos = pos + self._last_twist = twist + + action["ee.x"] = float(pos[0]) + action["ee.y"] = float(pos[1]) + action["ee.z"] = float(pos[2]) + action["ee.wx"] = float(twist[0]) + action["ee.wy"] = float(twist[1]) + action["ee.wz"] = float(twist[2]) + return action + + def reset(self): + """Resets the last known position and orientation.""" + self._last_pos = None + self._last_twist = None + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + +@ProcessorStepRegistry.register("inverse_kinematics_ee_to_joints") +@dataclass +class InverseKinematicsEEToJoints(RobotActionProcessorStep): + """ + Computes desired joint positions from a target end-effector pose using inverse kinematics (IK). + + This step translates a Cartesian command (position and orientation of the end-effector) into + the corresponding joint-space commands for each motor. + + Attributes: + kinematics: The robot's kinematic model for inverse kinematics. + motor_names: A list of motor names for which to compute joint positions. + q_curr: Internal state storing the last joint positions, used as an initial guess for the IK solver. + initial_guess_current_joints: If True, use the robot's current joint state as the IK guess. + If False, use the solution from the previous step. + """ + + kinematics: RobotKinematics + motor_names: list[str] + q_curr: np.ndarray | None = field(default=None, init=False, repr=False) + initial_guess_current_joints: bool = True + + def action(self, action: RobotAction) -> RobotAction: + x = action.pop("ee.x") + y = action.pop("ee.y") + z = action.pop("ee.z") + wx = action.pop("ee.wx") + wy = action.pop("ee.wy") + wz = action.pop("ee.wz") + gripper_pos = action.pop("ee.gripper_pos") + + if None in (x, y, z, wx, wy, wz, gripper_pos): + raise ValueError( + "Missing required end-effector pose components: ee.x, ee.y, ee.z, ee.wx, ee.wy, ee.wz, ee.gripper_pos must all be present in action" + ) + + observation = self.transition.get(TransitionKey.OBSERVATION).copy() + if observation is None: + raise ValueError("Joints observation is require for computing robot kinematics") + + q_raw = np.array( + [float(v) for k, v in observation.items() if isinstance(k, str) and k.endswith(".pos")], + dtype=float, + ) + if q_raw is None: + raise ValueError("Joints observation is require for computing robot kinematics") + + if self.initial_guess_current_joints: # Use current joints as initial guess + self.q_curr = q_raw + else: # Use previous ik solution as initial guess + if self.q_curr is None: + self.q_curr = q_raw + + # Build desired 4x4 transform from pos + rotvec (twist) + t_des = np.eye(4, dtype=float) + t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix() + t_des[:3, 3] = [x, y, z] + + # Compute inverse kinematics + q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des) + self.q_curr = q_target + + # TODO: This is sentitive to order of motor_names = q_target mapping + for i, name in enumerate(self.motor_names): + if name != "gripper": + action[f"{name}.pos"] = float(q_target[i]) + else: + action["gripper.pos"] = float(gripper_pos) + + return action + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]: + features[PipelineFeatureType.ACTION].pop(f"ee.{feat}", None) + + for name in self.motor_names: + features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + return features + + def reset(self): + """Resets the initial guess for the IK solver.""" + self.q_curr = None + + +@ProcessorStepRegistry.register("gripper_velocity_to_joint") +@dataclass +class GripperVelocityToJoint(RobotActionProcessorStep): + """ + Converts a gripper velocity command into a target gripper joint position. + + This step integrates a normalized velocity command over time to produce a position command, + taking the current gripper position as a starting point. It also supports a discrete mode + where integer actions map to open, close, or no-op. + + Attributes: + motor_names: A list of motor names, which must include 'gripper'. + speed_factor: A scaling factor to convert the normalized velocity command to a position change. + clip_min: The minimum allowed gripper joint position. + clip_max: The maximum allowed gripper joint position. + discrete_gripper: If True, treat the input action as discrete (0: open, 1: close, 2: stay). + """ + + speed_factor: float = 20.0 + clip_min: float = 0.0 + clip_max: float = 100.0 + discrete_gripper: bool = False + + def action(self, action: RobotAction) -> RobotAction: + observation = self.transition.get(TransitionKey.OBSERVATION).copy() + + gripper_vel = action.pop("ee.gripper_vel") + + if observation is None: + raise ValueError("Joints observation is require for computing robot kinematics") + + q_raw = np.array( + [float(v) for k, v in observation.items() if isinstance(k, str) and k.endswith(".pos")], + dtype=float, + ) + if q_raw is None: + raise ValueError("Joints observation is require for computing robot kinematics") + + if self.discrete_gripper: + # Discrete gripper actions are in [0, 1, 2] + # 0: open, 1: close, 2: stay + # We need to shift them to [-1, 0, 1] and then scale them to clip_max + gripper_vel = (gripper_vel - 1) * self.clip_max + + # Compute desired gripper position + delta = gripper_vel * float(self.speed_factor) + # TODO: This assumes gripper is the last specified joint in the robot + gripper_pos = float(np.clip(q_raw[-1] + delta, self.clip_min, self.clip_max)) + action["ee.gripper_pos"] = gripper_pos + + return action + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + features[PipelineFeatureType.ACTION].pop("ee.gripper_vel", None) + features[PipelineFeatureType.ACTION]["ee.gripper_pos"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + return features + + +def compute_forward_kinematics_joints_to_ee( + joints: dict[str, Any], kinematics: RobotKinematics, motor_names: list[str] +) -> dict[str, Any]: + motor_joint_values = [joints[f"{n}.pos"] for n in motor_names] + + q = np.array(motor_joint_values, dtype=float) + t = kinematics.forward_kinematics(q) + pos = t[:3, 3] + tw = Rotation.from_matrix(t[:3, :3]).as_rotvec() + gripper_pos = joints["gripper.pos"] + for n in motor_names: + joints.pop(f"{n}.pos") + joints["ee.x"] = float(pos[0]) + joints["ee.y"] = float(pos[1]) + joints["ee.z"] = float(pos[2]) + joints["ee.wx"] = float(tw[0]) + joints["ee.wy"] = float(tw[1]) + joints["ee.wz"] = float(tw[2]) + joints["ee.gripper_pos"] = float(gripper_pos) + return joints + + +@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee_observation") +@dataclass +class ForwardKinematicsJointsToEEObservation(ObservationProcessorStep): + """ + Computes the end-effector pose from joint positions using forward kinematics (FK). + + This step is typically used to add the robot's Cartesian pose to the observation space, + which can be useful for visualization or as an input to a policy. + + Attributes: + kinematics: The robot's kinematic model. + """ + + kinematics: RobotKinematics + motor_names: list[str] + + def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + return compute_forward_kinematics_joints_to_ee(observation, self.kinematics, self.motor_names) + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We only use the ee pose in the dataset, so we don't need the joint positions + for n in self.motor_names: + features[PipelineFeatureType.OBSERVATION].pop(f"{n}.pos", None) + # We specify the dataset features of this step that we want to be stored in the dataset + for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]: + features[PipelineFeatureType.OBSERVATION][f"ee.{k}"] = PolicyFeature( + type=FeatureType.STATE, shape=(1,) + ) + return features + + +@ProcessorStepRegistry.register("forward_kinematics_joints_to_ee_action") +@dataclass +class ForwardKinematicsJointsToEEAction(RobotActionProcessorStep): + """ + Computes the end-effector pose from joint positions using forward kinematics (FK). + + This step is typically used to add the robot's Cartesian pose to the observation space, + which can be useful for visualization or as an input to a policy. + + Attributes: + kinematics: The robot's kinematic model. + """ + + kinematics: RobotKinematics + motor_names: list[str] + + def action(self, action: RobotAction) -> RobotAction: + return compute_forward_kinematics_joints_to_ee(action, self.kinematics, self.motor_names) + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We only use the ee pose in the dataset, so we don't need the joint positions + for n in self.motor_names: + features[PipelineFeatureType.ACTION].pop(f"{n}.pos", None) + # We specify the dataset features of this step that we want to be stored in the dataset + for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]: + features[PipelineFeatureType.ACTION][f"ee.{k}"] = PolicyFeature( + type=FeatureType.STATE, shape=(1,) + ) + return features + + +@ProcessorStepRegistry.register(name="forward_kinematics_joints_to_ee") +@dataclass +class ForwardKinematicsJointsToEE(ProcessorStep): + kinematics: RobotKinematics + motor_names: list[str] + + def __post_init__(self): + self.joints_to_ee_action_processor = ForwardKinematicsJointsToEEAction( + kinematics=self.kinematics, motor_names=self.motor_names + ) + self.joints_to_ee_observation_processor = ForwardKinematicsJointsToEEObservation( + kinematics=self.kinematics, motor_names=self.motor_names + ) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + if transition.get(TransitionKey.ACTION) is not None: + transition = self.joints_to_ee_action_processor(transition) + if transition.get(TransitionKey.OBSERVATION) is not None: + transition = self.joints_to_ee_observation_processor(transition) + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + if features[PipelineFeatureType.ACTION] is not None: + features = self.joints_to_ee_action_processor.transform_features(features) + if features[PipelineFeatureType.OBSERVATION] is not None: + features = self.joints_to_ee_observation_processor.transform_features(features) + return features + + +@ProcessorStepRegistry.register("inverse_kinematics_rl_step") +@dataclass +class InverseKinematicsRLStep(ProcessorStep): + """ + Computes desired joint positions from a target end-effector pose using inverse kinematics (IK). + + This is modified from the InverseKinematicsEEToJoints step to be used in the RL pipeline. + """ + + kinematics: RobotKinematics + motor_names: list[str] + q_curr: np.ndarray | None = field(default=None, init=False, repr=False) + initial_guess_current_joints: bool = True + + def __call__(self, transition: EnvTransition) -> EnvTransition: + new_transition = dict(transition) + action = new_transition.get(TransitionKey.ACTION) + if action is None: + raise ValueError("Action is required for InverseKinematicsEEToJoints") + action = dict(action) + + x = action.pop("ee.x") + y = action.pop("ee.y") + z = action.pop("ee.z") + wx = action.pop("ee.wx") + wy = action.pop("ee.wy") + wz = action.pop("ee.wz") + gripper_pos = action.pop("ee.gripper_pos") + + if None in (x, y, z, wx, wy, wz, gripper_pos): + raise ValueError( + "Missing required end-effector pose components: ee.x, ee.y, ee.z, ee.wx, ee.wy, ee.wz, ee.gripper_pos must all be present in action" + ) + + observation = new_transition.get(TransitionKey.OBSERVATION).copy() + if observation is None: + raise ValueError("Joints observation is require for computing robot kinematics") + + q_raw = np.array( + [float(v) for k, v in observation.items() if isinstance(k, str) and k.endswith(".pos")], + dtype=float, + ) + if q_raw is None: + raise ValueError("Joints observation is require for computing robot kinematics") + + if self.initial_guess_current_joints: # Use current joints as initial guess + self.q_curr = q_raw + else: # Use previous ik solution as initial guess + if self.q_curr is None: + self.q_curr = q_raw + + # Build desired 4x4 transform from pos + rotvec (twist) + t_des = np.eye(4, dtype=float) + t_des[:3, :3] = Rotation.from_rotvec([wx, wy, wz]).as_matrix() + t_des[:3, 3] = [x, y, z] + + # Compute inverse kinematics + q_target = self.kinematics.inverse_kinematics(self.q_curr, t_des) + self.q_curr = q_target + + # TODO: This is sentitive to order of motor_names = q_target mapping + for i, name in enumerate(self.motor_names): + if name != "gripper": + action[f"{name}.pos"] = float(q_target[i]) + else: + action["gripper.pos"] = float(gripper_pos) + + new_transition[TransitionKey.ACTION] = action + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + complementary_data["IK_solution"] = q_target + new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + return new_transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + for feat in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]: + features[PipelineFeatureType.ACTION].pop(f"ee.{feat}", None) + + for name in self.motor_names: + features[PipelineFeatureType.ACTION][f"{name}.pos"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + return features + + def reset(self): + """Resets the initial guess for the IK solver.""" + self.q_curr = None diff --git a/src/lerobot/robots/so100_follower/so100_follower_end_effector.py b/src/lerobot/robots/so100_follower/so100_follower_end_effector.py deleted file mode 100644 index 5fe2993cb..000000000 --- a/src/lerobot/robots/so100_follower/so100_follower_end_effector.py +++ /dev/null @@ -1,200 +0,0 @@ -# !/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time -from typing import Any - -import numpy as np - -from lerobot.cameras import make_cameras_from_configs -from lerobot.errors import DeviceNotConnectedError -from lerobot.model.kinematics import RobotKinematics -from lerobot.motors import Motor, MotorNormMode -from lerobot.motors.feetech import FeetechMotorsBus - -from . import SO100Follower -from .config_so100_follower import SO100FollowerEndEffectorConfig - -logger = logging.getLogger(__name__) - - -class SO100FollowerEndEffector(SO100Follower): - """ - SO100Follower robot with end-effector space control. - - This robot inherits from SO100Follower but transforms actions from - end-effector space to joint space before sending them to the motors. - """ - - config_class = SO100FollowerEndEffectorConfig - name = "so100_follower_end_effector" - - def __init__(self, config: SO100FollowerEndEffectorConfig): - super().__init__(config) - self.bus = FeetechMotorsBus( - port=self.config.port, - motors={ - "shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREES), - "shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREES), - "elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREES), - "wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREES), - "wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREES), - "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), - }, - calibration=self.calibration, - ) - - self.cameras = make_cameras_from_configs(config.cameras) - - self.config = config - - # Initialize the kinematics module for the so100 robot - if self.config.urdf_path is None: - raise ValueError( - "urdf_path must be provided in the configuration for end-effector control. " - "Please set urdf_path in your SO100FollowerEndEffectorConfig." - ) - - self.kinematics = RobotKinematics( - urdf_path=self.config.urdf_path, - target_frame_name=self.config.target_frame_name, - ) - - # Store the bounds for end-effector position - self.end_effector_bounds = self.config.end_effector_bounds - - self.current_ee_pos = None - self.current_joint_pos = None - - @property - def action_features(self) -> dict[str, Any]: - """ - Define action features for end-effector control. - Returns dictionary with dtype, shape, and names. - """ - return { - "dtype": "float32", - "shape": (4,), - "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3}, - } - - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: - """ - Transform action from end-effector space to joint space and send to motors. - - Args: - action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control - or a numpy array with [delta_x, delta_y, delta_z] - - Returns: - The joint-space action that was sent to the motors - """ - - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - # Convert action to numpy array if not already - if isinstance(action, dict): - if all(k in action for k in ["delta_x", "delta_y", "delta_z"]): - delta_ee = np.array( - [ - action["delta_x"] * self.config.end_effector_step_sizes["x"], - action["delta_y"] * self.config.end_effector_step_sizes["y"], - action["delta_z"] * self.config.end_effector_step_sizes["z"], - ], - dtype=np.float32, - ) - if "gripper" not in action: - action["gripper"] = [1.0] - action = np.append(delta_ee, action["gripper"]) - else: - logger.warning( - f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}" - ) - action = np.zeros(4, dtype=np.float32) - - if self.current_joint_pos is None: - # Read current joint positions - current_joint_pos = self.bus.sync_read("Present_Position") - self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors]) - - # Calculate current end-effector position using forward kinematics - if self.current_ee_pos is None: - self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos) - - # Set desired end-effector position by adding delta - desired_ee_pos = np.eye(4) - desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation - - # Add delta to position and clip to bounds - desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3] - if self.end_effector_bounds is not None: - desired_ee_pos[:3, 3] = np.clip( - desired_ee_pos[:3, 3], - self.end_effector_bounds["min"], - self.end_effector_bounds["max"], - ) - - # Compute inverse kinematics to get joint positions - target_joint_values_in_degrees = self.kinematics.inverse_kinematics( - self.current_joint_pos, desired_ee_pos - ) - - # Create joint space action dictionary - joint_action = { - f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys()) - } - - # Handle gripper separately if included in action - # Gripper delta action is in the range 0 - 2, - # We need to shift the action to the range -1, 1 so that we can expand it to -Max_gripper_pos, Max_gripper_pos - joint_action["gripper.pos"] = np.clip( - self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos, - 5, - self.config.max_gripper_pos, - ) - - self.current_ee_pos = desired_ee_pos.copy() - self.current_joint_pos = target_joint_values_in_degrees.copy() - self.current_joint_pos[-1] = joint_action["gripper.pos"] - - # Send joint space action to parent class - return super().send_action(joint_action) - - def get_observation(self) -> dict[str, Any]: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - # Read arm position - start = time.perf_counter() - obs_dict = self.bus.sync_read("Present_Position") - obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()} - dt_ms = (time.perf_counter() - start) * 1e3 - logger.debug(f"{self} read state: {dt_ms:.1f}ms") - - # Capture images from cameras - for cam_key, cam in self.cameras.items(): - start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() - dt_ms = (time.perf_counter() - start) * 1e3 - logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") - - return obs_dict - - def reset(self): - self.current_ee_pos = None - self.current_joint_pos = None diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index 261e59a32..0455bce3f 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -29,10 +29,6 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .so100_follower import SO100Follower return SO100Follower(config) - elif config.type == "so100_follower_end_effector": - from .so100_follower import SO100FollowerEndEffector - - return SO100FollowerEndEffector(config) elif config.type == "so101_follower": from .so101_follower import SO101Follower @@ -73,6 +69,7 @@ def make_robot_from_config(config: RobotConfig) -> Robot: raise ValueError(config.type) +# TODO(pepijn): Move to pipeline step to make sure we don't have to do this in the robot code and send action to robot is clean for use in dataset def ensure_safe_goal_position( goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[str, float] ) -> dict[str, float]: diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index 13d30c686..bf398a0a9 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -56,6 +56,7 @@ from copy import deepcopy from dataclasses import asdict from pathlib import Path from pprint import pformat +from typing import Any import einops import gymnasium as gym @@ -69,9 +70,9 @@ from lerobot.configs import parser from lerobot.configs.eval import EvalPipelineConfig from lerobot.envs.factory import make_env from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation -from lerobot.policies.factory import make_policy +from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.utils import get_device_from_parameters +from lerobot.processor import PolicyAction, PolicyProcessorPipeline from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( @@ -84,6 +85,8 @@ from lerobot.utils.utils import ( def rollout( env: gym.vector.VectorEnv, policy: PreTrainedPolicy, + preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], seeds: list[int] | None = None, return_observations: bool = False, render_callback: Callable[[gym.vector.VectorEnv], None] | None = None, @@ -120,7 +123,6 @@ def rollout( The dictionary described above. """ assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module." - device = get_device_from_parameters(policy) # Reset the policy and environments. policy.reset() @@ -151,23 +153,20 @@ def rollout( if return_observations: all_observations.append(deepcopy(observation)) - observation = { - key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation - } - # Infer "task" from attributes of environments. # TODO: works with SyncVectorEnv but not AsyncVectorEnv observation = add_envs_task(env, observation) - + observation = preprocessor(observation) with torch.inference_mode(): action = policy.select_action(observation) + action = postprocessor(action) # Convert to CPU / numpy. - action = action.to("cpu").numpy() - assert action.ndim == 2, "Action dimensions should be (batch, action_dim)" + action_numpy: np.ndarray = action.to("cpu").numpy() + assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)" # Apply the next action. - observation, reward, terminated, truncated, info = env.step(action) + observation, reward, terminated, truncated, info = env.step(action_numpy) if render_callback is not None: render_callback(env) @@ -181,7 +180,7 @@ def rollout( # Keep track of which environments are done so far. done = terminated | truncated | done - all_actions.append(torch.from_numpy(action)) + all_actions.append(torch.from_numpy(action_numpy)) all_rewards.append(torch.from_numpy(reward)) all_dones.append(torch.from_numpy(done)) all_successes.append(torch.tensor(successes)) @@ -220,6 +219,8 @@ def rollout( def eval_policy( env: gym.vector.VectorEnv, policy: PreTrainedPolicy, + preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], n_episodes: int, max_episodes_rendered: int = 0, videos_dir: Path | None = None, @@ -296,8 +297,10 @@ def eval_policy( start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs) ) rollout_data = rollout( - env, - policy, + env=env, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, seeds=list(seeds) if seeds else None, return_observations=return_episode_data, render_callback=render_frame if max_episodes_rendered > 0 else None, @@ -479,13 +482,22 @@ def eval_main(cfg: EvalPipelineConfig): cfg=cfg.policy, env_cfg=cfg.env, ) + policy.eval() + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + # The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility. + preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}}, + ) with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): info = eval_policy( - env, - policy, - cfg.eval.n_episodes, + env=env, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + n_episodes=cfg.eval.n_episodes, max_episodes_rendered=10, videos_dir=Path(cfg.output_dir) / "videos", start_seed=cfg.seed, diff --git a/src/lerobot/scripts/rl/actor.py b/src/lerobot/scripts/rl/actor.py index 1c8f9286b..baa284c4a 100644 --- a/src/lerobot/scripts/rl/actor.py +++ b/src/lerobot/scripts/rl/actor.py @@ -62,9 +62,16 @@ from lerobot.configs import parser from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.policies.factory import make_policy from lerobot.policies.sac.modeling_sac import SACPolicy +from lerobot.processor import TransitionKey from lerobot.robots import so100_follower # noqa: F401 -from lerobot.scripts.rl.gym_manipulator import make_robot_env +from lerobot.scripts.rl.gym_manipulator import ( + create_transition, + make_processors, + make_robot_env, + step_env_and_process_transition, +) from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 +from lerobot.teleoperators.utils import TeleopEvents from lerobot.transport import services_pb2, services_pb2_grpc from lerobot.transport.utils import ( bytes_to_state_dict, @@ -91,10 +98,7 @@ from lerobot.utils.utils import ( ACTOR_SHUTDOWN_TIMEOUT = 30 - -################################################# -# Main entry point # -################################################# +# Main entry point @parser.wrap() @@ -201,9 +205,7 @@ def actor_cli(cfg: TrainRLServerPipelineConfig): logging.info("[ACTOR] queues closed") -################################################# -# Core algorithm functions # -################################################# +# Core algorithm functions def act_with_policy( @@ -236,7 +238,8 @@ def act_with_policy( logging.info("make_env online") - online_env = make_robot_env(cfg=cfg.env) + online_env, teleop_device = make_robot_env(cfg=cfg.env) + env_processor, action_processor = make_processors(online_env, teleop_device, cfg.env, cfg.policy.device) set_seed(cfg.seed) device = get_safe_torch_device(cfg.policy.device, log=True) @@ -257,6 +260,12 @@ def act_with_policy( assert isinstance(policy, nn.Module) obs, info = online_env.reset() + env_processor.reset() + action_processor.reset() + + # Process initial observation + transition = create_transition(observation=obs, info=info) + transition = env_processor(transition) # NOTE: For the moment we will solely handle the case of a single environment sum_reward_episode = 0 @@ -274,45 +283,71 @@ def act_with_policy( logging.info("[ACTOR] Shutting down act_with_policy") return - if interaction_step >= cfg.policy.online_step_before_learning: - # Time policy inference and check if it meets FPS requirement - with policy_timer: - action = policy.select_action(batch=obs) - policy_fps = policy_timer.fps_last + observation = { + k: v for k, v in transition[TransitionKey.OBSERVATION].items() if k in cfg.policy.input_features + } - log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) + # Time policy inference and check if it meets FPS requirement + with policy_timer: + # Extract observation from transition for policy + action = policy.select_action(batch=observation) + policy_fps = policy_timer.fps_last - else: - action = online_env.action_space.sample() + log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) - next_obs, reward, done, truncated, info = online_env.step(action) + # Use the new step function + new_transition = step_env_and_process_transition( + env=online_env, + transition=transition, + action=action, + env_processor=env_processor, + action_processor=action_processor, + ) + + # Extract values from processed transition + next_observation = { + k: v + for k, v in new_transition[TransitionKey.OBSERVATION].items() + if k in cfg.policy.input_features + } + + # Teleop action is the action that was executed in the environment + # It is either the action from the teleop device or the action from the policy + executed_action = new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"] + + reward = new_transition[TransitionKey.REWARD] + done = new_transition.get(TransitionKey.DONE, False) + truncated = new_transition.get(TransitionKey.TRUNCATED, False) sum_reward_episode += float(reward) - # Increment total steps counter for intervention rate episode_total_steps += 1 - # NOTE: We override the action if the intervention is True, because the action applied is the intervention action - if "is_intervention" in info and info["is_intervention"]: - # NOTE: The action space for demonstration before hand is with the full action space - # but sometimes for example we want to deactivate the gripper - action = info["action_intervention"] + # Check for intervention from transition info + intervention_info = new_transition[TransitionKey.INFO] + if intervention_info.get(TeleopEvents.IS_INTERVENTION, False): episode_intervention = True - # Increment intervention steps counter episode_intervention_steps += 1 + complementary_info = { + "discrete_penalty": torch.tensor( + [new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)] + ), + } + # Create transition for learner (convert to old format) list_transition_to_send_to_learner.append( Transition( - state=obs, - action=action, + state=observation, + action=executed_action, reward=reward, - next_state=next_obs, + next_state=next_observation, done=done, - truncated=truncated, # TODO: (azouitine) Handle truncation properly - complementary_info=info, + truncated=truncated, + complementary_info=complementary_info, ) ) - # assign obs to the next obs and continue the rollout - obs = next_obs + + # Update transition for next iteration + transition = new_transition if done or truncated: logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") @@ -347,21 +382,27 @@ def act_with_policy( ) ) - # Reset intervention counters + # Reset intervention counters and environment sum_reward_episode = 0.0 episode_intervention = False episode_intervention_steps = 0 episode_total_steps = 0 + + # Reset environment and processors obs, info = online_env.reset() + env_processor.reset() + action_processor.reset() + + # Process initial observation + transition = create_transition(observation=obs, info=info) + transition = env_processor(transition) if cfg.env.fps is not None: dt_time = time.perf_counter() - start_time busy_wait(1 / cfg.env.fps - dt_time) -################################################# -# Communication Functions - Group all gRPC/messaging functions # -################################################# +# Communication Functions - Group all gRPC/messaging functions def establish_learner_connection( @@ -606,9 +647,7 @@ def interactions_stream( return services_pb2.Empty() -################################################# -# Policy functions # -################################################# +# Policy functions def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device): @@ -640,9 +679,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device) logging.info("[ACTOR] Loaded discrete critic parameters from Learner.") -################################################# -# Utilities functions # -################################################# +# Utilities functions def push_transitions_to_transport_queue(transitions: list, transitions_queue): diff --git a/src/lerobot/scripts/rl/gym_manipulator.py b/src/lerobot/scripts/rl/gym_manipulator.py index 046be03e8..f91d077f4 100644 --- a/src/lerobot/scripts/rl/gym_manipulator.py +++ b/src/lerobot/scripts/rl/gym_manipulator.py @@ -14,65 +14,95 @@ # See the License for the specific language governing permissions and # limitations under the License. - -""" -Robot Environment for LeRobot Manipulation Tasks - -This module provides a comprehensive gym-compatible environment for robot manipulation -with support for: -- Multiple robot types (SO100, SO101, Koch and Moss) -- Human intervention via leader-follower control or gamepad - -- End-effector and joint space control -- Image processing (cropping and resizing) - -The environment is built using a composable wrapper pattern where each wrapper -adds specific functionality to the base RobotEnv. - -Example: - env = make_robot_env(cfg) - obs, info = env.reset() - action = policy.select_action(obs) - obs, reward, terminated, truncated, info = env.step(action) -""" - import logging import time -from collections import deque -from collections.abc import Sequence -from threading import Lock -from typing import Annotated, Any +from dataclasses import dataclass +from typing import Any import gymnasium as gym import numpy as np import torch -import torchvision.transforms.functional as F # noqa: N812 from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser -from lerobot.envs.configs import EnvConfig -from lerobot.envs.utils import preprocess_observation +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.envs.configs import HILSerlRobotEnvConfig from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + AddTeleopActionAsComplimentaryDataStep, + AddTeleopEventsAsInfoStep, + DataProcessorPipeline, + DeviceProcessorStep, + EnvTransition, + GripperPenaltyProcessorStep, + ImageCropResizeProcessorStep, + InterventionActionProcessorStep, + JointVelocityProcessorStep, + MapDeltaActionToRobotActionStep, + MapTensorToDeltaActionDictStep, + MotorCurrentProcessorStep, + Numpy2TorchActionProcessorStep, + RewardClassifierProcessorStep, + RobotActionToPolicyActionProcessorStep, + TimeLimitProcessorStep, + Torch2NumpyActionProcessorStep, + TransitionKey, + VanillaObservationProcessorStep, + create_transition, +) +from lerobot.processor.converters import identity_transition from lerobot.robots import ( # noqa: F401 RobotConfig, make_robot_from_config, so100_follower, ) +from lerobot.robots.robot import Robot +from lerobot.robots.so100_follower.robot_kinematic_processor import ( + EEBoundsAndSafety, + EEReferenceAndDelta, + ForwardKinematicsJointsToEEObservation, + GripperVelocityToJoint, + InverseKinematicsRLStep, +) from lerobot.teleoperators import ( gamepad, # noqa: F401 keyboard, # noqa: F401 make_teleoperator_from_config, so101_leader, # noqa: F401 ) -from lerobot.teleoperators.gamepad.teleop_gamepad import GamepadTeleop -from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardEndEffectorTeleop +from lerobot.teleoperators.teleoperator import Teleoperator +from lerobot.teleoperators.utils import TeleopEvents from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import log_say logging.basicConfig(level=logging.INFO) -def reset_follower_position(robot_arm, target_position): +@dataclass +class DatasetConfig: + """Configuration for dataset creation and management.""" + + repo_id: str + task: str + root: str | None = None + num_episodes_to_record: int = 5 + replay_episode: int | None = None + push_to_hub: bool = False + + +@dataclass +class GymManipulatorConfig: + """Main configuration for gym manipulator environment.""" + + env: HILSerlRobotEnvConfig + dataset: DatasetConfig + mode: str | None = None # Either "record", "replay", None + device: str = "cpu" + + +def reset_follower_position(robot_arm: Robot, target_position: np.ndarray) -> None: + """Reset robot arm to target position using smooth trajectory.""" current_position_dict = robot_arm.bus.sync_read("Present_Position") current_position = np.array( [current_position_dict[name] for name in current_position_dict], dtype=np.float32 @@ -86,158 +116,25 @@ def reset_follower_position(robot_arm, target_position): busy_wait(0.015) -class TorchBox(gym.spaces.Box): - """ - A version of gym.spaces.Box that handles PyTorch tensors. - - This class extends gym.spaces.Box to work with PyTorch tensors, - providing compatibility between NumPy arrays and PyTorch tensors. - """ - - def __init__( - self, - low: float | Sequence[float] | np.ndarray, - high: float | Sequence[float] | np.ndarray, - shape: Sequence[int] | None = None, - np_dtype: np.dtype | type = np.float32, - torch_dtype: torch.dtype = torch.float32, - device: str = "cpu", - seed: int | np.random.Generator | None = None, - ) -> None: - """ - Initialize the PyTorch-compatible Box space. - - Args: - low: Lower bounds of the space. - high: Upper bounds of the space. - shape: Shape of the space. If None, inferred from low and high. - np_dtype: NumPy data type for internal storage. - torch_dtype: PyTorch data type for tensor conversion. - device: PyTorch device for returned tensors. - seed: Random seed for sampling. - """ - super().__init__(low, high, shape=shape, dtype=np_dtype, seed=seed) - self.torch_dtype = torch_dtype - self.device = device - - def sample(self) -> torch.Tensor: - """ - Sample a random point from the space. - - Returns: - A PyTorch tensor within the space bounds. - """ - arr = super().sample() - return torch.as_tensor(arr, dtype=self.torch_dtype, device=self.device) - - def contains(self, x: torch.Tensor) -> bool: - """ - Check if a tensor is within the space bounds. - - Args: - x: The PyTorch tensor to check. - - Returns: - Boolean indicating whether the tensor is within bounds. - """ - # Move to CPU/numpy and cast to the internal dtype - arr = x.detach().cpu().numpy().astype(self.dtype, copy=False) - return super().contains(arr) - - def seed(self, seed: int | np.random.Generator | None = None): - """ - Set the random seed for sampling. - - Args: - seed: The random seed to use. - - Returns: - List containing the seed. - """ - super().seed(seed) - return [seed] - - def __repr__(self) -> str: - """ - Return a string representation of the space. - - Returns: - Formatted string with space details. - """ - return ( - f"TorchBox({self.low_repr}, {self.high_repr}, {self.shape}, " - f"np={self.dtype.name}, torch={self.torch_dtype}, device={self.device})" - ) - - -class TorchActionWrapper(gym.Wrapper): - """ - Wrapper that changes the action space to use PyTorch tensors. - - This wrapper modifies the action space to return PyTorch tensors when sampled - and handles converting PyTorch actions to NumPy when stepping the environment. - """ - - def __init__(self, env: gym.Env, device: str): - """ - Initialize the PyTorch action space wrapper. - - Args: - env: The environment to wrap. - device: The PyTorch device to use for tensor operations. - """ - super().__init__(env) - self.action_space = TorchBox( - low=env.action_space.low, - high=env.action_space.high, - shape=env.action_space.shape, - torch_dtype=torch.float32, - device=torch.device("cpu"), - ) - - def step(self, action: torch.Tensor): - """ - Step the environment with a PyTorch tensor action. - - This method handles conversion from PyTorch tensors to NumPy arrays - for compatibility with the underlying environment. - - Args: - action: PyTorch tensor action to take. - - Returns: - Tuple of (observation, reward, terminated, truncated, info). - """ - if action.dim() == 2: - action = action.squeeze(0) - action = action.detach().cpu().numpy() - return self.env.step(action) - - class RobotEnv(gym.Env): - """ - Gym-compatible environment for evaluating robotic control policies with integrated human intervention. - - This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta) - and absolute joint position commands and automatically configures its observation and action spaces based on the robot's - sensors and configuration. - """ + """Gym environment for robotic control with human intervention support.""" def __init__( self, robot, use_gripper: bool = False, display_cameras: bool = False, - ): - """ - Initialize the RobotEnv environment. - - The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup - supports both relative (delta) adjustments and absolute joint positions for controlling the robot. + reset_pose: list[float] | None = None, + reset_time_s: float = 5.0, + ) -> None: + """Initialize robot environment with configuration options. Args: - robot: The robot interface object used to connect and interact with the physical robot. - display_cameras: If True, the robot's camera feeds will be displayed during execution. + robot: Robot interface for hardware communication. + use_gripper: Whether to include gripper in action space. + display_cameras: Whether to show camera feeds during execution. + reset_pose: Joint positions for environment reset. + reset_time_s: Time to wait during reset. """ super().__init__() @@ -255,52 +152,50 @@ class RobotEnv(gym.Env): self._joint_names = [f"{key}.pos" for key in self.robot.bus.motors] self._image_keys = self.robot.cameras.keys() - self.current_observation = None + self.reset_pose = reset_pose + self.reset_time_s = reset_time_s self.use_gripper = use_gripper + self._joint_names = list(self.robot.bus.motors.keys()) + self._raw_joint_positions = None + self._setup_spaces() - def _get_observation(self) -> dict[str, np.ndarray]: - """Helper to convert a dictionary from bus.sync_read to an ordered numpy array.""" + def _get_observation(self) -> dict[str, Any]: + """Get current robot observation including joint positions and camera images.""" obs_dict = self.robot.get_observation() - joint_positions = np.array([obs_dict[name] for name in self._joint_names]) + raw_joint_joint_position = {f"{name}.pos": obs_dict[f"{name}.pos"] for name in self._joint_names} + joint_positions = np.array([raw_joint_joint_position[f"{name}.pos"] for name in self._joint_names]) images = {key: obs_dict[key] for key in self._image_keys} - self.current_observation = {"agent_pos": joint_positions, "pixels": images} - def _setup_spaces(self): - """ - Dynamically configure the observation and action spaces based on the robot's capabilities. + return {"agent_pos": joint_positions, "pixels": images, **raw_joint_joint_position} - Observation Space: - - For keys with "image": A Box space with pixel values ranging from 0 to 255. - - For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range. - - Action Space: - - The action space is defined as a Box space representing joint position commands. It is defined as relative (delta) - or absolute, based on the configuration. - """ - self._get_observation() + def _setup_spaces(self) -> None: + """Configure observation and action spaces based on robot capabilities.""" + current_observation = self._get_observation() observation_spaces = {} # Define observation spaces for images and other states. - if "pixels" in self.current_observation: + if current_observation is not None and "pixels" in current_observation: prefix = "observation.images" observation_spaces = { f"{prefix}.{key}": gym.spaces.Box( - low=0, high=255, shape=self.current_observation["pixels"][key].shape, dtype=np.uint8 + low=0, high=255, shape=current_observation["pixels"][key].shape, dtype=np.uint8 ) - for key in self.current_observation["pixels"] + for key in current_observation["pixels"] } - observation_spaces["observation.state"] = gym.spaces.Box( - low=0, - high=10, - shape=self.current_observation["agent_pos"].shape, - dtype=np.float32, - ) + if current_observation is not None: + agent_pos = current_observation["agent_pos"] + observation_spaces["observation.state"] = gym.spaces.Box( + low=0, + high=10, + shape=agent_pos.shape, + dtype=np.float32, + ) self.observation_space = gym.spaces.Dict(observation_spaces) @@ -322,57 +217,46 @@ class RobotEnv(gym.Env): dtype=np.float32, ) - def reset(self, seed=None, options=None) -> tuple[dict[str, np.ndarray], dict[str, Any]]: - """ - Reset the environment to its initial state. - This method resets the step counter and clears any episodic data. + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Reset environment to initial state. Args: - seed: A seed for random number generation to ensure reproducibility. - options: Additional options to influence the reset behavior. + seed: Random seed for reproducibility. + options: Additional reset options. Returns: - A tuple containing: - - observation (dict): The initial sensor observation. - - info (dict): A dictionary with supplementary information, including the key "is_intervention". + Tuple of (observation, info) dictionaries. """ - super().reset(seed=seed, options=options) + # Reset the robot + # self.robot.reset() + start_time = time.perf_counter() + if self.reset_pose is not None: + log_say("Reset the environment.", play_sounds=True) + reset_follower_position(self.robot, np.array(self.reset_pose)) + log_say("Reset the environment done.", play_sounds=True) - self.robot.reset() + busy_wait(self.reset_time_s - (time.perf_counter() - start_time)) + + super().reset(seed=seed, options=options) # Reset episode tracking variables. self.current_step = 0 self.episode_data = None - self.current_observation = None - self._get_observation() - return self.current_observation, {"is_intervention": False} + obs = self._get_observation() + self._raw_joint_positions = {f"{key}.pos": obs[f"{key}.pos"] for key in self._joint_names} + return obs, {TeleopEvents.IS_INTERVENTION: False} def step(self, action) -> tuple[dict[str, np.ndarray], float, bool, bool, dict[str, Any]]: - """ - Execute a single step within the environment using the specified action. + """Execute one environment step with given action.""" + joint_targets_dict = {f"{key}.pos": action[i] for i, key in enumerate(self.robot.bus.motors.keys())} - The provided action is processed and sent to the robot as joint position commands - that may be either absolute values or deltas based on the environment configuration. + self.robot.send_action(joint_targets_dict) - Args: - action: The commanded joint positions as a numpy array or torch tensor. + obs = self._get_observation() - Returns: - A tuple containing: - - observation (dict): The new sensor observation after taking the step. - - reward (float): The step reward (default is 0.0 within this wrapper). - - terminated (bool): True if the episode has reached a terminal state. - - truncated (bool): True if the episode was truncated (e.g., time constraints). - - info (dict): Additional debugging information including intervention status. - """ - action_dict = {"delta_x": action[0], "delta_y": action[1], "delta_z": action[2]} - - # 1.0 action corresponds to no-op action - action_dict["gripper"] = action[3] if self.use_gripper else 1.0 - - self.robot.send_action(action_dict) - - self._get_observation() + self._raw_joint_positions = {f"{key}.pos": obs[f"{key}.pos"] for key in self._joint_names} if self.display_cameras: self.render() @@ -384,1880 +268,501 @@ class RobotEnv(gym.Env): truncated = False return ( - self.current_observation, + obs, reward, terminated, truncated, - {"is_intervention": False}, + {TeleopEvents.IS_INTERVENTION: False}, ) - def render(self): - """ - Render the current state of the environment by displaying the robot's camera feeds. - """ + def render(self) -> None: + """Display robot camera feeds.""" import cv2 - image_keys = [key for key in self.current_observation if "image" in key] + current_observation = self._get_observation() + if current_observation is not None: + image_keys = [key for key in current_observation if "image" in key] - for key in image_keys: - cv2.imshow(key, cv2.cvtColor(self.current_observation[key].numpy(), cv2.COLOR_RGB2BGR)) - cv2.waitKey(1) + for key in image_keys: + cv2.imshow(key, cv2.cvtColor(current_observation[key].numpy(), cv2.COLOR_RGB2BGR)) + cv2.waitKey(1) - def close(self): - """ - Close the environment and clean up resources by disconnecting the robot. - - If the robot is currently connected, this method properly terminates the connection to ensure that all - associated resources are released. - """ + def close(self) -> None: + """Close environment and disconnect robot.""" if self.robot.is_connected: self.robot.disconnect() + def get_raw_joint_positions(self) -> dict[str, float]: + """Get raw joint positions.""" + return self._raw_joint_positions -class AddJointVelocityToObservation(gym.ObservationWrapper): - """ - Wrapper that adds joint velocity information to the observation. - This wrapper computes joint velocities by tracking changes in joint positions over time, - and extends the observation space to include these velocities. - """ - - def __init__(self, env, joint_velocity_limits=100.0, fps=30, num_dof=6): - """ - Initialize the joint velocity wrapper. - - Args: - env: The environment to wrap. - joint_velocity_limits: Maximum expected joint velocity for space bounds. - fps: Frames per second used to calculate velocity (position delta / time). - num_dof: Number of degrees of freedom (joints) in the robot. - """ - super().__init__(env) - - # Extend observation space to include joint velocities - old_low = self.observation_space["observation.state"].low - old_high = self.observation_space["observation.state"].high - old_shape = self.observation_space["observation.state"].shape - - self.last_joint_positions = np.zeros(num_dof) - - new_low = np.concatenate([old_low, np.ones(num_dof) * -joint_velocity_limits]) - new_high = np.concatenate([old_high, np.ones(num_dof) * joint_velocity_limits]) - - new_shape = (old_shape[0] + num_dof,) - - self.observation_space["observation.state"] = gym.spaces.Box( - low=new_low, - high=new_high, - shape=new_shape, - dtype=np.float32, - ) - - self.dt = 1.0 / fps - - def observation(self, observation): - """ - Add joint velocity information to the observation. - - Args: - observation: The original observation from the environment. - - Returns: - The modified observation with joint velocities. - """ - joint_velocities = (observation["agent_pos"] - self.last_joint_positions) / self.dt - self.last_joint_positions = observation["agent_pos"] - observation["agent_pos"] = np.concatenate([observation["agent_pos"], joint_velocities], axis=-1) - return observation - - -class AddCurrentToObservation(gym.ObservationWrapper): - """ - Wrapper that adds motor current information to the observation. - - This wrapper extends the observation space to include the current values - from each motor, providing information about the forces being applied. - """ - - def __init__(self, env, max_current=500, num_dof=6): - """ - Initialize the current observation wrapper. - - Args: - env: The environment to wrap. - max_current: Maximum expected current for space bounds. - num_dof: Number of degrees of freedom (joints) in the robot. - """ - super().__init__(env) - - # Extend observation space to include joint velocities - old_low = self.observation_space["observation.state"].low - old_high = self.observation_space["observation.state"].high - old_shape = self.observation_space["observation.state"].shape - - new_low = np.concatenate([old_low, np.zeros(num_dof)]) - new_high = np.concatenate([old_high, np.ones(num_dof) * max_current]) - - new_shape = (old_shape[0] + num_dof,) - - self.observation_space["observation.state"] = gym.spaces.Box( - low=new_low, - high=new_high, - shape=new_shape, - dtype=np.float32, - ) - - def observation(self, observation): - """ - Add current information to the observation. - - Args: - observation: The original observation from the environment. - - Returns: - The modified observation with current values. - """ - present_current_dict = self.env.unwrapped.robot.bus.sync_read("Present_Current") - present_current_observation = np.array( - [present_current_dict[name] for name in self.env.unwrapped.robot.bus.motors] - ) - observation["agent_pos"] = np.concatenate( - [observation["agent_pos"], present_current_observation], axis=-1 - ) - return observation - - -class RewardWrapper(gym.Wrapper): - def __init__(self, env, reward_classifier, device="cuda"): - """ - Wrapper to add reward prediction to the environment using a trained classifier. - - Args: - env: The environment to wrap. - reward_classifier: The reward classifier model. - device: The device to run the model on. - """ - self.env = env - - self.device = device - - self.reward_classifier = torch.compile(reward_classifier) - self.reward_classifier.to(self.device) - - def step(self, action): - """ - Execute a step and compute the reward using the classifier. - - Args: - action: The action to take in the environment. - - Returns: - Tuple of (observation, reward, terminated, truncated, info). - """ - observation, _, terminated, truncated, info = self.env.step(action) - - images = {} - for key in observation: - if "image" in key: - images[key] = observation[key].to(self.device, non_blocking=(self.device == "cuda")) - if images[key].dim() == 3: - images[key] = images[key].unsqueeze(0) - - start_time = time.perf_counter() - with torch.inference_mode(): - success = ( - self.reward_classifier.predict_reward(images, threshold=0.7) - if self.reward_classifier is not None - else 0.0 - ) - info["Reward classifier frequency"] = 1 / (time.perf_counter() - start_time) - - reward = 0.0 - if success == 1.0: - terminated = True - reward = 1.0 - - return observation, reward, terminated, truncated, info - - def reset(self, seed=None, options=None): - """ - Reset the environment. - - Args: - seed: Random seed for reproducibility. - options: Additional reset options. - - Returns: - The initial observation and info from the wrapped environment. - """ - return self.env.reset(seed=seed, options=options) - - -class TimeLimitWrapper(gym.Wrapper): - """ - Wrapper that adds a time limit to episodes and tracks execution time. - - This wrapper terminates episodes after a specified time has elapsed, providing - better control over episode length. - """ - - def __init__(self, env, control_time_s, fps): - """ - Initialize the time limit wrapper. - - Args: - env: The environment to wrap. - control_time_s: Maximum episode duration in seconds. - fps: Frames per second for calculating the maximum number of steps. - """ - self.env = env - self.control_time_s = control_time_s - self.fps = fps - - self.last_timestamp = 0.0 - self.episode_time_in_s = 0.0 - - self.max_episode_steps = int(self.control_time_s * self.fps) - - self.current_step = 0 - - def step(self, action): - """ - Step the environment and track time elapsed. - - Args: - action: The action to take in the environment. - - Returns: - Tuple of (observation, reward, terminated, truncated, info). - """ - obs, reward, terminated, truncated, info = self.env.step(action) - time_since_last_step = time.perf_counter() - self.last_timestamp - self.episode_time_in_s += time_since_last_step - self.last_timestamp = time.perf_counter() - self.current_step += 1 - # check if last timestep took more time than the expected fps - if 1.0 / time_since_last_step < self.fps: - logging.debug(f"Current timestep exceeded expected fps {self.fps}") - - if self.current_step >= self.max_episode_steps: - terminated = True - return obs, reward, terminated, truncated, info - - def reset(self, seed=None, options=None): - """ - Reset the environment and time tracking. - - Args: - seed: Random seed for reproducibility. - options: Additional reset options. - - Returns: - The initial observation and info from the wrapped environment. - """ - self.episode_time_in_s = 0.0 - self.last_timestamp = time.perf_counter() - self.current_step = 0 - return self.env.reset(seed=seed, options=options) - - -class ImageCropResizeWrapper(gym.Wrapper): - """ - Wrapper that crops and resizes image observations. - - This wrapper processes image observations to focus on relevant regions by - cropping and then resizing to a standard size. - """ - - def __init__( - self, - env, - crop_params_dict: dict[str, Annotated[tuple[int], 4]], - resize_size=None, - ): - """ - Initialize the image crop and resize wrapper. - - Args: - env: The environment to wrap. - crop_params_dict: Dictionary mapping image observation keys to crop parameters - (top, left, height, width). - resize_size: Target size for resized images (height, width). Defaults to (128, 128). - """ - super().__init__(env) - self.env = env - self.crop_params_dict = crop_params_dict - print(f"obs_keys , {self.env.observation_space}") - print(f"crop params dict {crop_params_dict.keys()}") - for key_crop in crop_params_dict: - if key_crop not in self.env.observation_space.keys(): # noqa: SIM118 - raise ValueError(f"Key {key_crop} not in observation space") - for key in crop_params_dict: - new_shape = (3, resize_size[0], resize_size[1]) - self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape) - - self.resize_size = resize_size - if self.resize_size is None: - self.resize_size = (128, 128) - - def step(self, action): - """ - Step the environment and process image observations. - - Args: - action: The action to take in the environment. - - Returns: - Tuple of (observation, reward, terminated, truncated, info) with processed images. - """ - obs, reward, terminated, truncated, info = self.env.step(action) - for k in self.crop_params_dict: - device = obs[k].device - if obs[k].dim() >= 3: - # Reshape to combine height and width dimensions for easier calculation - batch_size = obs[k].size(0) - channels = obs[k].size(1) - flattened_spatial_dims = obs[k].view(batch_size, channels, -1) - - # Calculate standard deviation across spatial dimensions (H, W) - # If any channel has std=0, all pixels in that channel have the same value - # This is helpful if one camera mistakenly covered or the image is black - std_per_channel = torch.std(flattened_spatial_dims, dim=2) - if (std_per_channel <= 0.02).any(): - logging.warning( - f"Potential hardware issue detected: All pixels have the same value in observation {k}" - ) - - if device == torch.device("mps:0"): - obs[k] = obs[k].cpu() - - obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) - obs[k] = F.resize(obs[k], self.resize_size) - # TODO (michel-aractingi): Bug in resize, it returns values outside [0, 1] - obs[k] = obs[k].clamp(0.0, 1.0) - obs[k] = obs[k].to(device) - - return obs, reward, terminated, truncated, info - - def reset(self, seed=None, options=None): - """ - Reset the environment and process image observations. - - Args: - seed: Random seed for reproducibility. - options: Additional reset options. - - Returns: - Tuple of (observation, info) with processed images. - """ - obs, info = self.env.reset(seed=seed, options=options) - for k in self.crop_params_dict: - device = obs[k].device - if device == torch.device("mps:0"): - obs[k] = obs[k].cpu() - obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) - obs[k] = F.resize(obs[k], self.resize_size) - obs[k] = obs[k].clamp(0.0, 1.0) - obs[k] = obs[k].to(device) - return obs, info - - -class ConvertToLeRobotObservation(gym.ObservationWrapper): - """ - Wrapper that converts standard observations to LeRobot format. - - This wrapper processes observations to match the expected format for LeRobot, - including normalizing image values and moving tensors to the specified device. - """ - - def __init__(self, env, device: str = "cpu"): - """ - Initialize the LeRobot observation converter. - - Args: - env: The environment to wrap. - device: Target device for the observation tensors. - """ - super().__init__(env) - - self.device = torch.device(device) - - def observation(self, observation): - """ - Convert observations to LeRobot format. - - Args: - observation: The original observation from the environment. - - Returns: - The processed observation with normalized images and proper tensor formats. - """ - observation = preprocess_observation(observation) - observation = { - key: observation[key].to(self.device, non_blocking=self.device.type == "cuda") - for key in observation - } - return observation - - -class ResetWrapper(gym.Wrapper): - """ - Wrapper that handles environment reset procedures. - - This wrapper provides additional functionality during environment reset, - including the option to reset to a fixed pose or allow manual reset. - """ - - def __init__( - self, - env: RobotEnv, - reset_pose: np.ndarray | None = None, - reset_time_s: float = 5, - ): - """ - Initialize the reset wrapper. - - Args: - env: The environment to wrap. - reset_pose: Fixed joint positions to reset to. If None, manual reset is used. - reset_time_s: Time in seconds to wait after reset or allowed for manual reset. - """ - super().__init__(env) - self.reset_time_s = reset_time_s - self.reset_pose = reset_pose - self.robot = self.unwrapped.robot - - def reset(self, *, seed=None, options=None): - """ - Reset the environment with either fixed or manual reset procedure. - - If reset_pose is provided, the robot will move to that position. - Otherwise, manual teleoperation control is allowed for reset_time_s seconds. - - Args: - seed: Random seed for reproducibility. - options: Additional reset options. - - Returns: - The initial observation and info from the wrapped environment. - """ - start_time = time.perf_counter() - if self.reset_pose is not None: - log_say("Reset the environment.", play_sounds=True) - reset_follower_position(self.unwrapped.robot, self.reset_pose) - log_say("Reset the environment done.", play_sounds=True) - - if hasattr(self.env, "robot_leader"): - self.env.robot_leader.bus.sync_write("Torque_Enable", 1) - log_say("Reset the leader robot.", play_sounds=True) - reset_follower_position(self.env.robot_leader, self.reset_pose) - log_say("Reset the leader robot done.", play_sounds=True) - else: - log_say( - f"Manually reset the environment for {self.reset_time_s} seconds.", - play_sounds=True, - ) - start_time = time.perf_counter() - while time.perf_counter() - start_time < self.reset_time_s: - action = self.env.robot_leader.get_action() - self.unwrapped.robot.send_action(action) - - log_say("Manual reset of the environment done.", play_sounds=True) - - busy_wait(self.reset_time_s - (time.perf_counter() - start_time)) - - return super().reset(seed=seed, options=options) - - -class BatchCompatibleWrapper(gym.ObservationWrapper): - """ - Wrapper that ensures observations are compatible with batch processing. - - This wrapper adds a batch dimension to observations that don't already have one, - making them compatible with models that expect batched inputs. - """ - - def __init__(self, env): - """ - Initialize the batch compatibility wrapper. - - Args: - env: The environment to wrap. - """ - super().__init__(env) - - def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - """ - Add batch dimensions to observations if needed. - - Args: - observation: Dictionary of observation tensors. - - Returns: - Dictionary of observation tensors with batch dimensions. - """ - for key in observation: - if "image" in key and observation[key].dim() == 3: - observation[key] = observation[key].unsqueeze(0) - if "state" in key and observation[key].dim() == 1: - observation[key] = observation[key].unsqueeze(0) - if "velocity" in key and observation[key].dim() == 1: - observation[key] = observation[key].unsqueeze(0) - return observation - - -class GripperPenaltyWrapper(gym.RewardWrapper): - """ - Wrapper that adds penalties for inefficient gripper commands. - - This wrapper modifies rewards to discourage excessive gripper movement - or commands that attempt to move the gripper beyond its physical limits. - """ - - def __init__(self, env, penalty: float = -0.1): - """ - Initialize the gripper penalty wrapper. - - Args: - env: The environment to wrap. - penalty: Negative reward value to apply for inefficient gripper actions. - """ - super().__init__(env) - self.penalty = penalty - self.last_gripper_state = None - - def reward(self, reward, action): - """ - Apply penalties to reward based on gripper actions. - - Args: - reward: The original reward from the environment. - action: The action that was taken. - - Returns: - Modified reward with penalty applied if necessary. - """ - gripper_state_normalized = self.last_gripper_state / self.unwrapped.robot.config.max_gripper_pos - - action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND - - gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or ( - gripper_state_normalized > 0.75 and action_normalized < -0.5 - ) - - return reward + self.penalty * int(gripper_penalty_bool) - - def step(self, action): - """ - Step the environment and apply gripper penalties. - - Args: - action: The action to take in the environment. - - Returns: - Tuple of (observation, reward, terminated, truncated, info) with penalty applied. - """ - self.last_gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"] - - gripper_action = action[-1] - obs, reward, terminated, truncated, info = self.env.step(action) - gripper_penalty = self.reward(reward, gripper_action) - - info["discrete_penalty"] = gripper_penalty - - return obs, reward, terminated, truncated, info - - def reset(self, **kwargs): - """ - Reset the environment and penalty tracking. - - Args: - **kwargs: Keyword arguments passed to the wrapped environment's reset. - - Returns: - The initial observation and info with gripper penalty initialized. - """ - self.last_gripper_state = None - obs, info = super().reset(**kwargs) - info["gripper_penalty"] = 0.0 - return obs, info - - -class GripperActionWrapper(gym.ActionWrapper): - """ - Wrapper that processes gripper control commands. - - This wrapper quantizes and processes gripper commands, adding a sleep time between - consecutive gripper actions to prevent rapid toggling. - """ - - def __init__(self, env, quantization_threshold: float = 0.2, gripper_sleep: float = 0.0): - """ - Initialize the gripper action wrapper. - - Args: - env: The environment to wrap. - quantization_threshold: Threshold below which gripper commands are quantized to zero. - gripper_sleep: Minimum time in seconds between consecutive gripper commands. - """ - super().__init__(env) - self.quantization_threshold = quantization_threshold - self.gripper_sleep = gripper_sleep - self.last_gripper_action_time = 0.0 - self.last_gripper_action = None - - def action(self, action): - """ - Process gripper commands in the action. - - Args: - action: The original action from the agent. - - Returns: - Modified action with processed gripper command. - """ - if self.gripper_sleep > 0.0: - if ( - self.last_gripper_action is not None - and time.perf_counter() - self.last_gripper_action_time < self.gripper_sleep - ): - action[-1] = self.last_gripper_action - else: - self.last_gripper_action_time = time.perf_counter() - self.last_gripper_action = action[-1] - - gripper_command = action[-1] - # Gripper actions are between 0, 2 - # we want to quantize them to -1, 0 or 1 - gripper_command = gripper_command - 1.0 - - if self.quantization_threshold is not None: - # Quantize gripper command to -1, 0 or 1 - gripper_command = ( - np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0 - ) - gripper_command = gripper_command * self.unwrapped.robot.config.max_gripper_pos - - gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"] - - gripper_action_value = np.clip( - gripper_state + gripper_command, 0, self.unwrapped.robot.config.max_gripper_pos - ) - action[-1] = gripper_action_value.item() - return action - - def reset(self, **kwargs): - """ - Reset the gripper action tracking. - - Args: - **kwargs: Keyword arguments passed to the wrapped environment's reset. - - Returns: - The initial observation and info. - """ - obs, info = super().reset(**kwargs) - self.last_gripper_action_time = 0.0 - self.last_gripper_action = None - return obs, info - - -class EEObservationWrapper(gym.ObservationWrapper): - """ - Wrapper that adds end-effector pose information to observations. - - This wrapper computes the end-effector pose using forward kinematics - and adds it to the observation space. - """ - - def __init__(self, env, ee_pose_limits): - """ - Initialize the end-effector observation wrapper. - - Args: - env: The environment to wrap. - ee_pose_limits: Dictionary with 'min' and 'max' keys containing limits for EE pose. - """ - super().__init__(env) - - # Extend observation space to include end effector pose - prev_space = self.observation_space["observation.state"] - - self.observation_space["observation.state"] = gym.spaces.Box( - low=np.concatenate([prev_space.low, ee_pose_limits["min"]]), - high=np.concatenate([prev_space.high, ee_pose_limits["max"]]), - shape=(prev_space.shape[0] + 3,), - dtype=np.float32, - ) - - self.kinematics = RobotKinematics( - urdf_path=env.unwrapped.robot.config.urdf_path, - target_frame_name=env.unwrapped.robot.config.target_frame_name, - ) - - def observation(self, observation): - """ - Add end-effector pose to the observation. - - Args: - observation: Original observation from the environment. - - Returns: - Enhanced observation with end-effector pose information. - """ - current_joint_pos = self.unwrapped.current_observation["agent_pos"] - - current_ee_pos = self.kinematics.forward_kinematics(current_joint_pos)[:3, 3] - observation["agent_pos"] = np.concatenate([observation["agent_pos"], current_ee_pos], -1) - return observation - - -########################################################### -# Wrappers related to human intervention and input devices -########################################################### - - -class BaseLeaderControlWrapper(gym.Wrapper): - """ - Base class for leader-follower robot control wrappers. - - This wrapper enables human intervention through a leader-follower robot setup, - where the human can control a leader robot to guide the follower robot's movements. - """ - - def __init__( - self, - env, - teleop_device, - end_effector_step_sizes, - use_geared_leader_arm: bool = False, - use_gripper=False, - ): - """ - Initialize the base leader control wrapper. - - Args: - env: The environment to wrap. - teleop_device: The teleoperation device. - use_geared_leader_arm: Whether to use a geared leader arm setup. - use_gripper: Whether to include gripper control. - """ - super().__init__(env) - self.robot_leader = teleop_device - self.robot_follower = env.unwrapped.robot - self.use_geared_leader_arm = use_geared_leader_arm - self.use_gripper: bool = use_gripper - self.end_effector_step_sizes = np.array(list(end_effector_step_sizes.values())) - - # Set up keyboard event tracking - self._init_keyboard_events() - self.event_lock = Lock() # Thread-safe access to events - - # Initialize robot control - self.kinematics = RobotKinematics( - urdf_path=env.unwrapped.robot.config.urdf_path, - target_frame_name=env.unwrapped.robot.config.target_frame_name, - ) - self.leader_torque_enabled = True - self.prev_leader_gripper = None - - # Configure leader arm - # NOTE: Lower the gains of leader arm for automatic take-over - # With lower gains we can manually move the leader arm without risk of injury to ourselves or the robot - # With higher gains, it would be dangerous and difficult to modify the leader's pose while torque is enabled - # Default value for P_coeff is 32 - self.robot_leader.bus.sync_write("Torque_Enable", 1) - for motor in self.robot_leader.bus.motors: - self.robot_leader.bus.write("P_Coefficient", motor, 16) - self.robot_leader.bus.write("I_Coefficient", motor, 0) - self.robot_leader.bus.write("D_Coefficient", motor, 16) - - self.leader_tracking_error_queue = deque(maxlen=4) - self._init_keyboard_listener() - - def _init_keyboard_events(self): - """ - Initialize the keyboard events dictionary. - - This method sets up tracking for keyboard events used for intervention control. - It should be overridden in subclasses to add additional events. - """ - self.keyboard_events = { - "episode_success": False, - "episode_end": False, - "rerecord_episode": False, - } - - def _handle_key_press(self, key, keyboard_device): - """ - Handle key press events. - - Args: - key: The key that was pressed. - keyboard: The keyboard module with key definitions. - - This method should be overridden in subclasses for additional key handling. - """ - try: - if key == keyboard_device.Key.esc: - self.keyboard_events["episode_end"] = True - return - if key == keyboard_device.Key.left: - self.keyboard_events["rerecord_episode"] = True - return - if hasattr(key, "char") and key.char == "s": - logging.info("Key 's' pressed. Episode success triggered.") - self.keyboard_events["episode_success"] = True - return - except Exception as e: - logging.error(f"Error handling key press: {e}") - - def _init_keyboard_listener(self): - """ - Initialize the keyboard listener for intervention control. - - This method sets up keyboard event handling if not in headless mode. - """ - from pynput import keyboard as keyboard_device - - def on_press(key): - with self.event_lock: - self._handle_key_press(key, keyboard_device) - - self.listener = keyboard_device.Listener(on_press=on_press) - self.listener.start() - - def _check_intervention(self): - """ - Check if human intervention is needed. - - Returns: - Boolean indicating whether intervention is needed. - - This method should be overridden in subclasses with specific intervention logic. - """ - return False - - def _handle_intervention(self, action): - """ - Process actions during intervention mode. - - Args: - action: The original action from the agent. - - Returns: - Tuple of (modified_action, intervention_action). - """ - if self.leader_torque_enabled: - self.robot_leader.bus.sync_write("Torque_Enable", 0) - self.leader_torque_enabled = False - - leader_pos_dict = self.robot_leader.bus.sync_read("Present_Position") - follower_pos_dict = self.robot_follower.bus.sync_read("Present_Position") - - leader_pos = np.array([leader_pos_dict[name] for name in leader_pos_dict]) - follower_pos = np.array([follower_pos_dict[name] for name in follower_pos_dict]) - - self.leader_tracking_error_queue.append(np.linalg.norm(follower_pos[:-1] - leader_pos[:-1])) - - # [:3, 3] Last column of the transformation matrix corresponds to the xyz translation - leader_ee = self.kinematics.forward_kinematics(leader_pos)[:3, 3] - follower_ee = self.kinematics.forward_kinematics(follower_pos)[:3, 3] - - action = np.clip(leader_ee - follower_ee, -self.end_effector_step_sizes, self.end_effector_step_sizes) - # Normalize the action to the range [-1, 1] - action = action / self.end_effector_step_sizes - - if self.use_gripper: - if self.prev_leader_gripper is None: - self.prev_leader_gripper = np.clip( - leader_pos[-1], 0, self.robot_follower.config.max_gripper_pos - ) - - # Get gripper action delta based on leader pose - leader_gripper = leader_pos[-1] - gripper_delta = leader_gripper - self.prev_leader_gripper - - # Normalize by max angle and quantize to {0,1,2} - normalized_delta = gripper_delta / self.robot_follower.config.max_gripper_pos - if normalized_delta >= 0.3: - gripper_action = 2 - elif normalized_delta <= 0.1: - gripper_action = 0 - else: - gripper_action = 1 - - action = np.append(action, gripper_action) - - return action - - def _handle_leader_teleoperation(self): - """ - Handle leader teleoperation in non-intervention mode. - - This method synchronizes the leader robot position with the follower. - """ - - prev_leader_pos_dict = self.robot_leader.bus.sync_read("Present_Position") - prev_leader_pos = np.array( - [prev_leader_pos_dict[name] for name in prev_leader_pos_dict], dtype=np.float32 - ) - - if not self.leader_torque_enabled: - self.robot_leader.bus.sync_write("Torque_Enable", 1) - self.leader_torque_enabled = True - - follower_pos_dict = self.robot_follower.bus.sync_read("Present_Position") - follower_pos = np.array([follower_pos_dict[name] for name in follower_pos_dict], dtype=np.float32) - - goal_pos = {f"{motor}": follower_pos[i] for i, motor in enumerate(self.robot_leader.bus.motors)} - self.robot_leader.bus.sync_write("Goal_Position", goal_pos) - - self.leader_tracking_error_queue.append(np.linalg.norm(follower_pos[:-1] - prev_leader_pos[:-1])) - - def step(self, action): - """ - Execute a step with possible human intervention. - - Args: - action: The action to take in the environment. - - Returns: - Tuple of (observation, reward, terminated, truncated, info). - """ - is_intervention = self._check_intervention() - - # NOTE: - if is_intervention: - action = self._handle_intervention(action) - else: - self._handle_leader_teleoperation() - - # NOTE: - obs, reward, terminated, truncated, info = self.env.step(action) - - if isinstance(action, np.ndarray): - action = torch.from_numpy(action) - - # Add intervention info - info["is_intervention"] = is_intervention - info["action_intervention"] = action - - self.prev_leader_gripper = np.clip( - self.robot_leader.bus.sync_read("Present_Position")["gripper"], - 0, - self.robot_follower.config.max_gripper_pos, - ) - - # Check for success or manual termination - success = self.keyboard_events["episode_success"] - terminated = terminated or self.keyboard_events["episode_end"] or success - - if success: - reward = 1.0 - logging.info("Episode ended successfully with reward 1.0") - - return obs, reward, terminated, truncated, info - - def reset(self, **kwargs): - """ - Reset the environment and intervention state. - - Args: - **kwargs: Keyword arguments passed to the wrapped environment's reset. - - Returns: - The initial observation and info. - """ - self.keyboard_events = dict.fromkeys(self.keyboard_events, False) - self.leader_tracking_error_queue.clear() - return super().reset(**kwargs) - - def close(self): - """ - Clean up resources, including stopping keyboard listener. - - Returns: - Result of closing the wrapped environment. - """ - if hasattr(self, "listener") and self.listener is not None: - self.listener.stop() - return self.env.close() - - -class GearedLeaderControlWrapper(BaseLeaderControlWrapper): - """ - Wrapper that enables manual intervention via keyboard. - - This wrapper extends the BaseLeaderControlWrapper to allow explicit toggling - of human intervention mode with keyboard controls. - """ - - def _init_keyboard_events(self): - """ - Initialize keyboard events including human intervention flag. - - Extends the base class dictionary with an additional flag for tracking - intervention state toggled by keyboard. - """ - super()._init_keyboard_events() - self.keyboard_events["human_intervention_step"] = False - - def _handle_key_press(self, key, keyboard_device): - """ - Handle key presses including space for intervention toggle. - - Args: - key: The key that was pressed. - keyboard: The keyboard module with key definitions. - - Extends the base handler to respond to space key for toggling intervention. - """ - super()._handle_key_press(key, keyboard_device) - if key == keyboard_device.Key.space: - if not self.keyboard_events["human_intervention_step"]: - logging.info( - "Space key pressed. Human intervention required.\n" - "Place the leader in similar pose to the follower and press space again." - ) - self.keyboard_events["human_intervention_step"] = True - log_say("Human intervention step.", play_sounds=True) - else: - self.keyboard_events["human_intervention_step"] = False - logging.info("Space key pressed for a second time.\nContinuing with policy actions.") - log_say("Continuing with policy actions.", play_sounds=True) - - def _check_intervention(self): - """ - Check if human intervention is active based on keyboard toggle. - - Returns: - Boolean indicating whether intervention mode is active. - """ - return self.keyboard_events["human_intervention_step"] - - -class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper): - """ - Wrapper with automatic intervention based on error thresholds. - - This wrapper monitors the error between leader and follower positions - and automatically triggers intervention when error exceeds thresholds. - """ - - def __init__( - self, - env, - teleop_device, - end_effector_step_sizes, - use_gripper=False, - intervention_threshold=10.0, - release_threshold=1e-2, - ): - """ - Initialize the automatic intervention wrapper. - - Args: - env: The environment to wrap. - teleop_device: The teleoperation device. - use_gripper: Whether to include gripper control. - intervention_threshold: Error threshold to trigger intervention. - release_threshold: Error threshold to release intervention. - queue_size: Number of error measurements to track for smoothing. - """ - super().__init__(env, teleop_device, end_effector_step_sizes, use_gripper=use_gripper) - - # Error tracking parameters - self.intervention_threshold = intervention_threshold # Threshold to trigger intervention - self.release_threshold = release_threshold # Threshold to release intervention - self.is_intervention_active = False - self.start_time = time.perf_counter() - - def _check_intervention(self): - """ - Determine if intervention should occur based on the rate of change of leader-follower error in end_effector space. - - This method monitors the rate of change of leader-follower error in end_effector space - and automatically triggers intervention when the rate of change exceeds - the intervention threshold, releasing when it falls below the release threshold. - - Returns: - Boolean indicating whether intervention should be active. - """ - - # Condition for starting the intervention - # If the error in teleoperation is too high, that means the a user has grasped the leader robot and he wants to take over - if ( - not self.is_intervention_active - and len(self.leader_tracking_error_queue) == self.leader_tracking_error_queue.maxlen - and np.var(list(self.leader_tracking_error_queue)[-2:]) > self.intervention_threshold - ): - self.is_intervention_active = True - self.leader_tracking_error_queue.clear() - log_say("Intervention started", play_sounds=True) - return True - - # Track the error over time in leader_tracking_error_queue - # If the variance of the tracking error is too low, that means the user has let go of the leader robot and the intervention is over - if ( - self.is_intervention_active - and len(self.leader_tracking_error_queue) == self.leader_tracking_error_queue.maxlen - and np.var(self.leader_tracking_error_queue) < self.release_threshold - ): - self.is_intervention_active = False - self.leader_tracking_error_queue.clear() - log_say("Intervention ended", play_sounds=True) - return False - - # If not change has happened that merits a change in the intervention state, return the current state - return self.is_intervention_active - - def reset(self, **kwargs): - """ - Reset error tracking on environment reset. - - Args: - **kwargs: Keyword arguments passed to the wrapped environment's reset. - - Returns: - The initial observation and info. - """ - self.is_intervention_active = False - return super().reset(**kwargs) - - -class GamepadControlWrapper(gym.Wrapper): - """ - Wrapper that allows controlling a gym environment with a gamepad. - - This wrapper intercepts the step method and allows human input via gamepad - to override the agent's actions when desired. - """ - - def __init__( - self, - env, - teleop_device, # Accepts an instantiated teleoperator - use_gripper=False, # This should align with teleop_device's config - auto_reset=False, - ): - """ - Initialize the gamepad controller wrapper. - - Args: - env: The environment to wrap. - teleop_device: The instantiated teleoperation device (e.g., GamepadTeleop). - use_gripper: Whether to include gripper control (should match teleop_device.config.use_gripper). - auto_reset: Whether to auto reset the environment when episode ends. - """ - super().__init__(env) - - self.teleop_device = teleop_device - # Ensure the teleop_device is connected if it has a connect method - if hasattr(self.teleop_device, "connect") and not self.teleop_device.is_connected: - self.teleop_device.connect() - - # self.controller attribute is removed - - self.auto_reset = auto_reset - # use_gripper from args should ideally match teleop_device.config.use_gripper - # For now, we use the one passed, but it can lead to inconsistency if not set correctly from config - self.use_gripper = use_gripper - - logging.info("Gamepad control wrapper initialized with provided teleop_device.") - print( - "Gamepad controls (managed by the provided teleop_device - specific button mappings might vary):" - ) - print(" Left analog stick: Move in X-Y plane") - print(" Right analog stick: Move in Z axis (up/down)") - print(" X/Square button: End episode (FAILURE)") - print(" Y/Triangle button: End episode (SUCCESS)") - print(" B/Circle button: Exit program") - - def get_teleop_commands( - self, - ) -> tuple[bool, np.ndarray, bool, bool, bool]: - """ - Get the current action from the gamepad if any input is active. - - Returns: - Tuple containing: - - is_active: Whether gamepad input is active (from teleop_device.gamepad.should_intervene()) - - action: The action derived from gamepad input (from teleop_device.get_action()) - - terminate_episode: Whether episode termination was requested - - success: Whether episode success was signaled - - rerecord_episode: Whether episode rerecording was requested - """ - if not hasattr(self.teleop_device, "gamepad") or self.teleop_device.gamepad is None: - raise AttributeError( - "teleop_device does not have a 'gamepad' attribute or it is None. Expected for GamepadControlWrapper." - ) - - # Get status flags from the underlying gamepad controller within the teleop_device - self.teleop_device.gamepad.update() # Ensure gamepad state is fresh - intervention_is_active = self.teleop_device.gamepad.should_intervene() - episode_end_status = self.teleop_device.gamepad.get_episode_end_status() - - terminate_episode = episode_end_status is not None - success = episode_end_status == "success" - rerecord_episode = episode_end_status == "rerecord_episode" - - # Get the action dictionary from the teleop_device - action_dict = self.teleop_device.get_action() - - # Convert action_dict to numpy array based on expected structure - # Order: delta_x, delta_y, delta_z, gripper (if use_gripper) - action_list = [action_dict["delta_x"], action_dict["delta_y"], action_dict["delta_z"]] - if self.use_gripper: - # GamepadTeleop returns gripper action as 0 (close), 1 (stay), 2 (open) - # This needs to be consistent with what EEActionWrapper expects if it's used downstream - # EEActionWrapper for gripper typically expects 0.0 (closed) to 2.0 (open) - # For now, we pass the direct value from GamepadTeleop, ensure downstream compatibility. - gripper_val = action_dict.get("gripper", 1.0) # Default to 1.0 (stay) if not present - action_list.append(float(gripper_val)) - - gamepad_action_np = np.array(action_list, dtype=np.float32) - - return ( - intervention_is_active, - gamepad_action_np, - terminate_episode, - success, - rerecord_episode, - ) - - def step(self, action): - """ - Step the environment, using gamepad input to override actions when active. - - Args: - action: Original action from agent. - - Returns: - Tuple of (observation, reward, terminated, truncated, info). - """ - # Get gamepad state and action - ( - is_intervention, - gamepad_action, - terminate_episode, - success, - rerecord_episode, - ) = self.get_teleop_commands() - - # Update episode ending state if requested - if terminate_episode: - logging.info(f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}") - - # Only override the action if gamepad is active - action = gamepad_action if is_intervention else action - - # Step the environment - obs, reward, terminated, truncated, info = self.env.step(action) - - # Add episode ending if requested via gamepad - terminated = terminated or truncated or terminate_episode - - if success: - reward = 1.0 - logging.info("Episode ended successfully with reward 1.0") - - if isinstance(action, np.ndarray): - action = torch.from_numpy(action) - - info["is_intervention"] = is_intervention - # The original `BaseLeaderControlWrapper` puts `action_intervention` in info. - # For Gamepad, if intervention, `gamepad_action` is the intervention. - # If not intervention, policy's action is `action`. - # For consistency, let's store the *human's* action if intervention occurred. - info["action_intervention"] = action - - info["rerecord_episode"] = rerecord_episode - - # If episode ended, reset the state - if terminated or truncated: - # Add success/failure information to info dict - info["next.success"] = success - - # Auto reset if configured - if self.auto_reset: - obs, reset_info = self.reset() - info.update(reset_info) - - return obs, reward, terminated, truncated, info - - def close(self): - """ - Clean up resources when environment closes. - - Returns: - Result of closing the wrapped environment. - """ - if hasattr(self.teleop_device, "disconnect"): - self.teleop_device.disconnect() - - # Call the parent close method - return self.env.close() - - -class KeyboardControlWrapper(GamepadControlWrapper): - """ - Wrapper that allows controlling a gym environment with a keyboard. - - This wrapper intercepts the step method and allows human input via keyboard - to override the agent's actions when desired. - - Inherits from GamepadControlWrapper to avoid code duplication. - """ - - def __init__( - self, - env, - teleop_device, # Accepts an instantiated teleoperator - use_gripper=False, # This should align with teleop_device's config - auto_reset=False, - ): - """ - Initialize the gamepad controller wrapper. - - Args: - env: The environment to wrap. - teleop_device: The instantiated teleoperation device (e.g., GamepadTeleop). - use_gripper: Whether to include gripper control (should match teleop_device.config.use_gripper). - auto_reset: Whether to auto reset the environment when episode ends. - """ - super().__init__(env, teleop_device, use_gripper, auto_reset) - - self.is_intervention_active = False - - logging.info("Keyboard control wrapper initialized with provided teleop_device.") - print("Keyboard controls:") - print(" Arrow keys: Move in X-Y plane") - print(" Shift and Shift_R: Move in Z axis") - print(" Right Ctrl and Left Ctrl: Open and close gripper") - print(" f: End episode with FAILURE") - print(" s: End episode with SUCCESS") - print(" r: End episode with RERECORD") - print(" i: Start/Stop Intervention") - - def get_teleop_commands( - self, - ) -> tuple[bool, np.ndarray, bool, bool, bool]: - action_dict = self.teleop_device.get_action() - episode_end_status = None - - # Unroll the misc_keys_queue to check for events related to intervention, episode success, etc. - while not self.teleop_device.misc_keys_queue.empty(): - key = self.teleop_device.misc_keys_queue.get() - if key == "i": - self.is_intervention_active = not self.is_intervention_active - elif key == "f": - episode_end_status = "failure" - elif key == "s": - episode_end_status = "success" - elif key == "r": - episode_end_status = "rerecord_episode" - - terminate_episode = episode_end_status is not None - success = episode_end_status == "success" - rerecord_episode = episode_end_status == "rerecord_episode" - - # Convert action_dict to numpy array based on expected structure - # Order: delta_x, delta_y, delta_z, gripper (if use_gripper) - action_list = [action_dict["delta_x"], action_dict["delta_y"], action_dict["delta_z"]] - if self.use_gripper: - # GamepadTeleop returns gripper action as 0 (close), 1 (stay), 2 (open) - # This needs to be consistent with what EEActionWrapper expects if it's used downstream - # EEActionWrapper for gripper typically expects 0.0 (closed) to 2.0 (open) - # For now, we pass the direct value from GamepadTeleop, ensure downstream compatibility. - gripper_val = action_dict.get("gripper", 1.0) # Default to 1.0 (stay) if not present - action_list.append(float(gripper_val)) - - gamepad_action_np = np.array(action_list, dtype=np.float32) - - return ( - self.is_intervention_active, - gamepad_action_np, - terminate_episode, - success, - rerecord_episode, - ) - - -class GymHilDeviceWrapper(gym.Wrapper): - def __init__(self, env, device="cpu"): - super().__init__(env) - self.device = device - - def step(self, action): - obs, reward, terminated, truncated, info = self.env.step(action) - for k in obs: - obs[k] = obs[k].to(self.device) - if "action_intervention" in info: - # NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device - info["action_intervention"] = info["action_intervention"].astype(np.float32) - info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device) - return obs, reward, terminated, truncated, info - - def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): - obs, info = self.env.reset(seed=seed, options=options) - for k in obs: - obs[k] = obs[k].to(self.device) - if "action_intervention" in info: - # NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device - info["action_intervention"] = info["action_intervention"].astype(np.float32) - info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device) - return obs, info - - -class GymHilObservationProcessorWrapper(gym.ObservationWrapper): - def __init__(self, env: gym.Env): - super().__init__(env) - prev_space = self.observation_space - new_space = {} - - for key in prev_space: - if "pixels" in key: - for k in prev_space["pixels"]: - new_space[f"observation.images.{k}"] = gym.spaces.Box( - 0.0, 255.0, shape=(3, 128, 128), dtype=np.uint8 - ) - - if key == "agent_pos": - new_space["observation.state"] = prev_space["agent_pos"] - - self.observation_space = gym.spaces.Dict(new_space) - - def observation(self, observation: dict[str, Any]) -> dict[str, Any]: - return preprocess_observation(observation) - - -########################################################### -# Factory functions -########################################################### - - -def make_robot_env(cfg: EnvConfig) -> gym.Env: - """ - Factory function to create a robot environment. - - This function builds a robot environment with all necessary wrappers - based on the provided configuration. +def make_robot_env(cfg: HILSerlRobotEnvConfig) -> tuple[gym.Env, Any]: + """Create robot environment from configuration. Args: - cfg: Configuration object containing environment parameters. + cfg: Environment configuration. Returns: - A gym environment with all necessary wrappers applied. + Tuple of (gym environment, teleoperator device). """ - if cfg.type == "hil": + # Check if this is a GymHIL simulation environment + if cfg.name == "gym_hil": + assert cfg.robot is None and cfg.teleop is None, "GymHIL environment does not support robot or teleop" import gym_hil # noqa: F401 - # TODO (azouitine) + # Extract gripper settings with defaults + use_gripper = cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else True + gripper_penalty = cfg.processor.gripper.gripper_penalty if cfg.processor.gripper is not None else 0.0 + env = gym.make( f"gym_hil/{cfg.task}", image_obs=True, render_mode="human", - use_gripper=cfg.wrapper.use_gripper, - gripper_penalty=cfg.wrapper.gripper_penalty, - ) - env = GymHilObservationProcessorWrapper(env=env) - env = GymHilDeviceWrapper(env=env, device=cfg.device) - env = BatchCompatibleWrapper(env=env) - env = TorchActionWrapper(env=env, device=cfg.device) - return env - - if not hasattr(cfg, "robot") or not hasattr(cfg, "teleop"): - raise ValueError( - "Configuration for 'gym_manipulator' must be HILSerlRobotEnvConfig with robot and teleop." + use_gripper=use_gripper, + gripper_penalty=gripper_penalty, ) - if cfg.robot is None: - raise ValueError("RobotConfig (cfg.robot) must be provided for gym_manipulator environment.") + return env, None + + # Real robot environment + assert cfg.robot is not None, "Robot config must be provided for real robot environment" + assert cfg.teleop is not None, "Teleop config must be provided for real robot environment" + robot = make_robot_from_config(cfg.robot) teleop_device = make_teleoperator_from_config(cfg.teleop) teleop_device.connect() - # Create base environment + # Create base environment with safe defaults + use_gripper = cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else True + display_cameras = ( + cfg.processor.observation.display_cameras if cfg.processor.observation is not None else False + ) + reset_pose = cfg.processor.reset.fixed_reset_joint_positions if cfg.processor.reset is not None else None + env = RobotEnv( robot=robot, - use_gripper=cfg.wrapper.use_gripper, - display_cameras=cfg.wrapper.display_cameras if cfg.wrapper else False, + use_gripper=use_gripper, + display_cameras=display_cameras, + reset_pose=reset_pose, ) - # Add observation and image processing - if cfg.wrapper: - if cfg.wrapper.add_joint_velocity_to_observation: - env = AddJointVelocityToObservation(env=env, fps=cfg.fps) - if cfg.wrapper.add_current_to_observation: - env = AddCurrentToObservation(env=env) - if cfg.wrapper.add_ee_pose_to_observation: - env = EEObservationWrapper(env=env, ee_pose_limits=robot.end_effector_bounds) - - env = ConvertToLeRobotObservation(env=env, device=cfg.device) - - if cfg.wrapper and cfg.wrapper.crop_params_dict is not None: - env = ImageCropResizeWrapper( - env=env, - crop_params_dict=cfg.wrapper.crop_params_dict, - resize_size=cfg.wrapper.resize_size, - ) - - # Add reward computation and control wrappers - reward_classifier = init_reward_classifier(cfg) - if reward_classifier is not None: - env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) - - env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) - if cfg.wrapper.use_gripper and cfg.wrapper.gripper_penalty is not None: - env = GripperPenaltyWrapper( - env=env, - penalty=cfg.wrapper.gripper_penalty, - ) - - # Control mode specific wrappers - control_mode = cfg.wrapper.control_mode - if control_mode == "gamepad": - assert isinstance(teleop_device, GamepadTeleop), ( - "teleop_device must be an instance of GamepadTeleop for gamepad control mode" - ) - env = GamepadControlWrapper( - env=env, - teleop_device=teleop_device, - use_gripper=cfg.wrapper.use_gripper, - ) - elif control_mode == "keyboard_ee": - assert isinstance(teleop_device, KeyboardEndEffectorTeleop), ( - "teleop_device must be an instance of KeyboardEndEffectorTeleop for keyboard control mode" - ) - env = KeyboardControlWrapper( - env=env, - teleop_device=teleop_device, - use_gripper=cfg.wrapper.use_gripper, - ) - elif control_mode == "leader": - env = GearedLeaderControlWrapper( - env=env, - teleop_device=teleop_device, - end_effector_step_sizes=cfg.robot.end_effector_step_sizes, - use_gripper=cfg.wrapper.use_gripper, - ) - elif control_mode == "leader_automatic": - env = GearedLeaderAutomaticControlWrapper( - env=env, - teleop_device=teleop_device, - end_effector_step_sizes=cfg.robot.end_effector_step_sizes, - use_gripper=cfg.wrapper.use_gripper, - ) - else: - raise ValueError(f"Invalid control mode: {control_mode}") - - env = ResetWrapper( - env=env, - reset_pose=cfg.wrapper.fixed_reset_joint_positions, - reset_time_s=cfg.wrapper.reset_time_s, - ) - - env = BatchCompatibleWrapper(env=env) - env = TorchActionWrapper(env=env, device=cfg.device) - - return env + return env, teleop_device -def init_reward_classifier(cfg): - """ - Load a reward classifier policy from a pretrained path if configured. +def make_processors( + env: gym.Env, teleop_device: Teleoperator | None, cfg: HILSerlRobotEnvConfig, device: str = "cpu" +) -> tuple[ + DataProcessorPipeline[EnvTransition, EnvTransition], DataProcessorPipeline[EnvTransition, EnvTransition] +]: + """Create environment and action processors. Args: - cfg: The environment configuration containing classifier paths. + env: Robot environment instance. + teleop_device: Teleoperator device for intervention. + cfg: Processor configuration. + device: Target device for computations. Returns: - The loaded classifier model or None if not configured. + Tuple of (environment processor, action processor). """ - if cfg.reward_classifier_pretrained_path is None: - return None - - from lerobot.policies.sac.reward_model.modeling_classifier import Classifier - - # Get device from config or default to CUDA - device = getattr(cfg, "device", "cpu") - - # Load the classifier directly using from_pretrained - classifier = Classifier.from_pretrained( - pretrained_name_or_path=cfg.reward_classifier_pretrained_path, + terminate_on_success = ( + cfg.processor.reset.terminate_on_success if cfg.processor.reset is not None else True ) - # Ensure model is on the correct device - classifier.to(device) - classifier.eval() # Set to evaluation mode + if cfg.name == "gym_hil": + action_pipeline_steps = [ + InterventionActionProcessorStep(terminate_on_success=terminate_on_success), + Torch2NumpyActionProcessorStep(), + ] - return classifier + env_pipeline_steps = [ + Numpy2TorchActionProcessorStep(), + VanillaObservationProcessorStep(), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=device), + ] + + return DataProcessorPipeline( + steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition + ), DataProcessorPipeline( + steps=action_pipeline_steps, to_transition=identity_transition, to_output=identity_transition + ) + + # Full processor pipeline for real robot environment + # Get robot and motor information for kinematics + motor_names = list(env.robot.bus.motors.keys()) + + # Set up kinematics solver if inverse kinematics is configured + kinematics_solver = None + if cfg.processor.inverse_kinematics is not None: + kinematics_solver = RobotKinematics( + urdf_path=cfg.processor.inverse_kinematics.urdf_path, + target_frame_name=cfg.processor.inverse_kinematics.target_frame_name, + joint_names=motor_names, + ) + + env_pipeline_steps = [VanillaObservationProcessorStep()] + + if cfg.processor.observation is not None: + if cfg.processor.observation.add_joint_velocity_to_observation: + env_pipeline_steps.append(JointVelocityProcessorStep(dt=1.0 / cfg.fps)) + if cfg.processor.observation.add_current_to_observation: + env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot)) + + if kinematics_solver is not None: + env_pipeline_steps.append( + ForwardKinematicsJointsToEEObservation( + kinematics=kinematics_solver, + motor_names=motor_names, + ) + ) + + if cfg.processor.image_preprocessing is not None: + env_pipeline_steps.append( + ImageCropResizeProcessorStep( + crop_params_dict=cfg.processor.image_preprocessing.crop_params_dict, + resize_size=cfg.processor.image_preprocessing.resize_size, + ) + ) + + # Add time limit processor if reset config exists + if cfg.processor.reset is not None: + env_pipeline_steps.append( + TimeLimitProcessorStep(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps)) + ) + + # Add gripper penalty processor if gripper config exists and enabled + if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper: + env_pipeline_steps.append( + GripperPenaltyProcessorStep( + penalty=cfg.processor.gripper.gripper_penalty, + max_gripper_pos=cfg.processor.max_gripper_pos, + ) + ) + + if ( + cfg.processor.reward_classifier is not None + and cfg.processor.reward_classifier.pretrained_path is not None + ): + env_pipeline_steps.append( + RewardClassifierProcessorStep( + pretrained_path=cfg.processor.reward_classifier.pretrained_path, + device=device, + success_threshold=cfg.processor.reward_classifier.success_threshold, + success_reward=cfg.processor.reward_classifier.success_reward, + terminate_on_success=terminate_on_success, + ) + ) + + env_pipeline_steps.append(AddBatchDimensionProcessorStep()) + env_pipeline_steps.append(DeviceProcessorStep(device=device)) + + action_pipeline_steps = [ + AddTeleopActionAsComplimentaryDataStep(teleop_device=teleop_device), + AddTeleopEventsAsInfoStep(teleop_device=teleop_device), + InterventionActionProcessorStep( + use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False, + terminate_on_success=terminate_on_success, + ), + ] + + # Replace InverseKinematicsProcessor with new kinematic processors + if cfg.processor.inverse_kinematics is not None and kinematics_solver is not None: + # Add EE bounds and safety processor + inverse_kinematics_steps = [ + MapTensorToDeltaActionDictStep( + use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False + ), + MapDeltaActionToRobotActionStep(), + EEReferenceAndDelta( + kinematics=kinematics_solver, + end_effector_step_sizes=cfg.processor.inverse_kinematics.end_effector_step_sizes, + motor_names=motor_names, + use_latched_reference=False, + use_ik_solution=True, + ), + EEBoundsAndSafety( + end_effector_bounds=cfg.processor.inverse_kinematics.end_effector_bounds, + ), + GripperVelocityToJoint( + clip_max=cfg.processor.max_gripper_pos, + speed_factor=1.0, + discrete_gripper=True, + ), + InverseKinematicsRLStep( + kinematics=kinematics_solver, motor_names=motor_names, initial_guess_current_joints=False + ), + ] + action_pipeline_steps.extend(inverse_kinematics_steps) + action_pipeline_steps.append(RobotActionToPolicyActionProcessorStep(motor_names=motor_names)) + + return DataProcessorPipeline( + steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition + ), DataProcessorPipeline( + steps=action_pipeline_steps, to_transition=identity_transition, to_output=identity_transition + ) -########################################################### -# Record and replay functions -########################################################### - - -def record_dataset(env, policy, cfg): +def step_env_and_process_transition( + env: gym.Env, + transition: EnvTransition, + action: torch.Tensor, + env_processor: DataProcessorPipeline[EnvTransition, EnvTransition], + action_processor: DataProcessorPipeline[EnvTransition, EnvTransition], +) -> EnvTransition: """ - Record a dataset of robot interactions using either a policy or teleop. - - This function runs episodes in the environment and records the observations, - actions, and results for dataset creation. + Execute one step with processor pipeline. Args: - env: The environment to record from. - policy: Optional policy to generate actions (if None, uses teleop). - cfg: Configuration object containing recording parameters like: - - repo_id: Repository ID for dataset storage - - dataset_root: Local root directory for dataset - - num_episodes: Number of episodes to record - - fps: Frames per second for recording - - push_to_hub: Whether to push dataset to Hugging Face Hub - - task: Name/description of the task being recorded - - number_of_steps_after_success: Number of additional steps to continue recording after - a success (reward=1) is detected. This helps collect - more positive examples for reward classifier training. + env: The robot environment + transition: Current transition state + action: Action to execute + env_processor: Environment processor + action_processor: Action processor + + Returns: + Processed transition with updated state. """ - from lerobot.datasets.lerobot_dataset import LeRobotDataset - # Setup initial action (zero action if using teleop) - action = env.action_space.sample() * 0.0 - - action_names = ["delta_x_ee", "delta_y_ee", "delta_z_ee"] - if cfg.wrapper.use_gripper: - action_names.append("gripper_delta") - - # Configure dataset features based on environment spaces - features = { - "observation.state": { - "dtype": "float32", - "shape": env.observation_space["observation.state"].shape, - "names": None, - }, - "action": { - "dtype": "float32", - "shape": (len(action_names),), - "names": action_names, - }, - "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, - "next.done": {"dtype": "bool", "shape": (1,), "names": None}, - "complementary_info.discrete_penalty": { - "dtype": "float32", - "shape": (1,), - "names": ["discrete_penalty"], - }, - } - - # Add image features - for key in env.observation_space: - if "image" in key: - features[key] = { - "dtype": "video", - "shape": env.observation_space[key].shape, - "names": ["channels", "height", "width"], - } - - # Create dataset - dataset = LeRobotDataset.create( - cfg.repo_id, - cfg.fps, - root=cfg.dataset_root, - use_videos=True, - image_writer_threads=4, - image_writer_processes=0, - features=features, + # Create action transition + transition[TransitionKey.ACTION] = action + transition[TransitionKey.OBSERVATION] = ( + env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {} ) + processed_action_transition = action_processor(transition) + processed_action = processed_action_transition[TransitionKey.ACTION] - # Record episodes - episode_index = 0 - recorded_action = None - while episode_index < cfg.num_episodes: - obs, _ = env.reset() - start_episode_t = time.perf_counter() - log_say(f"Recording episode {episode_index}", play_sounds=True) + obs, reward, terminated, truncated, info = env.step(processed_action) - # Track success state collection - success_detected = False - success_steps_collected = 0 + reward = reward + processed_action_transition[TransitionKey.REWARD] + terminated = terminated or processed_action_transition[TransitionKey.DONE] + truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED] + complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy() + new_info = processed_action_transition[TransitionKey.INFO].copy() + new_info.update(info) - # Run episode steps - while time.perf_counter() - start_episode_t < cfg.wrapper.control_time_s: - start_loop_t = time.perf_counter() + new_transition = create_transition( + observation=obs, + action=processed_action, + reward=reward, + done=terminated, + truncated=truncated, + info=new_info, + complementary_data=complementary_data, + ) + new_transition = env_processor(new_transition) - # Get action from policy if available - if cfg.pretrained_policy_name_or_path is not None: - action = policy.select_action(obs) + return new_transition - # Step environment - obs, reward, terminated, truncated, info = env.step(action) - # Check if episode needs to be rerecorded - if info.get("rerecord_episode", False): - break +def control_loop( + env: gym.Env, + env_processor: DataProcessorPipeline[EnvTransition, EnvTransition], + action_processor: DataProcessorPipeline[EnvTransition, EnvTransition], + teleop_device: Teleoperator, + cfg: GymManipulatorConfig, +) -> None: + """Main control loop for robot environment interaction. + if cfg.mode == "record": then a dataset will be created and recorded - # For teleop, get action from intervention - recorded_action = { - "action": info["action_intervention"].cpu().squeeze(0).float() if policy is None else action + Args: + env: The robot environment + env_processor: Environment processor + action_processor: Action processor + teleop_device: Teleoperator device + cfg: gym_manipulator configuration + """ + dt = 1.0 / cfg.env.fps + + print(f"Starting control loop at {cfg.env.fps} FPS") + print("Controls:") + print("- Use gamepad/teleop device for intervention") + print("- When not intervening, robot will stay still") + print("- Press Ctrl+C to exit") + + # Reset environment and processors + obs, info = env.reset() + complementary_data = ( + {"raw_joint_positions": info.pop("raw_joint_positions")} if "raw_joint_positions" in info else {} + ) + env_processor.reset() + action_processor.reset() + + # Process initial observation + transition = create_transition(observation=obs, info=info, complementary_data=complementary_data) + transition = env_processor(data=transition) + + # Determine if gripper is used + use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True + + dataset = None + if cfg.mode == "record": + action_features = teleop_device.action_features + features = { + "action": action_features, + "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, + "next.done": {"dtype": "bool", "shape": (1,), "names": None}, + } + if use_gripper: + features["complementary_info.discrete_penalty"] = { + "dtype": "float32", + "shape": (1,), + "names": ["discrete_penalty"], } - # Process observation for dataset - obs_processed = {k: v.cpu().squeeze(0).float() for k, v in obs.items()} + for key, value in transition[TransitionKey.OBSERVATION].items(): + if key == "observation.state": + features[key] = { + "dtype": "float32", + "shape": value.squeeze(0).shape, + "names": None, + } + if "image" in key: + features[key] = { + "dtype": "video", + "shape": value.squeeze(0).shape, + "names": ["channels", "height", "width"], + } - # Check if we've just detected success - if reward == 1.0 and not success_detected: - success_detected = True - logging.info("Success detected! Collecting additional success states.") + # Create dataset + dataset = LeRobotDataset.create( + cfg.dataset.repo_id, + cfg.env.fps, + root=cfg.dataset.root, + use_videos=True, + image_writer_threads=4, + image_writer_processes=0, + features=features, + ) - # Add frame to dataset - continue marking as success even during extra collection steps - frame = {**obs_processed, **recorded_action} + episode_idx = 0 + episode_step = 0 + episode_start_time = time.perf_counter() - # If we're in the success collection phase, keep marking rewards as 1.0 - if success_detected: - frame["next.reward"] = np.array([1.0], dtype=np.float32) - else: - frame["next.reward"] = np.array([reward], dtype=np.float32) + while episode_idx < cfg.dataset.num_episodes_to_record: + step_start_time = time.perf_counter() - # Only mark as done if we're truly done (reached end or collected enough success states) - really_done = terminated or truncated - if success_detected: - success_steps_collected += 1 - really_done = success_steps_collected >= cfg.number_of_steps_after_success + # Create a neutral action (no movement) + neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) + if use_gripper: + neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay - frame["next.done"] = np.array([really_done], dtype=bool) - frame["complementary_info.discrete_penalty"] = torch.tensor( - [info.get("discrete_penalty", 0.0)], dtype=torch.float32 + # Use the new step function + transition = step_env_and_process_transition( + env=env, + transition=transition, + action=neutral_action, + env_processor=env_processor, + action_processor=action_processor, + ) + terminated = transition.get(TransitionKey.DONE, False) + truncated = transition.get(TransitionKey.TRUNCATED, False) + + if cfg.mode == "record": + observations = { + k: v.squeeze(0).cpu() + for k, v in transition[TransitionKey.OBSERVATION].items() + if isinstance(v, torch.Tensor) + } + # Use teleop_action if available, otherwise use the action from the transition + action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get( + "teleop_action", transition[TransitionKey.ACTION] ) - frame["task"] = cfg.task - dataset.add_frame(frame) + frame = { + **observations, + "action": action_to_record.cpu(), + "next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32), + "next.done": np.array([terminated or truncated], dtype=bool), + } + if use_gripper: + discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0) + frame["complementary_info.discrete_penalty"] = np.array([discrete_penalty], dtype=np.float32) - # Maintain consistent timing - if cfg.fps: - dt_s = time.perf_counter() - start_loop_t - busy_wait(1 / cfg.fps - dt_s) + if dataset is not None: + frame["task"] = cfg.dataset.task + dataset.add_frame(frame) - # Check if we should end the episode - if (terminated or truncated) and not success_detected: - # Regular termination without success - break - elif success_detected and success_steps_collected >= cfg.number_of_steps_after_success: - # We've collected enough success states - logging.info(f"Collected {success_steps_collected} additional success states") - break + episode_step += 1 - # Handle episode recording - if info.get("rerecord_episode", False): - dataset.clear_episode_buffer() - logging.info(f"Re-recording episode {episode_index}") - continue + # Handle episode termination + if terminated or truncated: + episode_time = time.perf_counter() - episode_start_time + logging.info( + f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}" + ) + episode_step = 0 + episode_idx += 1 - dataset.save_episode() - episode_index += 1 + if dataset is not None: + if transition[TransitionKey.INFO].get("rerecord_episode", False): + logging.info(f"Re-recording episode {episode_idx}") + dataset.clear_episode_buffer() + episode_idx -= 1 + else: + logging.info(f"Saving episode {episode_idx}") + dataset.save_episode() - # Finalize dataset - # dataset.consolidate(run_compute_stats=True) - if cfg.push_to_hub: + # Reset for new episode + obs, info = env.reset() + env_processor.reset() + action_processor.reset() + + transition = create_transition(observation=obs, info=info) + transition = env_processor(transition) + + # Maintain fps timing + busy_wait(dt - (time.perf_counter() - step_start_time)) + + if dataset is not None and cfg.dataset.push_to_hub: + logging.info("Pushing dataset to hub") dataset.push_to_hub() -def replay_episode(env, cfg): - """ - Replay a recorded episode in the environment. +def replay_trajectory( + env: gym.Env, action_processor: DataProcessorPipeline, cfg: GymManipulatorConfig +) -> None: + """Replay recorded trajectory on robot environment.""" + assert cfg.dataset.replay_episode is not None, "Replay episode must be provided for replay" - This function loads actions from a previously recorded episode - and executes them in the environment. + dataset = LeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + episodes=[cfg.dataset.replay_episode], + download_videos=False, + ) + episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode) + actions = episode_frames.select_columns("action") - Args: - env: The environment to replay in. - cfg: Configuration object containing replay parameters: - - repo_id: Repository ID for dataset - - dataset_root: Local root directory for dataset - - episode: Episode ID to replay - """ - from lerobot.datasets.lerobot_dataset import LeRobotDataset + _, info = env.reset() - dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode]) - env.reset() - - actions = dataset.hf_dataset.select_columns("action") - - for idx in range(dataset.num_frames): - start_episode_t = time.perf_counter() - - action = actions[idx]["action"] - env.step(action) - - dt_s = time.perf_counter() - start_episode_t - busy_wait(1 / 10 - dt_s) + for action_data in actions: + start_time = time.perf_counter() + transition = create_transition( + observation=env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {}, + action=action_data["action"], + ) + transition = action_processor(transition) + env.step(transition[TransitionKey.ACTION]) + busy_wait(1 / cfg.env.fps - (time.perf_counter() - start_time)) @parser.wrap() -def main(cfg: EnvConfig): - """Main entry point for the robot environment script. +def main(cfg: GymManipulatorConfig) -> None: + """Main entry point for gym manipulator script.""" + env, teleop_device = make_robot_env(cfg.env) + env_processor, action_processor = make_processors(env, teleop_device, cfg.env, cfg.device) - This function runs the robot environment in one of several modes - based on the provided configuration. - - Args: - cfg: Configuration object defining the run parameters, - including mode (record, replay, random) and other settings. - """ - env = make_robot_env(cfg) - - if cfg.mode == "record": - policy = None - if cfg.pretrained_policy_name_or_path is not None: - from lerobot.policies.sac.modeling_sac import SACPolicy - - policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path) - policy.to(cfg.device) - policy.eval() - - record_dataset( - env, - policy=policy, - cfg=cfg, - ) - exit() + print("Environment observation space:", env.observation_space) + print("Environment action space:", env.action_space) + print("Environment processor:", env_processor) + print("Action processor:", action_processor) if cfg.mode == "replay": - replay_episode( - env, - cfg=cfg, - ) + replay_trajectory(env, action_processor, cfg) exit() - env.reset() - - # Initialize the smoothed action as a random sample. - smoothed_action = env.action_space.sample() * 0.0 - - # Smoothing coefficient (alpha) defines how much of the new random sample to mix in. - # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. - alpha = 1.0 - - num_episode = 0 - successes = [] - while num_episode < 10: - start_loop_s = time.perf_counter() - # Sample a new random action from the robot's action space. - new_random_action = env.action_space.sample() - # Update the smoothed action using an exponential moving average. - smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action - - # Execute the step: wrap the NumPy action in a torch tensor. - obs, reward, terminated, truncated, info = env.step(smoothed_action) - if terminated or truncated: - successes.append(reward) - env.reset() - num_episode += 1 - - dt_s = time.perf_counter() - start_loop_s - busy_wait(1 / cfg.fps - dt_s) - - logging.info(f"Success after 20 steps {successes}") - logging.info(f"success rate {sum(successes) / len(successes)}") + control_loop(env, env_processor, action_processor, teleop_device, cfg) if __name__ == "__main__": diff --git a/src/lerobot/scripts/rl/learner.py b/src/lerobot/scripts/rl/learner.py index f9f3901ce..5d9953827 100644 --- a/src/lerobot/scripts/rl/learner.py +++ b/src/lerobot/scripts/rl/learner.py @@ -75,6 +75,7 @@ from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.robots import so100_follower # noqa: F401 from lerobot.scripts.rl import learner_service from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 +from lerobot.teleoperators.utils import TeleopEvents from lerobot.transport import services_pb2_grpc from lerobot.transport.utils import ( MAX_MESSAGE_SIZE, @@ -102,11 +103,6 @@ from lerobot.utils.wandb_utils import WandBLogger LOG_PREFIX = "[LEARNER]" -################################################# -# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS # -################################################# - - @parser.wrap() def train_cli(cfg: TrainRLServerPipelineConfig): if not use_threads(cfg): @@ -249,9 +245,7 @@ def start_learner_threads( logging.info("[LEARNER] queues closed") -################################################# -# Core algorithm functions # -################################################# +# Core algorithm functions def add_actor_information_and_train( @@ -819,9 +813,7 @@ def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.M return optimizers, lr_scheduler -################################################# -# Training setup functions # -################################################# +# Training setup functions def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipelineConfig: @@ -1022,9 +1014,7 @@ def initialize_offline_replay_buffer( return offline_replay_buffer -################################################# -# Utilities/Helpers functions # -################################################# +# Utilities/Helpers functions def get_observation_features( @@ -1048,10 +1038,8 @@ def get_observation_features( return None, None with torch.no_grad(): - observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True) - next_observation_features = policy.actor.encoder.get_cached_image_features( - next_observations, normalize=True - ) + observation_features = policy.actor.encoder.get_cached_image_features(observations) + next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations) return observation_features, next_observation_features @@ -1176,7 +1164,7 @@ def process_transitions( # Add to offline buffer if it's an intervention if dataset_repo_id is not None and transition.get("complementary_info", {}).get( - "is_intervention" + TeleopEvents.IS_INTERVENTION ): offline_replay_buffer.add(**transition) diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 398bea90e..485fc9275 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -31,7 +31,7 @@ from lerobot.datasets.sampler import EpisodeAwareSampler from lerobot.datasets.utils import cycle from lerobot.envs.factory import make_env from lerobot.optim.factory import make_optimizer_and_scheduler -from lerobot.policies.factory import make_policy +from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters from lerobot.scripts.eval import eval_policy @@ -64,6 +64,28 @@ def update_policy( use_amp: bool = False, lock=None, ) -> tuple[MetricsTracker, dict]: + """ + Performs a single training step to update the policy's weights. + + This function executes the forward and backward passes, clips gradients, and steps the optimizer and + learning rate scheduler. It also handles mixed-precision training via a GradScaler. + + Args: + train_metrics: A MetricsTracker instance to record training statistics. + policy: The policy model to be trained. + batch: A batch of training data. + optimizer: The optimizer used to update the policy's parameters. + grad_clip_norm: The maximum norm for gradient clipping. + grad_scaler: The GradScaler for automatic mixed-precision training. + lr_scheduler: An optional learning rate scheduler. + use_amp: A boolean indicating whether to use automatic mixed precision. + lock: An optional lock for thread-safe optimizer updates. + + Returns: + A tuple containing: + - The updated MetricsTracker with new statistics for this step. + - A dictionary of outputs from the policy's forward pass, for logging purposes. + """ start_time = time.perf_counter() device = get_device_from_parameters(policy) policy.train() @@ -107,6 +129,20 @@ def update_policy( @parser.wrap() def train(cfg: TrainPipelineConfig): + """ + Main function to train a policy. + + This function orchestrates the entire training pipeline, including: + - Setting up logging, seeding, and device configuration. + - Creating the dataset, evaluation environment (if applicable), policy, and optimizer. + - Handling resumption from a checkpoint. + - Running the main training loop, which involves fetching data batches and calling `update_policy`. + - Periodically logging metrics, saving model checkpoints, and evaluating the policy. + - Pushing the final trained model to the Hugging Face Hub if configured. + + Args: + cfg: A `TrainPipelineConfig` object containing all training configurations. + """ cfg.validate() logging.info(pformat(cfg.to_dict())) @@ -141,6 +177,16 @@ def train(cfg: TrainPipelineConfig): ds_meta=dataset.meta, ) + # Create processors - only provide dataset_stats if not resuming from saved processors + processor_kwargs = {} + if not (cfg.resume and cfg.policy.pretrained_path): + # Only provide dataset_stats when not resuming from saved processor state + processor_kwargs["dataset_stats"] = dataset.meta.stats + + preprocessor, postprocessor = make_pre_post_processors( + policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs + ) + logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) @@ -205,15 +251,9 @@ def train(cfg: TrainPipelineConfig): for _ in range(step, cfg.steps): start_time = time.perf_counter() batch = next(dl_iter) + batch = preprocessor(batch) train_tracker.dataloading_s = time.perf_counter() - start_time - for key in batch: - if isinstance(batch[key], torch.Tensor): - if batch[key].dtype != torch.bool: - batch[key] = batch[key].type(torch.float32) if device.type == "mps" else batch[key] - - batch[key] = batch[key].to(device, non_blocking=device.type == "cuda") - train_tracker, output_dict = update_policy( train_tracker, policy, @@ -245,7 +285,9 @@ def train(cfg: TrainPipelineConfig): if cfg.save_checkpoint and is_saving_step: logging.info(f"Checkpoint policy after step {step}") checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) - save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler) + save_checkpoint( + checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor + ) update_last_checkpoint(checkpoint_dir) if wandb_logger: wandb_logger.log_policy(checkpoint_dir) @@ -258,9 +300,11 @@ def train(cfg: TrainPipelineConfig): torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), ): eval_info = eval_policy( - eval_env, - policy, - cfg.eval.n_episodes, + env=eval_env, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + n_episodes=cfg.eval.n_episodes, videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", max_episodes_rendered=4, start_seed=cfg.seed, @@ -289,6 +333,8 @@ def train(cfg: TrainPipelineConfig): if cfg.policy.push_to_hub: policy.push_model_to_hub(cfg) + preprocessor.push_to_hub(cfg.policy.repo_id) + postprocessor.push_to_hub(cfg.policy.repo_id) def main(): diff --git a/src/lerobot/teleoperate.py b/src/lerobot/teleoperate.py index e7be6967b..62c243e95 100644 --- a/src/lerobot/teleoperate.py +++ b/src/lerobot/teleoperate.py @@ -56,11 +56,17 @@ import time from dataclasses import asdict, dataclass from pprint import pformat -import draccus import rerun as rr from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.configs import parser +from lerobot.processor import ( + RobotAction, + RobotObservation, + RobotProcessorPipeline, + make_default_processors, +) from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -100,36 +106,81 @@ class TeleoperateConfig: def teleop_loop( - teleop: Teleoperator, robot: Robot, fps: int, display_data: bool = False, duration: float | None = None + teleop: Teleoperator, + robot: Robot, + fps: int, + teleop_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction], + robot_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction], + robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation], + display_data: bool = False, + duration: float | None = None, ): + """ + This function continuously reads actions from a teleoperation device, processes them through optional + pipelines, sends them to a robot, and optionally displays the robot's state. The loop runs at a + specified frequency until a set duration is reached or it is manually interrupted. + + Args: + teleop: The teleoperator device instance providing control actions. + robot: The robot instance being controlled. + fps: The target frequency for the control loop in frames per second. + display_data: If True, fetches robot observations and displays them in the console and Rerun. + duration: The maximum duration of the teleoperation loop in seconds. If None, the loop runs indefinitely. + teleop_action_processor: An optional pipeline to process raw actions from the teleoperator. + robot_action_processor: An optional pipeline to process actions before they are sent to the robot. + robot_observation_processor: An optional pipeline to process raw observations from the robot. + """ + display_len = max(len(key) for key in robot.action_features) start = time.perf_counter() + while True: loop_start = time.perf_counter() - action = teleop.get_action() - if display_data: - observation = robot.get_observation() - log_rerun_data(observation, action) - robot.send_action(action) + # Get robot observation + # Not really needed for now other than for visualization + # teleop_action_processor can take None as an observation + # given that it is the identity processor as default + obs = robot.get_observation() + + # Get teleop action + raw_action = teleop.get_action() + + # Process teleop action through pipeline + teleop_action = teleop_action_processor((raw_action, obs)) + + # Process action for robot through pipeline + robot_action_to_send = robot_action_processor((teleop_action, obs)) + + # Send processed action to robot (robot_action_processor.to_output should return dict[str, Any]) + _ = robot.send_action(robot_action_to_send) + + if display_data: + # Process robot observation through pipeline + obs_transition = robot_observation_processor(obs) + + log_rerun_data( + observation=obs_transition, + action=teleop_action, + ) + + print("\n" + "-" * (display_len + 10)) + print(f"{'NAME':<{display_len}} | {'NORM':>7}") + # Display the final robot action that was sent + for motor, value in robot_action_to_send.items(): + print(f"{motor:<{display_len}} | {value:>7.2f}") + move_cursor_up(len(robot_action_to_send) + 5) + dt_s = time.perf_counter() - loop_start busy_wait(1 / fps - dt_s) - loop_s = time.perf_counter() - loop_start - - print("\n" + "-" * (display_len + 10)) - print(f"{'NAME':<{display_len}} | {'NORM':>7}") - for motor, value in action.items(): - print(f"{motor:<{display_len}} | {value:>7.2f}") print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)") if duration is not None and time.perf_counter() - start >= duration: return - move_cursor_up(len(action) + 5) - -@draccus.wrap() +@parser.wrap() def teleoperate(cfg: TeleoperateConfig): init_logging() logging.info(pformat(asdict(cfg))) @@ -138,12 +189,22 @@ def teleoperate(cfg: TeleoperateConfig): teleop = make_teleoperator_from_config(cfg.teleop) robot = make_robot_from_config(cfg.robot) + teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() teleop.connect() robot.connect() try: - teleop_loop(teleop, robot, cfg.fps, display_data=cfg.display_data, duration=cfg.teleop_time_s) + teleop_loop( + teleop=teleop, + robot=robot, + fps=cfg.fps, + display_data=cfg.display_data, + duration=cfg.teleop_time_s, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, + ) except KeyboardInterrupt: pass finally: diff --git a/src/lerobot/teleoperators/__init__.py b/src/lerobot/teleoperators/__init__.py index 56f48af7e..ee508dddb 100644 --- a/src/lerobot/teleoperators/__init__.py +++ b/src/lerobot/teleoperators/__init__.py @@ -16,4 +16,4 @@ from .config import TeleoperatorConfig from .teleoperator import Teleoperator -from .utils import make_teleoperator_from_config +from .utils import TeleopEvents, make_teleoperator_from_config diff --git a/src/lerobot/teleoperators/gamepad/gamepad_utils.py b/src/lerobot/teleoperators/gamepad/gamepad_utils.py index 7ebed6b31..d994dadd1 100644 --- a/src/lerobot/teleoperators/gamepad/gamepad_utils.py +++ b/src/lerobot/teleoperators/gamepad/gamepad_utils.py @@ -16,6 +16,8 @@ import logging +from ..utils import TeleopEvents + class InputController: """Base class for input controllers that generate motion deltas.""" @@ -134,10 +136,10 @@ class KeyboardController(InputController): return False elif key == keyboard.Key.enter: self.key_states["success"] = True - self.episode_end_status = "success" + self.episode_end_status = TeleopEvents.SUCCESS elif key == keyboard.Key.backspace: self.key_states["failure"] = True - self.episode_end_status = "failure" + self.episode_end_status = TeleopEvents.FAILURE except AttributeError: pass @@ -255,13 +257,13 @@ class GamepadController(InputController): for event in pygame.event.get(): if event.type == pygame.JOYBUTTONDOWN: if event.button == 3: - self.episode_end_status = "success" + self.episode_end_status = TeleopEvents.SUCCESS # A button (1) for failure elif event.button == 1: - self.episode_end_status = "failure" + self.episode_end_status = TeleopEvents.FAILURE # X button (0) for rerecord elif event.button == 0: - self.episode_end_status = "rerecord_episode" + self.episode_end_status = TeleopEvents.RERECORD_EPISODE # RB button (6) for closing gripper elif event.button == 6: @@ -295,8 +297,8 @@ class GamepadController(InputController): try: # Read joystick axes # Left stick X and Y (typically axes 0 and 1) - x_input = self.joystick.get_axis(0) # Left/Right - y_input = self.joystick.get_axis(1) # Up/Down (often inverted) + y_input = self.joystick.get_axis(0) # Up/Down (often inverted) + x_input = self.joystick.get_axis(1) # Left/Right # Right stick Y (typically axis 3 or 4) z_input = self.joystick.get_axis(3) # Up/Down for Z @@ -308,7 +310,7 @@ class GamepadController(InputController): # Calculate deltas (note: may need to invert axes depending on controller) delta_x = -x_input * self.x_step_size # Forward/backward - delta_y = y_input * self.y_step_size # Left/right + delta_y = -y_input * self.y_step_size # Left/right delta_z = -z_input * self.z_step_size # Up/down return delta_x, delta_y, delta_z @@ -451,11 +453,11 @@ class GamepadControllerHID(InputController): # Check if X/Square button (bit 5) is pressed for failure # Check if A/Cross button (bit 4) is pressed for rerecording if buttons & 1 << 7: - self.episode_end_status = "success" + self.episode_end_status = TeleopEvents.SUCCESS elif buttons & 1 << 5: - self.episode_end_status = "failure" + self.episode_end_status = TeleopEvents.FAILURE elif buttons & 1 << 4: - self.episode_end_status = "rerecord_episode" + self.episode_end_status = TeleopEvents.RERECORD_EPISODE else: self.episode_end_status = None diff --git a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py index 98a0647e2..c7072f4a7 100644 --- a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py +++ b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py @@ -21,6 +21,7 @@ from typing import Any import numpy as np from ..teleoperator import Teleoperator +from ..utils import TeleopEvents from .configuration_gamepad import GamepadTeleopConfig @@ -107,6 +108,48 @@ class GamepadTeleop(Teleoperator): return action_dict + def get_teleop_events(self) -> dict[str, Any]: + """ + Get extra control events from the gamepad such as intervention status, + episode termination, success indicators, etc. + + Returns: + Dictionary containing: + - is_intervention: bool - Whether human is currently intervening + - terminate_episode: bool - Whether to terminate the current episode + - success: bool - Whether the episode was successful + - rerecord_episode: bool - Whether to rerecord the episode + """ + if self.gamepad is None: + return { + TeleopEvents.IS_INTERVENTION: False, + TeleopEvents.TERMINATE_EPISODE: False, + TeleopEvents.SUCCESS: False, + TeleopEvents.RERECORD_EPISODE: False, + } + + # Update gamepad state to get fresh inputs + self.gamepad.update() + + # Check if intervention is active + is_intervention = self.gamepad.should_intervene() + + # Get episode end status + episode_end_status = self.gamepad.get_episode_end_status() + terminate_episode = episode_end_status in [ + TeleopEvents.RERECORD_EPISODE, + TeleopEvents.FAILURE, + ] + success = episode_end_status == TeleopEvents.SUCCESS + rerecord_episode = episode_end_status == TeleopEvents.RERECORD_EPISODE + + return { + TeleopEvents.IS_INTERVENTION: is_intervention, + TeleopEvents.TERMINATE_EPISODE: terminate_episode, + TeleopEvents.SUCCESS: success, + TeleopEvents.RERECORD_EPISODE: rerecord_episode, + } + def disconnect(self) -> None: """Disconnect from the gamepad.""" if self.gamepad is not None: diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py index d034982f1..7f489b25a 100644 --- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py +++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py @@ -24,6 +24,7 @@ from typing import Any from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..teleoperator import Teleoperator +from ..utils import TeleopEvents from .configuration_keyboard import KeyboardEndEffectorTeleopConfig, KeyboardTeleopConfig PYNPUT_AVAILABLE = True @@ -176,16 +177,6 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop): "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2}, } - def _on_press(self, key): - if hasattr(key, "char"): - key = key.char - self.event_queue.put((key, True)) - - def _on_release(self, key): - if hasattr(key, "char"): - key = key.char - self.event_queue.put((key, False)) - def get_action(self) -> dict[str, Any]: if not self.is_connected: raise DeviceNotConnectedError( @@ -235,3 +226,66 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop): action_dict["gripper"] = gripper_action return action_dict + + def get_teleop_events(self) -> dict[str, Any]: + """ + Get extra control events from the keyboard such as intervention status, + episode termination, success indicators, etc. + + Keyboard mappings: + - Any movement keys pressed = intervention active + - 's' key = success (terminate episode successfully) + - 'r' key = rerecord episode (terminate and rerecord) + - 'q' key = quit episode (terminate without success) + + Returns: + Dictionary containing: + - is_intervention: bool - Whether human is currently intervening + - terminate_episode: bool - Whether to terminate the current episode + - success: bool - Whether the episode was successful + - rerecord_episode: bool - Whether to rerecord the episode + """ + if not self.is_connected: + return { + TeleopEvents.IS_INTERVENTION: False, + TeleopEvents.TERMINATE_EPISODE: False, + TeleopEvents.SUCCESS: False, + TeleopEvents.RERECORD_EPISODE: False, + } + + # Check if any movement keys are currently pressed (indicates intervention) + movement_keys = [ + keyboard.Key.up, + keyboard.Key.down, + keyboard.Key.left, + keyboard.Key.right, + keyboard.Key.shift, + keyboard.Key.shift_r, + keyboard.Key.ctrl_r, + keyboard.Key.ctrl_l, + ] + is_intervention = any(self.current_pressed.get(key, False) for key in movement_keys) + + # Check for episode control commands from misc_keys_queue + terminate_episode = False + success = False + rerecord_episode = False + + # Process any pending misc keys + while not self.misc_keys_queue.empty(): + key = self.misc_keys_queue.get_nowait() + if key == "s": + success = True + elif key == "r": + terminate_episode = True + rerecord_episode = True + elif key == "q": + terminate_episode = True + success = False + + return { + TeleopEvents.IS_INTERVENTION: is_intervention, + TeleopEvents.TERMINATE_EPISODE: terminate_episode, + TeleopEvents.SUCCESS: success, + TeleopEvents.RERECORD_EPISODE: rerecord_episode, + } diff --git a/src/lerobot/teleoperators/phone/__init__.py b/src/lerobot/teleoperators/phone/__init__.py new file mode 100644 index 000000000..2b28c1f97 --- /dev/null +++ b/src/lerobot/teleoperators/phone/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_phone import PhoneConfig +from .teleop_phone import Phone diff --git a/src/lerobot/teleoperators/phone/config_phone.py b/src/lerobot/teleoperators/phone/config_phone.py new file mode 100644 index 000000000..380d5f5ff --- /dev/null +++ b/src/lerobot/teleoperators/phone/config_phone.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from enum import Enum + +import numpy as np + +from ..config import TeleoperatorConfig + + +class PhoneOS(Enum): + ANDROID = "android" + IOS = "ios" + + +@TeleoperatorConfig.register_subclass("phone") +@dataclass +class PhoneConfig(TeleoperatorConfig): + phone_os: PhoneOS = PhoneOS.IOS + camera_offset = np.array( + [0.0, -0.02, 0.04] + ) # iPhone 14 Pro camera is 2cm off center and 4cm above center diff --git a/src/lerobot/teleoperators/phone/phone_processor.py b/src/lerobot/teleoperators/phone/phone_processor.py new file mode 100644 index 000000000..67e64c7d5 --- /dev/null +++ b/src/lerobot/teleoperators/phone/phone_processor.py @@ -0,0 +1,110 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.processor import ProcessorStepRegistry, RobotAction, RobotActionProcessorStep +from lerobot.teleoperators.phone.config_phone import PhoneOS + + +@ProcessorStepRegistry.register("map_phone_action_to_robot_action") +@dataclass +class MapPhoneActionToRobotAction(RobotActionProcessorStep): + """ + Maps calibrated phone pose actions to standardized robot action inputs. + + This processor step acts as a bridge between the phone teleoperator's output + and the robot's expected action format. It remaps the phone's 6-DoF pose + (position and rotation) to the robot's target end-effector pose, applying + necessary axis inversions and swaps. It also interprets platform-specific + button presses to generate a gripper command. + + Attributes: + platform: The operating system of the phone (iOS or Android), used + to determine the correct button mappings for the gripper. + """ + + # TODO(Steven): Gripper vel could be output of phone_teleop directly + platform: PhoneOS + _enabled_prev: bool = field(default=False, init=False, repr=False) + + def action(self, action: RobotAction) -> RobotAction: + """ + Processes the phone action dictionary to create a robot action dictionary. + + Args: + act: The input action dictionary from the phone teleoperator. + + Returns: + A new action dictionary formatted for the robot controller. + + Raises: + ValueError: If 'pos' or 'rot' keys are missing from the input action. + """ + # Pop them from the action + enabled = bool(action.pop("phone.enabled")) + pos = action.pop("phone.pos") + rot = action.pop("phone.rot") + inputs = action.pop("phone.raw_inputs") + + if pos is None or rot is None: + raise ValueError("pos and rot must be present in action") + + rotvec = rot.as_rotvec() # Absolute orientation as rotvec + + # Map certain inputs to certain actions + if self.platform == PhoneOS.IOS: + gripper_vel = float(inputs.get("a3", 0.0)) + else: + a = float(inputs.get("reservedButtonA", 0.0)) + b = float(inputs.get("reservedButtonB", 0.0)) + gripper_vel = ( + a - b + ) # Positive if a is pressed, negative if b is pressed, 0 if both or neither are pressed + + # For some actions we need to invert the axis + action["enabled"] = enabled + action["target_x"] = -pos[1] if enabled else 0.0 + action["target_y"] = pos[0] if enabled else 0.0 + action["target_z"] = pos[2] if enabled else 0.0 + action["target_wx"] = rotvec[1] if enabled else 0.0 + action["target_wy"] = rotvec[0] if enabled else 0.0 + action["target_wz"] = -rotvec[2] if enabled else 0.0 + action["gripper_vel"] = gripper_vel # Still send gripper action when disabled + return action + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + for feat in ["enabled", "pos", "rot", "raw_inputs"]: + features[PipelineFeatureType.ACTION].pop(f"phone.{feat}", None) + + for feat in [ + "enabled", + "target_x", + "target_y", + "target_z", + "target_wx", + "target_wy", + "target_wz", + "gripper_vel", + ]: + features[PipelineFeatureType.ACTION][f"{feat}"] = PolicyFeature( + type=FeatureType.ACTION, shape=(1,) + ) + + return features diff --git a/src/lerobot/teleoperators/phone/teleop_phone.py b/src/lerobot/teleoperators/phone/teleop_phone.py new file mode 100644 index 000000000..c90729efa --- /dev/null +++ b/src/lerobot/teleoperators/phone/teleop_phone.py @@ -0,0 +1,421 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Docs: +# hebi: https://docs.hebi.us/tools.html#mobile-io +# teleop: https://github.com/SpesRobotics/teleop + +import logging +import threading +import time + +import hebi +import numpy as np +from teleop import Teleop + +from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS +from lerobot.teleoperators.teleoperator import Teleoperator +from lerobot.utils.rotation import Rotation + +logger = logging.getLogger(__name__) + + +class BasePhone: + _enabled: bool = False + _calib_pos: np.ndarray | None = None + _calib_rot_inv: Rotation | None = None + + def _reapply_position_calibration(self, pos: np.ndarray) -> None: + self._calib_pos = pos.copy() + + @property + def is_calibrated(self) -> bool: + return (self._calib_pos is not None) and (self._calib_rot_inv is not None) + + @property + def action_features(self) -> dict[str, type]: + return { + "phone.pos": np.ndarray, # shape (3,) + "phone.rot": Rotation, # scipy.spatial.transform.Rotation + "phone.raw_inputs": dict, # analogs/buttons or webXR meta + "phone.enabled": bool, + } + + @property + def feedback_features(self) -> dict[str, type]: + # No haptic or other feedback implemented yet + pass + + def configure(self) -> None: + # No additional configuration required for phone teleop + pass + + def send_feedback(self, feedback: dict[str, float]) -> None: + # We could add haptic feedback (vibrations) here, but it's not implemented yet + raise NotImplementedError + + +class IOSPhone(BasePhone, Teleoperator): + name = "ios_phone" + + def __init__(self, config: PhoneConfig): + super().__init__(config) + self.config = config + self._group = None + + @property + def is_connected(self) -> bool: + return self._group is not None + + def connect(self) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.") + lookup = hebi.Lookup() + time.sleep(2.0) + group = lookup.get_group_from_names(["HEBI"], ["mobileIO"]) + if group is None: + raise RuntimeError("Mobile I/O not found — check name/family settings in the app.") + self._group = group + logger.info(f"{self} connected to HEBI group with {group.size} module(s).") + + self.calibrate() + + def calibrate(self) -> None: + print( + "Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)" + ) + print("Press and hold B1 in the HEBI Mobile I/O app to capture this pose...\n") + position, rotation = self._wait_for_capture_trigger() + self._calib_pos = position.copy() + self._calib_rot_inv = rotation.inv() + self._enabled = False + print("Calibration done\n") + + def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]: + """ + Blocks execution until the calibration trigger is detected from the iOS device. + + This method enters a loop, continuously reading the phone's state. It waits for the user to press + and hold the 'B1' button in the HEBI Mobile I/O app. Once B1 is pressed, the loop breaks and + returns the phone's pose at that exact moment. + + Returns: + A tuple containing the position (np.ndarray) and rotation (Rotation) of the phone at the + moment the trigger was activated. + """ + while True: + has_pose, position, rotation, fb_pose = self._read_current_pose() + if not has_pose: + time.sleep(0.01) + continue + + io = getattr(fb_pose, "io", None) + button_b = getattr(io, "b", None) if io is not None else None + button_b1_pressed = False + if button_b is not None: + button_b1_pressed = bool(button_b.get_int(1)) + if button_b1_pressed: + return position, rotation + + time.sleep(0.01) + + def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]: + """ + Reads the instantaneous 6-DoF pose from the connected iOS device via the HEBI SDK. + + This method fetches the latest feedback packet from the HEBI group, extracts the ARKit + position and orientation, and converts them into a standard format. It also applies a + configured camera offset to adjust the pose from the camera's frame to the phone's + physical frame. + + Returns: + A tuple containing: + - A boolean indicating if a valid pose was successfully read. + - The 3D position as a NumPy array, or None if not available. + - The orientation as a `Rotation` object, or None if not available. + - The raw HEBI feedback object for accessing other data like button presses. + """ + fbk = self._group.get_next_feedback() + pose = fbk[0] + ar_pos = getattr(pose, "ar_position", None) + ar_quat = getattr(pose, "ar_orientation", None) + if ar_pos is None or ar_quat is None: + return False, None, None, None + # HEBI provides orientation in w, x, y, z format. + # Scipy's Rotation expects x, y, z, w. + quat_xyzw = np.concatenate((ar_quat[1:], [ar_quat[0]])) # wxyz to xyzw + rot = Rotation.from_quat(quat_xyzw) + pos = ar_pos - rot.apply(self.config.camera_offset) + return True, pos, rot, pose + + def get_action(self) -> dict: + has_pose, raw_position, raw_rotation, fb_pose = self._read_current_pose() + if not has_pose or not self.is_calibrated: + return {} + + # Collect raw inputs (B1 / analogs on iOS, move/scale on Android) + raw_inputs: dict[str, float | int | bool] = {} + io = getattr(fb_pose, "io", None) + if io is not None: + bank_a, bank_b = io.a, io.b + if bank_a: + for ch in range(1, 9): + if bank_a.has_float(ch): + raw_inputs[f"a{ch}"] = float(bank_a.get_float(ch)) + if bank_b: + for ch in range(1, 9): + if bank_b.has_int(ch): + raw_inputs[f"b{ch}"] = int(bank_b.get_int(ch)) + elif hasattr(bank_b, "has_bool") and bank_b.has_bool(ch): + raw_inputs[f"b{ch}"] = int(bank_b.get_bool(ch)) + + enable = bool(raw_inputs.get("b1", 0)) + + # Rising edge then re-capture calibration immediately from current raw pose + if enable and not self._enabled: + self._reapply_position_calibration(raw_position) + + # Apply calibration + pos_cal = self._calib_rot_inv.apply(raw_position - self._calib_pos) + rot_cal = self._calib_rot_inv * raw_rotation + + self._enabled = enable + + return { + "phone.pos": pos_cal, + "phone.rot": rot_cal, + "phone.raw_inputs": raw_inputs, + "phone.enabled": self._enabled, + } + + def disconnect(self) -> None: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self._group = None + + +class AndroidPhone(BasePhone, Teleoperator): + name = "android_phone" + + def __init__(self, config: PhoneConfig): + super().__init__(config) + self.config = config + self._teleop = None + self._teleop_thread = None + self._latest_pose = None + self._latest_message = None + self._android_lock = threading.Lock() + + @property + def is_connected(self) -> bool: + return self._teleop is not None + + def connect(self) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + logger.info("Starting teleop stream for Android...") + self._teleop = Teleop() + self._teleop.subscribe(self._android_callback) + self._teleop_thread = threading.Thread(target=self._teleop.run, daemon=True) + self._teleop_thread.start() + logger.info(f"{self} connected, teleop stream started.") + + self.calibrate() + + def calibrate(self) -> None: + print( + "Hold the phone so that: top edge points forward in same direction as the robot (robot +x) and screen points up (robot +z)" + ) + print("Touch and move on the WebXR page to capture this pose...\n") + + pos, rot = self._wait_for_capture_trigger() + self._calib_pos = pos.copy() + self._calib_rot_inv = rot.inv() + self._enabled = False + print("Calibration done\n") + + def _wait_for_capture_trigger(self) -> tuple[np.ndarray, Rotation]: + """ + Blocks execution until the calibration trigger is detected from the Android device. + + This method enters a loop, continuously checking the latest message received from the WebXR + session. It waits for the user to touch and move their finger on the screen, which generates + a `move` event. Once this event is detected, the loop breaks and returns the phone's current + pose. + + Returns: + A tuple containing the position (np.ndarray) and rotation (Rotation) of the phone at the + moment the trigger was activated. + """ + while True: + with self._android_lock: + msg = self._latest_message or {} + + if bool(msg.get("move", False)): + ok, pos, rot, _pose = self._read_current_pose() + if ok: + return pos, rot + + time.sleep(0.01) + + def _read_current_pose(self) -> tuple[bool, np.ndarray | None, Rotation | None, object | None]: + """ + Reads the latest 6-DoF pose received from the Android device's WebXR session. + + This method accesses the most recent pose data stored by the `_android_callback`. It uses a + thread lock to safely read the shared `_latest_pose` variable. The pose, a 4x4 matrix, is + then decomposed into position and rotation, and the configured camera offset is applied. + + Returns: + A tuple containing: + - A boolean indicating if a valid pose was available. + - The 3D position as a NumPy array, or None if no pose has been received yet. + - The orientation as a `Rotation` object, or None if no pose has been received. + - The raw 4x4 pose matrix as received from the teleop stream. + """ + with self._android_lock: + if self._latest_pose is None: + return False, None, None, None + p = self._latest_pose.copy() + pose = self._latest_pose + rot = Rotation.from_matrix(p[:3, :3]) + pos = p[:3, 3] - rot.apply(self.config.camera_offset) + return True, pos, rot, pose + + def _android_callback(self, pose: np.ndarray, message: dict) -> None: + """ + Callback function to handle incoming data from the Android teleop stream. + + This method is executed by the `teleop` package's subscriber thread whenever a new + pose and message are received from the WebXR session on the Android phone. It updates + the internal state (`_latest_pose` and `_latest_message`) with the new data. + A thread lock is used to ensure that these shared variables are updated atomically, + preventing race conditions with the main thread that reads them. + + Args: + pose: A 4x4 NumPy array representing the phone's transformation matrix. + message: A dictionary containing additional data, such as button presses or touch events. + """ + with self._android_lock: + self._latest_pose = pose + self._latest_message = message + + def get_action(self) -> dict: + ok, raw_pos, raw_rot, pose = self._read_current_pose() + if not ok or not self.is_calibrated: + return {} + + # Collect raw inputs (B1 / analogs on iOS, move/scale on Android) + raw_inputs: dict[str, float | int | bool] = {} + msg = self._latest_message or {} + raw_inputs["move"] = bool(msg.get("move", False)) + raw_inputs["scale"] = float(msg.get("scale", 1.0)) + raw_inputs["reservedButtonA"] = bool(msg.get("reservedButtonA", False)) + raw_inputs["reservedButtonB"] = bool(msg.get("reservedButtonB", False)) + + enable = bool(raw_inputs.get("move", False)) + + # Rising edge then re-capture calibration immediately from current raw pose + if enable and not self._enabled: + self._reapply_position_calibration(raw_pos) + + # Apply calibration + pos_cal = self._calib_rot_inv.apply(raw_pos - self._calib_pos) + rot_cal = self._calib_rot_inv * raw_rot + + self._enabled = enable + + return { + "phone.pos": pos_cal, + "phone.rot": rot_cal, + "phone.raw_inputs": raw_inputs, + "phone.enabled": self._enabled, + } + + def disconnect(self) -> None: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self._teleop = None + if self._teleop_thread and self._teleop_thread.is_alive(): + self._teleop_thread.join(timeout=1.0) + self._teleop_thread = None + self._latest_pose = None + + +class Phone(Teleoperator): + """ + Phone-based teleoperator using ARKit (iOS via HEBI Mobile I/O App) or the teleop Python package (Android via WebXR API). + For HEBI Mobile I/O we also expose 8 analog (a1-a8) and 8 digital (b1-b8) inputs. + + Press and hold **B1** to enable teleoperation. While enabled, the first B1 press + captures a reference pose and rotation, when disabled and pressed again the position is reapplied. + """ + + config_class = PhoneConfig + name = "phone" + + def __init__(self, config: PhoneConfig): + super().__init__(config) + self.config = config + + self._phone_impl: Teleoperator + + if self.config.phone_os == PhoneOS.IOS: + self._phone_impl = IOSPhone(config) + elif self.config.phone_os == PhoneOS.ANDROID: + self._phone_impl = AndroidPhone(config) + else: + raise ValueError(f"Invalid config phone_os: {self.config.phone_os}") + + @property + def is_connected(self) -> bool: + return self._phone_impl.is_connected + + def connect(self) -> None: + return self._phone_impl.connect() + + def calibrate(self) -> None: + return self._phone_impl.calibrate() + + @property + def is_calibrated(self) -> bool: + return self._phone_impl.is_calibrated + + @property + def action_features(self) -> dict[str, type]: + return self._phone_impl.action_features + + @property + def feedback_features(self) -> dict[str, type]: + return self._phone_impl.feedback_features + + def configure(self) -> None: + return self._phone_impl.configure() + + def get_action(self) -> dict: + return self._phone_impl.get_action() + + def send_feedback(self, feedback: dict[str, float]) -> None: + return self._phone_impl.send_feedback(feedback) + + def disconnect(self) -> None: + return self._phone_impl.disconnect() diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index 02e6fd22c..bad7d9c37 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -12,10 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum + from .config import TeleoperatorConfig from .teleoperator import Teleoperator +class TeleopEvents(Enum): + """Shared constants for teleoperator events across teleoperators.""" + + SUCCESS = "success" + FAILURE = "failure" + RERECORD_EPISODE = "rerecord_episode" + IS_INTERVENTION = "is_intervention" + TERMINATE_EPISODE = "terminate_episode" + + def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: if config.type == "keyboard": from .keyboard import KeyboardTeleop diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py index 4bcc241da..47beb5746 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/utils/control_utils.py @@ -22,6 +22,7 @@ import traceback from contextlib import nullcontext from copy import copy from functools import cache +from typing import Any import numpy as np import torch @@ -31,10 +32,25 @@ from termcolor import colored from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import DEFAULT_FEATURES from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.processor import PolicyAction, PolicyProcessorPipeline from lerobot.robots import Robot def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): + """ + Logs performance metrics for a single step of the robot control loop. + + This function formats and prints a single line of log information, including episode/frame counters, + total loop time (dt), and detailed timings for various robot and camera operations. It can also + highlight performance drops in yellow if the actual FPS is lower than the target FPS. + + Args: + robot: The `Robot` instance, used to access its internal logs for detailed timings. + dt_s: The total duration of the control loop step in seconds. + episode_index: The index of the current episode. + frame_index: The index of the current frame within the episode. + fps: The target frames per second, used to check for performance degradation. + """ log_items = [] if episode_index is not None: log_items.append(f"ep:{episode_index}") @@ -80,7 +96,16 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f @cache def is_headless(): - """Detects if python is running without a monitor.""" + """ + Detects if the Python script is running in a headless environment (e.g., without a display). + + This function attempts to import `pynput`, a library that requires a graphical environment. + If the import fails, it assumes the environment is headless. The result is cached to avoid + re-running the check. + + Returns: + True if the environment is determined to be headless, False otherwise. + """ try: import pynput # noqa @@ -101,10 +126,35 @@ def predict_action( observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, + preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], use_amp: bool, task: str | None = None, robot_type: str | None = None, ): + """ + Performs a single-step inference to predict a robot action from an observation. + + This function encapsulates the full inference pipeline: + 1. Prepares the observation by converting it to PyTorch tensors and adding a batch dimension. + 2. Runs the preprocessor pipeline on the observation. + 3. Feeds the processed observation to the policy to get a raw action. + 4. Runs the postprocessor pipeline on the raw action. + 5. Formats the final action by removing the batch dimension and moving it to the CPU. + + Args: + observation: A dictionary of NumPy arrays representing the robot's current observation. + policy: The `PreTrainedPolicy` model to use for action prediction. + device: The `torch.device` (e.g., 'cuda' or 'cpu') to run inference on. + preprocessor: The `PolicyProcessorPipeline` for preprocessing observations. + postprocessor: The `PolicyProcessorPipeline` for postprocessing actions. + use_amp: A boolean to enable/disable Automatic Mixed Precision for CUDA inference. + task: An optional string identifier for the task. + robot_type: An optional string identifier for the robot type. + + Returns: + A `torch.Tensor` containing the predicted action, ready for the robot. + """ observation = copy(observation) with ( torch.inference_mode(), @@ -122,10 +172,14 @@ def predict_action( observation["task"] = task if task else "" observation["robot_type"] = robot_type if robot_type else "" + observation = preprocessor(observation) + # Compute the next action with the policy # based on the current observation action = policy.select_action(observation) + action = postprocessor(action) + # Remove batch dimension action = action.squeeze(0) @@ -136,6 +190,18 @@ def predict_action( def init_keyboard_listener(): + """ + Initializes a non-blocking keyboard listener for real-time user interaction. + + This function sets up a listener for specific keys (right arrow, left arrow, escape) to control + the program flow during execution, such as stopping recording or exiting loops. It gracefully + handles headless environments where keyboard listening is not possible. + + Returns: + A tuple containing: + - The `pynput.keyboard.Listener` instance, or `None` if in a headless environment. + - A dictionary of event flags (e.g., `exit_early`) that are set by key presses. + """ # Allow to exit early while recording an episode or resetting the environment, # by tapping the right arrow key '->'. This might require a sudo permission # to allow your terminal to monitor keyboard events. @@ -177,6 +243,19 @@ def init_keyboard_listener(): def sanity_check_dataset_name(repo_id, policy_cfg): + """ + Validates the dataset repository name against the presence of a policy configuration. + + This function enforces a naming convention: a dataset repository ID should start with "eval_" + if and only if a policy configuration is provided for evaluation purposes. + + Args: + repo_id: The Hugging Face Hub repository ID of the dataset. + policy_cfg: The configuration object for the policy, or `None`. + + Raises: + ValueError: If the naming convention is violated. + """ _, dataset_name = repo_id.split("/") # either repo_id doesnt start with "eval_" and there is no policy # or repo_id starts with "eval_" and there is a policy @@ -197,6 +276,21 @@ def sanity_check_dataset_name(repo_id, policy_cfg): def sanity_check_dataset_robot_compatibility( dataset: LeRobotDataset, robot: Robot, fps: int, features: dict ) -> None: + """ + Checks if a dataset's metadata is compatible with the current robot and recording setup. + + This function compares key metadata fields (`robot_type`, `fps`, and `features`) from the + dataset against the current configuration to ensure that appended data will be consistent. + + Args: + dataset: The `LeRobotDataset` instance to check. + robot: The `Robot` instance representing the current hardware setup. + fps: The current recording frequency (frames per second). + features: The dictionary of features for the current recording session. + + Raises: + ValueError: If any of the checked metadata fields do not match. + """ fields = [ ("robot_type", dataset.meta.robot_type, robot.robot_type), ("fps", dataset.fps, fps), diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index 5c29b5a84..09e649372 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -58,6 +58,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b _torch_available, _torch_version = is_package_available("torch", return_version=True) +_transformers_available = is_package_available("transformers") _gym_xarm_available = is_package_available("gym_xarm") _gym_aloha_available = is_package_available("gym_aloha") _gym_pusht_available = is_package_available("gym_pusht") diff --git a/src/lerobot/utils/rotation.py b/src/lerobot/utils/rotation.py new file mode 100644 index 000000000..41b652947 --- /dev/null +++ b/src/lerobot/utils/rotation.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom rotation utilities to replace scipy.spatial.transform.Rotation.""" + +import numpy as np + + +class Rotation: + """ + Custom rotation class that provides a subset of scipy.spatial.transform.Rotation functionality. + + Supports conversions between rotation vectors, rotation matrices, and quaternions. + """ + + def __init__(self, quat: np.ndarray) -> None: + """Initialize rotation from quaternion [x, y, z, w].""" + self._quat = np.asarray(quat, dtype=float) + # Normalize quaternion + norm = np.linalg.norm(self._quat) + if norm > 0: + self._quat = self._quat / norm + + @classmethod + def from_rotvec(cls, rotvec: np.ndarray) -> "Rotation": + """ + Create rotation from rotation vector using Rodrigues' formula. + + Args: + rotvec: Rotation vector [x, y, z] where magnitude is angle in radians + + Returns: + Rotation instance + """ + rotvec = np.asarray(rotvec, dtype=float) + angle = np.linalg.norm(rotvec) + + if angle < 1e-8: + # For very small angles, use identity quaternion + quat = np.array([0.0, 0.0, 0.0, 1.0]) + else: + axis = rotvec / angle + half_angle = angle / 2.0 + sin_half = np.sin(half_angle) + cos_half = np.cos(half_angle) + + # Quaternion [x, y, z, w] + quat = np.array([axis[0] * sin_half, axis[1] * sin_half, axis[2] * sin_half, cos_half]) + + return cls(quat) + + @classmethod + def from_matrix(cls, matrix: np.ndarray) -> "Rotation": + """ + Create rotation from 3x3 rotation matrix. + + Args: + matrix: 3x3 rotation matrix + + Returns: + Rotation instance + """ + matrix = np.asarray(matrix, dtype=float) + + # Shepherd's method for converting rotation matrix to quaternion + trace = np.trace(matrix) + + if trace > 0: + s = np.sqrt(trace + 1.0) * 2 # s = 4 * qw + qw = 0.25 * s + qx = (matrix[2, 1] - matrix[1, 2]) / s + qy = (matrix[0, 2] - matrix[2, 0]) / s + qz = (matrix[1, 0] - matrix[0, 1]) / s + elif matrix[0, 0] > matrix[1, 1] and matrix[0, 0] > matrix[2, 2]: + s = np.sqrt(1.0 + matrix[0, 0] - matrix[1, 1] - matrix[2, 2]) * 2 # s = 4 * qx + qw = (matrix[2, 1] - matrix[1, 2]) / s + qx = 0.25 * s + qy = (matrix[0, 1] + matrix[1, 0]) / s + qz = (matrix[0, 2] + matrix[2, 0]) / s + elif matrix[1, 1] > matrix[2, 2]: + s = np.sqrt(1.0 + matrix[1, 1] - matrix[0, 0] - matrix[2, 2]) * 2 # s = 4 * qy + qw = (matrix[0, 2] - matrix[2, 0]) / s + qx = (matrix[0, 1] + matrix[1, 0]) / s + qy = 0.25 * s + qz = (matrix[1, 2] + matrix[2, 1]) / s + else: + s = np.sqrt(1.0 + matrix[2, 2] - matrix[0, 0] - matrix[1, 1]) * 2 # s = 4 * qz + qw = (matrix[1, 0] - matrix[0, 1]) / s + qx = (matrix[0, 2] + matrix[2, 0]) / s + qy = (matrix[1, 2] + matrix[2, 1]) / s + qz = 0.25 * s + + quat = np.array([qx, qy, qz, qw]) + return cls(quat) + + @classmethod + def from_quat(cls, quat: np.ndarray) -> "Rotation": + """ + Create rotation from quaternion. + + Args: + quat: Quaternion [x, y, z, w] or [w, x, y, z] (specify convention in docstring) + This implementation expects [x, y, z, w] format + + Returns: + Rotation instance + """ + return cls(quat) + + def as_matrix(self) -> np.ndarray: + """ + Convert rotation to 3x3 rotation matrix. + + Returns: + 3x3 rotation matrix + """ + qx, qy, qz, qw = self._quat + + # Compute rotation matrix from quaternion + return np.array( + [ + [1 - 2 * (qy * qy + qz * qz), 2 * (qx * qy - qz * qw), 2 * (qx * qz + qy * qw)], + [2 * (qx * qy + qz * qw), 1 - 2 * (qx * qx + qz * qz), 2 * (qy * qz - qx * qw)], + [2 * (qx * qz - qy * qw), 2 * (qy * qz + qx * qw), 1 - 2 * (qx * qx + qy * qy)], + ], + dtype=float, + ) + + def as_rotvec(self) -> np.ndarray: + """ + Convert rotation to rotation vector. + + Returns: + Rotation vector [x, y, z] where magnitude is angle in radians + """ + qx, qy, qz, qw = self._quat + + # Ensure qw is positive for unique representation + if qw < 0: + qx, qy, qz, qw = -qx, -qy, -qz, -qw + + # Compute angle and axis + angle = 2.0 * np.arccos(np.clip(abs(qw), 0.0, 1.0)) + sin_half_angle = np.sqrt(1.0 - qw * qw) + + if sin_half_angle < 1e-8: + # For very small angles, use linearization: rotvec ≈ 2 * [qx, qy, qz] + return 2.0 * np.array([qx, qy, qz]) + + # Extract axis and scale by angle + axis = np.array([qx, qy, qz]) / sin_half_angle + return angle * axis + + def as_quat(self) -> np.ndarray: + """ + Get quaternion representation. + + Returns: + Quaternion [x, y, z, w] + """ + return self._quat.copy() + + def apply(self, vectors: np.ndarray, inverse: bool = False) -> np.ndarray: + """ + Apply this rotation to a set of vectors. + + This is equivalent to applying the rotation matrix to the vectors: + self.as_matrix() @ vectors (or self.as_matrix().T @ vectors if inverse=True). + + Args: + vectors: Array of shape (3,) or (N, 3) representing vectors in 3D space + inverse: If True, apply the inverse of the rotation. Default is False. + + Returns: + Rotated vectors with shape: + - (3,) if input was single vector with shape (3,) + - (N, 3) in all other cases + """ + vectors = np.asarray(vectors, dtype=float) + original_shape = vectors.shape + + # Handle single vector case - ensure it's 2D for matrix multiplication + if vectors.ndim == 1: + if len(vectors) != 3: + raise ValueError("Single vector must have length 3") + vectors = vectors.reshape(1, 3) + single_vector = True + elif vectors.ndim == 2: + if vectors.shape[1] != 3: + raise ValueError("Vectors must have shape (N, 3)") + single_vector = False + else: + raise ValueError("Vectors must be 1D or 2D array") + + # Get rotation matrix + rotation_matrix = self.as_matrix() + + # Apply inverse if requested (transpose for orthogonal rotation matrices) + if inverse: + rotation_matrix = rotation_matrix.T + + # Apply rotation: (N, 3) @ (3, 3).T -> (N, 3) + rotated_vectors = vectors @ rotation_matrix.T + + # Return original shape for single vector case + if single_vector and original_shape == (3,): + return rotated_vectors.flatten() + + return rotated_vectors + + def inv(self) -> "Rotation": + """ + Invert this rotation. + + Composition of a rotation with its inverse results in an identity transformation. + + Returns: + Rotation instance containing the inverse of this rotation + """ + qx, qy, qz, qw = self._quat + + # For a unit quaternion, the inverse is the conjugate: [-x, -y, -z, w] + inverse_quat = np.array([-qx, -qy, -qz, qw]) + + return Rotation(inverse_quat) + + def __mul__(self, other: "Rotation") -> "Rotation": + """ + Compose this rotation with another rotation using the * operator. + + The composition `r2 * r1` means "apply r1 first, then r2". + This is equivalent to applying rotation matrices: r2.as_matrix() @ r1.as_matrix() + + Args: + other: Another Rotation instance to compose with + + Returns: + Rotation instance representing the composition of rotations + """ + if not isinstance(other, Rotation): + return NotImplemented + + # Get quaternions [x, y, z, w] + x1, y1, z1, w1 = other._quat # Apply first + x2, y2, z2, w2 = self._quat # Apply second + + # Quaternion multiplication: q2 * q1 (apply q1 first, then q2) + composed_quat = np.array( + [ + w2 * x1 + x2 * w1 + y2 * z1 - z2 * y1, # x component + w2 * y1 - x2 * z1 + y2 * w1 + z2 * x1, # y component + w2 * z1 + x2 * y1 - y2 * x1 + z2 * w1, # z component + w2 * w1 - x2 * x1 - y2 * y1 - z2 * z1, # w component + ] + ) + + return Rotation(composed_quat) diff --git a/src/lerobot/utils/train_utils.py b/src/lerobot/utils/train_utils.py index 2859fe057..be2eb8146 100644 --- a/src/lerobot/utils/train_utils.py +++ b/src/lerobot/utils/train_utils.py @@ -32,6 +32,7 @@ from lerobot.datasets.utils import load_json, write_json from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.processor import PolicyProcessorPipeline from lerobot.utils.random_utils import load_rng_state, save_rng_state @@ -74,6 +75,8 @@ def save_checkpoint( policy: PreTrainedPolicy, optimizer: Optimizer, scheduler: LRScheduler | None = None, + preprocessor: PolicyProcessorPipeline | None = None, + postprocessor: PolicyProcessorPipeline | None = None, ) -> None: """This function creates the following directory structure: @@ -81,7 +84,9 @@ def save_checkpoint( ├── pretrained_model/ │ ├── config.json # policy config │ ├── model.safetensors # policy weights - │ └── train_config.json # train config + │ ├── train_config.json # train config + │ ├── processor.json # processor config (if preprocessor provided) + │ └── step_*.safetensors # processor state files (if any) └── training_state/ ├── optimizer_param_groups.json # optimizer param groups ├── optimizer_state.safetensors # optimizer state @@ -95,10 +100,15 @@ def save_checkpoint( policy (PreTrainedPolicy): The policy to save. optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None. scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None. + preprocessor: The preprocessor/pipeline to save. Defaults to None. """ pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR policy.save_pretrained(pretrained_dir) cfg.save_pretrained(pretrained_dir) + if preprocessor is not None: + preprocessor.save_pretrained(pretrained_dir) + if postprocessor is not None: + postprocessor.save_pretrained(pretrained_dir) save_training_state(checkpoint_dir, step, optimizer, scheduler) diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index f0f9aebb7..e6acc87de 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numbers import os from typing import Any @@ -28,19 +29,69 @@ def _init_rerun(session_name: str = "lerobot_control_loop") -> None: rr.spawn(memory_limit=memory_limit) -def log_rerun_data(observation: dict[str | Any], action: dict[str | Any]): - for obs, val in observation.items(): - if isinstance(val, float): - rr.log(f"observation.{obs}", rr.Scalar(val)) - elif isinstance(val, np.ndarray): - if val.ndim == 1: - for i, v in enumerate(val): - rr.log(f"observation.{obs}_{i}", rr.Scalar(float(v))) - else: - rr.log(f"observation.{obs}", rr.Image(val), static=True) - for act, val in action.items(): - if isinstance(val, float): - rr.log(f"action.{act}", rr.Scalar(val)) - elif isinstance(val, np.ndarray): - for i, v in enumerate(val): - rr.log(f"action.{act}_{i}", rr.Scalar(float(v))) +def _is_scalar(x): + return ( + isinstance(x, float) + or isinstance(x, numbers.Real) + or isinstance(x, (np.integer, np.floating)) + or (isinstance(x, np.ndarray) and x.ndim == 0) + ) + + +def log_rerun_data( + observation: dict[str, Any] | None = None, + action: dict[str, Any] | None = None, +) -> None: + """ + Logs observation and action data to Rerun for real-time visualization. + + This function iterates through the provided observation and action dictionaries and sends their contents + to the Rerun viewer. It handles different data types appropriately: + - Scalar values (floats, ints) are logged as `rr.Scalar`. + - 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed + from CHW to HWC format and logged as `rr.Image`. + - 1D NumPy arrays are logged as a series of individual scalars, with each element indexed. + - Other multi-dimensional arrays are flattened and logged as individual scalars. + + Keys are automatically namespaced with "observation." or "action." if not already present. + + Args: + observation: An optional dictionary containing observation data to log. + action: An optional dictionary containing action data to log. + """ + if observation: + for k, v in observation.items(): + if v is None: + continue + key = k if str(k).startswith("observation.") else f"observation.{k}" + + if _is_scalar(v): + rr.log(key, rr.Scalar(float(v))) + elif isinstance(v, np.ndarray): + arr = v + # Convert CHW -> HWC when needed + if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4): + arr = np.transpose(arr, (1, 2, 0)) + if arr.ndim == 1: + for i, vi in enumerate(arr): + rr.log(f"{key}_{i}", rr.Scalar(float(vi))) + else: + rr.log(key, rr.Image(arr), static=True) + + if action: + for k, v in action.items(): + if v is None: + continue + key = k if str(k).startswith("action.") else f"action.{k}" + + if _is_scalar(v): + rr.log(key, rr.Scalar(float(v))) + elif isinstance(v, np.ndarray): + if v.ndim == 1: + for i, vi in enumerate(v): + rr.log(f"{key}_{i}", rr.Scalar(float(vi))) + else: + # Fall back to flattening higher-dimensional arrays + flat = v.flatten() + for i, vi in enumerate(flat): + rr.log(f"{key}_{i}", rr.Scalar(float(vi))) diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_/actions.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_/actions.safetensors index 8bd63e894..771af2445 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_/actions.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f3e4c8e85e146b043fd4e4984947c2a6f01627f174a19f18b5914cf690579d77 +oid sha256:ee0c29d3782aa1cadcf4dc6ed767d9460ff00fff9fc70b460502340b832eefcc size 5104 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_/param_stats.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_/param_stats.safetensors index 724d22b58..3e8df708e 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_/param_stats.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9b5f557e30aead3731c38cbd85af8c706395d8689a918ad88805b5a886245603 -size 33400 +oid sha256:ea76e6711959fd3f905ec2bdc306f488920f00ec99421e4870d05f6205eb323e +size 31672 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors index 6d912d81a..dd7d4d0e7 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2e6625cabfeb4800abc80252cf9112a9271c154edd01eb291658f143c951610b +oid sha256:c2b8f8532c7a0b776de5e536b8b54e30b1a0c2e3d5cc25a2d86fe43e40ae5e8c size 515400 diff --git a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors index cc6b4a24b..5da67a1af 100644 --- a/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors +++ b/tests/artifacts/policies/aloha_sim_insertion_human_act_1000_steps/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:021562ee3e4814425e367ed0c144d6fbe2eb28838247085716cf0b58fd69a075 -size 33400 +oid sha256:eca0d87a699620e4fec7e68539b0be91e4cc933f6bf12032da52c182ab6f38cf +size 31672 diff --git a/tests/artifacts/policies/pusht_diffusion_/actions.safetensors b/tests/artifacts/policies/pusht_diffusion_/actions.safetensors index 84e14b975..ef581727d 100644 --- a/tests/artifacts/policies/pusht_diffusion_/actions.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a32376dde65a1562403afd1db3e56c7e6b987ebaf6c3c601336e77155b9e608c +oid sha256:19eaaa85f66ba4aa6388dbb83819ffad6ea4363247208f871a8dc385689f6fc8 size 992 diff --git a/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors b/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors index 542297910..e00ed3238 100644 --- a/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:12ee532c53173d0361ebb979f087b229cc045aa3d9e6b94cfd4290af54fd1201 +oid sha256:227296eaeeb54acdc3dae2eb8af3d4d08fb87e245337624447140b1e91cfd002 size 47424 diff --git a/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors b/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors index e91cd08b7..614cc754e 100644 --- a/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors +++ b/tests/artifacts/policies/pusht_diffusion_/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:010c01181b95625051276d69cb4209423c21f2e30a3fa9464ae67064a2ba4c22 -size 49120 +oid sha256:778fddbbaa64248cee35cb377c02cc2b6076f7ce5855146de677128900617ddf +size 47424 diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index 6ccb47c3e..b0ffa9a31 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -23,7 +23,7 @@ from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset from lerobot.optim.factory import make_optimizer_and_scheduler -from lerobot.policies.factory import make_policy, make_policy_config +from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors from lerobot.utils.random_utils import set_seed @@ -37,7 +37,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): train_cfg.validate() # Needed for auto-setting some parameters dataset = make_dataset(train_cfg) + dataset_stats = dataset.meta.stats policy = make_policy(train_cfg.policy, ds_meta=dataset.meta) + preprocessor, postprocessor = make_pre_post_processors(train_cfg.policy, dataset_stats=dataset_stats) policy.train() optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy) @@ -49,7 +51,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): ) batch = next(iter(dataloader)) + batch = preprocessor(batch) loss, output_dict = policy.forward(batch) + if output_dict is not None: output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)} output_dict["loss"] = loss @@ -96,7 +100,12 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): else: actions_queue = train_cfg.policy.n_action_repeats - actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)} + actions = {} + for i in range(actions_queue): + unnormalized_action = policy.select_action(obs).contiguous() + action_robot = postprocessor(unnormalized_action) + actions[str(i)] = action_robot + return output_dict, grad_stats, param_stats, actions diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors index fa9bf06ab..e23eacffd 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c5edc5600d7206f027cb696a597bc99fcdd9073a15fa130b8031c52c0a7c134b +oid sha256:d640988f2269cf6aa03c8ee17f9d096edace83d837f90025011fafec5bf53c61 size 200 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors index 8d90a671f..e665f73c6 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10 +oid sha256:32ddf36af25791935b395c7641531cda14d5c4a2cf654a2e76ac45271665d07a size 16904 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors index cde6c6dca..97d783580 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/output_dict.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b +oid sha256:22a1031a2acfc36a455bff73ffbe097cfeb7742b6485e7422507e78d7a682703 size 164 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors index 692377d1f..3090b7051 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_mpc/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170 -size 36312 +oid sha256:b5dca7940998421ae58e9e26b2b2641b058d23b0270b7a147ebf85fbbdce7184 +size 35496 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/actions.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/actions.safetensors index 7a0b165e2..5ce44048f 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/actions.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/actions.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a9c08753ddc43b6c02a176418b81eb784146e59f4fc914591cbd3582ade392bb +oid sha256:2212ae7b910d14d723214f5af50985e419f7bd0f4261565ef48b1ef495443d6d size 200 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors index 8d90a671f..e665f73c6 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/grad_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a70e29263afdbff3a49d7041ff2d5065df75472b7c030cc8a5d12ab20d24cc10 +oid sha256:32ddf36af25791935b395c7641531cda14d5c4a2cf654a2e76ac45271665d07a size 16904 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors index cde6c6dca..97d783580 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/output_dict.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c49a5b4d4df92c9564009780f5e286ddfca84ca2b1753557024057b3b36afb8b +oid sha256:22a1031a2acfc36a455bff73ffbe097cfeb7742b6485e7422507e78d7a682703 size 164 diff --git a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors index 692377d1f..3090b7051 100644 --- a/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors +++ b/tests/artifacts/policies/xarm_lift_medium_tdmpc_use_policy/param_stats.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5f8d19a86065937cffdd3ca49caef87c59e67d419b28f40f2817bad892dc3170 -size 36312 +oid sha256:b5dca7940998421ae58e9e26b2b2641b058d23b0270b7a147ebf85fbbdce7184 +size 35496 diff --git a/tests/conftest.py b/tests/conftest.py index e273da50f..245cde526 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,7 @@ import traceback import pytest from serial import SerialException -from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from tests.utils import DEVICE # Import fixture modules as plugins @@ -83,7 +83,9 @@ def policy_feature_factory(): return _pf -def assert_contract_is_typed(features: dict[str, PolicyFeature]) -> None: +def assert_contract_is_typed(features: dict[PipelineFeatureType, dict[str, PolicyFeature]]) -> None: assert isinstance(features, dict) - assert all(isinstance(k, str) for k in features.keys()) - assert all(isinstance(v, PolicyFeature) for v in features.values()) + assert all(isinstance(k, PipelineFeatureType) for k in features.keys()) + assert all(isinstance(v, dict) for v in features.values()) + assert all(all(isinstance(nk, str) for nk in v.keys()) for v in features.values()) + assert all(all(isinstance(nv, PolicyFeature) for nv in v.values()) for v in features.values()) diff --git a/tests/datasets/test_dataset_utils.py b/tests/datasets/test_dataset_utils.py new file mode 100644 index 000000000..f1ffd800a --- /dev/null +++ b/tests/datasets/test_dataset_utils.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from datasets import Dataset +from huggingface_hub import DatasetCard + +from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index +from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch + + +def test_default_parameters(): + card = create_lerobot_dataset_card() + assert isinstance(card, DatasetCard) + assert card.data.tags == ["LeRobot"] + assert card.data.task_categories == ["robotics"] + assert card.data.configs == [ + { + "config_name": "default", + "data_files": "data/*/*.parquet", + } + ] + + +def test_with_tags(): + tags = ["tag1", "tag2"] + card = create_lerobot_dataset_card(tags=tags) + assert card.data.tags == ["LeRobot", "tag1", "tag2"] + + +def test_calculate_episode_data_index(): + dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "index": [0, 1, 2, 3, 4, 5], + "episode_index": [0, 0, 1, 2, 2, 2], + }, + ) + dataset.set_transform(hf_transform_to_torch) + episode_data_index = calculate_episode_data_index(dataset) + assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3])) + assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6])) + + +def test_merge_simple_vectors(): + g1 = { + "action": { + "dtype": "float32", + "shape": (2,), + "names": ["ee.x", "ee.y"], + } + } + g2 = { + "action": { + "dtype": "float32", + "shape": (2,), + "names": ["ee.y", "ee.z"], + } + } + + out = combine_feature_dicts(g1, g2) + + assert "action" in out + assert out["action"]["dtype"] == "float32" + # Names merged with preserved order and de-dupuplication + assert out["action"]["names"] == ["ee.x", "ee.y", "ee.z"] + # Shape correctly recomputed from names length + assert out["action"]["shape"] == (3,) + + +def test_merge_multiple_groups_order_and_dedup(): + g1 = {"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}} + g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}} + g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}} + + out = combine_feature_dicts(g1, g2, g3) + + assert out["action"]["names"] == ["a", "b", "c", "d"] + assert out["action"]["shape"] == (4,) + + +def test_non_vector_last_wins_for_images(): + # Non-vector (images) with same name should be overwritten by the last image specified + g1 = { + "observation.images.front": { + "dtype": "image", + "shape": (3, 480, 640), + "names": ["channels", "height", "width"], + } + } + g2 = { + "observation.images.front": { + "dtype": "image", + "shape": (3, 720, 1280), + "names": ["channels", "height", "width"], + } + } + + out = combine_feature_dicts(g1, g2) + assert out["observation.images.front"]["shape"] == (3, 720, 1280) + assert out["observation.images.front"]["dtype"] == "image" + + +def test_dtype_mismatch_raises(): + g1 = {"action": {"dtype": "float32", "shape": (1,), "names": ["a"]}} + g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}} + + with pytest.raises(ValueError, match="dtype mismatch for 'action'"): + _ = combine_feature_dicts(g1, g2) + + +def test_non_dict_passthrough_last_wins(): + g1 = {"misc": 123} + g2 = {"misc": 456} + + out = combine_feature_dicts(g1, g2) + # For non-dict entries the last one wins + assert out["misc"] == 456 diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py deleted file mode 100644 index 91d661b3c..000000000 --- a/tests/datasets/test_utils.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from copy import deepcopy - -import torch -from datasets import Dataset -from huggingface_hub import DatasetCard - -from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index -from lerobot.datasets.utils import ( - create_lerobot_dataset_card, - flatten_dict, - hf_transform_to_torch, - unflatten_dict, -) - - -def test_default_parameters(): - card = create_lerobot_dataset_card() - assert isinstance(card, DatasetCard) - assert card.data.tags == ["LeRobot"] - assert card.data.task_categories == ["robotics"] - assert card.data.configs == [ - { - "config_name": "default", - "data_files": "data/*/*.parquet", - } - ] - - -def test_with_tags(): - tags = ["tag1", "tag2"] - card = create_lerobot_dataset_card(tags=tags) - assert card.data.tags == ["LeRobot", "tag1", "tag2"] - - -def test_calculate_episode_data_index(): - dataset = Dataset.from_dict( - { - "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - "index": [0, 1, 2, 3, 4, 5], - "episode_index": [0, 0, 1, 2, 2, 2], - }, - ) - dataset.set_transform(hf_transform_to_torch) - episode_data_index = calculate_episode_data_index(dataset) - assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3])) - assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6])) - - -def test_flatten_unflatten_dict(): - d = { - "obs": { - "min": 0, - "max": 1, - "mean": 2, - "std": 3, - }, - "action": { - "min": 4, - "max": 5, - "mean": 6, - "std": 7, - }, - } - - original_d = deepcopy(d) - d = unflatten_dict(flatten_dict(d)) - - # test equality between nested dicts - assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}" diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index ef2d4ecd8..ef09bcd22 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -26,7 +26,7 @@ from safetensors.torch import load_file from lerobot import available_policies from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.constants import ACTION, OBS_STATE from lerobot.datasets.factory import make_dataset from lerobot.datasets.utils import cycle, dataset_to_policy_features @@ -39,8 +39,8 @@ from lerobot.policies.factory import ( get_policy_class, make_policy, make_policy_config, + make_pre_post_processors, ) -from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.utils.random_utils import seeded_context from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats @@ -154,6 +154,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): # Check that we can make the policy object. dataset = make_dataset(train_cfg) + preprocessor, _ = make_pre_post_processors(train_cfg.policy, None) policy = make_policy(train_cfg.policy, ds_meta=dataset.meta) assert isinstance(policy, PreTrainedPolicy) @@ -227,6 +228,7 @@ def test_act_backbone_lr(): assert cfg.policy.optimizer_lr_backbone == 0.001 dataset = make_dataset(cfg) + preprocessor, _ = make_pre_post_processors(cfg.policy, None) policy = make_policy(cfg.policy, ds_meta=dataset.meta) optimizer, _ = make_optimizer_and_scheduler(cfg, policy) assert len(optimizer.param_groups) == 2 @@ -266,108 +268,6 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0) -@pytest.mark.parametrize("insert_temporal_dim", [False, True]) -def test_normalize(insert_temporal_dim): - """ - Test that normalize/unnormalize can run without exceptions when properly set up, and that they raise - an exception when the forward pass is called without the stats having been provided. - - TODO(rcadene, alexander-soare): This should also test that the normalization / unnormalization works as - expected. - """ - - input_features = { - "observation.image": PolicyFeature( - type=FeatureType.VISUAL, - shape=(3, 96, 96), - ), - "observation.state": PolicyFeature( - type=FeatureType.STATE, - shape=(10,), - ), - } - output_features = { - "action": PolicyFeature( - type=FeatureType.ACTION, - shape=(5,), - ), - } - - norm_map = { - "VISUAL": NormalizationMode.MEAN_STD, - "STATE": NormalizationMode.MIN_MAX, - "ACTION": NormalizationMode.MIN_MAX, - } - - dataset_stats = { - "observation.image": { - "mean": torch.randn(3, 1, 1), - "std": torch.randn(3, 1, 1), - "min": torch.randn(3, 1, 1), - "max": torch.randn(3, 1, 1), - }, - "observation.state": { - "mean": torch.randn(10), - "std": torch.randn(10), - "min": torch.randn(10), - "max": torch.randn(10), - }, - "action": { - "mean": torch.randn(5), - "std": torch.randn(5), - "min": torch.randn(5), - "max": torch.randn(5), - }, - } - - bsize = 2 - input_batch = { - "observation.image": torch.randn(bsize, 3, 96, 96), - "observation.state": torch.randn(bsize, 10), - } - output_batch = { - "action": torch.randn(bsize, 5), - } - - if insert_temporal_dim: - tdim = 4 - - for key in input_batch: - # [2,3,96,96] -> [2,tdim,3,96,96] - input_batch[key] = torch.stack([input_batch[key]] * tdim, dim=1) - - for key in output_batch: - output_batch[key] = torch.stack([output_batch[key]] * tdim, dim=1) - - # test without stats - normalize = Normalize(input_features, norm_map, stats=None) - with pytest.raises(AssertionError): - normalize(input_batch) - - # test with stats - normalize = Normalize(input_features, norm_map, stats=dataset_stats) - normalize(input_batch) - - # test loading pretrained models - new_normalize = Normalize(input_features, norm_map, stats=None) - new_normalize.load_state_dict(normalize.state_dict()) - new_normalize(input_batch) - - # test without stats - unnormalize = Unnormalize(output_features, norm_map, stats=None) - with pytest.raises(AssertionError): - unnormalize(output_batch) - - # test with stats - unnormalize = Unnormalize(output_features, norm_map, stats=dataset_stats) - unnormalize(output_batch) - - # test loading pretrained models - new_unnormalize = Unnormalize(output_features, norm_map, stats=None) - new_unnormalize.load_state_dict(unnormalize.state_dict()) - unnormalize(output_batch) - - @pytest.mark.parametrize("multikey", [True, False]) def test_multikey_construction(multikey: bool): """ @@ -467,6 +367,8 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs NOTE: If the test does not pass, and you don't change the policy, it is likely that the test artifact is out of date. For example, some PyTorch versions have different randomness, see this PR: https://github.com/huggingface/lerobot/pull/1127. + NOTE: If the test don't pass and you don't change the policy, and note the dependencies version, + and you changed your processor, you might have to update the test artifact. """ diff --git a/tests/processor/test_act_processor.py b/tests/processor/test_act_processor.py new file mode 100644 index 000000000..f96f871aa --- /dev/null +++ b/tests/processor/test_act_processor.py @@ -0,0 +1,412 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for ACT policy processor.""" + +import tempfile + +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_STATE +from lerobot.policies.act.configuration_act import ACTConfig +from lerobot.policies.act.processor_act import make_act_pre_post_processors +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DataProcessorPipeline, + DeviceProcessorStep, + NormalizerProcessorStep, + RenameObservationsProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import create_transition, transition_to_batch + + +def create_default_config(): + """Create a default ACT configuration for testing.""" + config = ACTConfig() + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(7,)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)), + } + config.normalization_mapping = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + config.device = "cpu" + return config + + +def create_default_stats(): + """Create default dataset statistics for testing.""" + return { + OBS_STATE: {"mean": torch.zeros(7), "std": torch.ones(7)}, + ACTION: {"mean": torch.zeros(4), "std": torch.ones(4)}, + } + + +def test_make_act_processor_basic(): + """Test basic creation of ACT processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_act_pre_post_processors(config, stats) + + # Check processor names + assert preprocessor.name == "policy_preprocessor" + assert postprocessor.name == "policy_postprocessor" + + # Check steps in preprocessor + assert len(preprocessor.steps) == 4 + assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + assert isinstance(preprocessor.steps[3], NormalizerProcessorStep) + + # Check steps in postprocessor + assert len(postprocessor.steps) == 2 + assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) + assert isinstance(postprocessor.steps[1], DeviceProcessorStep) + + +def test_act_processor_normalization(): + """Test that ACT processor correctly normalizes and unnormalizes data.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_act_pre_post_processors( + config, + stats, + ) + + # Create test data + observation = {OBS_STATE: torch.randn(7)} + action = torch.randn(4) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data is normalized and batched + assert processed[OBS_STATE].shape == (1, 7) + assert processed[TransitionKey.ACTION.value].shape == (1, 4) + + # Process action through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that action is unnormalized + assert postprocessed.shape == (1, 4) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_act_processor_cuda(): + """Test ACT processor with CUDA device.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, postprocessor = make_act_pre_post_processors( + config, + stats, + ) + + # Create CPU data + observation = {OBS_STATE: torch.randn(7)} + action = torch.randn(4) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data is on CUDA + assert processed[OBS_STATE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" + + # Process through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that action is back on CPU + assert postprocessed.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_act_processor_accelerate_scenario(): + """Test ACT processor in simulated Accelerate scenario (data already on GPU).""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_act_pre_post_processors( + config, + stats, + ) + + # Simulate Accelerate: data already on GPU + device = torch.device("cuda:0") + observation = {OBS_STATE: torch.randn(1, 7).to(device)} # Already batched and on GPU + action = torch.randn(1, 4).to(device) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data stays on same GPU (not moved unnecessarily) + assert processed[OBS_STATE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_act_processor_multi_gpu(): + """Test ACT processor with multi-GPU setup.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_act_pre_post_processors( + config, + stats, + ) + + # Simulate data on different GPU (like in multi-GPU training) + device = torch.device("cuda:1") + observation = {OBS_STATE: torch.randn(1, 7).to(device)} + action = torch.randn(1, 4).to(device) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data stays on cuda:1 (not moved to cuda:0) + assert processed[OBS_STATE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +def test_act_processor_without_stats(): + """Test ACT processor creation without dataset statistics.""" + config = create_default_config() + + preprocessor, postprocessor = make_act_pre_post_processors( + config, + dataset_stats=None, + ) + + # Should still create processors, but normalization won't have stats + assert preprocessor is not None + assert postprocessor is not None + + # Process should still work (but won't normalize without stats) + observation = {OBS_STATE: torch.randn(7)} + action = torch.randn(4) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed is not None + + +def test_act_processor_save_and_load(): + """Test saving and loading ACT processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_act_pre_post_processors( + config, + stats, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + # Save preprocessor + preprocessor.save_pretrained(tmpdir) + + # Load preprocessor + loaded_preprocessor = DataProcessorPipeline.from_pretrained( + tmpdir, config_filename="policy_preprocessor.json" + ) + + # Test that loaded processor works + observation = {OBS_STATE: torch.randn(7)} + action = torch.randn(4) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 7) + assert processed[TransitionKey.ACTION.value].shape == (1, 4) + + +def test_act_processor_device_placement_preservation(): + """Test that ACT processor preserves device placement correctly.""" + config = create_default_config() + stats = create_default_stats() + + # Test with CPU config + config.device = "cpu" + preprocessor, _ = make_act_pre_post_processors( + config, + stats, + ) + + # Process CPU data + observation = {OBS_STATE: torch.randn(7)} + action = torch.randn(4) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed[OBS_STATE].device.type == "cpu" + assert processed[TransitionKey.ACTION.value].device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_act_processor_mixed_precision(): + """Test ACT processor with mixed precision (float16).""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + # Modify the device processor to use float16 + preprocessor, postprocessor = make_act_pre_post_processors( + config, + stats, + ) + + # Replace DeviceProcessorStep with one that uses float16 + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) + elif isinstance(step, NormalizerProcessorStep): + # Update normalizer to use the same device as the device processor + norm_step = step # Now type checker knows this is NormalizerProcessorStep + modified_steps.append( + NormalizerProcessorStep( + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, + device=config.device, + dtype=torch.float16, # Match the float16 dtype + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Create test data + observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} + action = torch.randn(4, dtype=torch.float32) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data is converted to float16 + assert processed[OBS_STATE].dtype == torch.float16 + assert processed[TransitionKey.ACTION.value].dtype == torch.float16 + + +def test_act_processor_batch_consistency(): + """Test that ACT processor handles different batch sizes correctly.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_act_pre_post_processors( + config, + stats, + ) + + # Test single sample (unbatched) + observation = {OBS_STATE: torch.randn(7)} + action = torch.randn(4) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed["observation.state"].shape[0] == 1 # Batched + + # Test already batched data + observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8 + action_batched = torch.randn(8, 4) + transition_batched = create_transition(observation_batched, action_batched) + batch_batched = transition_to_batch(transition_batched) + + processed_batched = preprocessor(batch_batched) + assert processed_batched[OBS_STATE].shape[0] == 8 + assert processed_batched[TransitionKey.ACTION.value].shape[0] == 8 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_act_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, _ = make_act_pre_post_processors( + config, + stats, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + norm_step = step # Now type checker knows this is NormalizerProcessorStep + modified_steps.append( + NormalizerProcessorStep( + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration + normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data + observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} # Start with float32 + action = torch.randn(4, dtype=torch.float32) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through full pipeline + processed = preprocessor(batch) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 63894025d..631ad7899 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -1,11 +1,7 @@ import torch -from lerobot.processor.pipeline import ( - RobotProcessor, - TransitionKey, - _default_batch_to_transition, - _default_transition_to_batch, -) +from lerobot.processor import DataProcessorPipeline, TransitionKey +from lerobot.processor.converters import batch_to_transition, transition_to_batch def _dummy_batch(): @@ -24,7 +20,7 @@ def _dummy_batch(): def test_observation_grouping_roundtrip(): """Test that observation.* keys are properly grouped and ungrouped.""" - proc = RobotProcessor([]) + proc = DataProcessorPipeline([]) batch_in = _dummy_batch() batch_out = proc(batch_in) @@ -48,19 +44,19 @@ def test_observation_grouping_roundtrip(): def test_batch_to_transition_observation_grouping(): - """Test that _default_batch_to_transition correctly groups observation.* keys.""" + """Test that batch_to_transition correctly groups observation.* keys.""" batch = { "observation.image.top": torch.randn(1, 3, 128, 128), "observation.image.left": torch.randn(1, 3, 128, 128), "observation.state": [1, 2, 3, 4], - "action": "action_data", + "action": torch.tensor([0.1, 0.2, 0.3, 0.4]), "next.reward": 1.5, "next.done": True, "next.truncated": False, "info": {"episode": 42}, } - transition = _default_batch_to_transition(batch) + transition = batch_to_transition(batch) # Check observation is a dict with all observation.* keys assert isinstance(transition[TransitionKey.OBSERVATION], dict) @@ -78,7 +74,7 @@ def test_batch_to_transition_observation_grouping(): assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4] # Check other fields - assert transition[TransitionKey.ACTION] == "action_data" + assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.1, 0.2, 0.3, 0.4])) assert transition[TransitionKey.REWARD] == 1.5 assert transition[TransitionKey.DONE] assert not transition[TransitionKey.TRUNCATED] @@ -87,7 +83,7 @@ def test_batch_to_transition_observation_grouping(): def test_transition_to_batch_observation_flattening(): - """Test that _default_transition_to_batch correctly flattens observation dict.""" + """Test that transition_to_batch correctly flattens observation dict.""" observation_dict = { "observation.image.top": torch.randn(1, 3, 128, 128), "observation.image.left": torch.randn(1, 3, 128, 128), @@ -104,7 +100,7 @@ def test_transition_to_batch_observation_flattening(): TransitionKey.COMPLEMENTARY_DATA: {}, } - batch = _default_transition_to_batch(transition) + batch = transition_to_batch(transition) # Check that observation.* keys are flattened back to batch assert "observation.image.top" in batch @@ -127,28 +123,28 @@ def test_transition_to_batch_observation_flattening(): def test_no_observation_keys(): """Test behavior when there are no observation.* keys.""" batch = { - "action": "action_data", + "action": torch.tensor([1.0, 2.0]), "next.reward": 2.0, "next.done": False, "next.truncated": True, "info": {"test": "no_obs"}, } - transition = _default_batch_to_transition(batch) + transition = batch_to_transition(batch) # Observation should be None when no observation.* keys assert transition[TransitionKey.OBSERVATION] is None # Check other fields - assert transition[TransitionKey.ACTION] == "action_data" + assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([1.0, 2.0])) assert transition[TransitionKey.REWARD] == 2.0 assert not transition[TransitionKey.DONE] assert transition[TransitionKey.TRUNCATED] assert transition[TransitionKey.INFO] == {"test": "no_obs"} # Round trip should work - reconstructed_batch = _default_transition_to_batch(transition) - assert reconstructed_batch["action"] == "action_data" + reconstructed_batch = transition_to_batch(transition) + assert torch.allclose(reconstructed_batch["action"], torch.tensor([1.0, 2.0])) assert reconstructed_batch["next.reward"] == 2.0 assert not reconstructed_batch["next.done"] assert reconstructed_batch["next.truncated"] @@ -157,13 +153,13 @@ def test_no_observation_keys(): def test_minimal_batch(): """Test with minimal batch containing only observation.* and action.""" - batch = {"observation.state": "minimal_state", "action": "minimal_action"} + batch = {"observation.state": "minimal_state", "action": torch.tensor([0.5])} - transition = _default_batch_to_transition(batch) + transition = batch_to_transition(batch) # Check observation assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"} - assert transition[TransitionKey.ACTION] == "minimal_action" + assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.5])) # Check defaults assert transition[TransitionKey.REWARD] == 0.0 @@ -173,9 +169,9 @@ def test_minimal_batch(): assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} # Round trip - reconstructed_batch = _default_transition_to_batch(transition) + reconstructed_batch = transition_to_batch(transition) assert reconstructed_batch["observation.state"] == "minimal_state" - assert reconstructed_batch["action"] == "minimal_action" + assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5])) assert reconstructed_batch["next.reward"] == 0.0 assert not reconstructed_batch["next.done"] assert not reconstructed_batch["next.truncated"] @@ -186,7 +182,7 @@ def test_empty_batch(): """Test behavior with empty batch.""" batch = {} - transition = _default_batch_to_transition(batch) + transition = batch_to_transition(batch) # All fields should have defaults assert transition[TransitionKey.OBSERVATION] is None @@ -198,7 +194,7 @@ def test_empty_batch(): assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} # Round trip - reconstructed_batch = _default_transition_to_batch(transition) + reconstructed_batch = transition_to_batch(transition) assert reconstructed_batch["action"] is None assert reconstructed_batch["next.reward"] == 0.0 assert not reconstructed_batch["next.done"] @@ -219,8 +215,8 @@ def test_complex_nested_observation(): "info": {"episode_length": 200, "success": True}, } - transition = _default_batch_to_transition(batch) - reconstructed_batch = _default_transition_to_batch(transition) + transition = batch_to_transition(batch) + reconstructed_batch = transition_to_batch(transition) # Check that all observation keys are preserved original_obs_keys = {k for k in batch if k.startswith("observation.")} @@ -254,7 +250,7 @@ def test_custom_converter(): def to_tr(batch): # Custom converter that modifies the reward - tr = _default_batch_to_transition(batch) + tr = batch_to_transition(batch) # Double the reward reward = tr.get(TransitionKey.REWARD, 0.0) new_tr = tr.copy() @@ -262,10 +258,10 @@ def test_custom_converter(): return new_tr def to_batch(tr): - batch = _default_transition_to_batch(tr) + batch = transition_to_batch(tr) return batch - processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch) + processor = DataProcessorPipeline(steps=[], to_transition=to_tr, to_output=to_batch) batch = { "observation.state": torch.randn(1, 4), diff --git a/tests/processor/test_batch_processor.py b/tests/processor/test_batch_processor.py new file mode 100644 index 000000000..f7cbafd27 --- /dev/null +++ b/tests/processor/test_batch_processor.py @@ -0,0 +1,1184 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path + +import numpy as np +import pytest +import torch + +from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DataProcessorPipeline, + ProcessorStepRegistry, + TransitionKey, +) +from lerobot.processor.converters import create_transition, identity_transition + + +def test_state_1d_to_2d(): + """Test that 1D state tensors get unsqueezed to 2D.""" + processor = AddBatchDimensionProcessorStep() + + # Test observation.state + state_1d = torch.randn(7) + observation = {OBS_STATE: state_1d} + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + + processed_state = result[TransitionKey.OBSERVATION][OBS_STATE] + assert processed_state.shape == (1, 7) + assert torch.allclose(processed_state.squeeze(0), state_1d) + + +def test_env_state_1d_to_2d(): + """Test that 1D environment state tensors get unsqueezed to 2D.""" + processor = AddBatchDimensionProcessorStep() + + # Test observation.environment_state + env_state_1d = torch.randn(10) + observation = {OBS_ENV_STATE: env_state_1d} + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + + processed_env_state = result[TransitionKey.OBSERVATION][OBS_ENV_STATE] + assert processed_env_state.shape == (1, 10) + assert torch.allclose(processed_env_state.squeeze(0), env_state_1d) + + +def test_image_3d_to_4d(): + """Test that 3D image tensors get unsqueezed to 4D.""" + processor = AddBatchDimensionProcessorStep() + + # Test observation.image + image_3d = torch.randn(224, 224, 3) + observation = {OBS_IMAGE: image_3d} + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + + processed_image = result[TransitionKey.OBSERVATION][OBS_IMAGE] + assert processed_image.shape == (1, 224, 224, 3) + assert torch.allclose(processed_image.squeeze(0), image_3d) + + +def test_multiple_images_3d_to_4d(): + """Test that 3D image tensors in observation.images.* get unsqueezed to 4D.""" + processor = AddBatchDimensionProcessorStep() + + # Test observation.images.camera1 and observation.images.camera2 + image1_3d = torch.randn(64, 64, 3) + image2_3d = torch.randn(128, 128, 3) + observation = { + f"{OBS_IMAGES}.camera1": image1_3d, + f"{OBS_IMAGES}.camera2": image2_3d, + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + + processed_obs = result[TransitionKey.OBSERVATION] + processed_image1 = processed_obs[f"{OBS_IMAGES}.camera1"] + processed_image2 = processed_obs[f"{OBS_IMAGES}.camera2"] + + assert processed_image1.shape == (1, 64, 64, 3) + assert processed_image2.shape == (1, 128, 128, 3) + assert torch.allclose(processed_image1.squeeze(0), image1_3d) + assert torch.allclose(processed_image2.squeeze(0), image2_3d) + + +def test_already_batched_tensors_unchanged(): + """Test that already batched tensors remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + # Create already batched tensors + state_2d = torch.randn(1, 7) + env_state_2d = torch.randn(1, 10) + image_4d = torch.randn(1, 224, 224, 3) + + observation = { + OBS_STATE: state_2d, + OBS_ENV_STATE: env_state_2d, + OBS_IMAGE: image_4d, + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + + processed_obs = result[TransitionKey.OBSERVATION] + + # Should remain unchanged + assert torch.allclose(processed_obs[OBS_STATE], state_2d) + assert torch.allclose(processed_obs[OBS_ENV_STATE], env_state_2d) + assert torch.allclose(processed_obs[OBS_IMAGE], image_4d) + + +def test_higher_dimensional_tensors_unchanged(): + """Test that tensors with more dimensions than expected remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + # Create tensors with more dimensions + state_3d = torch.randn(2, 7, 5) # More than 1D + image_5d = torch.randn(2, 3, 224, 224, 1) # More than 3D + + observation = { + OBS_STATE: state_3d, + OBS_IMAGE: image_5d, + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + + processed_obs = result[TransitionKey.OBSERVATION] + + # Should remain unchanged + assert torch.allclose(processed_obs[OBS_STATE], state_3d) + assert torch.allclose(processed_obs[OBS_IMAGE], image_5d) + + +def test_non_tensor_values_unchanged(): + """Test that non-tensor values in observations remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + observation = { + OBS_STATE: [1, 2, 3], # List, not tensor + OBS_IMAGE: "not_a_tensor", # String + "custom_key": 42, # Integer + "another_key": {"nested": "dict"}, # Dict + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + + processed_obs = result[TransitionKey.OBSERVATION] + + # Should remain unchanged + assert processed_obs[OBS_STATE] == [1, 2, 3] + assert processed_obs[OBS_IMAGE] == "not_a_tensor" + assert processed_obs["custom_key"] == 42 + assert processed_obs["another_key"] == {"nested": "dict"} + + +def test_none_observation(): + """Test processor handles None observation gracefully.""" + processor = AddBatchDimensionProcessorStep() + + transition = create_transition(observation={}, action=torch.empty(0)) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION] == {} + + +def test_empty_observation(): + """Test processor handles empty observation dict.""" + processor = AddBatchDimensionProcessorStep() + + observation = {} + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + + assert result[TransitionKey.OBSERVATION] == {} + + +def test_mixed_observation(): + """Test processor with mixed observation containing various types and dimensions.""" + processor = AddBatchDimensionProcessorStep() + + state_1d = torch.randn(5) + env_state_2d = torch.randn(1, 8) # Already batched + image_3d = torch.randn(32, 32, 3) + other_tensor = torch.randn(3, 3, 3, 3) # 4D, should be unchanged + + observation = { + OBS_STATE: state_1d, + OBS_ENV_STATE: env_state_2d, + OBS_IMAGE: image_3d, + f"{OBS_IMAGES}.front": torch.randn(64, 64, 3), # 3D, should be batched + f"{OBS_IMAGES}.back": torch.randn(1, 64, 64, 3), # 4D, should be unchanged + "other_tensor": other_tensor, + "non_tensor": "string_value", + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check transformations + assert processed_obs[OBS_STATE].shape == (1, 5) + assert processed_obs[OBS_ENV_STATE].shape == (1, 8) # Unchanged + assert processed_obs[OBS_IMAGE].shape == (1, 32, 32, 3) + assert processed_obs[f"{OBS_IMAGES}.front"].shape == (1, 64, 64, 3) + assert processed_obs[f"{OBS_IMAGES}.back"].shape == (1, 64, 64, 3) # Unchanged + assert processed_obs["other_tensor"].shape == (3, 3, 3, 3) # Unchanged + assert processed_obs["non_tensor"] == "string_value" # Unchanged + + +def test_integration_with_robot_processor(): + """Test AddBatchDimensionProcessorStep integration with RobotProcessor.""" + to_batch_processor = AddBatchDimensionProcessorStep() + pipeline = DataProcessorPipeline( + [to_batch_processor], to_transition=identity_transition, to_output=identity_transition + ) + + # Create unbatched observation + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(224, 224, 3), + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = pipeline(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert processed_obs[OBS_STATE].shape == (1, 7) + assert processed_obs[OBS_IMAGE].shape == (1, 224, 224, 3) + + +def test_serialization_methods(): + """Test get_config, state_dict, load_state_dict, and reset methods.""" + processor = AddBatchDimensionProcessorStep() + + # Test get_config + config = processor.get_config() + assert isinstance(config, dict) + assert config == {} + + # Test state_dict + state = processor.state_dict() + assert isinstance(state, dict) + assert state == {} + + # Test load_state_dict (should not raise an error) + processor.load_state_dict({}) + + # Test reset (should not raise an error) + processor.reset() + + +def test_save_and_load_pretrained(): + """Test saving and loading AddBatchDimensionProcessorStep with RobotProcessor.""" + processor = AddBatchDimensionProcessorStep() + pipeline = DataProcessorPipeline( + [processor], name="BatchPipeline", to_transition=identity_transition, to_output=identity_transition + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save pipeline + pipeline.save_pretrained(tmp_dir) + + # Check config file exists + config_path = Path(tmp_dir) / "batchpipeline.json" + assert config_path.exists() + + # Load pipeline + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="batchpipeline.json", + to_transition=identity_transition, + to_output=identity_transition, + ) + + assert loaded_pipeline.name == "BatchPipeline" + assert len(loaded_pipeline) == 1 + assert isinstance(loaded_pipeline.steps[0], AddBatchDimensionProcessorStep) + + # Test functionality of loaded processor + observation = {OBS_STATE: torch.randn(5)} + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = loaded_pipeline(transition) + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5) + + +def test_registry_functionality(): + """Test that AddBatchDimensionProcessorStep is properly registered.""" + # Check that the processor is registered + registered_class = ProcessorStepRegistry.get("to_batch_processor") + assert registered_class is AddBatchDimensionProcessorStep + + # Check that it's in the list of registered processors + assert "to_batch_processor" in ProcessorStepRegistry.list() + + +def test_registry_based_save_load(): + """Test saving and loading using registry name.""" + processor = AddBatchDimensionProcessorStep() + pipeline = DataProcessorPipeline( + [processor], to_transition=identity_transition, to_output=identity_transition + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + to_transition=identity_transition, + to_output=identity_transition, + ) + + # Verify the loaded processor works + observation = { + OBS_STATE: torch.randn(3), + OBS_IMAGE: torch.randn(100, 100, 3), + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = loaded_pipeline(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert processed_obs[OBS_STATE].shape == (1, 3) + assert processed_obs[OBS_IMAGE].shape == (1, 100, 100, 3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_device_compatibility(): + """Test processor works with tensors on different devices.""" + processor = AddBatchDimensionProcessorStep() + + # Create tensors on GPU + state_1d = torch.randn(7, device="cuda") + image_3d = torch.randn(64, 64, 3, device="cuda") + + observation = { + OBS_STATE: state_1d, + OBS_IMAGE: image_3d, + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check shapes and that tensors stayed on GPU + assert processed_obs[OBS_STATE].shape == (1, 7) + assert processed_obs[OBS_IMAGE].shape == (1, 64, 64, 3) + assert processed_obs[OBS_STATE].device.type == "cuda" + assert processed_obs[OBS_IMAGE].device.type == "cuda" + + +def test_processor_preserves_other_transition_keys(): + """Test that processor only modifies observation and preserves other transition keys.""" + processor = AddBatchDimensionProcessorStep() + + action = torch.randn(5) + reward = 1.5 + done = True + truncated = False + info = {"step": 10} + comp_data = {"extra": "data"} + + observation = {OBS_STATE: torch.randn(7)} + + transition = create_transition( + observation=observation, + action=action, + reward=reward, + done=done, + truncated=truncated, + info=info, + complementary_data=comp_data, + ) + + result = processor(transition) + + # Check that non-observation keys are preserved + assert torch.allclose(result[TransitionKey.ACTION], action) + assert result[TransitionKey.REWARD] == reward + assert result[TransitionKey.DONE] == done + assert result[TransitionKey.TRUNCATED] == truncated + assert result[TransitionKey.INFO] == info + assert result[TransitionKey.COMPLEMENTARY_DATA] == comp_data + + # Check that observation was processed + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7) + + +def test_edge_case_zero_dimensional_tensors(): + """Test processor handles 0D tensors (scalars) correctly.""" + processor = AddBatchDimensionProcessorStep() + + # 0D tensors should not be modified + scalar_tensor = torch.tensor(42.0) + + observation = { + OBS_STATE: scalar_tensor, + "scalar_value": scalar_tensor, + } + transition = create_transition(observation=observation, action=torch.empty(0)) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # 0D tensors should remain unchanged + assert torch.allclose(processed_obs[OBS_STATE], scalar_tensor) + assert torch.allclose(processed_obs["scalar_value"], scalar_tensor) + + +# Action-specific tests +def test_action_1d_to_2d(): + """Test that 1D action tensors get batch dimension added.""" + processor = AddBatchDimensionProcessorStep() + + # Create 1D action tensor + action_1d = torch.randn(4) + transition = create_transition(observation={}, action=action_1d) + + result = processor(transition) + + # Should add batch dimension + assert result[TransitionKey.ACTION].shape == (1, 4) + assert torch.equal(result[TransitionKey.ACTION][0], action_1d) + + +def test_action_already_batched(): + """Test that already batched action tensors remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + # Test various batch sizes + action_batched_1 = torch.randn(1, 4) + action_batched_5 = torch.randn(5, 4) + + # Single batch + transition = create_transition(action=action_batched_1, observation={}) + result = processor(transition) + assert torch.equal(result[TransitionKey.ACTION], action_batched_1) + + # Multiple batch + transition = create_transition(action=action_batched_5, observation={}) + result = processor(transition) + assert torch.equal(result[TransitionKey.ACTION], action_batched_5) + + +def test_action_higher_dimensional(): + """Test that higher dimensional action tensors remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + # 3D action tensor (e.g., sequence of actions) + action_3d = torch.randn(2, 4, 3) + transition = create_transition(action=action_3d, observation={}) + result = processor(transition) + assert torch.equal(result[TransitionKey.ACTION], action_3d) + + # 4D action tensor + action_4d = torch.randn(2, 10, 4, 3) + transition = create_transition(action=action_4d, observation={}) + result = processor(transition) + assert torch.equal(result[TransitionKey.ACTION], action_4d) + + +def test_action_scalar_tensor(): + """Test that scalar (0D) action tensors remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + action_scalar = torch.tensor(1.5) + transition = create_transition(action=action_scalar, observation={}) + result = processor(transition) + + # Should remain scalar + assert result[TransitionKey.ACTION].dim() == 0 + assert torch.equal(result[TransitionKey.ACTION], action_scalar) + + +def test_action_non_tensor_raises_error(): + """Test that non-tensor actions raise ValueError for PolicyAction processors.""" + processor = AddBatchDimensionProcessorStep() + + # List action should raise error + action_list = [0.1, 0.2, 0.3, 0.4] + transition = create_transition(action=action_list) + with pytest.raises(ValueError, match="Action should be a PolicyAction type"): + processor(transition) + + # Numpy array action should raise error + action_numpy = np.array([1, 2, 3, 4]) + transition = create_transition(action=action_numpy) + with pytest.raises(ValueError, match="Action should be a PolicyAction type"): + processor(transition) + + # String action should raise error + action_string = "forward" + transition = create_transition(action=action_string) + with pytest.raises(ValueError, match="Action should be a PolicyAction type"): + processor(transition) + + # Dict action should raise error + action_dict = {"linear": [0.5, 0.0], "angular": 0.2} + transition = create_transition(action=action_dict) + with pytest.raises(ValueError, match="Action should be a PolicyAction type"): + processor(transition) + + +def test_action_none(): + """Test that empty action tensor is handled correctly.""" + processor = AddBatchDimensionProcessorStep() + + transition = create_transition(action=torch.empty(0), observation={}) + result = processor(transition) + # Empty 1D tensor becomes empty 2D tensor with batch dimension + assert result[TransitionKey.ACTION].shape == (1, 0) + + +def test_action_with_observation(): + """Test action processing together with observation processing.""" + processor = AddBatchDimensionProcessorStep() + + # Both need batching + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(64, 64, 3), + } + action = torch.randn(4) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Both should be batched + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7) + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 64, 64, 3) + assert result[TransitionKey.ACTION].shape == (1, 4) + + +def test_action_different_sizes(): + """Test action processing with various action dimensions.""" + processor = AddBatchDimensionProcessorStep() + + # Different action sizes (robot with different DOF) + action_sizes = [1, 2, 4, 7, 10, 20] + + for size in action_sizes: + action = torch.randn(size) + transition = create_transition(action=action, observation={}) + result = processor(transition) + + assert result[TransitionKey.ACTION].shape == (1, size) + assert torch.equal(result[TransitionKey.ACTION][0], action) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_action_device_compatibility(): + """Test action processing on different devices.""" + processor = AddBatchDimensionProcessorStep() + + # CUDA action + action_cuda = torch.randn(4, device="cuda") + transition = create_transition(action=action_cuda, observation={}) + result = processor(transition) + + assert result[TransitionKey.ACTION].shape == (1, 4) + assert result[TransitionKey.ACTION].device.type == "cuda" + + # CPU action + action_cpu = torch.randn(4, device="cpu") + transition = create_transition(action=action_cpu, observation={}) + result = processor(transition) + + assert result[TransitionKey.ACTION].shape == (1, 4) + assert result[TransitionKey.ACTION].device.type == "cpu" + + +def test_action_dtype_preservation(): + """Test that action dtype is preserved during processing.""" + processor = AddBatchDimensionProcessorStep() + + # Different dtypes + dtypes = [torch.float32, torch.float64, torch.int32, torch.int64] + + for dtype in dtypes: + action = torch.randn(4).to(dtype) + transition = create_transition(action=action, observation={}) + result = processor(transition) + + assert result[TransitionKey.ACTION].dtype == dtype + assert result[TransitionKey.ACTION].shape == (1, 4) + + +def test_empty_action_tensor(): + """Test handling of empty action tensors.""" + processor = AddBatchDimensionProcessorStep() + + # Empty 1D tensor + action_empty = torch.tensor([]) + transition = create_transition(action=action_empty, observation={}) + result = processor(transition) + + # Should add batch dimension even to empty tensor + assert result[TransitionKey.ACTION].shape == (1, 0) + + # Empty 2D tensor (already batched) + action_empty_2d = torch.randn(1, 0) + transition = create_transition(action=action_empty_2d, observation={}) + result = processor(transition) + + # Should remain unchanged + assert result[TransitionKey.ACTION].shape == (1, 0) + + +# Task-specific tests +def test_task_string_to_list(): + """Test that string tasks get wrapped in lists to add batch dimension.""" + processor = AddBatchDimensionProcessorStep() + + # Create complementary data with string task + complementary_data = {"task": "pick_cube"} + transition = create_transition( + action=torch.empty(0), observation={}, complementary_data=complementary_data + ) + + result = processor(transition) + + # String task should be wrapped in list + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == ["pick_cube"] + assert isinstance(processed_comp_data["task"], list) + assert len(processed_comp_data["task"]) == 1 + + +def test_task_string_validation(): + """Test that only string and list of strings are valid task values.""" + processor = AddBatchDimensionProcessorStep() + + # Valid string task - should be converted to list + complementary_data = {"task": "valid_task"} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + result = processor(transition) + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == ["valid_task"] + + # Valid list of strings - should remain unchanged + complementary_data = {"task": ["task1", "task2"]} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + result = processor(transition) + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == ["task1", "task2"] + + +def test_task_list_of_strings(): + """Test that lists of strings remain unchanged (already batched).""" + processor = AddBatchDimensionProcessorStep() + + # Test various list of strings + test_lists = [ + ["pick_cube"], # Single string in list + ["pick_cube", "place_cube"], # Multiple strings + ["task1", "task2", "task3"], # Three strings + [], # Empty list + [""], # List with empty string + ["task with spaces", "task_with_underscores"], # Mixed formats + ] + + for task_list in test_lists: + complementary_data = {"task": task_list} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + # Should remain unchanged since it's already a list + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == task_list + assert isinstance(processed_comp_data["task"], list) + + +def test_complementary_data_none(): + """Test processor handles None complementary_data gracefully.""" + processor = AddBatchDimensionProcessorStep() + + transition = create_transition(complementary_data=None, action=torch.empty(0), observation={}) + result = processor(transition) + + assert result[TransitionKey.COMPLEMENTARY_DATA] == {} + + +def test_complementary_data_empty(): + """Test processor handles empty complementary_data dict.""" + processor = AddBatchDimensionProcessorStep() + + complementary_data = {} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + assert result[TransitionKey.COMPLEMENTARY_DATA] == {} + + +def test_complementary_data_no_task(): + """Test processor handles complementary_data without task field.""" + processor = AddBatchDimensionProcessorStep() + + complementary_data = { + "episode_id": 123, + "timestamp": 1234567890.0, + "extra_info": "some data", + } + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + # Should remain unchanged + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data == complementary_data + + +def test_complementary_data_mixed(): + """Test processor with mixed complementary_data containing task and other fields.""" + processor = AddBatchDimensionProcessorStep() + + complementary_data = { + "task": "stack_blocks", + "episode_id": 456, + "difficulty": "hard", + "metadata": {"scene": "kitchen"}, + } + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Task should be batched + assert processed_comp_data["task"] == ["stack_blocks"] + + # Other fields should remain unchanged + assert processed_comp_data["episode_id"] == 456 + assert processed_comp_data["difficulty"] == "hard" + assert processed_comp_data["metadata"] == {"scene": "kitchen"} + + +def test_task_with_observation_and_action(): + """Test task processing together with observation and action processing.""" + processor = AddBatchDimensionProcessorStep() + + # All components need batching + observation = { + OBS_STATE: torch.randn(5), + OBS_IMAGE: torch.randn(32, 32, 3), + } + action = torch.randn(4) + complementary_data = {"task": "navigate_to_goal"} + + transition = create_transition( + observation=observation, action=action, complementary_data=complementary_data + ) + + result = processor(transition) + + # All should be batched + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5) + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 32, 32, 3) + assert result[TransitionKey.ACTION].shape == (1, 4) + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["navigate_to_goal"] + + +def test_task_comprehensive_string_cases(): + """Test task processing with comprehensive string cases and edge cases.""" + processor = AddBatchDimensionProcessorStep() + + # Test various string formats + string_tasks = [ + "pick_and_place", + "navigate", + "open_drawer", + "", # Empty string (valid but edge case) + "task with spaces", + "task_with_underscores", + "task-with-dashes", + "UPPERCASE_TASK", + "MixedCaseTask", + "task123", + "数字任务", # Unicode task + "🤖 robot task", # Emoji in task + "task\nwith\nnewlines", # Special characters + "task\twith\ttabs", + "task with 'quotes'", + 'task with "double quotes"', + ] + + # Test that all string tasks get properly batched + for task in string_tasks: + complementary_data = {"task": task} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == [task] + assert isinstance(processed_comp_data["task"], list) + assert len(processed_comp_data["task"]) == 1 + + # Test various list of strings (should remain unchanged) + list_tasks = [ + ["single_task"], + ["task1", "task2"], + ["pick", "place", "navigate"], + [], # Empty list + [""], # List with empty string + ["task with spaces", "task_with_underscores", "UPPERCASE"], + ["🤖 task", "数字任务", "normal_task"], # Mixed formats + ] + + for task_list in list_tasks: + complementary_data = {"task": task_list} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == task_list + assert isinstance(processed_comp_data["task"], list) + + +def test_task_preserves_other_keys(): + """Test that task processing preserves other keys in complementary_data.""" + processor = AddBatchDimensionProcessorStep() + + complementary_data = { + "task": "clean_table", + "robot_id": "robot_123", + "motor_id": "motor_456", + "config": {"speed": "slow", "precision": "high"}, + "metrics": [1.0, 2.0, 3.0], + } + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Task should be processed + assert processed_comp_data["task"] == ["clean_table"] + + # All other keys should be preserved exactly + assert processed_comp_data["robot_id"] == "robot_123" + assert processed_comp_data["motor_id"] == "motor_456" + assert processed_comp_data["config"] == {"speed": "slow", "precision": "high"} + assert processed_comp_data["metrics"] == [1.0, 2.0, 3.0] + + +# Index and task_index specific tests +def test_index_scalar_to_1d(): + """Test that 0D index tensor gets unsqueezed to 1D.""" + processor = AddBatchDimensionProcessorStep() + + # Create 0D index tensor (scalar) + index_0d = torch.tensor(42, dtype=torch.int64) + complementary_data = {"index": index_0d} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["index"].shape == (1,) + assert processed_comp_data["index"].dtype == torch.int64 + assert processed_comp_data["index"][0] == 42 + + +def test_task_index_scalar_to_1d(): + """Test that 0D task_index tensor gets unsqueezed to 1D.""" + processor = AddBatchDimensionProcessorStep() + + # Create 0D task_index tensor (scalar) + task_index_0d = torch.tensor(7, dtype=torch.int64) + complementary_data = {"task_index": task_index_0d} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task_index"].shape == (1,) + assert processed_comp_data["task_index"].dtype == torch.int64 + assert processed_comp_data["task_index"][0] == 7 + + +def test_index_and_task_index_together(): + """Test processing both index and task_index together.""" + processor = AddBatchDimensionProcessorStep() + + # Create 0D tensors for both + index_0d = torch.tensor(100, dtype=torch.int64) + task_index_0d = torch.tensor(3, dtype=torch.int64) + complementary_data = { + "index": index_0d, + "task_index": task_index_0d, + "task": "pick_object", + } + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check index + assert processed_comp_data["index"].shape == (1,) + assert processed_comp_data["index"][0] == 100 + + # Check task_index + assert processed_comp_data["task_index"].shape == (1,) + assert processed_comp_data["task_index"][0] == 3 + + # Check task is also processed + assert processed_comp_data["task"] == ["pick_object"] + + +def test_index_already_batched(): + """Test that already batched index tensors remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + # Create already batched tensors + index_1d = torch.tensor([42], dtype=torch.int64) + index_2d = torch.tensor([[42, 43]], dtype=torch.int64) + + # Test 1D (already batched) + complementary_data = {"index": index_1d} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + result = processor(transition) + assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_1d) + + # Test 2D + complementary_data = {"index": index_2d} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + result = processor(transition) + assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_2d) + + +def test_task_index_already_batched(): + """Test that already batched task_index tensors remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + # Create already batched tensors + task_index_1d = torch.tensor([7], dtype=torch.int64) + task_index_2d = torch.tensor([[7, 8]], dtype=torch.int64) + + # Test 1D (already batched) + complementary_data = {"task_index": task_index_1d} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + result = processor(transition) + assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_1d) + + # Test 2D + complementary_data = {"task_index": task_index_2d} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + result = processor(transition) + assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_2d) + + +def test_index_non_tensor_unchanged(): + """Test that non-tensor index values remain unchanged.""" + processor = AddBatchDimensionProcessorStep() + + complementary_data = { + "index": 42, # Plain int, not tensor + "task_index": [1, 2, 3], # List, not tensor + } + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["index"] == 42 + assert processed_comp_data["task_index"] == [1, 2, 3] + + +def test_index_dtype_preservation(): + """Test that index and task_index dtype is preserved during processing.""" + processor = AddBatchDimensionProcessorStep() + + # Test different dtypes + dtypes = [torch.int32, torch.int64, torch.long] + + for dtype in dtypes: + index_0d = torch.tensor(42, dtype=dtype) + task_index_0d = torch.tensor(7, dtype=dtype) + complementary_data = { + "index": index_0d, + "task_index": task_index_0d, + } + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["index"].dtype == dtype + assert processed_comp_data["task_index"].dtype == dtype + + +def test_index_with_full_transition(): + """Test index/task_index processing with full transition data.""" + processor = AddBatchDimensionProcessorStep() + + # Create full transition with all components + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(64, 64, 3), + } + action = torch.randn(4) + complementary_data = { + "task": "navigate_to_goal", + "index": torch.tensor(1000, dtype=torch.int64), + "task_index": torch.tensor(5, dtype=torch.int64), + "episode_id": 123, + } + + transition = create_transition( + observation=observation, + action=action, + reward=0.5, + done=False, + complementary_data=complementary_data, + ) + + result = processor(transition) + + # Check all components are processed correctly + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 7) + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 64, 64, 3) + assert result[TransitionKey.ACTION].shape == (1, 4) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == ["navigate_to_goal"] + assert processed_comp_data["index"].shape == (1,) + assert processed_comp_data["index"][0] == 1000 + assert processed_comp_data["task_index"].shape == (1,) + assert processed_comp_data["task_index"][0] == 5 + assert processed_comp_data["episode_id"] == 123 # Non-tensor field unchanged + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_index_device_compatibility(): + """Test processor works with index/task_index tensors on different devices.""" + processor = AddBatchDimensionProcessorStep() + + # Create tensors on GPU + index_0d = torch.tensor(42, dtype=torch.int64, device="cuda") + task_index_0d = torch.tensor(7, dtype=torch.int64, device="cuda") + + complementary_data = { + "index": index_0d, + "task_index": task_index_0d, + } + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check shapes and that tensors stayed on GPU + assert processed_comp_data["index"].shape == (1,) + assert processed_comp_data["task_index"].shape == (1,) + assert processed_comp_data["index"].device.type == "cuda" + assert processed_comp_data["task_index"].device.type == "cuda" + + +def test_empty_index_tensor(): + """Test handling of empty index tensors.""" + processor = AddBatchDimensionProcessorStep() + + # Empty 0D tensor doesn't make sense, but test empty 1D + index_empty = torch.tensor([], dtype=torch.int64) + complementary_data = {"index": index_empty} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + result = processor(transition) + + # Should remain unchanged (already 1D) + assert result[TransitionKey.COMPLEMENTARY_DATA]["index"].shape == (0,) + + +def test_action_processing_creates_new_transition(): + """Test that the processor creates a new transition object with correctly processed action.""" + processor = AddBatchDimensionProcessorStep() + + action = torch.randn(4) + transition = create_transition(action=action, observation={}) + + # Store reference to original transition + original_transition = transition + + # Process + result = processor(transition) + + # Should be a different object (functional design, not in-place mutation) + assert result is not original_transition + # Original transition should remain unchanged + assert original_transition[TransitionKey.ACTION].shape == (4,) + # Result should have correctly processed action with batch dimension + assert result[TransitionKey.ACTION].shape == (1, 4) + assert torch.equal(result[TransitionKey.ACTION][0], action) + + +def test_task_processing_creates_new_transition(): + """Test that the processor creates a new transition object with correctly processed task.""" + processor = AddBatchDimensionProcessorStep() + + complementary_data = {"task": "sort_objects"} + transition = create_transition( + complementary_data=complementary_data, observation={}, action=torch.empty(0) + ) + + # Store reference to original transition and complementary_data + original_transition = transition + original_comp_data = complementary_data + + # Process + result = processor(transition) + + # Should be different transition object (functional design) + assert result is not original_transition + # The task should be processed correctly (wrapped in list) + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["sort_objects"] + # Original complementary data is also modified (current behavior) + assert original_comp_data["task"] == "sort_objects" diff --git a/tests/processor/test_classifier_processor.py b/tests/processor/test_classifier_processor.py new file mode 100644 index 000000000..139e99bd7 --- /dev/null +++ b/tests/processor/test_classifier_processor.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Reward Classifier processor.""" + +import tempfile + +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import OBS_IMAGE, OBS_STATE +from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig +from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor +from lerobot.processor import ( + DataProcessorPipeline, + DeviceProcessorStep, + IdentityProcessorStep, + NormalizerProcessorStep, + TransitionKey, +) +from lerobot.processor.converters import create_transition, transition_to_batch + + +def create_default_config(): + """Create a default Reward Classifier configuration for testing.""" + config = RewardClassifierConfig() + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "reward": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), # Classifier output + } + config.normalization_mapping = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.IDENTITY, # No normalization for classifier output + } + config.device = "cpu" + return config + + +def create_default_stats(): + """Create default dataset statistics for testing.""" + return { + OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)}, + OBS_IMAGE: {}, # No normalization for images + "reward": {}, # No normalization for classifier output + } + + +def test_make_classifier_processor_basic(): + """Test basic creation of Classifier processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor(config, stats) + + # Check processor names + assert preprocessor.name == "classifier_preprocessor" + assert postprocessor.name == "classifier_postprocessor" + + # Check steps in preprocessor + assert len(preprocessor.steps) == 3 + assert isinstance(preprocessor.steps[0], NormalizerProcessorStep) # For input features + assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) # For output features + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + + # Check steps in postprocessor + assert len(postprocessor.steps) == 2 + assert isinstance(postprocessor.steps[0], DeviceProcessorStep) + assert isinstance(postprocessor.steps[1], IdentityProcessorStep) + + +def test_classifier_processor_normalization(): + """Test that Classifier processor correctly normalizes data.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor( + config, + stats, + ) + + # Create test data + observation = { + OBS_STATE: torch.randn(10), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(1) # Dummy action/reward + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data is processed + assert processed[OBS_STATE].shape == (10,) + assert processed[OBS_IMAGE].shape == (3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1,) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_classifier_processor_cuda(): + """Test Classifier processor with CUDA device.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor( + config, + stats, + ) + + # Create CPU data + observation = { + OBS_STATE: torch.randn(10), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(1) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is on CUDA + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" + + # Process through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that output is back on CPU + assert postprocessed.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_classifier_processor_accelerate_scenario(): + """Test Classifier processor in simulated Accelerate scenario.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor( + config, + stats, + ) + + # Simulate Accelerate: data already on GPU + device = torch.device("cuda:0") + observation = { + OBS_STATE: torch.randn(10).to(device), + OBS_IMAGE: torch.randn(3, 224, 224).to(device), + } + action = torch.randn(1).to(device) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on same GPU + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_classifier_processor_multi_gpu(): + """Test Classifier processor with multi-GPU setup.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor(config, stats) + + # Simulate data on different GPU + device = torch.device("cuda:1") + observation = { + OBS_STATE: torch.randn(10).to(device), + OBS_IMAGE: torch.randn(3, 224, 224).to(device), + } + action = torch.randn(1).to(device) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on cuda:1 + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +def test_classifier_processor_without_stats(): + """Test Classifier processor creation without dataset statistics.""" + config = create_default_config() + + preprocessor, postprocessor = make_classifier_processor(config, dataset_stats=None) + + # Should still create processors + assert preprocessor is not None + assert postprocessor is not None + + # Process should still work + observation = { + OBS_STATE: torch.randn(10), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(1) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed is not None + + +def test_classifier_processor_save_and_load(): + """Test saving and loading Classifier processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor(config, stats) + + with tempfile.TemporaryDirectory() as tmpdir: + # Save preprocessor + preprocessor.save_pretrained(tmpdir) + + # Load preprocessor + loaded_preprocessor = DataProcessorPipeline.from_pretrained( + tmpdir, config_filename="classifier_preprocessor.json" + ) + + # Test that loaded processor works + observation = { + OBS_STATE: torch.randn(10), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(1) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (10,) + assert processed[OBS_IMAGE].shape == (3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1,) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_classifier_processor_mixed_precision(): + """Test Classifier processor with mixed precision.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor(config, stats) + + # Replace DeviceProcessorStep with one that uses float16 + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Create test data + observation = { + OBS_STATE: torch.randn(10, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(1, dtype=torch.float32) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is converted to float16 + assert processed[OBS_STATE].dtype == torch.float16 + assert processed[OBS_IMAGE].dtype == torch.float16 + assert processed[TransitionKey.ACTION.value].dtype == torch.float16 + + +def test_classifier_processor_batch_data(): + """Test Classifier processor with batched data.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor( + config, + stats, + ) + + # Test with batched data + batch_size = 16 + observation = { + OBS_STATE: torch.randn(batch_size, 10), + OBS_IMAGE: torch.randn(batch_size, 3, 224, 224), + } + action = torch.randn(batch_size, 1) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that batch dimension is preserved + assert processed[OBS_STATE].shape == (batch_size, 10) + assert processed[OBS_IMAGE].shape == (batch_size, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (batch_size, 1) + + +def test_classifier_processor_postprocessor_identity(): + """Test that Classifier postprocessor uses IdentityProcessor correctly.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_classifier_processor( + config, + stats, + ) + + # Create test data for postprocessor + reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions + transition = create_transition(action=reward) + + _ = transition_to_batch(transition) + + # Process through postprocessor + processed = postprocessor(reward) + + # IdentityProcessor should leave values unchanged (except device) + assert torch.allclose(processed.cpu(), reward.cpu()) + assert processed.device.type == "cpu" diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py new file mode 100644 index 000000000..fc91951de --- /dev/null +++ b/tests/processor/test_converters.py @@ -0,0 +1,292 @@ +import numpy as np +import pytest +import torch + +from lerobot.processor import TransitionKey +from lerobot.processor.converters import ( + batch_to_transition, + create_transition, + to_tensor, + transition_to_batch, +) + + +# Tests for the unified to_tensor function +def test_to_tensor_numpy_arrays(): + """Test to_tensor with various numpy arrays.""" + # Regular numpy array + arr = np.array([1.0, 2.0, 3.0]) + result = to_tensor(arr) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0])) + + # Different numpy dtypes should convert to float32 by default + int_arr = np.array([1, 2, 3], dtype=np.int64) + result = to_tensor(int_arr) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0])) + + # uint8 arrays (previously "preserved") should now convert + uint8_arr = np.array([100, 150, 200], dtype=np.uint8) + result = to_tensor(uint8_arr) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([100.0, 150.0, 200.0])) + + +def test_to_tensor_numpy_scalars(): + """Test to_tensor with numpy scalars (0-dimensional arrays).""" + # numpy float32 scalar + scalar = np.float32(3.14) + result = to_tensor(scalar) + assert isinstance(result, torch.Tensor) + assert result.ndim == 0 # Should be 0-dimensional tensor + assert result.dtype == torch.float32 + assert result.item() == pytest.approx(3.14) + + # numpy int32 scalar + int_scalar = np.int32(42) + result = to_tensor(int_scalar) + assert isinstance(result, torch.Tensor) + assert result.ndim == 0 + assert result.dtype == torch.float32 + assert result.item() == pytest.approx(42.0) + + +def test_to_tensor_python_scalars(): + """Test to_tensor with Python scalars.""" + # Python int + result = to_tensor(42) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert result.item() == pytest.approx(42.0) + + # Python float + result = to_tensor(3.14) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert result.item() == pytest.approx(3.14) + + +def test_to_tensor_sequences(): + """Test to_tensor with lists and tuples.""" + # List + result = to_tensor([1, 2, 3]) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0])) + + # Tuple + result = to_tensor((4.5, 5.5, 6.5)) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([4.5, 5.5, 6.5])) + + +def test_to_tensor_existing_tensors(): + """Test to_tensor with existing PyTorch tensors.""" + # Tensor with same dtype should pass through with potential device change + tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + result = to_tensor(tensor) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, tensor) + + # Tensor with different dtype should convert + int_tensor = torch.tensor([1, 2, 3], dtype=torch.int64) + result = to_tensor(int_tensor) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0])) + + +def test_to_tensor_dictionaries(): + """Test to_tensor with nested dictionaries.""" + # Simple dictionary + data = {"mean": [0.1, 0.2], "std": np.array([1.0, 2.0]), "count": 42} + result = to_tensor(data) + assert isinstance(result, dict) + assert isinstance(result["mean"], torch.Tensor) + assert isinstance(result["std"], torch.Tensor) + assert isinstance(result["count"], torch.Tensor) + assert torch.allclose(result["mean"], torch.tensor([0.1, 0.2])) + assert torch.allclose(result["std"], torch.tensor([1.0, 2.0])) + assert result["count"].item() == pytest.approx(42.0) + + # Nested dictionary + nested = { + "action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]}, + "observation": {"mean": np.array([0.5, 0.6]), "count": 10}, + } + result = to_tensor(nested) + assert isinstance(result, dict) + assert isinstance(result["action"], dict) + assert isinstance(result["observation"], dict) + assert isinstance(result["action"]["mean"], torch.Tensor) + assert isinstance(result["observation"]["mean"], torch.Tensor) + assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2])) + assert torch.allclose(result["observation"]["mean"], torch.tensor([0.5, 0.6])) + + +def test_to_tensor_none_filtering(): + """Test that None values are filtered out from dictionaries.""" + data = {"valid": [1, 2, 3], "none_value": None, "nested": {"valid": [4, 5], "also_none": None}} + result = to_tensor(data) + assert "none_value" not in result + assert "also_none" not in result["nested"] + assert "valid" in result + assert "valid" in result["nested"] + assert torch.allclose(result["valid"], torch.tensor([1.0, 2.0, 3.0])) + + +def test_to_tensor_dtype_parameter(): + """Test to_tensor with different dtype parameters.""" + arr = np.array([1, 2, 3]) + + # Default dtype (float32) + result = to_tensor(arr) + assert result.dtype == torch.float32 + + # Explicit float32 + result = to_tensor(arr, dtype=torch.float32) + assert result.dtype == torch.float32 + + # Float64 + result = to_tensor(arr, dtype=torch.float64) + assert result.dtype == torch.float64 + + # Preserve original dtype + float64_arr = np.array([1.0, 2.0, 3.0], dtype=np.float64) + result = to_tensor(float64_arr, dtype=None) + assert result.dtype == torch.float64 + + +def test_to_tensor_device_parameter(): + """Test to_tensor with device parameter.""" + arr = np.array([1.0, 2.0, 3.0]) + + # CPU device (default) + result = to_tensor(arr, device="cpu") + assert result.device.type == "cpu" + + # CUDA device (if available) + if torch.cuda.is_available(): + result = to_tensor(arr, device="cuda") + assert result.device.type == "cuda" + + +def test_to_tensor_empty_dict(): + """Test to_tensor with empty dictionary.""" + result = to_tensor({}) + assert isinstance(result, dict) + assert len(result) == 0 + + +def test_to_tensor_unsupported_type(): + """Test to_tensor with unsupported types raises TypeError.""" + with pytest.raises(TypeError, match="Unsupported type for tensor conversion"): + to_tensor("unsupported_string") + + with pytest.raises(TypeError, match="Unsupported type for tensor conversion"): + to_tensor(object()) + + +def test_batch_to_transition_with_index_fields(): + """Test that batch_to_transition handles index and task_index fields correctly.""" + + # Create batch with index and task_index fields + batch = { + "observation.state": torch.randn(1, 7), + "action": torch.randn(1, 4), + "next.reward": 1.5, + "next.done": False, + "task": ["pick_cube"], + "index": torch.tensor([42], dtype=torch.int64), + "task_index": torch.tensor([3], dtype=torch.int64), + } + + transition = batch_to_transition(batch) + + # Check basic transition structure + assert TransitionKey.OBSERVATION in transition + assert TransitionKey.ACTION in transition + assert TransitionKey.COMPLEMENTARY_DATA in transition + + # Check that index and task_index are in complementary_data + comp_data = transition[TransitionKey.COMPLEMENTARY_DATA] + assert "index" in comp_data + assert "task_index" in comp_data + assert "task" in comp_data + + # Verify values + assert torch.equal(comp_data["index"], batch["index"]) + assert torch.equal(comp_data["task_index"], batch["task_index"]) + assert comp_data["task"] == batch["task"] + + +def testtransition_to_batch_with_index_fields(): + """Test that transition_to_batch handles index and task_index fields correctly.""" + + # Create transition with index and task_index in complementary_data + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + action=torch.randn(1, 4), + reward=1.5, + done=False, + complementary_data={ + "task": ["navigate"], + "index": torch.tensor([100], dtype=torch.int64), + "task_index": torch.tensor([5], dtype=torch.int64), + }, + ) + + batch = transition_to_batch(transition) + + # Check that index and task_index are in the batch + assert "index" in batch + assert "task_index" in batch + assert "task" in batch + + # Verify values + assert torch.equal(batch["index"], transition[TransitionKey.COMPLEMENTARY_DATA]["index"]) + assert torch.equal(batch["task_index"], transition[TransitionKey.COMPLEMENTARY_DATA]["task_index"]) + assert batch["task"] == transition[TransitionKey.COMPLEMENTARY_DATA]["task"] + + +def test_batch_to_transition_without_index_fields(): + """Test that conversion works without index and task_index fields.""" + + # Batch without index/task_index + batch = { + "observation.state": torch.randn(1, 7), + "action": torch.randn(1, 4), + "task": ["pick_cube"], + } + + transition = batch_to_transition(batch) + comp_data = transition[TransitionKey.COMPLEMENTARY_DATA] + + # Should have task but not index/task_index + assert "task" in comp_data + assert "index" not in comp_data + assert "task_index" not in comp_data + + +def test_transition_to_batch_without_index_fields(): + """Test that conversion works without index and task_index fields.""" + + # Transition without index/task_index + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + action=torch.randn(1, 4), + complementary_data={"task": ["navigate"]}, + ) + + batch = transition_to_batch(transition) + + # Should have task but not index/task_index + assert "task" in batch + assert "index" not in batch + assert "task_index" not in batch diff --git a/tests/processor/test_device_processor.py b/tests/processor/test_device_processor.py new file mode 100644 index 000000000..ba00bde4d --- /dev/null +++ b/tests/processor/test_device_processor.py @@ -0,0 +1,1161 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile + +import pytest +import torch + +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey +from lerobot.processor.converters import create_transition, identity_transition + + +def test_basic_functionality(): + """Test basic device processor functionality on CPU.""" + processor = DeviceProcessorStep(device="cpu") + + # Create a transition with CPU tensors + observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)} + action = torch.randn(5) + reward = torch.tensor(1.0) + done = torch.tensor(False) + truncated = torch.tensor(False) + + transition = create_transition( + observation=observation, action=action, reward=reward, done=done, truncated=truncated + ) + + result = processor(transition) + + # Check that all tensors are on CPU + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" + assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cpu" + assert result[TransitionKey.ACTION].device.type == "cpu" + assert result[TransitionKey.REWARD].device.type == "cpu" + assert result[TransitionKey.DONE].device.type == "cpu" + assert result[TransitionKey.TRUNCATED].device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cuda_functionality(): + """Test device processor functionality on CUDA.""" + processor = DeviceProcessorStep(device="cuda") + + # Create a transition with CPU tensors + observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)} + action = torch.randn(5) + reward = torch.tensor(1.0) + done = torch.tensor(False) + truncated = torch.tensor(False) + + transition = create_transition( + observation=observation, action=action, reward=reward, done=done, truncated=truncated + ) + + result = processor(transition) + + # Check that all tensors are on CUDA + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda" + assert result[TransitionKey.ACTION].device.type == "cuda" + assert result[TransitionKey.REWARD].device.type == "cuda" + assert result[TransitionKey.DONE].device.type == "cuda" + assert result[TransitionKey.TRUNCATED].device.type == "cuda" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_specific_cuda_device(): + """Test device processor with specific CUDA device.""" + processor = DeviceProcessorStep(device="cuda:0") + + observation = {"observation.state": torch.randn(10)} + action = torch.randn(5) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.state"].device.index == 0 + assert result[TransitionKey.ACTION].device.type == "cuda" + assert result[TransitionKey.ACTION].device.index == 0 + + +def test_non_tensor_values(): + """Test that non-tensor values are preserved.""" + processor = DeviceProcessorStep(device="cpu") + + observation = { + "observation.state": torch.randn(10), + "observation.metadata": {"key": "value"}, # Non-tensor data + "observation.list": [1, 2, 3], # Non-tensor data + } + action = torch.randn(5) + info = {"episode": 1, "step": 42} + + transition = create_transition(observation=observation, action=action, info=info) + + result = processor(transition) + + # Check tensors are processed + assert isinstance(result[TransitionKey.OBSERVATION]["observation.state"], torch.Tensor) + assert isinstance(result[TransitionKey.ACTION], torch.Tensor) + + # Check non-tensor values are preserved + assert result[TransitionKey.OBSERVATION]["observation.metadata"] == {"key": "value"} + assert result[TransitionKey.OBSERVATION]["observation.list"] == [1, 2, 3] + assert result[TransitionKey.INFO] == {"episode": 1, "step": 42} + + +def test_none_values(): + """Test handling of None values.""" + processor = DeviceProcessorStep(device="cpu") + + # Test with None observation + transition = create_transition(observation=None, action=torch.randn(5)) + result = processor(transition) + assert result[TransitionKey.OBSERVATION] is None + assert result[TransitionKey.ACTION].device.type == "cpu" + + # Test with None action + transition = create_transition(observation={"observation.state": torch.randn(10)}, action=None) + result = processor(transition) + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" + assert result[TransitionKey.ACTION] is None + + +def test_empty_observation(): + """Test handling of empty observation dictionary.""" + processor = DeviceProcessorStep(device="cpu") + + transition = create_transition(observation={}, action=torch.randn(5)) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION] == {} + assert result[TransitionKey.ACTION].device.type == "cpu" + + +def test_scalar_tensors(): + """Test handling of scalar tensors.""" + processor = DeviceProcessorStep(device="cpu") + + observation = {"observation.scalar": torch.tensor(1.5)} + action = torch.tensor(2.0) + reward = torch.tensor(0.5) + + transition = create_transition(observation=observation, action=action, reward=reward) + + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.scalar"].item() == 1.5 + assert result[TransitionKey.ACTION].item() == 2.0 + assert result[TransitionKey.REWARD].item() == 0.5 + + +def test_dtype_preservation(): + """Test that tensor dtypes are preserved.""" + processor = DeviceProcessorStep(device="cpu") + + observation = { + "observation.float32": torch.randn(5, dtype=torch.float32), + "observation.float64": torch.randn(5, dtype=torch.float64), + "observation.int32": torch.randint(0, 10, (5,), dtype=torch.int32), + "observation.bool": torch.tensor([True, False, True], dtype=torch.bool), + } + action = torch.randn(3, dtype=torch.float16) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float32 + assert result[TransitionKey.OBSERVATION]["observation.float64"].dtype == torch.float64 + assert result[TransitionKey.OBSERVATION]["observation.int32"].dtype == torch.int32 + assert result[TransitionKey.OBSERVATION]["observation.bool"].dtype == torch.bool + assert result[TransitionKey.ACTION].dtype == torch.float16 + + +def test_shape_preservation(): + """Test that tensor shapes are preserved.""" + processor = DeviceProcessorStep(device="cpu") + + observation = { + "observation.1d": torch.randn(10), + "observation.2d": torch.randn(5, 10), + "observation.3d": torch.randn(3, 224, 224), + "observation.4d": torch.randn(2, 3, 224, 224), + } + action = torch.randn(2, 5, 3) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.1d"].shape == (10,) + assert result[TransitionKey.OBSERVATION]["observation.2d"].shape == (5, 10) + assert result[TransitionKey.OBSERVATION]["observation.3d"].shape == (3, 224, 224) + assert result[TransitionKey.OBSERVATION]["observation.4d"].shape == (2, 3, 224, 224) + assert result[TransitionKey.ACTION].shape == (2, 5, 3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_mixed_devices(): + """Test handling of tensors already on different devices.""" + processor = DeviceProcessorStep(device="cuda") + + # Create tensors on different devices + observation = { + "observation.cpu": torch.randn(5), # CPU + "observation.cuda": torch.randn(5).cuda(), # Already on CUDA + } + action = torch.randn(3).cuda() # Already on CUDA + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # All should be on CUDA + assert result[TransitionKey.OBSERVATION]["observation.cpu"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.cuda"].device.type == "cuda" + assert result[TransitionKey.ACTION].device.type == "cuda" + + +def test_non_blocking_flag(): + """Test that non_blocking flag is set correctly.""" + # CPU processor should have non_blocking=False + cpu_processor = DeviceProcessorStep(device="cpu") + assert cpu_processor.non_blocking is False + + if torch.cuda.is_available(): + # CUDA processor should have non_blocking=True + cuda_processor = DeviceProcessorStep(device="cuda") + assert cuda_processor.non_blocking is True + + cuda_0_processor = DeviceProcessorStep(device="cuda:0") + assert cuda_0_processor.non_blocking is True + + +def test_serialization_methods(): + """Test get_config, state_dict, and load_state_dict methods.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + processor = DeviceProcessorStep(device=device) + + # Test get_config + config = processor.get_config() + assert config == {"device": device, "float_dtype": None} + + # Test state_dict (should be empty) + state = processor.state_dict() + assert state == {} + + # Test load_state_dict (should be no-op) + processor.load_state_dict({}) + assert processor.device == device + + # Test reset (should be no-op) + processor.reset() + assert processor.device == device + + +def test_features(): + """Test that features returns features unchanged.""" + processor = DeviceProcessorStep(device="cpu") + + features = { + PipelineFeatureType.OBSERVATION: { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)) + }, + PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, + } + + result = processor.transform_features(features) + assert result == features + assert result is features # Should return the same object + + +def test_integration_with_robot_processor(): + """Test integration with RobotProcessor.""" + from lerobot.constants import OBS_STATE + from lerobot.processor import AddBatchDimensionProcessorStep + + # Create a pipeline with DeviceProcessorStep + device_processor = DeviceProcessorStep(device="cpu") + batch_processor = AddBatchDimensionProcessorStep() + + processor = DataProcessorPipeline( + steps=[batch_processor, device_processor], + name="test_pipeline", + to_transition=identity_transition, + to_output=identity_transition, + ) + + # Create test data + observation = {OBS_STATE: torch.randn(10)} + action = torch.randn(5) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Check that tensors are batched and on correct device + # The result has TransitionKey.OBSERVATION as the key, with observation.state inside + assert result[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 # Batched + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu" + assert result[TransitionKey.ACTION].shape[0] == 1 # Batched + assert result[TransitionKey.ACTION].device.type == "cpu" + + +def test_save_and_load_pretrained(): + """Test saving and loading processor with DeviceProcessorStep.""" + device = "cuda:0" if torch.cuda.is_available() else "cpu" + processor = DeviceProcessorStep(device=device, float_dtype="float16") + robot_processor = DataProcessorPipeline(steps=[processor], name="device_test_processor") + + with tempfile.TemporaryDirectory() as tmpdir: + # Save + robot_processor.save_pretrained(tmpdir) + + # Load + loaded_processor = DataProcessorPipeline.from_pretrained( + tmpdir, config_filename="device_test_processor.json" + ) + + assert len(loaded_processor.steps) == 1 + loaded_device_processor = loaded_processor.steps[0] + assert isinstance(loaded_device_processor, DeviceProcessorStep) + # Use getattr to access attributes safely + assert ( + getattr(loaded_device_processor, "device", None) == device.split(":")[0] + ) # Device normalizes cuda:0 to cuda + assert getattr(loaded_device_processor, "float_dtype", None) == "float16" + + +def test_registry_functionality(): + """Test that DeviceProcessorStep is properly registered.""" + from lerobot.processor import ProcessorStepRegistry + + # Check that DeviceProcessorStep is registered + registered_class = ProcessorStepRegistry.get("device_processor") + assert registered_class is DeviceProcessorStep + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_performance_with_large_tensors(): + """Test performance with large tensors and non_blocking flag.""" + processor = DeviceProcessorStep(device="cuda") + + # Create large tensors + observation = { + "observation.large_image": torch.randn(10, 3, 512, 512), # Large image batch + "observation.features": torch.randn(10, 2048), # Large feature vector + } + action = torch.randn(10, 100) # Large action space + + transition = create_transition(observation=observation, action=action) + + # Process should not raise any errors + result = processor(transition) + + # Verify all tensors are on CUDA + assert result[TransitionKey.OBSERVATION]["observation.large_image"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.features"].device.type == "cuda" + assert result[TransitionKey.ACTION].device.type == "cuda" + + +def test_reward_done_truncated_types(): + """Test handling of different types for reward, done, and truncated.""" + processor = DeviceProcessorStep(device="cpu") + + # Test with scalar values (not tensors) + transition = create_transition( + observation={"observation.state": torch.randn(5)}, + action=torch.randn(3), + reward=1.0, # float + done=False, # bool + truncated=True, # bool + ) + + result = processor(transition) + + # Non-tensor values should be preserved as-is + assert result[TransitionKey.REWARD] == 1.0 + assert result[TransitionKey.DONE] is False + assert result[TransitionKey.TRUNCATED] is True + + # Test with tensor values + transition = create_transition( + observation={"observation.state": torch.randn(5)}, + action=torch.randn(3), + reward=torch.tensor(1.0), + done=torch.tensor(False), + truncated=torch.tensor(True), + ) + + result = processor(transition) + + # Tensor values should be moved to device + assert isinstance(result[TransitionKey.REWARD], torch.Tensor) + assert isinstance(result[TransitionKey.DONE], torch.Tensor) + assert isinstance(result[TransitionKey.TRUNCATED], torch.Tensor) + assert result[TransitionKey.REWARD].device.type == "cpu" + assert result[TransitionKey.DONE].device.type == "cpu" + assert result[TransitionKey.TRUNCATED].device.type == "cpu" + + +def test_complementary_data_preserved(): + """Test that complementary_data is preserved unchanged.""" + processor = DeviceProcessorStep(device="cpu") + + complementary_data = { + "task": "pick_object", + "episode_id": 42, + "metadata": {"sensor": "camera_1"}, + "observation_is_pad": torch.tensor([False, False, True]), # This should be moved to device + } + + transition = create_transition( + observation={"observation.state": torch.randn(5)}, complementary_data=complementary_data + ) + + result = processor(transition) + + # Check that complementary_data is preserved + assert TransitionKey.COMPLEMENTARY_DATA in result + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "pick_object" + assert result[TransitionKey.COMPLEMENTARY_DATA]["episode_id"] == 42 + assert result[TransitionKey.COMPLEMENTARY_DATA]["metadata"] == {"sensor": "camera_1"} + # Note: Currently DeviceProcessorStep doesn't process tensors in complementary_data + # This is intentional as complementary_data is typically metadata + + +def test_float_dtype_conversion(): + """Test float dtype conversion functionality.""" + processor = DeviceProcessorStep(device="cpu", float_dtype="float16") + + # Create tensors of different types + observation = { + "observation.float32": torch.randn(5, dtype=torch.float32), + "observation.float64": torch.randn(5, dtype=torch.float64), + "observation.int32": torch.randint(0, 10, (5,), dtype=torch.int32), + "observation.int64": torch.randint(0, 10, (5,), dtype=torch.int64), + "observation.bool": torch.tensor([True, False, True], dtype=torch.bool), + } + action = torch.randn(3, dtype=torch.float32) + reward = torch.tensor(1.0, dtype=torch.float32) + + transition = create_transition(observation=observation, action=action, reward=reward) + result = processor(transition) + + # Check that float tensors are converted to float16 + assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float16 + assert result[TransitionKey.OBSERVATION]["observation.float64"].dtype == torch.float16 + assert result[TransitionKey.ACTION].dtype == torch.float16 + assert result[TransitionKey.REWARD].dtype == torch.float16 + + # Check that non-float tensors are preserved + assert result[TransitionKey.OBSERVATION]["observation.int32"].dtype == torch.int32 + assert result[TransitionKey.OBSERVATION]["observation.int64"].dtype == torch.int64 + assert result[TransitionKey.OBSERVATION]["observation.bool"].dtype == torch.bool + + +def test_float_dtype_none(): + """Test that when float_dtype is None, no dtype conversion occurs.""" + processor = DeviceProcessorStep(device="cpu", float_dtype=None) + + observation = { + "observation.float32": torch.randn(5, dtype=torch.float32), + "observation.float64": torch.randn(5, dtype=torch.float64), + "observation.int32": torch.randint(0, 10, (5,), dtype=torch.int32), + } + action = torch.randn(3, dtype=torch.float64) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Check that dtypes are preserved when float_dtype is None + assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float32 + assert result[TransitionKey.OBSERVATION]["observation.float64"].dtype == torch.float64 + assert result[TransitionKey.OBSERVATION]["observation.int32"].dtype == torch.int32 + assert result[TransitionKey.ACTION].dtype == torch.float64 + + +def test_float_dtype_bfloat16(): + """Test conversion to bfloat16.""" + processor = DeviceProcessorStep(device="cpu", float_dtype="bfloat16") + + observation = {"observation.state": torch.randn(5, dtype=torch.float32)} + action = torch.randn(3, dtype=torch.float64) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.bfloat16 + assert result[TransitionKey.ACTION].dtype == torch.bfloat16 + + +def test_float_dtype_float64(): + """Test conversion to float64.""" + processor = DeviceProcessorStep(device="cpu", float_dtype="float64") + + observation = {"observation.state": torch.randn(5, dtype=torch.float16)} + action = torch.randn(3, dtype=torch.float32) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float64 + assert result[TransitionKey.ACTION].dtype == torch.float64 + + +def test_float_dtype_invalid(): + """Test that invalid float_dtype raises ValueError.""" + with pytest.raises(ValueError, match="Invalid float_dtype 'invalid_dtype'"): + DeviceProcessorStep(device="cpu", float_dtype="invalid_dtype") + + +def test_float_dtype_aliases(): + """Test that dtype aliases work correctly.""" + # Test 'half' alias for float16 + processor_half = DeviceProcessorStep(device="cpu", float_dtype="half") + assert processor_half._target_float_dtype == torch.float16 + + # Test 'float' alias for float32 + processor_float = DeviceProcessorStep(device="cpu", float_dtype="float") + assert processor_float._target_float_dtype == torch.float32 + + # Test 'double' alias for float64 + processor_double = DeviceProcessorStep(device="cpu", float_dtype="double") + assert processor_double._target_float_dtype == torch.float64 + + +def test_float_dtype_with_mixed_tensors(): + """Test float dtype conversion with mixed tensor types.""" + processor = DeviceProcessorStep(device="cpu", float_dtype="float32") + + observation = { + "observation.image": torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert + "observation.state": torch.randn(10, dtype=torch.float64), # Should convert + "observation.mask": torch.tensor([True, False, True], dtype=torch.bool), # Should not convert + "observation.indices": torch.tensor([1, 2, 3], dtype=torch.long), # Should not convert + } + action = torch.randn(5, dtype=torch.float16) # Should convert + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Check conversions + assert result[TransitionKey.OBSERVATION]["observation.image"].dtype == torch.uint8 # Unchanged + assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted + assert result[TransitionKey.OBSERVATION]["observation.mask"].dtype == torch.bool # Unchanged + assert result[TransitionKey.OBSERVATION]["observation.indices"].dtype == torch.long # Unchanged + assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted + + +def test_float_dtype_serialization(): + """Test that float_dtype is properly serialized in get_config.""" + device = "cuda" if torch.cuda.is_available() else "cpu" + processor = DeviceProcessorStep(device=device, float_dtype="float16") + config = processor.get_config() + + assert config == {"device": device, "float_dtype": "float16"} + + # Test with None float_dtype + processor_none = DeviceProcessorStep(device="cpu", float_dtype=None) + config_none = processor_none.get_config() + + assert config_none == {"device": "cpu", "float_dtype": None} + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_float_dtype_with_cuda(): + """Test float dtype conversion combined with CUDA device.""" + processor = DeviceProcessorStep(device="cuda", float_dtype="float16") + + # Create tensors on CPU with different dtypes + observation = { + "observation.float32": torch.randn(5, dtype=torch.float32), + "observation.int64": torch.tensor([1, 2, 3], dtype=torch.int64), + } + action = torch.randn(3, dtype=torch.float64) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Check that tensors are on CUDA and float types are converted + assert result[TransitionKey.OBSERVATION]["observation.float32"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float16 + + assert result[TransitionKey.OBSERVATION]["observation.int64"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.int64"].dtype == torch.int64 # Unchanged + + assert result[TransitionKey.ACTION].device.type == "cuda" + assert result[TransitionKey.ACTION].dtype == torch.float16 + + +def test_complementary_data_index_fields(): + """Test processing of index and task_index fields in complementary_data.""" + processor = DeviceProcessorStep(device="cpu") + + # Create transition with index and task_index in complementary_data + complementary_data = { + "task": ["pick_cube"], + "index": torch.tensor([42], dtype=torch.int64), + "task_index": torch.tensor([3], dtype=torch.int64), + "episode_id": 123, # Non-tensor field + } + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + action=torch.randn(1, 4), + complementary_data=complementary_data, + ) + + result = processor(transition) + + # Check that tensors in complementary_data are processed + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check index tensor + assert isinstance(processed_comp_data["index"], torch.Tensor) + assert processed_comp_data["index"].device.type == "cpu" + assert torch.equal(processed_comp_data["index"], complementary_data["index"]) + + # Check task_index tensor + assert isinstance(processed_comp_data["task_index"], torch.Tensor) + assert processed_comp_data["task_index"].device.type == "cpu" + assert torch.equal(processed_comp_data["task_index"], complementary_data["task_index"]) + + # Check non-tensor fields remain unchanged + assert processed_comp_data["task"] == ["pick_cube"] + assert processed_comp_data["episode_id"] == 123 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_complementary_data_index_fields_cuda(): + """Test moving index and task_index fields to CUDA.""" + processor = DeviceProcessorStep(device="cuda:0") + + # Create CPU tensors + complementary_data = { + "index": torch.tensor([100, 101], dtype=torch.int64), + "task_index": torch.tensor([5], dtype=torch.int64), + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check tensors moved to CUDA + assert processed_comp_data["index"].device.type == "cuda" + assert processed_comp_data["index"].device.index == 0 + assert processed_comp_data["task_index"].device.type == "cuda" + assert processed_comp_data["task_index"].device.index == 0 + + +def test_complementary_data_without_index_fields(): + """Test that complementary_data without index/task_index fields works correctly.""" + processor = DeviceProcessorStep(device="cpu") + + complementary_data = { + "task": ["navigate"], + "episode_id": 456, + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + # Should process without errors and preserve non-tensor fields + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["task"] == ["navigate"] + assert processed_comp_data["episode_id"] == 456 + + +def test_complementary_data_mixed_tensors(): + """Test complementary_data with mix of tensors and non-tensors.""" + processor = DeviceProcessorStep(device="cpu") + + complementary_data = { + "task": ["pick_and_place"], + "index": torch.tensor([42], dtype=torch.int64), + "task_index": torch.tensor([3], dtype=torch.int64), + "metrics": [1.0, 2.0, 3.0], # List, not tensor + "config": {"speed": "fast"}, # Dict + "episode_id": 789, # Int + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check tensors are processed + assert isinstance(processed_comp_data["index"], torch.Tensor) + assert isinstance(processed_comp_data["task_index"], torch.Tensor) + + # Check non-tensors remain unchanged + assert processed_comp_data["task"] == ["pick_and_place"] + assert processed_comp_data["metrics"] == [1.0, 2.0, 3.0] + assert processed_comp_data["config"] == {"speed": "fast"} + assert processed_comp_data["episode_id"] == 789 + + +def test_complementary_data_float_dtype_conversion(): + """Test that float dtype conversion doesn't affect int tensors in complementary_data.""" + processor = DeviceProcessorStep(device="cpu", float_dtype="float16") + + complementary_data = { + "index": torch.tensor([42], dtype=torch.int64), + "task_index": torch.tensor([3], dtype=torch.int64), + "float_tensor": torch.tensor([1.5, 2.5], dtype=torch.float32), # Should be converted + } + transition = create_transition(complementary_data=complementary_data) + + result = processor(transition) + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Int tensors should keep their dtype + assert processed_comp_data["index"].dtype == torch.int64 + assert processed_comp_data["task_index"].dtype == torch.int64 + + # Float tensor should be converted + assert processed_comp_data["float_tensor"].dtype == torch.float16 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_complementary_data_full_pipeline_cuda(): + """Test full transition with complementary_data on CUDA.""" + processor = DeviceProcessorStep(device="cuda:0", float_dtype="float16") + + # Create full transition with mixed CPU tensors + observation = {"observation.state": torch.randn(1, 7, dtype=torch.float32)} + action = torch.randn(1, 4, dtype=torch.float32) + reward = torch.tensor(1.5, dtype=torch.float32) + done = torch.tensor(False) + complementary_data = { + "task": ["reach_target"], + "index": torch.tensor([1000], dtype=torch.int64), + "task_index": torch.tensor([10], dtype=torch.int64), + } + + transition = create_transition( + observation=observation, + action=action, + reward=reward, + done=done, + complementary_data=complementary_data, + ) + + result = processor(transition) + + # Check all components moved to CUDA + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.ACTION].device.type == "cuda" + assert result[TransitionKey.REWARD].device.type == "cuda" + assert result[TransitionKey.DONE].device.type == "cuda" + + # Check complementary_data tensors + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["index"].device.type == "cuda" + assert processed_comp_data["task_index"].device.type == "cuda" + + # Check float conversion happened for float tensors + assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float16 + assert result[TransitionKey.ACTION].dtype == torch.float16 + assert result[TransitionKey.REWARD].dtype == torch.float16 + + # Check int tensors kept their dtype + assert processed_comp_data["index"].dtype == torch.int64 + assert processed_comp_data["task_index"].dtype == torch.int64 + + +def test_complementary_data_empty(): + """Test empty complementary_data handling.""" + processor = DeviceProcessorStep(device="cpu") + + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + complementary_data={}, + ) + + result = processor(transition) + + # Should have empty dict + assert result[TransitionKey.COMPLEMENTARY_DATA] == {} + + +def test_complementary_data_none(): + """Test None complementary_data handling.""" + processor = DeviceProcessorStep(device="cpu") + + transition = create_transition( + observation={"observation.state": torch.randn(1, 7)}, + complementary_data=None, + ) + + result = processor(transition) + + # Complementary data should not be in the result (same as input) + assert result[TransitionKey.COMPLEMENTARY_DATA] == {} + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_preserves_gpu_placement(): + """Test that DeviceProcessorStep preserves GPU placement when tensor is already on GPU.""" + processor = DeviceProcessorStep(device="cuda:0") + + # Create tensors already on GPU + observation = { + "observation.state": torch.randn(10).cuda(), # Already on GPU + "observation.image": torch.randn(3, 224, 224).cuda(), # Already on GPU + } + action = torch.randn(5).cuda() # Already on GPU + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Check that tensors remain on their original GPU + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda" + assert result[TransitionKey.ACTION].device.type == "cuda" + + # Verify no unnecessary copies were made (same data pointer) + assert torch.equal( + result[TransitionKey.OBSERVATION]["observation.state"], observation["observation.state"] + ) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_multi_gpu_preservation(): + """Test that DeviceProcessorStep preserves placement on different GPUs in multi-GPU setup.""" + # Test 1: GPU-to-GPU preservation (cuda:0 config, cuda:1 input) + processor_gpu = DeviceProcessorStep(device="cuda:0") + + # Create tensors on cuda:1 (simulating Accelerate placement) + cuda1_device = torch.device("cuda:1") + observation = { + "observation.state": torch.randn(10).to(cuda1_device), + "observation.image": torch.randn(3, 224, 224).to(cuda1_device), + } + action = torch.randn(5).to(cuda1_device) + + transition = create_transition(observation=observation, action=action) + result = processor_gpu(transition) + + # Check that tensors remain on cuda:1 (not moved to cuda:0) + assert result[TransitionKey.OBSERVATION]["observation.state"].device == cuda1_device + assert result[TransitionKey.OBSERVATION]["observation.image"].device == cuda1_device + assert result[TransitionKey.ACTION].device == cuda1_device + + # Test 2: GPU-to-CPU should move to CPU (not preserve GPU) + processor_cpu = DeviceProcessorStep(device="cpu") + + transition_gpu = create_transition( + observation={"observation.state": torch.randn(10).cuda()}, action=torch.randn(5).cuda() + ) + result_cpu = processor_cpu(transition_gpu) + + # Check that tensors are moved to CPU + assert result_cpu[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" + assert result_cpu[TransitionKey.ACTION].device.type == "cpu" + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_multi_gpu_with_cpu_tensors(): + """Test that CPU tensors are moved to configured device even in multi-GPU context.""" + # Processor configured for cuda:1 + processor = DeviceProcessorStep(device="cuda:1") + + # Mix of CPU and GPU tensors + observation = { + "observation.cpu": torch.randn(10), # CPU tensor + "observation.gpu0": torch.randn(10).cuda(0), # Already on cuda:0 + "observation.gpu1": torch.randn(10).cuda(1), # Already on cuda:1 + } + action = torch.randn(5) # CPU tensor + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # CPU tensor should move to configured device (cuda:1) + assert result[TransitionKey.OBSERVATION]["observation.cpu"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION]["observation.cpu"].device.index == 1 + assert result[TransitionKey.ACTION].device.type == "cuda" + assert result[TransitionKey.ACTION].device.index == 1 + + # GPU tensors should stay on their original devices + assert result[TransitionKey.OBSERVATION]["observation.gpu0"].device.index == 0 + assert result[TransitionKey.OBSERVATION]["observation.gpu1"].device.index == 1 + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_multi_gpu_with_float_dtype(): + """Test float dtype conversion works correctly with multi-GPU preservation.""" + processor = DeviceProcessorStep(device="cuda:0", float_dtype="float16") + + # Create float tensors on different GPUs + observation = { + "observation.gpu0": torch.randn(5, dtype=torch.float32).cuda(0), + "observation.gpu1": torch.randn(5, dtype=torch.float32).cuda(1), + "observation.cpu": torch.randn(5, dtype=torch.float32), # CPU + } + + transition = create_transition(observation=observation) + result = processor(transition) + + # Check device placement + assert result[TransitionKey.OBSERVATION]["observation.gpu0"].device.index == 0 + assert result[TransitionKey.OBSERVATION]["observation.gpu1"].device.index == 1 + assert result[TransitionKey.OBSERVATION]["observation.cpu"].device.index == 0 # Moved to cuda:0 + + # Check dtype conversion happened for all + assert result[TransitionKey.OBSERVATION]["observation.gpu0"].dtype == torch.float16 + assert result[TransitionKey.OBSERVATION]["observation.gpu1"].dtype == torch.float16 + assert result[TransitionKey.OBSERVATION]["observation.cpu"].dtype == torch.float16 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_simulated_accelerate_scenario(): + """Test a scenario simulating how Accelerate would use the processor.""" + # Simulate different processes getting different GPU assignments + for gpu_id in range(min(torch.cuda.device_count(), 2)): + # Each "process" has a processor configured for cuda:0 + # but data comes in already placed on the process's GPU + processor = DeviceProcessorStep(device="cuda:0") + + # Simulate data already placed by Accelerate + device = torch.device(f"cuda:{gpu_id}") + observation = {"observation.state": torch.randn(1, 10).to(device)} + action = torch.randn(1, 5).to(device) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Verify data stays on the GPU where Accelerate placed it + assert result[TransitionKey.OBSERVATION]["observation.state"].device == device + assert result[TransitionKey.ACTION].device == device + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_policy_processor_integration(): + """Test integration with policy processors - input on GPU, output on CPU.""" + from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + from lerobot.constants import ACTION, OBS_STATE + from lerobot.processor import ( + AddBatchDimensionProcessorStep, + NormalizerProcessorStep, + UnnormalizerProcessorStep, + ) + + # Create features and stats + features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(5,)), + } + + stats = { + OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)}, + ACTION: {"mean": torch.zeros(5), "std": torch.ones(5)}, + } + + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MEAN_STD} + + # Create input processor (preprocessor) that moves to GPU + input_processor = DataProcessorPipeline( + steps=[ + NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device="cuda"), + ], + name="test_preprocessor", + to_transition=identity_transition, + to_output=identity_transition, + ) + + # Create output processor (postprocessor) that moves to CPU + output_processor = DataProcessorPipeline( + steps=[ + DeviceProcessorStep(device="cpu"), + UnnormalizerProcessorStep(features={ACTION: features[ACTION]}, norm_map=norm_map, stats=stats), + ], + name="test_postprocessor", + to_transition=identity_transition, + to_output=identity_transition, + ) + + # Test data on CPU + observation = {OBS_STATE: torch.randn(10)} + action = torch.randn(5) + transition = create_transition(observation=observation, action=action) + + # Process through input processor + input_result = input_processor(transition) + + # Verify tensors are on GPU and batched + # The result has TransitionKey.OBSERVATION as the key, with observation.state inside + assert input_result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" + assert input_result[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 + assert input_result[TransitionKey.ACTION].device.type == "cuda" + assert input_result[TransitionKey.ACTION].shape[0] == 1 + + # Simulate model output on GPU + model_output = create_transition(action=torch.randn(1, 5).cuda()) + + # Process through output processor + output_result = output_processor(model_output) + + # Verify action is back on CPU and unnormalized + assert output_result[TransitionKey.ACTION].device.type == "cpu" + assert output_result[TransitionKey.ACTION].shape == (1, 5) + + +@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available") +def test_mps_float64_compatibility(): + """Test MPS device compatibility with float64 tensors (automatic conversion to float32).""" + processor = DeviceProcessorStep(device="mps") + + # Create tensors with different dtypes, including float64 which MPS doesn't support + observation = { + "observation.float64": torch.randn(5, dtype=torch.float64), # Should be converted to float32 + "observation.float32": torch.randn(5, dtype=torch.float32), # Should remain float32 + "observation.float16": torch.randn(5, dtype=torch.float16), # Should remain float16 + "observation.int64": torch.randint(0, 10, (5,), dtype=torch.int64), # Should remain int64 + "observation.bool": torch.tensor([True, False, True], dtype=torch.bool), # Should remain bool + } + action = torch.randn(3, dtype=torch.float64) # Should be converted to float32 + reward = torch.tensor(1.0, dtype=torch.float64) # Should be converted to float32 + done = torch.tensor(False, dtype=torch.bool) # Should remain bool + truncated = torch.tensor(True, dtype=torch.bool) # Should remain bool + + transition = create_transition( + observation=observation, action=action, reward=reward, done=done, truncated=truncated + ) + + result = processor(transition) + + # Check that all tensors are on MPS device + assert result[TransitionKey.OBSERVATION]["observation.float64"].device.type == "mps" + assert result[TransitionKey.OBSERVATION]["observation.float32"].device.type == "mps" + assert result[TransitionKey.OBSERVATION]["observation.float16"].device.type == "mps" + assert result[TransitionKey.OBSERVATION]["observation.int64"].device.type == "mps" + assert result[TransitionKey.OBSERVATION]["observation.bool"].device.type == "mps" + assert result[TransitionKey.ACTION].device.type == "mps" + assert result[TransitionKey.REWARD].device.type == "mps" + assert result[TransitionKey.DONE].device.type == "mps" + assert result[TransitionKey.TRUNCATED].device.type == "mps" + + # Check that float64 tensors were automatically converted to float32 + assert result[TransitionKey.OBSERVATION]["observation.float64"].dtype == torch.float32 + assert result[TransitionKey.ACTION].dtype == torch.float32 + assert result[TransitionKey.REWARD].dtype == torch.float32 + + # Check that other dtypes were preserved + assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float32 + assert result[TransitionKey.OBSERVATION]["observation.float16"].dtype == torch.float16 + assert result[TransitionKey.OBSERVATION]["observation.int64"].dtype == torch.int64 + assert result[TransitionKey.OBSERVATION]["observation.bool"].dtype == torch.bool + assert result[TransitionKey.DONE].dtype == torch.bool + assert result[TransitionKey.TRUNCATED].dtype == torch.bool + + +@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available") +def test_mps_float64_with_complementary_data(): + """Test MPS float64 conversion with complementary_data tensors.""" + processor = DeviceProcessorStep(device="mps") + + # Create complementary_data with float64 tensors + complementary_data = { + "task": ["pick_object"], + "index": torch.tensor([42], dtype=torch.int64), # Should remain int64 + "task_index": torch.tensor([3], dtype=torch.int64), # Should remain int64 + "float64_tensor": torch.tensor([1.5, 2.5], dtype=torch.float64), # Should convert to float32 + "float32_tensor": torch.tensor([3.5], dtype=torch.float32), # Should remain float32 + } + + transition = create_transition( + observation={"observation.state": torch.randn(5, dtype=torch.float64)}, + action=torch.randn(3, dtype=torch.float64), + complementary_data=complementary_data, + ) + + result = processor(transition) + + # Check that all tensors are on MPS device + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "mps" + assert result[TransitionKey.ACTION].device.type == "mps" + + processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert processed_comp_data["index"].device.type == "mps" + assert processed_comp_data["task_index"].device.type == "mps" + assert processed_comp_data["float64_tensor"].device.type == "mps" + assert processed_comp_data["float32_tensor"].device.type == "mps" + + # Check dtype conversions + assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted + assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted + assert processed_comp_data["float64_tensor"].dtype == torch.float32 # Converted + assert processed_comp_data["float32_tensor"].dtype == torch.float32 # Unchanged + assert processed_comp_data["index"].dtype == torch.int64 # Unchanged + assert processed_comp_data["task_index"].dtype == torch.int64 # Unchanged + + # Check non-tensor data preserved + assert processed_comp_data["task"] == ["pick_object"] + + +@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available") +def test_mps_with_explicit_float_dtype(): + """Test MPS device with explicit float_dtype setting.""" + # Test that explicit float_dtype still works on MPS + processor = DeviceProcessorStep(device="mps", float_dtype="float16") + + observation = { + "observation.float64": torch.randn( + 5, dtype=torch.float64 + ), # First converted to float32, then to float16 + "observation.float32": torch.randn(5, dtype=torch.float32), # Converted to float16 + "observation.int32": torch.randint(0, 10, (5,), dtype=torch.int32), # Should remain int32 + } + action = torch.randn(3, dtype=torch.float64) + + transition = create_transition(observation=observation, action=action) + result = processor(transition) + + # Check device placement + assert result[TransitionKey.OBSERVATION]["observation.float64"].device.type == "mps" + assert result[TransitionKey.OBSERVATION]["observation.float32"].device.type == "mps" + assert result[TransitionKey.OBSERVATION]["observation.int32"].device.type == "mps" + assert result[TransitionKey.ACTION].device.type == "mps" + + # Check that all float tensors end up as float16 (the target dtype) + assert result[TransitionKey.OBSERVATION]["observation.float64"].dtype == torch.float16 + assert result[TransitionKey.OBSERVATION]["observation.float32"].dtype == torch.float16 + assert result[TransitionKey.ACTION].dtype == torch.float16 + + # Check that non-float tensors are preserved + assert result[TransitionKey.OBSERVATION]["observation.int32"].dtype == torch.int32 + + +@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available") +def test_mps_serialization(): + """Test that MPS device processor can be serialized and loaded correctly.""" + processor = DeviceProcessorStep(device="mps", float_dtype="float32") + + # Test get_config + config = processor.get_config() + assert config == {"device": "mps", "float_dtype": "float32"} + + # Test state_dict (should be empty) + state = processor.state_dict() + assert state == {} + + # Test load_state_dict (should be no-op) + processor.load_state_dict({}) + assert processor.device == "mps" diff --git a/tests/processor/test_diffusion_processor.py b/tests/processor/test_diffusion_processor.py new file mode 100644 index 000000000..5d280f9cc --- /dev/null +++ b/tests/processor/test_diffusion_processor.py @@ -0,0 +1,398 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Diffusion policy processor.""" + +import tempfile + +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE +from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DataProcessorPipeline, + DeviceProcessorStep, + NormalizerProcessorStep, + RenameObservationsProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import create_transition, transition_to_batch + + +def create_default_config(): + """Create a default Diffusion configuration for testing.""" + config = DiffusionConfig() + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(7,)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)), + } + config.normalization_mapping = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + config.device = "cpu" + return config + + +def create_default_stats(): + """Create default dataset statistics for testing.""" + return { + OBS_STATE: {"mean": torch.zeros(7), "std": torch.ones(7)}, + OBS_IMAGE: {}, # No normalization for images + ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)}, + } + + +def test_make_diffusion_processor_basic(): + """Test basic creation of Diffusion processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats) + + # Check processor names + assert preprocessor.name == "policy_preprocessor" + assert postprocessor.name == "policy_postprocessor" + + # Check steps in preprocessor + assert len(preprocessor.steps) == 4 + assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + assert isinstance(preprocessor.steps[3], NormalizerProcessorStep) + + # Check steps in postprocessor + assert len(postprocessor.steps) == 2 + assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) + assert isinstance(postprocessor.steps[1], DeviceProcessorStep) + + +def test_diffusion_processor_with_images(): + """Test Diffusion processor with image observations.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_diffusion_pre_post_processors( + config, + stats, + ) + + # Create test data with images + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is batched + assert processed[OBS_STATE].shape == (1, 7) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 6) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_diffusion_processor_cuda(): + """Test Diffusion processor with CUDA device.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, postprocessor = make_diffusion_pre_post_processors( + config, + stats, + ) + + # Create CPU data + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is on CUDA + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" + + # Process through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that action is back on CPU + assert postprocessed.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_diffusion_processor_accelerate_scenario(): + """Test Diffusion processor in simulated Accelerate scenario.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_diffusion_pre_post_processors( + config, + stats, + ) + + # Simulate Accelerate: data already on GPU + device = torch.device("cuda:0") + observation = { + OBS_STATE: torch.randn(1, 7).to(device), + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), + } + action = torch.randn(1, 6).to(device) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on same GPU + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_diffusion_processor_multi_gpu(): + """Test Diffusion processor with multi-GPU setup.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats) + + # Simulate data on different GPU + device = torch.device("cuda:1") + observation = { + OBS_STATE: torch.randn(1, 7).to(device), + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), + } + action = torch.randn(1, 6).to(device) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on cuda:1 + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +def test_diffusion_processor_without_stats(): + """Test Diffusion processor creation without dataset statistics.""" + config = create_default_config() + + preprocessor, postprocessor = make_diffusion_pre_post_processors( + config, + dataset_stats=None, + ) + + # Should still create processors + assert preprocessor is not None + assert postprocessor is not None + + # Process should still work + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed is not None + + +def test_diffusion_processor_save_and_load(): + """Test saving and loading Diffusion processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats) + + with tempfile.TemporaryDirectory() as tmpdir: + # Save preprocessor + preprocessor.save_pretrained(tmpdir) + + # Load preprocessor + loaded_preprocessor = DataProcessorPipeline.from_pretrained( + tmpdir, config_filename="policy_preprocessor.json" + ) + + # Test that loaded processor works + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 7) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 6) + + +def test_diffusion_processor_identity_normalization(): + """Test that images with IDENTITY normalization are not normalized.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_diffusion_pre_post_processors( + config, + stats, + ) + + # Create test data + image_value = torch.rand(3, 224, 224) * 255 # Large values + observation = { + OBS_STATE: torch.randn(7), + OBS_IMAGE: image_value.clone(), + } + action = torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Image should not be normalized (IDENTITY mode) + # Just batched + assert torch.allclose(processed[OBS_IMAGE][0], image_value, rtol=1e-5) + + +def test_diffusion_processor_batch_consistency(): + """Test Diffusion processor with different batch sizes.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_diffusion_pre_post_processors( + config, + stats, + ) + + # Test with different batch sizes + for batch_size in [1, 8, 32]: + observation = { + OBS_STATE: torch.randn(batch_size, 7) if batch_size > 1 else torch.randn(7), + OBS_IMAGE: torch.randn(batch_size, 3, 224, 224) if batch_size > 1 else torch.randn(3, 224, 224), + } + action = torch.randn(batch_size, 6) if batch_size > 1 else torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + + # Check correct batch size + expected_batch = batch_size if batch_size > 1 else 1 + assert processed[OBS_STATE].shape[0] == expected_batch + assert processed[OBS_IMAGE].shape[0] == expected_batch + assert processed[TransitionKey.ACTION.value].shape[0] == expected_batch + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_diffusion_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, _ = make_diffusion_pre_post_processors(config, stats) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + norm_step = step # Now type checker knows this is NormalizerProcessorStep + modified_steps.append( + NormalizerProcessorStep( + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration + normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data with both state and visual observations + observation = { + OBS_STATE: torch.randn(7, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(6, dtype=torch.float32) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through full pipeline + processed = preprocessor(batch) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + # Check state stats (has normalization) + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 + # OBS_IMAGE uses IDENTITY normalization, so no stats to check diff --git a/tests/processor/test_migration_detection.py b/tests/processor/test_migration_detection.py new file mode 100644 index 000000000..6bed8289d --- /dev/null +++ b/tests/processor/test_migration_detection.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for processor migration detection functionality. +""" + +import json +import tempfile +from pathlib import Path + +import pytest + +from lerobot.processor.pipeline import DataProcessorPipeline, ProcessorMigrationError + + +def test_is_processor_config_valid_configs(): + """Test processor config detection with valid configurations.""" + valid_configs = [ + {"steps": []}, # Empty steps + {"steps": [{"class": "MyClass"}]}, # Class-based step + {"steps": [{"registry_name": "my_step"}]}, # Registry-based step + {"steps": [{"class": "A"}, {"registry_name": "B"}]}, # Mixed + {"name": "Test", "steps": [{"class": "MyClass"}]}, # With name + ] + + for i, config in enumerate(valid_configs): + assert DataProcessorPipeline._is_processor_config(config), ( + f"Valid config {i} should be detected as processor config: {config}" + ) + + +def test_is_processor_config_invalid_configs(): + """Test processor config detection with invalid configurations.""" + invalid_configs = [ + {}, # No steps field + {"steps": "not a list"}, # Steps is not a list + {"steps": [{}]}, # Step without class or registry_name + {"steps": ["not a dict"]}, # Step is not a dict + {"steps": [{"other_field": "value"}]}, # Step with wrong fields + {"other_field": "value"}, # Completely different structure + ] + + for i, config in enumerate(invalid_configs): + assert not DataProcessorPipeline._is_processor_config(config), ( + f"Invalid config {i} should not be detected as processor config: {config}" + ) + + +def test_should_suggest_migration_with_processor_config(): + """Test that migration is NOT suggested when processor config exists.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create a valid processor config + processor_config = { + "name": "TestProcessor", + "steps": [ + { + "class": "lerobot.processor.normalize.NormalizeStep", + "config": {"mean": 0.0, "std": 1.0}, + } + ], + } + + with open(tmp_path / "processor.json", "w") as f: + json.dump(processor_config, f) + + # Should NOT suggest migration (processor config exists) + result = DataProcessorPipeline._should_suggest_migration(tmp_path) + assert not result + + +def test_should_suggest_migration_with_empty_processor_config(): + """Test that migration is NOT suggested when empty processor config exists.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create an empty processor config + empty_processor_config = { + "name": "EmptyProcessor", + "steps": [], # Empty steps is valid + } + + with open(tmp_path / "empty_processor.json", "w") as f: + json.dump(empty_processor_config, f) + + # Should NOT suggest migration (processor config exists, even if empty) + result = DataProcessorPipeline._should_suggest_migration(tmp_path) + assert not result + + +def test_should_suggest_migration_with_model_config_only(): + """Test that migration IS suggested when only model config exists.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create a model config (like old LeRobot format) + model_config = { + "type": "act", + "input_features": {"observation.state": {"shape": [7]}}, + "output_features": {"action": {"shape": [7]}}, + "hidden_dim": 256, + "n_obs_steps": 1, + "n_action_steps": 1, + } + + with open(tmp_path / "config.json", "w") as f: + json.dump(model_config, f) + + # SHOULD suggest migration (model config exists but no processor) + result = DataProcessorPipeline._should_suggest_migration(tmp_path) + assert result + + +def test_should_suggest_migration_no_json_files(): + """Test that migration is NOT suggested when no JSON files exist.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create some non-JSON files + with open(tmp_path / "model.safetensors", "w") as f: + f.write("fake model data") + + with open(tmp_path / "README.md", "w") as f: + f.write("# Model README") + + # Should NOT suggest migration (no JSON files) + result = DataProcessorPipeline._should_suggest_migration(tmp_path) + assert not result + + +def test_should_suggest_migration_random_json_files(): + """Test that migration IS suggested when JSON files exist but none are processor configs.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create some random JSON file (not a processor config) + random_config = {"some_field": "some_value", "another_field": 123} + + with open(tmp_path / "random.json", "w") as f: + json.dump(random_config, f) + + # SHOULD suggest migration (JSON files exist but none are processor configs) + result = DataProcessorPipeline._should_suggest_migration(tmp_path) + assert result + + +def test_should_suggest_migration_mixed_configs(): + """Test that migration is NOT suggested when processor config exists alongside other configs.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create both a processor config and a model config + processor_config = {"name": "TestProcessor", "steps": [{"registry_name": "normalize_step"}]} + + model_config = {"type": "diffusion", "hidden_dim": 512} + + with open(tmp_path / "processor.json", "w") as f: + json.dump(processor_config, f) + + with open(tmp_path / "config.json", "w") as f: + json.dump(model_config, f) + + # Should NOT suggest migration (processor config exists) + result = DataProcessorPipeline._should_suggest_migration(tmp_path) + assert not result + + +def test_should_suggest_migration_invalid_json(): + """Test that invalid JSON is handled gracefully.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create an invalid JSON file + with open(tmp_path / "invalid.json", "w") as f: + f.write("{ invalid json") + + # Create a valid non-processor config + model_config = {"type": "act"} + with open(tmp_path / "model.json", "w") as f: + json.dump(model_config, f) + + # SHOULD suggest migration (invalid JSON is ignored, but we have non-processor JSON) + result = DataProcessorPipeline._should_suggest_migration(tmp_path) + assert result + + +def test_from_pretrained_multiple_json_files_migration_error(): + """Test that multiple JSON files trigger ProcessorMigrationError.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create multiple non-processor configs + model_config = {"type": "act", "hidden_dim": 128} + train_config = {"batch_size": 32, "lr": 0.001} + + with open(tmp_path / "config.json", "w") as f: + json.dump(model_config, f) + + with open(tmp_path / "train_config.json", "w") as f: + json.dump(train_config, f) + + # Should raise ProcessorMigrationError + with pytest.raises(ProcessorMigrationError) as exc_info: + DataProcessorPipeline.from_pretrained(tmp_path, config_filename="config.json") + + # Check the error details + error = exc_info.value + assert str(tmp_path) in str(error.model_path) + assert "migrate_policy_normalization.py" in error.migration_command + assert "not a valid processor configuration" in error.original_error + + +def test_from_pretrained_no_processor_config_migration_error(): + """Test that missing processor config triggers ProcessorMigrationError.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create a model config but no processor + model_config = {"type": "diffusion", "hidden_dim": 256} + + with open(tmp_path / "config.json", "w") as f: + json.dump(model_config, f) + + # Should raise ProcessorMigrationError + with pytest.raises(ProcessorMigrationError) as exc_info: + DataProcessorPipeline.from_pretrained(tmp_path, config_filename="config.json") + + # Check the error details + error = exc_info.value + assert str(tmp_path) in str(error.model_path) + assert "migrate_policy_normalization.py" in error.migration_command + assert "not a valid processor configuration" in error.original_error + + +def test_from_pretrained_valid_processor_no_migration_error(): + """Test that valid processor config does NOT trigger migration error.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create a valid processor config + processor_config = { + "name": "TestProcessor", + "steps": [], # Empty is valid + } + + with open(tmp_path / "processor.json", "w") as f: + json.dump(processor_config, f) + + # Should succeed and create pipeline + pipeline = DataProcessorPipeline.from_pretrained(tmp_path, config_filename="processor.json") + assert pipeline is not None + assert pipeline.name == "TestProcessor" + assert len(pipeline) == 0 + + +def test_from_pretrained_no_json_files_no_migration_error(): + """Test that directories with no JSON files don't trigger migration errors.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create some non-JSON files + with open(tmp_path / "model.safetensors", "w") as f: + f.write("fake model data") + + # Should raise FileNotFoundError (config file not found) + with pytest.raises(FileNotFoundError, match="not found in directory"): + DataProcessorPipeline.from_pretrained(tmp_path, config_filename="processor.json") + + +def test_processor_migration_error_creation(): + """Test that ProcessorMigrationError is created correctly.""" + model_path = "/path/to/model" + migration_command = "python migrate.py --path /path/to/model" + original_error = "Config not found" + + error = ProcessorMigrationError(model_path, migration_command, original_error) + + assert error.model_path == model_path + assert error.migration_command == migration_command + assert error.original_error == original_error + assert model_path in str(error) + assert migration_command in str(error) + assert original_error in str(error) + + +def test_processor_migration_error_attributes(): + """Test that ProcessorMigrationError has correct attributes.""" + model_path = Path("/test/path") + migration_command = "python test.py" + original_error = "Test error" + + error = ProcessorMigrationError(model_path, migration_command, original_error) + + # Test that attributes are accessible + assert hasattr(error, "model_path") + assert hasattr(error, "migration_command") + assert hasattr(error, "original_error") + + # Test that it's still an Exception + assert isinstance(error, Exception) + + +def test_migration_suggestion_raises_error(): + """Test that migration suggestion always raises ProcessorMigrationError.""" + with pytest.raises(ProcessorMigrationError) as exc_info: + DataProcessorPipeline._suggest_processor_migration("/test/path", "Test error") + + error = exc_info.value + assert "/test/path" in str(error.model_path) + assert "Test error" in error.original_error + assert "migrate_policy_normalization.py" in error.migration_command + + +def test_migration_error_always_raised_for_invalid_configs(): + """Test that ProcessorMigrationError is always raised for invalid configs.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create a model config + model_config = {"type": "test", "param": "value"} + with open(tmp_path / "config.json", "w") as f: + json.dump(model_config, f) + + # Should always raise ProcessorMigrationError + with pytest.raises(ProcessorMigrationError): + DataProcessorPipeline.from_pretrained(tmp_path, config_filename="config.json") diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 26aea56c7..5d7791919 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -20,27 +20,16 @@ import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.processor.normalize_processor import ( - NormalizerProcessor, - UnnormalizerProcessor, - _convert_stats_to_tensors, +from lerobot.processor import ( + DataProcessorPipeline, + IdentityProcessorStep, + NormalizerProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, + hotswap_stats, ) -from lerobot.processor.pipeline import RobotProcessor, TransitionKey - - -def create_transition( - observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None -): - """Helper to create an EnvTransition dictionary.""" - return { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: action, - TransitionKey.REWARD: reward, - TransitionKey.DONE: done, - TransitionKey.TRUNCATED: truncated, - TransitionKey.INFO: info, - TransitionKey.COMPLEMENTARY_DATA: complementary_data, - } +from lerobot.processor.converters import create_transition, identity_transition, to_tensor +from lerobot.utils.utils import auto_select_torch_device def test_numpy_conversion(): @@ -50,7 +39,7 @@ def test_numpy_conversion(): "std": np.array([0.2, 0.2, 0.2]), } } - tensor_stats = _convert_stats_to_tensors(stats) + tensor_stats = to_tensor(stats) assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor) assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor) @@ -65,7 +54,7 @@ def test_tensor_conversion(): "std": torch.tensor([1.0, 1.0]), } } - tensor_stats = _convert_stats_to_tensors(stats) + tensor_stats = to_tensor(stats) assert tensor_stats["action"]["mean"].dtype == torch.float32 assert tensor_stats["action"]["std"].dtype == torch.float32 @@ -78,7 +67,7 @@ def test_scalar_conversion(): "std": 0.1, } } - tensor_stats = _convert_stats_to_tensors(stats) + tensor_stats = to_tensor(stats) assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5)) assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1)) @@ -91,7 +80,7 @@ def test_list_conversion(): "max": [1.0, 1.0, 2.0], } } - tensor_stats = _convert_stats_to_tensors(stats) + tensor_stats = to_tensor(stats) assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0])) assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0])) @@ -104,7 +93,7 @@ def test_unsupported_type(): } } with pytest.raises(TypeError, match="Unsupported type"): - _convert_stats_to_tensors(stats) + to_tensor(stats) # Helper functions to create feature maps and norm maps @@ -122,7 +111,7 @@ def _create_observation_norm_map(): } -# Fixtures for observation normalisation tests using NormalizerProcessor +# Fixtures for observation normalisation tests using NormalizerProcessorStep @pytest.fixture def observation_stats(): return { @@ -139,10 +128,10 @@ def observation_stats(): @pytest.fixture def observation_normalizer(observation_stats): - """Return a NormalizerProcessor that only has observation stats (no action).""" + """Return a NormalizerProcessorStep that only has observation stats (no action).""" features = _create_observation_features() norm_map = _create_observation_norm_map() - return NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats) + return NormalizerProcessorStep(features=features, norm_map=norm_map, stats=observation_stats) def test_mean_std_normalization(observation_normalizer): @@ -179,8 +168,11 @@ def test_min_max_normalization(observation_normalizer): def test_selective_normalization(observation_stats): features = _create_observation_features() norm_map = _create_observation_norm_map() - normalizer = NormalizerProcessor( - features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"} + normalizer = NormalizerProcessorStep( + features=features, + norm_map=norm_map, + stats=observation_stats, + normalize_observation_keys={"observation.image"}, ) observation = { @@ -202,7 +194,7 @@ def test_selective_normalization(observation_stats): def test_device_compatibility(observation_stats): features = _create_observation_features() norm_map = _create_observation_norm_map() - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats) + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=observation_stats) observation = { "observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(), } @@ -231,7 +223,7 @@ def test_from_lerobot_dataset(): FeatureType.ACTION: NormalizationMode.MEAN_STD, } - normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) + normalizer = NormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map) # Both observation and action statistics should be present in tensor stats assert "observation.image" in normalizer._tensor_stats @@ -241,11 +233,12 @@ def test_from_lerobot_dataset(): def test_state_dict_save_load(observation_normalizer): # Save state state_dict = observation_normalizer.state_dict() + print("State dict:", state_dict) # Create new normalizer and load state features = _create_observation_features() norm_map = _create_observation_norm_map() - new_normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={}) + new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={}) new_normalizer.load_state_dict(state_dict) # Test that it works the same @@ -296,7 +289,7 @@ def _create_action_norm_map_min_max(): def test_mean_std_unnormalization(action_stats_mean_std): features = _create_action_features() norm_map = _create_action_norm_map_mean_std() - unnormalizer = UnnormalizerProcessor( + unnormalizer = UnnormalizerProcessorStep( features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} ) @@ -314,7 +307,7 @@ def test_mean_std_unnormalization(action_stats_mean_std): def test_min_max_unnormalization(action_stats_min_max): features = _create_action_features() norm_map = _create_action_norm_map_min_max() - unnormalizer = UnnormalizerProcessor( + unnormalizer = UnnormalizerProcessorStep( features=features, norm_map=norm_map, stats={"action": action_stats_min_max} ) @@ -337,14 +330,14 @@ def test_min_max_unnormalization(action_stats_min_max): assert torch.allclose(unnormalized_action, expected) -def test_numpy_action_input(action_stats_mean_std): +def test_tensor_action_input(action_stats_mean_std): features = _create_action_features() norm_map = _create_action_norm_map_mean_std() - unnormalizer = UnnormalizerProcessor( + unnormalizer = UnnormalizerProcessorStep( features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} ) - normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32) + normalized_action = torch.tensor([1.0, -0.5, 2.0], dtype=torch.float32) transition = create_transition(action=normalized_action) unnormalized_transition = unnormalizer(transition) @@ -358,7 +351,7 @@ def test_numpy_action_input(action_stats_mean_std): def test_none_action(action_stats_mean_std): features = _create_action_features() norm_map = _create_action_norm_map_mean_std() - unnormalizer = UnnormalizerProcessor( + unnormalizer = UnnormalizerProcessorStep( features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} ) @@ -374,11 +367,11 @@ def test_action_from_lerobot_dataset(): mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}} features = {"action": PolicyFeature(FeatureType.ACTION, (1,))} norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD} - unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) + unnormalizer = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map) assert "mean" in unnormalizer._tensor_stats["action"] -# Fixtures for NormalizerProcessor tests +# Fixtures for NormalizerProcessorStep tests @pytest.fixture def full_stats(): return { @@ -417,7 +410,7 @@ def _create_full_norm_map(): def normalizer_processor(full_stats): features = _create_full_features() norm_map = _create_full_norm_map() - return NormalizerProcessor(features=features, norm_map=norm_map, stats=full_stats) + return NormalizerProcessorStep(features=features, norm_map=norm_map, stats=full_stats) def test_combined_normalization(normalizer_processor): @@ -461,11 +454,11 @@ def test_processor_from_lerobot_dataset(full_stats): features = _create_full_features() norm_map = _create_full_norm_map() - processor = NormalizerProcessor.from_lerobot_dataset( - mock_dataset, features, norm_map, normalize_keys={"observation.image"} + processor = NormalizerProcessorStep.from_lerobot_dataset( + mock_dataset, features, norm_map, normalize_observation_keys={"observation.image"} ) - assert processor.normalize_keys == {"observation.image"} + assert processor.normalize_observation_keys == {"observation.image"} assert "observation.image" in processor._tensor_stats assert "action" in processor._tensor_stats @@ -473,13 +466,17 @@ def test_processor_from_lerobot_dataset(full_stats): def test_get_config(full_stats): features = _create_full_features() norm_map = _create_full_norm_map() - processor = NormalizerProcessor( - features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 + processor = NormalizerProcessorStep( + features=features, + norm_map=norm_map, + stats=full_stats, + normalize_observation_keys={"observation.image"}, + eps=1e-6, ) config = processor.get_config() expected_config = { - "normalize_keys": ["observation.image"], + "normalize_observation_keys": ["observation.image"], "eps": 1e-6, "features": { "observation.image": {"type": "VISUAL", "shape": (3, 96, 96)}, @@ -497,7 +494,9 @@ def test_get_config(full_stats): def test_integration_with_robot_processor(normalizer_processor): """Test integration with RobotProcessor pipeline""" - robot_processor = RobotProcessor([normalizer_processor]) + robot_processor = DataProcessorPipeline( + [normalizer_processor], to_transition=identity_transition, to_output=identity_transition + ) observation = { "observation.image": torch.tensor([0.7, 0.5, 0.3]), @@ -526,7 +525,7 @@ def test_empty_observation(): stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) transition = create_transition() result = normalizer(transition) @@ -537,7 +536,7 @@ def test_empty_observation(): def test_empty_stats(): features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={}) + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={}) observation = {"observation.image": torch.tensor([0.5])} transition = create_transition(observation=observation) @@ -553,7 +552,7 @@ def test_partial_stats(): stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max) features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) observation = {"observation.image": torch.tensor([0.7])} transition = create_transition(observation=observation) @@ -568,7 +567,7 @@ def test_missing_action_stats_no_error(): features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) + processor = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map) # The tensor stats should not contain the 'action' key assert "action" not in processor._tensor_stats @@ -577,19 +576,23 @@ def test_serialization_roundtrip(full_stats): """Test that features and norm_map can be serialized and deserialized correctly.""" features = _create_full_features() norm_map = _create_full_norm_map() - original_processor = NormalizerProcessor( - features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 + original_processor = NormalizerProcessorStep( + features=features, + norm_map=norm_map, + stats=full_stats, + normalize_observation_keys={"observation.image"}, + eps=1e-6, ) # Get config (serialization) config = original_processor.get_config() # Create a new processor from the config (deserialization) - new_processor = NormalizerProcessor( + new_processor = NormalizerProcessorStep( features=config["features"], norm_map=config["norm_map"], stats=full_stats, - normalize_keys=set(config["normalize_keys"]), + normalize_observation_keys=set(config["normalize_observation_keys"]), eps=config["eps"], ) @@ -620,9 +623,1299 @@ def test_serialization_roundtrip(full_stats): assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) # Verify features and norm_map are correctly reconstructed - assert new_processor.features.keys() == original_processor.features.keys() - for key in new_processor.features: - assert new_processor.features[key].type == original_processor.features[key].type - assert new_processor.features[key].shape == original_processor.features[key].shape + assert ( + new_processor.transform_features(features).keys() + == original_processor.transform_features(features).keys() + ) + for key in new_processor.transform_features(features): + assert ( + new_processor.transform_features(features)[key].type + == original_processor.transform_features(features)[key].type + ) + assert ( + new_processor.transform_features(features)[key].shape + == original_processor.transform_features(features)[key].shape + ) assert new_processor.norm_map == original_processor.norm_map + + +# Identity normalization tests +def test_identity_normalization_observations(): + """Test that IDENTITY mode skips normalization for observations.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode + FeatureType.STATE: NormalizationMode.MEAN_STD, # Normal mode for comparison + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([1.0, -0.5]), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Image should remain unchanged (IDENTITY) + assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + + # State should be normalized (MEAN_STD) + expected_state = (torch.tensor([1.0, -0.5]) - torch.tensor([0.0, 0.0])) / torch.tensor([1.0, 1.0]) + assert torch.allclose(normalized_obs["observation.state"], expected_state) + + +def test_identity_normalization_actions(): + """Test that IDENTITY mode skips normalization for actions.""" + features = {"action": PolicyFeature(FeatureType.ACTION, (2,))} + norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY} + stats = {"action": {"mean": [0.0, 0.0], "std": [1.0, 2.0]}} + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + action = torch.tensor([1.0, -0.5]) + transition = create_transition(action=action) + + normalized_transition = normalizer(transition) + + # Action should remain unchanged + assert torch.allclose(normalized_transition[TransitionKey.ACTION], action) + + +def test_identity_unnormalization_observations(): + """Test that IDENTITY mode skips unnormalization for observations.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode + FeatureType.STATE: NormalizationMode.MIN_MAX, # Normal mode for comparison + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + "observation.state": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, + } + + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.0, -1.0]), # Normalized values in [-1, 1] + } + transition = create_transition(observation=observation) + + unnormalized_transition = unnormalizer(transition) + unnormalized_obs = unnormalized_transition[TransitionKey.OBSERVATION] + + # Image should remain unchanged (IDENTITY) + assert torch.allclose(unnormalized_obs["observation.image"], observation["observation.image"]) + + # State should be unnormalized (MIN_MAX) + # (0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = 0.0 + # (-1.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = -1.0 + expected_state = torch.tensor([0.0, -1.0]) + assert torch.allclose(unnormalized_obs["observation.state"], expected_state) + + +def test_identity_unnormalization_actions(): + """Test that IDENTITY mode skips unnormalization for actions.""" + features = {"action": PolicyFeature(FeatureType.ACTION, (2,))} + norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY} + stats = {"action": {"min": [-1.0, -2.0], "max": [1.0, 2.0]}} + + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + action = torch.tensor([0.5, -0.8]) # Normalized values + transition = create_transition(action=action) + + unnormalized_transition = unnormalizer(transition) + + # Action should remain unchanged + assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action) + + +def test_identity_with_missing_stats(): + """Test that IDENTITY mode works even when stats are missing.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.IDENTITY, + } + stats = {} # No stats provided + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + action = torch.tensor([1.0, -0.5]) + transition = create_transition(observation=observation, action=action) + + # Both should work without errors and return unchanged data + normalized_transition = normalizer(transition) + unnormalized_transition = unnormalizer(transition) + + assert torch.allclose( + normalized_transition[TransitionKey.OBSERVATION]["observation.image"], + observation["observation.image"], + ) + assert torch.allclose(normalized_transition[TransitionKey.ACTION], action) + assert torch.allclose( + unnormalized_transition[TransitionKey.OBSERVATION]["observation.image"], + observation["observation.image"], + ) + assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action) + + +def test_identity_mixed_with_other_modes(): + """Test IDENTITY mode mixed with other normalization modes.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, # Will be ignored + "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + "action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([1.0, -0.5]), + } + action = torch.tensor([0.5, 0.0]) + transition = create_transition(observation=observation, action=action) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + normalized_action = normalized_transition[TransitionKey.ACTION] + + # Image should remain unchanged (IDENTITY) + assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + + # State should be normalized (MEAN_STD) + expected_state = torch.tensor([1.0, -0.5]) # (x - 0) / 1 = x + assert torch.allclose(normalized_obs["observation.state"], expected_state) + + # Action should be normalized (MIN_MAX) to [-1, 1] + # 2 * (0.5 - (-1)) / (1 - (-1)) - 1 = 2 * 1.5 / 2 - 1 = 0.5 + # 2 * (0.0 - (-1)) / (1 - (-1)) - 1 = 2 * 1.0 / 2 - 1 = 0.0 + expected_action = torch.tensor([0.5, 0.0]) + assert torch.allclose(normalized_action, expected_action) + + +def test_identity_defaults_when_not_in_norm_map(): + """Test that IDENTITY is used as default when feature type not in norm_map.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + # VISUAL not specified, should default to IDENTITY + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([1.0, -0.5]), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Image should remain unchanged (defaults to IDENTITY) + assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + + # State should be normalized (explicitly MEAN_STD) + expected_state = torch.tensor([1.0, -0.5]) + assert torch.allclose(normalized_obs["observation.state"], expected_state) + + +def test_identity_roundtrip(): + """Test that IDENTITY normalization and unnormalization are true inverses.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.IDENTITY, + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + "action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + original_observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + original_action = torch.tensor([0.5, -0.2]) + original_transition = create_transition(observation=original_observation, action=original_action) + + # Normalize then unnormalize + normalized = normalizer(original_transition) + roundtrip = unnormalizer(normalized) + + # Should be identical to original + assert torch.allclose( + roundtrip[TransitionKey.OBSERVATION]["observation.image"], original_observation["observation.image"] + ) + assert torch.allclose(roundtrip[TransitionKey.ACTION], original_action) + + +def test_identity_config_serialization(): + """Test that IDENTITY mode is properly saved and loaded in config.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + stats = { + "observation.image": {"mean": [0.5], "std": [0.2]}, + "action": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + # Get config + config = normalizer.get_config() + + # Check that IDENTITY is properly serialized + assert config["norm_map"]["VISUAL"] == "IDENTITY" + assert config["norm_map"]["ACTION"] == "MEAN_STD" + + # Create new processor from config (simulating load) + new_normalizer = NormalizerProcessorStep( + features=config["features"], + norm_map=config["norm_map"], + stats=stats, + eps=config["eps"], + ) + + # Test that both work the same way + observation = {"observation.image": torch.tensor([0.7])} + action = torch.tensor([1.0, -0.5]) + transition = create_transition(observation=observation, action=action) + + result1 = normalizer(transition) + result2 = new_normalizer(transition) + + # Results should be identical + assert torch.allclose( + result1[TransitionKey.OBSERVATION]["observation.image"], + result2[TransitionKey.OBSERVATION]["observation.image"], + ) + assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) + + +# def test_unsupported_normalization_mode_error(): +# """Test that unsupported normalization modes raise appropriate errors.""" +# features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))} + +# # Create an invalid norm_map (this would never happen in practice, but tests error handling) +# from enum import Enum + +# class InvalidMode(str, Enum): +# INVALID = "INVALID" + +# # We can't actually pass an invalid enum to the processor due to type checking, +# # but we can test the error by manipulating the norm_map after creation +# norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} +# stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}} + +# normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + +# # Manually inject an invalid mode to test error handling +# normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE" + +# observation = {"observation.state": torch.tensor([1.0, -0.5])} +# transition = create_transition(observation=observation) + +# with pytest.raises(ValueError, match="Unsupported normalization mode"): +# normalizer(transition) + + +def test_hotswap_stats_basic_functionality(): + """Test that hotswap_stats correctly updates stats in normalizer/unnormalizer steps.""" + # Create initial stats + initial_stats = { + "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + } + + # Create new stats for hotswapping + new_stats = { + "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, + } + + # Create features and norm_map + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + # Create processors + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + identity = IdentityProcessorStep() + + # Create robot processor + robot_processor = DataProcessorPipeline(steps=[normalizer, unnormalizer, identity]) + + # Hotswap stats + new_processor = hotswap_stats(robot_processor, new_stats) + + # Check that normalizer and unnormalizer have new stats + assert new_processor.steps[0].stats == new_stats + assert new_processor.steps[1].stats == new_stats + + # Check that tensor stats are updated correctly + expected_tensor_stats = to_tensor(new_stats) + for key in expected_tensor_stats: + for stat_name in expected_tensor_stats[key]: + torch.testing.assert_close( + new_processor.steps[0]._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name] + ) + torch.testing.assert_close( + new_processor.steps[1]._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name] + ) + + +def test_hotswap_stats_deep_copy(): + """Test that hotswap_stats creates a deep copy and doesn't modify the original processor.""" + initial_stats = { + "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + } + + new_stats = { + "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + } + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + original_processor = DataProcessorPipeline(steps=[normalizer]) + + # Store reference to original stats + original_stats_reference = original_processor.steps[0].stats + original_tensor_stats_reference = original_processor.steps[0]._tensor_stats + + # Hotswap stats + new_processor = hotswap_stats(original_processor, new_stats) + + # Original processor should be unchanged + assert original_processor.steps[0].stats is original_stats_reference + assert original_processor.steps[0]._tensor_stats is original_tensor_stats_reference + assert original_processor.steps[0].stats == initial_stats + + # New processor should have new stats + assert new_processor.steps[0].stats == new_stats + assert new_processor.steps[0].stats is not original_stats_reference + + # Processors should be different objects + assert new_processor is not original_processor + assert new_processor.steps[0] is not original_processor.steps[0] + + +def test_hotswap_stats_only_affects_normalizer_steps(): + """Test that hotswap_stats only modifies NormalizerProcessorStep and UnnormalizerProcessorStep steps.""" + stats = { + "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + } + + new_stats = { + "observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])}, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + } + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + + # Create mixed steps + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + identity = IdentityProcessorStep() + + robot_processor = DataProcessorPipeline(steps=[normalizer, identity, unnormalizer]) + + # Hotswap stats + new_processor = hotswap_stats(robot_processor, new_stats) + + # Check that only normalizer and unnormalizer steps are affected + assert new_processor.steps[0].stats == new_stats # normalizer + assert new_processor.steps[2].stats == new_stats # unnormalizer + + # Identity processor should remain unchanged (and it doesn't have stats attribute) + assert not hasattr(new_processor.steps[1], "stats") + + +def test_hotswap_stats_empty_stats(): + """Test hotswap_stats with empty stats dictionary.""" + initial_stats = { + "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + } + + empty_stats = {} + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + } + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + robot_processor = DataProcessorPipeline(steps=[normalizer]) + + # Hotswap with empty stats + new_processor = hotswap_stats(robot_processor, empty_stats) + + # Should update to empty stats + assert new_processor.steps[0].stats == empty_stats + assert new_processor.steps[0]._tensor_stats == {} + + +def test_hotswap_stats_no_normalizer_steps(): + """Test hotswap_stats with a processor that has no normalizer/unnormalizer steps.""" + stats = { + "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + } + + # Create processor with only identity steps + robot_processor = DataProcessorPipeline(steps=[IdentityProcessorStep(), IdentityProcessorStep()]) + + # Hotswap stats - should work without error + new_processor = hotswap_stats(robot_processor, stats) + + # Should return a different object (deep copy) + assert new_processor is not robot_processor + + # Steps should be deep copied but unchanged + assert len(new_processor.steps) == len(robot_processor.steps) + for i, step in enumerate(new_processor.steps): + assert step is not robot_processor.steps[i] # Different objects + assert isinstance(step, type(robot_processor.steps[i])) # Same type + + +def test_hotswap_stats_preserves_other_attributes(): + """Test that hotswap_stats preserves other processor attributes like features and norm_map.""" + initial_stats = { + "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + } + + new_stats = { + "observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])}, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + } + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + normalize_observation_keys = {"observation.image"} + eps = 1e-6 + + normalizer = NormalizerProcessorStep( + features=features, + norm_map=norm_map, + stats=initial_stats, + normalize_observation_keys=normalize_observation_keys, + eps=eps, + ) + robot_processor = DataProcessorPipeline(steps=[normalizer]) + + # Hotswap stats + new_processor = hotswap_stats(robot_processor, new_stats) + + # Check that other attributes are preserved + new_normalizer = new_processor.steps[0] + assert new_normalizer.features == features + assert new_normalizer.norm_map == norm_map + assert new_normalizer.normalize_observation_keys == normalize_observation_keys + assert new_normalizer.eps == eps + + # But stats should be updated + assert new_normalizer.stats == new_stats + + +def test_hotswap_stats_multiple_normalizer_types(): + """Test hotswap_stats with multiple normalizer and unnormalizer steps.""" + initial_stats = { + "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + "action": {"min": np.array([-1.0]), "max": np.array([1.0])}, + } + + new_stats = { + "observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])}, + "action": {"min": np.array([-2.0]), "max": np.array([2.0])}, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + + # Create multiple normalizers and unnormalizers + normalizer1 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + normalizer2 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + unnormalizer1 = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + unnormalizer2 = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + + robot_processor = DataProcessorPipeline(steps=[normalizer1, unnormalizer1, normalizer2, unnormalizer2]) + + # Hotswap stats + new_processor = hotswap_stats(robot_processor, new_stats) + + # All normalizer/unnormalizer steps should be updated + for step in new_processor.steps: + assert step.stats == new_stats + + # Check tensor stats conversion + expected_tensor_stats = to_tensor(new_stats) + for key in expected_tensor_stats: + for stat_name in expected_tensor_stats[key]: + torch.testing.assert_close( + step._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name] + ) + + +def test_hotswap_stats_with_different_data_types(): + """Test hotswap_stats with various data types in stats.""" + initial_stats = { + "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + } + + # New stats with different data types (int, float, list, tuple) + new_stats = { + "observation.image": { + "mean": [0.3, 0.4, 0.5], # list + "std": (0.1, 0.2, 0.3), # tuple + "min": 0, # int + "max": 1.0, # float + }, + "action": { + "mean": np.array([0.1, 0.2]), # numpy array + "std": torch.tensor([0.5, 0.6]), # torch tensor + }, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + robot_processor = DataProcessorPipeline(steps=[normalizer]) + + # Hotswap stats + new_processor = hotswap_stats(robot_processor, new_stats) + + # Check that stats are updated + assert new_processor.steps[0].stats == new_stats + + # Check that tensor conversion worked correctly + tensor_stats = new_processor.steps[0]._tensor_stats + assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor) + assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor) + assert isinstance(tensor_stats["observation.image"]["min"], torch.Tensor) + assert isinstance(tensor_stats["observation.image"]["max"], torch.Tensor) + assert isinstance(tensor_stats["action"]["mean"], torch.Tensor) + assert isinstance(tensor_stats["action"]["std"], torch.Tensor) + + # Check values + torch.testing.assert_close(tensor_stats["observation.image"]["mean"], torch.tensor([0.3, 0.4, 0.5])) + torch.testing.assert_close(tensor_stats["observation.image"]["std"], torch.tensor([0.1, 0.2, 0.3])) + torch.testing.assert_close(tensor_stats["observation.image"]["min"], torch.tensor(0.0)) + torch.testing.assert_close(tensor_stats["observation.image"]["max"], torch.tensor(1.0)) + + +def test_hotswap_stats_functional_test(): + """Test that hotswapped processor actually works functionally.""" + # Create test data + observation = { + "observation.image": torch.tensor([[[0.6, 0.7], [0.8, 0.9]], [[0.5, 0.6], [0.7, 0.8]]]), + } + action = torch.tensor([0.5, -0.5]) + transition = create_transition(observation=observation, action=action) + + # Initial stats + initial_stats = { + "observation.image": {"mean": np.array([0.5, 0.4]), "std": np.array([0.2, 0.3])}, + "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + } + + # New stats + new_stats = { + "observation.image": {"mean": np.array([0.3, 0.2]), "std": np.array([0.1, 0.2])}, + "action": {"mean": np.array([0.1, -0.1]), "std": np.array([0.5, 0.5])}, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(2, 2, 2)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + # Create original processor + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=initial_stats) + original_processor = DataProcessorPipeline( + steps=[normalizer], to_transition=identity_transition, to_output=identity_transition + ) + + # Process with original stats + original_result = original_processor(transition) + + # Hotswap stats + new_processor = hotswap_stats(original_processor, new_stats) + + # Process with new stats + new_result = new_processor(transition) + + # Results should be different since normalization changed + assert not torch.allclose( + original_result["observation"]["observation.image"], + new_result["observation"]["observation.image"], + rtol=1e-3, + atol=1e-3, + ) + assert not torch.allclose(original_result["action"], new_result["action"], rtol=1e-3, atol=1e-3) + + # Verify that the new processor is actually using the new stats by checking internal state + assert new_processor.steps[0].stats == new_stats + assert torch.allclose( + new_processor.steps[0]._tensor_stats["observation.image"]["mean"], torch.tensor([0.3, 0.2]) + ) + assert torch.allclose( + new_processor.steps[0]._tensor_stats["observation.image"]["std"], torch.tensor([0.1, 0.2]) + ) + assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["mean"], torch.tensor([0.1, -0.1])) + assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["std"], torch.tensor([0.5, 0.5])) + + # Test that normalization actually happens (output should not equal input) + assert not torch.allclose( + new_result["observation"]["observation.image"], observation["observation.image"] + ) + assert not torch.allclose(new_result["action"], action) + + +def test_zero_std_uses_eps(): + """When std == 0, (x-mean)/(std+eps) is well-defined; x==mean should map to 0.""" + features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} + stats = {"observation.state": {"mean": np.array([0.5]), "std": np.array([0.0])}} + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats, eps=1e-6) + + observation = {"observation.state": torch.tensor([0.5])} # equals mean + out = normalizer(create_transition(observation=observation)) + assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([0.0])) + + +def test_min_equals_max_maps_to_minus_one(): + """When min == max, MIN_MAX path maps to -1 after [-1,1] scaling for x==min.""" + features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + norm_map = {FeatureType.STATE: NormalizationMode.MIN_MAX} + stats = {"observation.state": {"min": np.array([2.0]), "max": np.array([2.0])}} + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats, eps=1e-6) + + observation = {"observation.state": torch.tensor([2.0])} + out = normalizer(create_transition(observation=observation)) + assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([-1.0])) + + +def test_action_normalized_despite_normalize_observation_keys(): + """Action normalization is independent of normalize_observation_keys filter for observations.""" + features = { + "observation.state": PolicyFeature(FeatureType.STATE, (1,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD} + stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} + normalizer = NormalizerProcessorStep( + features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={"observation.state"} + ) + + transition = create_transition( + observation={"observation.state": torch.tensor([3.0])}, action=torch.tensor([3.0, 3.0]) + ) + out = normalizer(transition) + # (3-1)/2 = 1.0 ; (3-(-1))/4 = 1.0 + assert torch.allclose(out[TransitionKey.ACTION], torch.tensor([1.0, 1.0])) + + +def test_unnormalize_observations_mean_std_and_min_max(): + features = { + "observation.ms": PolicyFeature(FeatureType.STATE, (2,)), + "observation.mm": PolicyFeature(FeatureType.STATE, (2,)), + } + # Build two processors: one mean/std and one min/max + unnorm_ms = UnnormalizerProcessorStep( + features={"observation.ms": features["observation.ms"]}, + norm_map={FeatureType.STATE: NormalizationMode.MEAN_STD}, + stats={"observation.ms": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}}, + ) + unnorm_mm = UnnormalizerProcessorStep( + features={"observation.mm": features["observation.mm"]}, + norm_map={FeatureType.STATE: NormalizationMode.MIN_MAX}, + stats={"observation.mm": {"min": np.array([0.0, -2.0]), "max": np.array([2.0, 2.0])}}, + ) + + tr = create_transition( + observation={ + "observation.ms": torch.tensor([0.0, 0.0]), # → mean + "observation.mm": torch.tensor([0.0, 0.0]), # → mid-point + } + ) + out_ms = unnorm_ms(tr)[TransitionKey.OBSERVATION]["observation.ms"] + out_mm = unnorm_mm(tr)[TransitionKey.OBSERVATION]["observation.mm"] + assert torch.allclose(out_ms, torch.tensor([1.0, -1.0])) + assert torch.allclose(out_mm, torch.tensor([1.0, 0.0])) # mid of [0,2] and [-2,2] + + +def test_unknown_observation_keys_ignored(): + features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} + stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}} + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + obs = {"observation.state": torch.tensor([1.0]), "observation.unknown": torch.tensor([5.0])} + tr = create_transition(observation=obs) + out = normalizer(tr) + + # Unknown key should pass through unchanged and not be tracked + assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.unknown"], obs["observation.unknown"]) + + +def test_batched_action_normalization(): + features = {"action": PolicyFeature(FeatureType.ACTION, (2,))} + norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD} + stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + actions = torch.tensor([[1.0, -1.0], [3.0, 3.0]]) # first equals mean → zeros; second → [1, 1] + out = normalizer(create_transition(action=actions))[TransitionKey.ACTION] + expected = torch.tensor([[0.0, 0.0], [1.0, 1.0]]) + assert torch.allclose(out, expected) + + +def test_complementary_data_preservation(): + features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} + stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}} + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + comp = {"existing": 123} + tr = create_transition(observation={"observation.state": torch.tensor([1.0])}, complementary_data=comp) + out = normalizer(tr) + new_comp = out[TransitionKey.COMPLEMENTARY_DATA] + assert new_comp["existing"] == 123 + + +def test_roundtrip_normalize_unnormalize_non_identity(): + features = { + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MIN_MAX} + stats = { + "observation.state": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}, + "action": {"min": np.array([-2.0, 0.0]), "max": np.array([2.0, 4.0])}, + } + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + # Add a time dimension in action for broadcasting check (B,T,D) + obs = {"observation.state": torch.tensor([[3.0, 3.0], [1.0, -1.0]])} + act = torch.tensor([[[0.0, -1.0], [1.0, 1.0]]]) # shape (1,2,2) already in [-1,1] + + tr = create_transition(observation=obs, action=act) + out = unnormalizer(normalizer(tr)) + + assert torch.allclose( + out[TransitionKey.OBSERVATION]["observation.state"], obs["observation.state"], atol=1e-5 + ) + assert torch.allclose(out[TransitionKey.ACTION], act, atol=1e-5) + + +def test_dtype_adaptation_bfloat16_input_float32_normalizer(): + """Test automatic dtype adaptation: NormalizerProcessor(float32) adapts to bfloat16 input → bfloat16 output""" + features = {"observation.state": PolicyFeature(FeatureType.STATE, (5,))} + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} + stats = { + "observation.state": { + "mean": np.array([0.0, 0.0, 0.0, 0.0, 0.0]), + "std": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), + } + } + + # Create normalizer configured with float32 dtype + normalizer = NormalizerProcessorStep( + features=features, norm_map=norm_map, stats=stats, dtype=torch.float32 + ) + + # Verify initial configuration + assert normalizer.dtype == torch.float32 + for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + assert stat_tensor.dtype == torch.float32 + + # Create bfloat16 input tensor + observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16)} + transition = create_transition(observation=observation) + + # Process the transition + result = normalizer(transition) + + # Verify that: + # 1. Stats were automatically adapted to bfloat16 + assert normalizer.dtype == torch.bfloat16 + for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + assert stat_tensor.dtype == torch.bfloat16 + + # 2. Output is in bfloat16 + output_tensor = result[TransitionKey.OBSERVATION]["observation.state"] + assert output_tensor.dtype == torch.bfloat16 + + # 3. Normalization was applied correctly (mean should be close to original - mean) / std + expected = ( + torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16) + - torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], dtype=torch.bfloat16) + ) / torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.bfloat16) + assert torch.allclose(output_tensor, expected, atol=1e-2) # bfloat16 has lower precision + + +def test_stats_override_preservation_in_load_state_dict(): + """ + Test that explicitly provided stats are preserved during load_state_dict. + + This tests the fix for the bug where stats provided via overrides were + being overwritten when load_state_dict was called. + """ + # Create original stats + original_stats = { + "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + } + + # Create override stats (what user wants to use) + override_stats = { + "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + # Create a normalizer with original stats and save its state + original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats) + saved_state_dict = original_normalizer.state_dict() + + # Create a new normalizer with override stats (simulating from_pretrained with overrides) + override_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=override_stats) + + # Verify that the override stats are initially set correctly + assert set(override_normalizer.stats.keys()) == set(override_stats.keys()) + for key in override_stats: + assert set(override_normalizer.stats[key].keys()) == set(override_stats[key].keys()) + for stat_name in override_stats[key]: + np.testing.assert_array_equal( + override_normalizer.stats[key][stat_name], override_stats[key][stat_name] + ) + assert override_normalizer._stats_explicitly_provided is True + + # This is the critical test: load_state_dict should NOT overwrite the override stats + override_normalizer.load_state_dict(saved_state_dict) + + # After loading state_dict, stats should still be the override stats, not the original stats + # Check that loaded stats match override stats + assert set(override_normalizer.stats.keys()) == set(override_stats.keys()) + for key in override_stats: + assert set(override_normalizer.stats[key].keys()) == set(override_stats[key].keys()) + for stat_name in override_stats[key]: + np.testing.assert_array_equal( + override_normalizer.stats[key][stat_name], override_stats[key][stat_name] + ) + # Compare individual arrays to avoid numpy array comparison ambiguity + for key in override_stats: + for stat_name in override_stats[key]: + assert not np.array_equal( + override_normalizer.stats[key][stat_name], original_stats[key][stat_name] + ), f"Stats for {key}.{stat_name} should not match original stats" + + # Verify that _tensor_stats are also correctly set to match the override stats + expected_tensor_stats = to_tensor(override_stats) + for key in expected_tensor_stats: + for stat_name in expected_tensor_stats[key]: + if isinstance(expected_tensor_stats[key][stat_name], torch.Tensor): + torch.testing.assert_close( + override_normalizer._tensor_stats[key][stat_name], expected_tensor_stats[key][stat_name] + ) + + +def test_stats_without_override_loads_normally(): + """ + Test that when stats are not explicitly provided (normal case), + load_state_dict works as before. + """ + original_stats = { + "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + } + + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + # Create a normalizer with original stats and save its state + original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats) + saved_state_dict = original_normalizer.state_dict() + + # Create a new normalizer without stats (simulating normal from_pretrained) + new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={}) + + # Verify that stats are not explicitly provided + assert new_normalizer._stats_explicitly_provided is False + + # Load state dict - this should work normally and load the saved stats + new_normalizer.load_state_dict(saved_state_dict) + + # Stats should now match the original stats (normal behavior) + # Check that all keys and values match + assert set(new_normalizer.stats.keys()) == set(original_stats.keys()) + for key in original_stats: + assert set(new_normalizer.stats[key].keys()) == set(original_stats[key].keys()) + for stat_name in original_stats[key]: + np.testing.assert_allclose( + new_normalizer.stats[key][stat_name], original_stats[key][stat_name], rtol=1e-6, atol=1e-6 + ) + + +def test_stats_explicit_provided_flag_detection(): + """Test that the _stats_explicitly_provided flag is set correctly in different scenarios.""" + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + } + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + + # Test 1: Explicitly provided stats (non-empty dict) + stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + normalizer1 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + assert normalizer1._stats_explicitly_provided is True + + # Test 2: Empty stats dict + normalizer2 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={}) + assert normalizer2._stats_explicitly_provided is False + + # Test 3: None stats + normalizer3 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=None) + assert normalizer3._stats_explicitly_provided is False + + # Test 4: Stats not provided (defaults to None) + normalizer4 = NormalizerProcessorStep(features=features, norm_map=norm_map) + assert normalizer4._stats_explicitly_provided is False + + +def test_pipeline_from_pretrained_with_stats_overrides(): + """ + Test the actual use case: DataProcessorPipeline.from_pretrained with stat overrides. + + This is an integration test that verifies the fix works in the real scenario + where users provide stat overrides when loading a pipeline. + """ + import tempfile + + # Create test data + features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + original_stats = { + "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + } + + override_stats = { + "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, + } + + # Create and save a pipeline with the original stats + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=original_stats) + identity = IdentityProcessorStep() + original_pipeline = DataProcessorPipeline(steps=[normalizer, identity], name="test_pipeline") + + with tempfile.TemporaryDirectory() as temp_dir: + # Save the pipeline + original_pipeline.save_pretrained(temp_dir) + + # Load the pipeline with stat overrides + overrides = {"normalizer_processor": {"stats": override_stats}} + + loaded_pipeline = DataProcessorPipeline.from_pretrained( + temp_dir, config_filename="test_pipeline.json", overrides=overrides + ) + + # The critical test: the loaded pipeline should use override stats, not original stats + loaded_normalizer = loaded_pipeline.steps[0] + assert isinstance(loaded_normalizer, NormalizerProcessorStep) + + # Check that loaded stats match override stats + assert set(loaded_normalizer.stats.keys()) == set(override_stats.keys()) + for key in override_stats: + assert set(loaded_normalizer.stats[key].keys()) == set(override_stats[key].keys()) + for stat_name in override_stats[key]: + np.testing.assert_array_equal( + loaded_normalizer.stats[key][stat_name], override_stats[key][stat_name] + ) + + # Verify stats don't match original stats + for key in override_stats: + for stat_name in override_stats[key]: + assert not np.array_equal( + loaded_normalizer.stats[key][stat_name], original_stats[key][stat_name] + ), f"Stats for {key}.{stat_name} should not match original stats" + + # Test that the override stats are actually used in processing + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + } + action = torch.tensor([1.0, -0.5]) + transition = create_transition(observation=observation, action=action) + + # Process with override pipeline + override_result = loaded_pipeline(transition) + + # Create a reference pipeline with override stats for comparison + reference_normalizer = NormalizerProcessorStep( + features=features, norm_map=norm_map, stats=override_stats + ) + reference_pipeline = DataProcessorPipeline( + steps=[reference_normalizer, identity], + to_transition=identity_transition, + to_output=identity_transition, + ) + _ = reference_pipeline(transition) + + # The critical part was verified above: loaded_normalizer.stats == override_stats + # This confirms that override stats are preserved during load_state_dict. + # Let's just verify the pipeline processes data successfully. + assert "action" in override_result + assert isinstance(override_result["action"], torch.Tensor) + + +def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32(): + """Test policy pipeline scenario: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → bfloat16 output""" + from lerobot.processor import DeviceProcessorStep + + features = {"observation.state": PolicyFeature(FeatureType.STATE, (3,))} + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} + stats = {"observation.state": {"mean": np.array([0.0, 0.0, 0.0]), "std": np.array([1.0, 1.0, 1.0])}} + + # Create pipeline: DeviceProcessor(bfloat16) → NormalizerProcessor(float32) + device_processor = DeviceProcessorStep(device=str(auto_select_torch_device()), float_dtype="bfloat16") + normalizer = NormalizerProcessorStep( + features=features, norm_map=norm_map, stats=stats, dtype=torch.float32 + ) + + # Verify initial normalizer configuration + assert normalizer.dtype == torch.float32 + + # Create CPU input + observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)} + transition = create_transition(observation=observation) + + # Step 1: DeviceProcessor converts to bfloat16 + moves to CUDA + processed_1 = device_processor(transition) + intermediate_tensor = processed_1[TransitionKey.OBSERVATION]["observation.state"] + assert intermediate_tensor.dtype == torch.bfloat16 + assert intermediate_tensor.device.type == str(auto_select_torch_device()) + + # Step 2: NormalizerProcessor receives bfloat16 input and adapts + final_result = normalizer(processed_1) + final_tensor = final_result[TransitionKey.OBSERVATION]["observation.state"] + + # Verify final output is bfloat16 (automatic adaptation worked) + assert final_tensor.dtype == torch.bfloat16 + assert final_tensor.device.type == str(auto_select_torch_device()) + + # Verify normalizer adapted its internal state + assert normalizer.dtype == torch.bfloat16 + for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + assert stat_tensor.dtype == torch.bfloat16 + assert stat_tensor.device.type == str(auto_select_torch_device()) + + +def test_stats_reconstruction_after_load_state_dict(): + """ + Test that stats dict is properly reconstructed from _tensor_stats after loading. + + This test ensures the bug where stats became empty after loading is fixed. + The bug occurred when: + 1. Only _tensor_stats were saved via state_dict() + 2. stats field became empty {} after loading + 3. Calling to() method or hotswap_stats would fail because they depend on self.stats + """ + + # Create normalizer with stats + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.STATE: NormalizationMode.MIN_MAX, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + stats = { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + }, + "observation.state": { + "min": np.array([0.0, -1.0]), + "max": np.array([1.0, 1.0]), + }, + "action": { + "mean": np.array([0.0, 0.0]), + "std": np.array([1.0, 2.0]), + }, + } + + original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + # Save state dict (simulating save/load) + state_dict = original_normalizer.state_dict() + + # Create new normalizer with empty stats (simulating load) + new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={}) + + # Before fix: this would cause stats to remain empty + new_normalizer.load_state_dict(state_dict) + + # Verify that stats dict is properly reconstructed from _tensor_stats + assert new_normalizer.stats is not None + assert new_normalizer.stats != {} + + # Check that all expected keys are present + assert "observation.image" in new_normalizer.stats + assert "observation.state" in new_normalizer.stats + assert "action" in new_normalizer.stats + + # Check that values are correct (converted back from tensors) + np.testing.assert_allclose(new_normalizer.stats["observation.image"]["mean"], [0.5, 0.5, 0.5]) + np.testing.assert_allclose(new_normalizer.stats["observation.image"]["std"], [0.2, 0.2, 0.2]) + np.testing.assert_allclose(new_normalizer.stats["observation.state"]["min"], [0.0, -1.0]) + np.testing.assert_allclose(new_normalizer.stats["observation.state"]["max"], [1.0, 1.0]) + np.testing.assert_allclose(new_normalizer.stats["action"]["mean"], [0.0, 0.0]) + np.testing.assert_allclose(new_normalizer.stats["action"]["std"], [1.0, 2.0]) + + # Test that methods that depend on self.stats work correctly after loading + # This would fail before the bug fix because self.stats was empty + + # Test 1: to() method should work without crashing + try: + new_normalizer.to(device="cpu", dtype=torch.float32) + # If we reach here, the bug is fixed + except (KeyError, AttributeError) as e: + pytest.fail(f"to() method failed after loading state_dict: {e}") + + # Test 2: hotswap_stats should work + new_stats = { + "observation.image": {"mean": [0.3, 0.3, 0.3], "std": [0.1, 0.1, 0.1]}, + "observation.state": {"min": [-1.0, -2.0], "max": [2.0, 2.0]}, + "action": {"mean": [0.1, 0.1], "std": [0.5, 0.5]}, + } + + pipeline = DataProcessorPipeline([new_normalizer]) + try: + new_pipeline = hotswap_stats(pipeline, new_stats) + # If we reach here, hotswap_stats worked correctly + assert new_pipeline.steps[0].stats == new_stats + except (KeyError, AttributeError) as e: + pytest.fail(f"hotswap_stats failed after loading state_dict: {e}") + + # Test 3: The normalizer should work functionally the same as the original + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + action = torch.tensor([1.0, -0.5]) + transition = create_transition(observation=observation, action=action) + + original_result = original_normalizer(transition) + new_result = new_normalizer(transition) + + # Results should be identical (within floating point precision) + torch.testing.assert_close( + original_result[TransitionKey.OBSERVATION]["observation.image"], + new_result[TransitionKey.OBSERVATION]["observation.image"], + ) + torch.testing.assert_close( + original_result[TransitionKey.OBSERVATION]["observation.state"], + new_result[TransitionKey.OBSERVATION]["observation.state"], + ) + torch.testing.assert_close(original_result[TransitionKey.ACTION], new_result[TransitionKey.ACTION]) diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index e48b6bc08..57f32482d 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -18,31 +18,16 @@ import numpy as np import pytest import torch -from lerobot.configs.types import FeatureType +from lerobot.configs.types import FeatureType, PipelineFeatureType from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE -from lerobot.processor import VanillaObservationProcessor -from lerobot.processor.pipeline import TransitionKey +from lerobot.processor import TransitionKey, VanillaObservationProcessorStep +from lerobot.processor.converters import create_transition from tests.conftest import assert_contract_is_typed -def create_transition( - observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None -): - """Helper to create an EnvTransition dictionary.""" - return { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: action, - TransitionKey.REWARD: reward, - TransitionKey.DONE: done, - TransitionKey.TRUNCATED: truncated, - TransitionKey.INFO: info, - TransitionKey.COMPLEMENTARY_DATA: complementary_data, - } - - def test_process_single_image(): """Test processing a single image.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Create a mock image (H, W, C) format, uint8 image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) @@ -68,7 +53,7 @@ def test_process_single_image(): def test_process_image_dict(): """Test processing multiple images in a dictionary.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Create mock images image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) @@ -91,7 +76,7 @@ def test_process_image_dict(): def test_process_batched_image(): """Test processing already batched images.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Create a batched image (B, H, W, C) image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8) @@ -108,7 +93,7 @@ def test_process_batched_image(): def test_invalid_image_format(): """Test error handling for invalid image formats.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Test wrong channel order (channels first) image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8) @@ -121,7 +106,7 @@ def test_invalid_image_format(): def test_invalid_image_dtype(): """Test error handling for invalid image dtype.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Test wrong dtype image = np.random.rand(64, 64, 3).astype(np.float32) @@ -134,7 +119,7 @@ def test_invalid_image_dtype(): def test_no_pixels_in_observation(): """Test processor when no pixels are in observation.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() observation = {"other_data": np.array([1, 2, 3])} transition = create_transition(observation=observation) @@ -149,9 +134,9 @@ def test_no_pixels_in_observation(): def test_none_observation(): """Test processor with None observation.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() - transition = create_transition() + transition = create_transition(observation={}) result = processor(transition) assert result == transition @@ -159,7 +144,7 @@ def test_none_observation(): def test_serialization_methods(): """Test serialization methods.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Test get_config config = processor.get_config() @@ -178,7 +163,7 @@ def test_serialization_methods(): def test_process_environment_state(): """Test processing environment_state.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) observation = {"environment_state": env_state} @@ -199,7 +184,7 @@ def test_process_environment_state(): def test_process_agent_pos(): """Test processing agent_pos.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) observation = {"agent_pos": agent_pos} @@ -220,7 +205,7 @@ def test_process_agent_pos(): def test_process_batched_states(): """Test processing already batched states.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32) @@ -238,7 +223,7 @@ def test_process_batched_states(): def test_process_both_states(): """Test processing both environment_state and agent_pos.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() env_state = np.array([1.0, 2.0], dtype=np.float32) agent_pos = np.array([0.5, -0.5], dtype=np.float32) @@ -263,7 +248,7 @@ def test_process_both_states(): def test_no_states_in_observation(): """Test processor when no states are in observation.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() observation = {"other_data": np.array([1, 2, 3])} transition = create_transition(observation=observation) @@ -277,7 +262,7 @@ def test_no_states_in_observation(): def test_complete_observation_processing(): """Test processing a complete observation with both images and states.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Create mock data image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) @@ -314,7 +299,7 @@ def test_complete_observation_processing(): def test_image_only_processing(): """Test processing observation with only images.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) observation = {"pixels": image} @@ -329,7 +314,7 @@ def test_image_only_processing(): def test_state_only_processing(): """Test processing observation with only states.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() agent_pos = np.array([1.0, 2.0], dtype=np.float32) observation = {"agent_pos": agent_pos} @@ -344,7 +329,7 @@ def test_state_only_processing(): def test_empty_observation(): """Test processing empty observation.""" - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() observation = {} transition = create_transition(observation=observation) @@ -360,7 +345,7 @@ def test_equivalent_to_original_function(): # Import the original function for comparison from lerobot.envs.utils import preprocess_observation - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Create test data similar to what the original function expects image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) @@ -387,7 +372,7 @@ def test_equivalent_with_image_dict(): """Test equivalence with dictionary of images.""" from lerobot.envs.utils import preprocess_observation - processor = VanillaObservationProcessor() + processor = VanillaObservationProcessorStep() # Create test data with multiple cameras image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) @@ -410,77 +395,133 @@ def test_equivalent_with_image_dict(): torch.testing.assert_close(original_result[key], processor_result[key]) -def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory): - processor = VanillaObservationProcessor() +def test_image_processor_features_pixels_to_image(policy_feature_factory): + processor = VanillaObservationProcessorStep() features = { - "pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "keep": policy_feature_factory(FeatureType.ENV, (1,)), + PipelineFeatureType.OBSERVATION: { + "pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + }, } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) - assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"] - assert "pixels" not in out - assert out["keep"] == features["keep"] + assert ( + OBS_IMAGE in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][OBS_IMAGE] + == features[PipelineFeatureType.OBSERVATION]["pixels"] + ) + assert "pixels" not in out[PipelineFeatureType.OBSERVATION] + assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"] assert_contract_is_typed(out) -def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory): - processor = VanillaObservationProcessor() +def test_image_processor_features_observation_pixels_to_image(policy_feature_factory): + processor = VanillaObservationProcessorStep() features = { - "observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "keep": policy_feature_factory(FeatureType.ENV, (1,)), + PipelineFeatureType.OBSERVATION: { + "observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + }, } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) - assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"] - assert "observation.pixels" not in out - assert out["keep"] == features["keep"] + assert ( + OBS_IMAGE in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][OBS_IMAGE] + == features[PipelineFeatureType.OBSERVATION]["observation.pixels"] + ) + assert "observation.pixels" not in out[PipelineFeatureType.OBSERVATION] + assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"] assert_contract_is_typed(out) -def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory): - processor = VanillaObservationProcessor() +def test_image_processor_features_multi_camera_and_prefixed(policy_feature_factory): + processor = VanillaObservationProcessorStep() features = { - "pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "keep": policy_feature_factory(FeatureType.ENV, (7,)), + PipelineFeatureType.OBSERVATION: { + "pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (7,)), + }, } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) - assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"] - assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"] - assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"] - assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out - assert out["keep"] == features["keep"] + assert ( + f"{OBS_IMAGES}.front" in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.front"] + == features[PipelineFeatureType.OBSERVATION]["pixels.front"] + ) + assert ( + f"{OBS_IMAGES}.wrist" in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.wrist"] + == features[PipelineFeatureType.OBSERVATION]["pixels.wrist"] + ) + assert ( + f"{OBS_IMAGES}.rear" in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.rear"] + == features[PipelineFeatureType.OBSERVATION]["observation.pixels.rear"] + ) + assert ( + "pixels.front" not in out[PipelineFeatureType.OBSERVATION] + and "pixels.wrist" not in out[PipelineFeatureType.OBSERVATION] + and "observation.pixels.rear" not in out[PipelineFeatureType.OBSERVATION] + ) + assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"] assert_contract_is_typed(out) -def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory): - processor = VanillaObservationProcessor() +def test_state_processor_features_environment_and_agent_pos(policy_feature_factory): + processor = VanillaObservationProcessorStep() features = { - "environment_state": policy_feature_factory(FeatureType.STATE, (3,)), - "agent_pos": policy_feature_factory(FeatureType.STATE, (7,)), - "keep": policy_feature_factory(FeatureType.ENV, (1,)), + PipelineFeatureType.OBSERVATION: { + "environment_state": policy_feature_factory(FeatureType.STATE, (3,)), + "agent_pos": policy_feature_factory(FeatureType.STATE, (7,)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + }, } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) - assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"] - assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"] - assert "environment_state" not in out and "agent_pos" not in out - assert out["keep"] == features["keep"] + assert ( + OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE] + == features[PipelineFeatureType.OBSERVATION]["environment_state"] + ) + assert ( + OBS_STATE in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][OBS_STATE] + == features[PipelineFeatureType.OBSERVATION]["agent_pos"] + ) + assert ( + "environment_state" not in out[PipelineFeatureType.OBSERVATION] + and "agent_pos" not in out[PipelineFeatureType.OBSERVATION] + ) + assert out[PipelineFeatureType.OBSERVATION]["keep"] == features[PipelineFeatureType.OBSERVATION]["keep"] assert_contract_is_typed(out) -def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory): - proc = VanillaObservationProcessor() +def test_state_processor_features_prefixed_inputs(policy_feature_factory): + proc = VanillaObservationProcessorStep() features = { - "observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), - "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), + PipelineFeatureType.OBSERVATION: { + "observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), + "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), + }, } - out = proc.feature_contract(features.copy()) + out = proc.transform_features(features.copy()) - assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"] - assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"] - assert "environment_state" not in out and "agent_pos" not in out + assert ( + OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE] + == features[PipelineFeatureType.OBSERVATION]["observation.environment_state"] + ) + assert ( + OBS_STATE in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION][OBS_STATE] + == features[PipelineFeatureType.OBSERVATION]["observation.agent_pos"] + ) + assert ( + "environment_state" not in out[PipelineFeatureType.OBSERVATION] + and "agent_pos" not in out[PipelineFeatureType.OBSERVATION] + ) assert_contract_is_typed(out) diff --git a/tests/processor/test_pi0_processor.py b/tests/processor/test_pi0_processor.py new file mode 100644 index 000000000..c481cb18f --- /dev/null +++ b/tests/processor/test_pi0_processor.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for PI0 policy processor.""" + +from unittest.mock import patch + +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE +from lerobot.policies.pi0.configuration_pi0 import PI0Config +from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre_post_processors +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + EnvTransition, + NormalizerProcessorStep, + ProcessorStep, + RenameObservationsProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import create_transition, transition_to_batch + + +class MockTokenizerProcessorStep(ProcessorStep): + """Mock tokenizer processor step for testing.""" + + def __init__(self, *args, **kwargs): + # Accept any arguments to mimic the real TokenizerProcessorStep interface + pass + + def __call__(self, transition: EnvTransition) -> EnvTransition: + # Pass through transition unchanged + return transition + + def transform_features(self, features): + # Pass through features unchanged + return features + + +def create_default_config(): + """Create a default PI0 configuration for testing.""" + config = PI0Config() + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)), + } + config.normalization_mapping = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + config.device = "cpu" + config.tokenizer_max_length = 128 + return config + + +def create_default_stats(): + """Create default dataset statistics for testing.""" + return { + OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)}, + OBS_IMAGE: {}, # No normalization for images + ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)}, + } + + +def test_make_pi0_processor_basic(): + """Test basic creation of PI0 processor.""" + config = create_default_config() + stats = create_default_stats() + + with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): + preprocessor, postprocessor = make_pi0_pre_post_processors( + config, + stats, + ) + + # Check processor names + assert preprocessor.name == "policy_preprocessor" + assert postprocessor.name == "policy_postprocessor" + + # Check steps in preprocessor + assert len(preprocessor.steps) == 6 + assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], Pi0NewLineProcessor) + # Step 3 would be TokenizerProcessorStep but it's mocked + assert isinstance(preprocessor.steps[4], DeviceProcessorStep) + assert isinstance(preprocessor.steps[5], NormalizerProcessorStep) + + # Check steps in postprocessor + assert len(postprocessor.steps) == 2 + assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) + assert isinstance(postprocessor.steps[1], DeviceProcessorStep) + + +def test_pi0_newline_processor_single_task(): + """Test Pi0NewLineProcessor with single task string.""" + processor = Pi0NewLineProcessor() + + # Test with task that doesn't have newline + transition = create_transition(complementary_data={"task": "test task"}) + result = processor(transition) + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n" + + # Test with task that already has newline + transition = create_transition(complementary_data={"task": "test task\n"}) + result = processor(transition) + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n" + + +def test_pi0_newline_processor_list_of_tasks(): + """Test Pi0NewLineProcessor with list of task strings.""" + processor = Pi0NewLineProcessor() + + # Test with list of tasks + tasks = ["task1", "task2\n", "task3"] + transition = create_transition(complementary_data={"task": tasks}) + result = processor(transition) + expected = ["task1\n", "task2\n", "task3\n"] + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == expected + + +def test_pi0_newline_processor_empty_transition(): + """Test Pi0NewLineProcessor with empty transition.""" + processor = Pi0NewLineProcessor() + + # Test with no complementary_data + transition = create_transition() + result = processor(transition) + assert result == transition + + # Test with complementary_data but no task + transition = create_transition(complementary_data={"other": "data"}) + result = processor(transition) + assert result == transition + + # Test with None task + transition = create_transition(complementary_data={"task": None}) + result = processor(transition) + assert result == transition + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_pi0_processor_cuda(): + """Test PI0 processor with CUDA device.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + # Mock the tokenizer processor to act as pass-through + class MockTokenizerProcessorStep(ProcessorStep): + def __init__(self, *args, **kwargs): + pass + + def __call__(self, transition): + return transition + + def state_dict(self): + return {} + + def load_state_dict(self, state): + pass + + def reset(self): + pass + + def get_config(self): + return {"tokenizer_name": "google/paligemma-3b-pt-224"} + + def transform_features(self, features): + return features + + with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): + preprocessor, postprocessor = make_pi0_pre_post_processors( + config, + stats, + ) + + # Create CPU data + observation = { + OBS_STATE: torch.randn(10), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action, complementary_data={"task": "test task"}) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data is on CUDA + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_pi0_processor_accelerate_scenario(): + """Test PI0 processor in simulated Accelerate scenario.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + # Mock the tokenizer processor to act as pass-through + class MockTokenizerProcessorStep(ProcessorStep): + def __init__(self, *args, **kwargs): + pass + + def __call__(self, transition): + return transition + + def state_dict(self): + return {} + + def load_state_dict(self, state): + pass + + def reset(self): + pass + + def get_config(self): + return {"tokenizer_name": "google/paligemma-3b-pt-224"} + + def transform_features(self, features): + return features + + with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): + preprocessor, postprocessor = make_pi0_pre_post_processors( + config, + stats, + ) + + # Simulate Accelerate: data already on GPU and batched + device = torch.device("cuda:0") + observation = { + OBS_STATE: torch.randn(1, 10).to(device), + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), + } + action = torch.randn(1, 6).to(device) + transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data stays on same GPU + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_pi0_processor_multi_gpu(): + """Test PI0 processor with multi-GPU setup.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + # Mock the tokenizer processor to act as pass-through + class MockTokenizerProcessorStep(ProcessorStep): + def __init__(self, *args, **kwargs): + pass + + def __call__(self, transition): + return transition + + def state_dict(self): + return {} + + def load_state_dict(self, state): + pass + + def reset(self): + pass + + def get_config(self): + return {"tokenizer_name": "google/paligemma-3b-pt-224"} + + def transform_features(self, features): + return features + + with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): + preprocessor, postprocessor = make_pi0_pre_post_processors( + config, + stats, + ) + + # Simulate data on different GPU + device = torch.device("cuda:1") + observation = { + OBS_STATE: torch.randn(1, 10).to(device), + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), + } + action = torch.randn(1, 6).to(device) + transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data stays on cuda:1 + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +def test_pi0_processor_without_stats(): + """Test PI0 processor creation without dataset statistics.""" + config = create_default_config() + + # Mock the tokenizer processor + with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): + preprocessor, postprocessor = make_pi0_pre_post_processors( + config, + dataset_stats=None, + ) + + # Should still create processors + assert preprocessor is not None + assert postprocessor is not None + + +def test_pi0_newline_processor_state_dict(): + """Test Pi0NewLineProcessor state dict methods.""" + processor = Pi0NewLineProcessor() + + # Test state_dict (should be empty) + state = processor.state_dict() + assert state == {} + + # Test load_state_dict (should do nothing) + processor.load_state_dict({}) + + # Test reset (should do nothing) + processor.reset() + + # Test get_config + config = processor.get_config() + assert config == {} + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_pi0_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + stats = create_default_stats() + config.device = "cuda" + + with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): + preprocessor, _ = make_pi0_pre_post_processors( + config, + stats, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + norm_step = step # Now type checker knows this is NormalizerProcessorStep + modified_steps.append( + NormalizerProcessorStep( + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration (PI0 has NormalizerProcessorStep at index 5) + normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data with both state and visual observations + observation = { + OBS_STATE: torch.randn(10, dtype=torch.float32), # PI0 expects size 10 + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6 + transition = create_transition( + observation, action, complementary_data={"task": "test bfloat16 adaptation"} + ) + batch = transition_to_batch(transition) + + # Process through full pipeline + processed = preprocessor(batch) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + # Check state stats (has normalization) + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 + # OBS_IMAGE uses IDENTITY normalization, so no stats to check diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 5665d5a7d..0d17fed00 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -25,29 +25,21 @@ import pytest import torch import torch.nn as nn -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor -from lerobot.processor.pipeline import TransitionKey +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features +from lerobot.processor import ( + DataProcessorPipeline, + EnvTransition, + ProcessorStep, + ProcessorStepRegistry, + TransitionKey, +) +from lerobot.processor.converters import create_transition, identity_transition from tests.conftest import assert_contract_is_typed -def create_transition( - observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None -): - """Helper to create an EnvTransition dictionary.""" - return { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: action, - TransitionKey.REWARD: reward, - TransitionKey.DONE: done, - TransitionKey.TRUNCATED: truncated, - TransitionKey.INFO: info if info is not None else {}, - TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {}, - } - - @dataclass -class MockStep: +class MockStep(ProcessorStep): """Mock pipeline step for testing - demonstrates best practices. This example shows the proper separation: @@ -90,13 +82,15 @@ class MockStep: def reset(self) -> None: self.counter = 0 - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features @dataclass -class MockStepWithoutOptionalMethods: +class MockStepWithoutOptionalMethods(ProcessorStep): """Mock step that only implements the required __call__ method.""" multiplier: float = 2.0 @@ -112,13 +106,15 @@ class MockStepWithoutOptionalMethods: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features @dataclass -class MockStepWithTensorState: +class MockStepWithTensorState(ProcessorStep): """Mock step demonstrating mixed JSON attributes and tensor state.""" name: str = "tensor_step" @@ -168,14 +164,16 @@ class MockStepWithTensorState: self.running_mean.zero_() self.running_count.zero_() - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features def test_empty_pipeline(): """Test pipeline with no steps.""" - pipeline = RobotProcessor() + pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition) transition = create_transition() result = pipeline(transition) @@ -187,7 +185,7 @@ def test_empty_pipeline(): def test_single_step_pipeline(): """Test pipeline with a single step.""" step = MockStep("test_step") - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step], to_transition=identity_transition, to_output=identity_transition) transition = create_transition() result = pipeline(transition) @@ -204,7 +202,9 @@ def test_multiple_steps_pipeline(): """Test pipeline with multiple steps.""" step1 = MockStep("step1") step2 = MockStep("step2") - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline( + [step1, step2], to_transition=identity_transition, to_output=identity_transition + ) transition = create_transition() result = pipeline(transition) @@ -216,7 +216,7 @@ def test_multiple_steps_pipeline(): def test_invalid_transition_format(): """Test pipeline with invalid transition format.""" - pipeline = RobotProcessor([MockStep()]) + pipeline = DataProcessorPipeline([MockStep()]) # Test with wrong type (tuple instead of dict) with pytest.raises(ValueError, match="EnvTransition must be a dictionary"): @@ -231,7 +231,7 @@ def test_step_through(): """Test step_through method with dict input.""" step1 = MockStep("step1") step2 = MockStep("step2") - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline([step1, step2]) transition = create_transition() @@ -252,7 +252,7 @@ def test_step_through_with_dict(): """Test step_through method with dict input.""" step1 = MockStep("step1") step2 = MockStep("step2") - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline([step1, step2]) batch = { "observation.image": None, @@ -291,7 +291,7 @@ def test_step_through_with_dict(): def test_step_through_no_hooks(): """Test that step_through doesn't execute hooks.""" step = MockStep("test_step") - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) hook_calls = [] @@ -326,7 +326,7 @@ def test_indexing(): """Test pipeline indexing.""" step1 = MockStep("step1") step2 = MockStep("step2") - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline([step1, step2]) # Test integer indexing assert pipeline[0] is step1 @@ -334,7 +334,7 @@ def test_indexing(): # Test slice indexing sub_pipeline = pipeline[0:1] - assert isinstance(sub_pipeline, RobotProcessor) + assert isinstance(sub_pipeline, DataProcessorPipeline) assert len(sub_pipeline) == 1 assert sub_pipeline[0] is step1 @@ -342,7 +342,7 @@ def test_indexing(): def test_hooks(): """Test before/after step hooks.""" step = MockStep("test_step") - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) before_calls = [] after_calls = [] @@ -366,7 +366,7 @@ def test_hooks(): def test_unregister_hooks(): """Test unregistering hooks from the pipeline.""" step = MockStep("test_step") - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) # Test before_step_hook before_calls = [] @@ -405,7 +405,7 @@ def test_unregister_hooks(): def test_unregister_nonexistent_hook(): """Test error handling when unregistering hooks that don't exist.""" - pipeline = RobotProcessor([MockStep()]) + pipeline = DataProcessorPipeline([MockStep()]) def some_hook(idx: int, transition: EnvTransition): pass @@ -423,7 +423,7 @@ def test_unregister_nonexistent_hook(): def test_multiple_hooks_and_selective_unregister(): """Test registering multiple hooks and selectively unregistering them.""" - pipeline = RobotProcessor([MockStep("step1"), MockStep("step2")]) + pipeline = DataProcessorPipeline([MockStep("step1"), MockStep("step2")]) calls_1 = [] calls_2 = [] @@ -469,7 +469,7 @@ def test_multiple_hooks_and_selective_unregister(): def test_hook_execution_order_documentation(): """Test and document that hooks are executed sequentially in registration order.""" - pipeline = RobotProcessor([MockStep("step")]) + pipeline = DataProcessorPipeline([MockStep("step")]) execution_order = [] @@ -521,7 +521,7 @@ def test_save_and_load_pretrained(): step1.counter = 5 step2.counter = 10 - pipeline = RobotProcessor([step1, step2], name="TestPipeline") + pipeline = DataProcessorPipeline([step1, step2], name="TestPipeline") with tempfile.TemporaryDirectory() as tmp_dir: # Save pipeline @@ -543,7 +543,7 @@ def test_save_and_load_pretrained(): assert config["steps"][1]["config"]["counter"] == 10 # Load pipeline - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_pipeline = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="testpipeline.json") assert loaded_pipeline.name == "TestPipeline" assert len(loaded_pipeline) == 2 @@ -556,7 +556,9 @@ def test_save_and_load_pretrained(): def test_step_without_optional_methods(): """Test pipeline with steps that don't implement optional methods.""" step = MockStepWithoutOptionalMethods(multiplier=3.0) - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline( + [step], to_transition=identity_transition, to_output=identity_transition + ) # Identity for EnvTransition input/output transition = create_transition(reward=2.0) result = pipeline(transition) @@ -569,14 +571,16 @@ def test_step_without_optional_methods(): # Save/load should work even without optional methods with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json" + ) assert len(loaded_pipeline) == 1 def test_mixed_json_and_tensor_state(): """Test step with both JSON attributes and tensor state.""" step = MockStepWithTensorState(name="stats", learning_rate=0.05, window_size=5) - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) # Process some transitions with rewards for i in range(10): @@ -592,13 +596,15 @@ def test_mixed_json_and_tensor_state(): pipeline.save_pretrained(tmp_dir) # Check that both config and state files were created - config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor" - state_path = Path(tmp_dir) / "robotprocessor_step_0.safetensors" + config_path = Path(tmp_dir) / "dataprocessorpipeline.json" # Default name is "RobotProcessor" + state_path = Path(tmp_dir) / "dataprocessorpipeline_step_0.safetensors" assert config_path.exists() assert state_path.exists() # Load and verify - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json" + ) loaded_step = loaded_pipeline.steps[0] # Check JSON attributes were restored @@ -611,7 +617,7 @@ def test_mixed_json_and_tensor_state(): assert torch.allclose(loaded_step.running_mean, step.running_mean) -class MockModuleStep(nn.Module): +class MockModuleStep(ProcessorStep, nn.Module): """Mock step that inherits from nn.Module to test state_dict handling of module parameters.""" def __init__(self, input_dim: int = 10, hidden_dim: int = 5): @@ -651,23 +657,25 @@ class MockModuleStep(nn.Module): def state_dict(self) -> dict[str, torch.Tensor]: """Override to return all module parameters and buffers.""" # Get the module's state dict (includes all parameters and buffers) - return super().state_dict() + return nn.Module.state_dict(self) def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: """Override to load all module parameters and buffers.""" # Use the module's load_state_dict - super().load_state_dict(state) + nn.Module.load_state_dict(self, state) def reset(self) -> None: self.running_mean.zero_() self.counter = 0 - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features -class MockNonModuleStepWithState: +class MockNonModuleStepWithState(ProcessorStep): """Mock step that explicitly does NOT inherit from nn.Module but has tensor state. This tests the state_dict/load_state_dict path for regular classes. @@ -744,14 +752,16 @@ class MockNonModuleStepWithState: self.step_count.zero_() self.history.clear() - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features # Tests for overrides functionality @dataclass -class MockStepWithNonSerializableParam: +class MockStepWithNonSerializableParam(ProcessorStep): """Mock step that requires a non-serializable parameter.""" def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None): @@ -799,14 +809,16 @@ class MockStepWithNonSerializableParam: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features @ProcessorStepRegistry.register("registered_mock_step") @dataclass -class RegisteredMockStep: +class RegisteredMockStep(ProcessorStep): """Mock step registered in the registry.""" value: int = 42 @@ -838,8 +850,10 @@ class RegisteredMockStep: def reset(self) -> None: pass - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features @@ -859,7 +873,7 @@ def test_from_pretrained_with_overrides(): env_step = MockStepWithNonSerializableParam(name="env_step", multiplier=2.0) registered_step = RegisteredMockStep(value=100, device="cpu") - pipeline = RobotProcessor([env_step, registered_step], name="TestOverrides") + pipeline = DataProcessorPipeline([env_step, registered_step], name="TestOverrides") with tempfile.TemporaryDirectory() as tmp_dir: # Save the pipeline @@ -877,7 +891,13 @@ def test_from_pretrained_with_overrides(): "registered_mock_step": {"device": "cuda", "value": 200}, } - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="testoverrides.json", + overrides=overrides, + to_transition=identity_transition, + to_output=identity_transition, + ) # Verify the pipeline was loaded correctly assert len(loaded_pipeline) == 2 @@ -903,7 +923,7 @@ def test_from_pretrained_with_partial_overrides(): step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0) step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0) - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline([step1, step2]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -913,7 +933,13 @@ def test_from_pretrained_with_partial_overrides(): # The current implementation applies overrides to ALL steps with the same class name # Both steps will get the override - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + overrides=overrides, + to_transition=identity_transition, + to_output=identity_transition, + ) transition = create_transition(reward=1.0) result = loaded_pipeline(transition) @@ -927,7 +953,7 @@ def test_from_pretrained_with_partial_overrides(): def test_from_pretrained_invalid_override_key(): """Test that invalid override keys raise KeyError.""" step = MockStepWithNonSerializableParam() - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -936,13 +962,15 @@ def test_from_pretrained_invalid_override_key(): overrides = {"NonExistentStep": {"param": "value"}} with pytest.raises(KeyError, match="Override keys.*do not match any step"): - RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json", overrides=overrides + ) def test_from_pretrained_multiple_invalid_override_keys(): """Test that multiple invalid override keys are reported.""" step = MockStepWithNonSerializableParam() - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -951,7 +979,9 @@ def test_from_pretrained_multiple_invalid_override_keys(): overrides = {"NonExistentStep1": {"param": "value1"}, "NonExistentStep2": {"param": "value2"}} with pytest.raises(KeyError) as exc_info: - RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json", overrides=overrides + ) error_msg = str(exc_info.value) assert "NonExistentStep1" in error_msg @@ -962,7 +992,7 @@ def test_from_pretrained_multiple_invalid_override_keys(): def test_from_pretrained_registered_step_override(): """Test overriding registered steps using registry names.""" registered_step = RegisteredMockStep(value=50, device="cpu") - pipeline = RobotProcessor([registered_step]) + pipeline = DataProcessorPipeline([registered_step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -970,7 +1000,13 @@ def test_from_pretrained_registered_step_override(): # Override using registry name overrides = {"registered_mock_step": {"value": 999, "device": "cuda"}} - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + overrides=overrides, + to_transition=identity_transition, + to_output=identity_transition, + ) # Test that overrides were applied transition = create_transition() @@ -986,7 +1022,7 @@ def test_from_pretrained_mixed_registered_and_unregistered(): unregistered_step = MockStepWithNonSerializableParam(name="unregistered", multiplier=1.0) registered_step = RegisteredMockStep(value=10, device="cpu") - pipeline = RobotProcessor([unregistered_step, registered_step]) + pipeline = DataProcessorPipeline([unregistered_step, registered_step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -998,7 +1034,13 @@ def test_from_pretrained_mixed_registered_and_unregistered(): "registered_mock_step": {"value": 777}, } - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + overrides=overrides, + to_transition=identity_transition, + to_output=identity_transition, + ) # Test both steps transition = create_transition(reward=2.0) @@ -1013,13 +1055,18 @@ def test_from_pretrained_mixed_registered_and_unregistered(): def test_from_pretrained_no_overrides(): """Test that from_pretrained works without overrides (backward compatibility).""" step = MockStepWithNonSerializableParam(name="no_override", multiplier=3.0) - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) # Load without overrides - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + to_transition=identity_transition, + to_output=identity_transition, + ) assert len(loaded_pipeline) == 1 @@ -1033,13 +1080,19 @@ def test_from_pretrained_no_overrides(): def test_from_pretrained_empty_overrides(): """Test that from_pretrained works with empty overrides dict.""" step = MockStepWithNonSerializableParam(multiplier=2.0) - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) # Load with empty overrides - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides={}) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + overrides={}, + to_transition=identity_transition, + to_output=identity_transition, + ) assert len(loaded_pipeline) == 1 @@ -1053,7 +1106,7 @@ def test_from_pretrained_empty_overrides(): def test_from_pretrained_override_instantiation_error(): """Test that instantiation errors with overrides are properly reported.""" step = MockStepWithNonSerializableParam(multiplier=1.0) - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -1066,13 +1119,15 @@ def test_from_pretrained_override_instantiation_error(): } with pytest.raises(ValueError, match="Failed to instantiate processor step"): - RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json", overrides=overrides + ) def test_from_pretrained_with_state_and_overrides(): """Test that overrides work correctly with steps that have tensor state.""" step = MockStepWithTensorState(name="tensor_step", learning_rate=0.01, window_size=5) - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) # Process some data to create state for i in range(10): @@ -1090,7 +1145,9 @@ def test_from_pretrained_with_state_and_overrides(): } } - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json", overrides=overrides + ) loaded_step = loaded_pipeline.steps[0] # Check that config overrides were applied @@ -1109,7 +1166,7 @@ def test_from_pretrained_override_error_messages(): """Test that error messages for override failures are helpful.""" step1 = MockStepWithNonSerializableParam(name="step1") step2 = RegisteredMockStep() - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline([step1, step2]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -1118,7 +1175,9 @@ def test_from_pretrained_override_error_messages(): overrides = {"WrongStepName": {"param": "value"}} with pytest.raises(KeyError) as exc_info: - RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json", overrides=overrides + ) error_msg = str(exc_info.value) assert "WrongStepName" in error_msg @@ -1129,20 +1188,20 @@ def test_from_pretrained_override_error_messages(): def test_repr_empty_processor(): """Test __repr__ with empty processor.""" - pipeline = RobotProcessor() + pipeline = DataProcessorPipeline() repr_str = repr(pipeline) - expected = "RobotProcessor(name='RobotProcessor', steps=0: [])" + expected = "DataProcessorPipeline(name='DataProcessorPipeline', steps=0: [])" assert repr_str == expected def test_repr_single_step(): """Test __repr__ with single step.""" step = MockStep("test_step") - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) repr_str = repr(pipeline) - expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])" + expected = "DataProcessorPipeline(name='DataProcessorPipeline', steps=1: [MockStep])" assert repr_str == expected @@ -1150,18 +1209,18 @@ def test_repr_multiple_steps_under_limit(): """Test __repr__ with 2-3 steps (all shown).""" step1 = MockStep("step1") step2 = MockStepWithoutOptionalMethods() - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline([step1, step2]) repr_str = repr(pipeline) - expected = "RobotProcessor(name='RobotProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])" + expected = "DataProcessorPipeline(name='DataProcessorPipeline', steps=2: [MockStep, MockStepWithoutOptionalMethods])" assert repr_str == expected # Test with 3 steps (boundary case) step3 = MockStepWithTensorState() - pipeline = RobotProcessor([step1, step2, step3]) + pipeline = DataProcessorPipeline([step1, step2, step3]) repr_str = repr(pipeline) - expected = "RobotProcessor(name='RobotProcessor', steps=3: [MockStep, MockStepWithoutOptionalMethods, MockStepWithTensorState])" + expected = "DataProcessorPipeline(name='DataProcessorPipeline', steps=3: [MockStep, MockStepWithoutOptionalMethods, MockStepWithTensorState])" assert repr_str == expected @@ -1173,30 +1232,30 @@ def test_repr_many_steps_truncated(): step4 = MockModuleStep() step5 = MockNonModuleStepWithState() - pipeline = RobotProcessor([step1, step2, step3, step4, step5]) + pipeline = DataProcessorPipeline([step1, step2, step3, step4, step5]) repr_str = repr(pipeline) - expected = "RobotProcessor(name='RobotProcessor', steps=5: [MockStep, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" + expected = "DataProcessorPipeline(name='DataProcessorPipeline', steps=5: [MockStep, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" assert repr_str == expected def test_repr_with_custom_name(): """Test __repr__ with custom processor name.""" step = MockStep("test_step") - pipeline = RobotProcessor([step], name="CustomProcessor") + pipeline = DataProcessorPipeline([step], name="CustomProcessor") repr_str = repr(pipeline) - expected = "RobotProcessor(name='CustomProcessor', steps=1: [MockStep])" + expected = "DataProcessorPipeline(name='CustomProcessor', steps=1: [MockStep])" assert repr_str == expected def test_repr_with_seed(): """Test __repr__ with seed parameter.""" step = MockStep("test_step") - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) repr_str = repr(pipeline) - expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])" + expected = "DataProcessorPipeline(name='DataProcessorPipeline', steps=1: [MockStep])" assert repr_str == expected @@ -1204,20 +1263,22 @@ def test_repr_with_custom_name_and_seed(): """Test __repr__ with both custom name and seed.""" step1 = MockStep("step1") step2 = MockStepWithoutOptionalMethods() - pipeline = RobotProcessor([step1, step2], name="MyProcessor") + pipeline = DataProcessorPipeline([step1, step2], name="MyProcessor") repr_str = repr(pipeline) - expected = "RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])" + expected = ( + "DataProcessorPipeline(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])" + ) assert repr_str == expected def test_repr_without_seed(): """Test __repr__ when seed is explicitly None (should not show seed).""" step = MockStep("test_step") - pipeline = RobotProcessor([step], name="TestProcessor") + pipeline = DataProcessorPipeline([step], name="TestProcessor") repr_str = repr(pipeline) - expected = "RobotProcessor(name='TestProcessor', steps=1: [MockStep])" + expected = "DataProcessorPipeline(name='TestProcessor', steps=1: [MockStep])" assert repr_str == expected @@ -1228,10 +1289,10 @@ def test_repr_various_step_types(): step3 = MockModuleStep() step4 = MockNonModuleStepWithState() - pipeline = RobotProcessor([step1, step2, step3, step4], name="MixedSteps") + pipeline = DataProcessorPipeline([step1, step2, step3, step4], name="MixedSteps") repr_str = repr(pipeline) - expected = "RobotProcessor(name='MixedSteps', steps=4: [MockStep, MockStepWithTensorState, ..., MockNonModuleStepWithState])" + expected = "DataProcessorPipeline(name='MixedSteps', steps=4: [MockStep, MockStepWithTensorState, ..., MockNonModuleStepWithState])" assert repr_str == expected @@ -1242,10 +1303,10 @@ def test_repr_edge_case_long_names(): step3 = MockStepWithTensorState() step4 = MockNonModuleStepWithState() - pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames") + pipeline = DataProcessorPipeline([step1, step2, step3, step4], name="LongNames") repr_str = repr(pipeline) - expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" + expected = "DataProcessorPipeline(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" assert repr_str == expected @@ -1253,7 +1314,7 @@ def test_repr_edge_case_long_names(): def test_save_with_custom_config_filename(): """Test saving processor with custom config filename.""" step = MockStep("test") - pipeline = RobotProcessor([step], name="TestProcessor") + pipeline = DataProcessorPipeline([step], name="TestProcessor") with tempfile.TemporaryDirectory() as tmp_dir: # Save with custom filename @@ -1269,16 +1330,18 @@ def test_save_with_custom_config_filename(): assert config["name"] == "TestProcessor" # Load with specific filename - loaded = RobotProcessor.from_pretrained(tmp_dir, config_filename="my_custom_config.json") + loaded = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="my_custom_config.json") assert loaded.name == "TestProcessor" def test_multiple_processors_same_directory(): """Test saving multiple processors to the same directory with different config files.""" # Create different processors - preprocessor = RobotProcessor([MockStep("pre1"), MockStep("pre2")], name="preprocessor") + preprocessor = DataProcessorPipeline([MockStep("pre1"), MockStep("pre2")], name="preprocessor") - postprocessor = RobotProcessor([MockStepWithoutOptionalMethods(multiplier=0.5)], name="postprocessor") + postprocessor = DataProcessorPipeline( + [MockStepWithoutOptionalMethods(multiplier=0.5)], name="postprocessor" + ) with tempfile.TemporaryDirectory() as tmp_dir: # Save both to same directory @@ -1290,8 +1353,8 @@ def test_multiple_processors_same_directory(): assert (Path(tmp_dir) / "postprocessor.json").exists() # Load them back - loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json") - loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json") + loaded_pre = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="preprocessor.json") + loaded_post = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="postprocessor.json") assert loaded_pre.name == "preprocessor" assert loaded_post.name == "postprocessor" @@ -1299,31 +1362,34 @@ def test_multiple_processors_same_directory(): assert len(loaded_post) == 1 -def test_auto_detect_single_config(): - """Test automatic config detection when there's only one JSON file.""" +def test_explicit_config_filename_loading(): + """Test explicit config filename loading (no more auto-detection).""" step = MockStepWithTensorState() - pipeline = RobotProcessor([step], name="SingleConfig") + pipeline = DataProcessorPipeline([step], name="SingleConfig") with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) - # Load without specifying config_filename - loaded = RobotProcessor.from_pretrained(tmp_dir) + # Load with explicit config_filename (now required) + loaded = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="singleconfig.json") assert loaded.name == "SingleConfig" -def test_error_multiple_configs_no_filename(): - """Test error when multiple configs exist and no filename specified.""" - proc1 = RobotProcessor([MockStep()], name="processor1") - proc2 = RobotProcessor([MockStep()], name="processor2") +def test_explicit_config_selection_with_multiple_configs(): + """Test explicit config selection when multiple configs exist.""" + proc1 = DataProcessorPipeline([MockStep()], name="processor1") + proc2 = DataProcessorPipeline([MockStep()], name="processor2") with tempfile.TemporaryDirectory() as tmp_dir: proc1.save_pretrained(tmp_dir) proc2.save_pretrained(tmp_dir) - # Should raise error - with pytest.raises(ValueError, match="Multiple .json files found"): - RobotProcessor.from_pretrained(tmp_dir) + # Can load specific configs explicitly + loaded1 = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="processor1.json") + loaded2 = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="processor2.json") + + assert loaded1.name == "processor1" + assert loaded2.name == "processor2" def test_state_file_naming_with_indices(): @@ -1333,7 +1399,7 @@ def test_state_file_naming_with_indices(): step2 = MockStepWithTensorState(name="norm2", window_size=10) step3 = MockModuleStep(input_dim=5) - pipeline = RobotProcessor([step1, step2, step3]) + pipeline = DataProcessorPipeline([step1, step2, step3]) # Process some data to create state for i in range(5): @@ -1349,9 +1415,9 @@ def test_state_file_naming_with_indices(): # Files should be named with pipeline name prefix and indices expected_names = [ - "robotprocessor_step_0.safetensors", - "robotprocessor_step_1.safetensors", - "robotprocessor_step_2.safetensors", + "dataprocessorpipeline_step_0.safetensors", + "dataprocessorpipeline_step_1.safetensors", + "dataprocessorpipeline_step_2.safetensors", ] actual_names = [f.name for f in state_files] assert actual_names == expected_names @@ -1363,7 +1429,7 @@ def test_state_file_naming_with_registry(): # Register a test step @ProcessorStepRegistry.register("test_stateful_step") @dataclass - class TestStatefulStep: + class TestStatefulStep(ProcessorStep): value: int = 0 def __init__(self, value: int = 0): @@ -1382,15 +1448,17 @@ def test_state_file_naming_with_registry(): def load_state_dict(self, state): self.state_tensor = state["state_tensor"] - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features try: # Create pipeline with registered steps step1 = TestStatefulStep(1) step2 = TestStatefulStep(2) - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline([step1, step2]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -1401,8 +1469,8 @@ def test_state_file_naming_with_registry(): # Should include pipeline name, index and registry name expected_names = [ - "robotprocessor_step_0_test_stateful_step.safetensors", - "robotprocessor_step_1_test_stateful_step.safetensors", + "dataprocessorpipeline_step_0_test_stateful_step.safetensors", + "dataprocessorpipeline_step_1_test_stateful_step.safetensors", ] actual_names = [f.name for f in state_files] assert actual_names == expected_names @@ -1418,7 +1486,7 @@ def test_override_with_nested_config(): @ProcessorStepRegistry.register("complex_config_step") @dataclass - class ComplexConfigStep: + class ComplexConfigStep(ProcessorStep): name: str = "complex" simple_param: int = 42 nested_config: dict = None @@ -1439,21 +1507,26 @@ def test_override_with_nested_config(): def get_config(self): return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config} - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features try: step = ComplexConfigStep() - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) # Load with nested override - loaded = RobotProcessor.from_pretrained( + loaded = DataProcessorPipeline.from_pretrained( tmp_dir, + config_filename="dataprocessorpipeline.json", overrides={"complex_config_step": {"nested_config": {"level1": {"level2": "overridden"}}}}, + to_transition=identity_transition, + to_output=identity_transition, ) # Test that override worked @@ -1467,14 +1540,15 @@ def test_override_with_nested_config(): def test_override_preserves_defaults(): """Test that overrides only affect specified parameters.""" step = MockStepWithNonSerializableParam(name="test", multiplier=2.0) - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) # Override only one parameter - loaded = RobotProcessor.from_pretrained( + loaded = DataProcessorPipeline.from_pretrained( tmp_dir, + config_filename="dataprocessorpipeline.json", overrides={ "MockStepWithNonSerializableParam": { "multiplier": 5.0 # Only override multiplier @@ -1491,7 +1565,7 @@ def test_override_preserves_defaults(): def test_override_type_validation(): """Test that type errors in overrides are caught properly.""" step = MockStepWithTensorState(learning_rate=0.01) - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -1504,7 +1578,9 @@ def test_override_type_validation(): } with pytest.raises(ValueError, match="Failed to instantiate"): - RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json", overrides=overrides + ) def test_override_with_callables(): @@ -1512,7 +1588,7 @@ def test_override_with_callables(): @ProcessorStepRegistry.register("callable_step") @dataclass - class CallableStep: + class CallableStep(ProcessorStep): name: str = "callable_step" transform_fn: Any = None @@ -1531,13 +1607,15 @@ def test_override_with_callables(): def get_config(self): return {"name": self.name} - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features try: step = CallableStep() - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -1551,8 +1629,12 @@ def test_override_with_callables(): return x # Load with callable override - loaded = RobotProcessor.from_pretrained( - tmp_dir, overrides={"callable_step": {"transform_fn": double_values}} + loaded = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + overrides={"callable_step": {"transform_fn": double_values}}, + to_transition=identity_transition, + to_output=identity_transition, ) # Test it works @@ -1567,14 +1649,16 @@ def test_override_multiple_same_class_warning(): """Test behavior when multiple steps of same class exist.""" step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0) step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0) - pipeline = RobotProcessor([step1, step2]) + pipeline = DataProcessorPipeline([step1, step2]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) # Override affects all instances of the class - loaded = RobotProcessor.from_pretrained( - tmp_dir, overrides={"MockStepWithNonSerializableParam": {"multiplier": 10.0}} + loaded = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + overrides={"MockStepWithNonSerializableParam": {"multiplier": 10.0}}, ) # Both steps get the same override @@ -1589,7 +1673,7 @@ def test_override_multiple_same_class_warning(): def test_config_filename_special_characters(): """Test config filenames with special characters are sanitized.""" # Processor name with special characters - pipeline = RobotProcessor([MockStep()], name="My/Processor\\With:Special*Chars") + pipeline = DataProcessorPipeline([MockStep()], name="My/Processor\\With:Special*Chars") with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) @@ -1607,10 +1691,10 @@ def test_state_file_naming_with_multiple_processors(): """Test that state files are properly prefixed with pipeline names to avoid conflicts.""" # Create two processors with state step1 = MockStepWithTensorState(name="norm", window_size=5) - preprocessor = RobotProcessor([step1], name="PreProcessor") + preprocessor = DataProcessorPipeline([step1], name="PreProcessor") step2 = MockStepWithTensorState(name="norm", window_size=10) - postprocessor = RobotProcessor([step2], name="PostProcessor") + postprocessor = DataProcessorPipeline([step2], name="PostProcessor") # Process some data to create state for i in range(3): @@ -1630,8 +1714,8 @@ def test_state_file_naming_with_multiple_processors(): assert (Path(tmp_dir) / "postprocessor_step_0.safetensors").exists() # Load both back and verify they work correctly - loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json") - loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json") + loaded_pre = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="preprocessor.json") + loaded_post = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="postprocessor.json") assert loaded_pre.name == "PreProcessor" assert loaded_post.name == "PostProcessor" @@ -1644,7 +1728,7 @@ def test_override_with_device_strings(): @ProcessorStepRegistry.register("device_aware_step") @dataclass - class DeviceAwareStep: + class DeviceAwareStep(ProcessorStep): device: str = "cpu" def __init__(self, device: str = "cpu"): @@ -1663,21 +1747,25 @@ def test_override_with_device_strings(): def load_state_dict(self, state): self.buffer = state["buffer"] - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - # We do not test feature_contract here + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # We do not test features here return features try: step = DeviceAwareStep(device="cpu") - pipeline = RobotProcessor([step]) + pipeline = DataProcessorPipeline([step]) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) # Override device if torch.cuda.is_available(): - loaded = RobotProcessor.from_pretrained( - tmp_dir, overrides={"device_aware_step": {"device": "cuda:0"}} + loaded = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="dataprocessorpipeline.json", + overrides={"device_aware_step": {"device": "cuda:0"}}, ) loaded_step = loaded.steps[0] @@ -1691,20 +1779,27 @@ def test_override_with_device_strings(): def test_from_pretrained_nonexistent_path(): """Test error handling when loading from non-existent sources.""" - from huggingface_hub.errors import HfHubHTTPError, HFValidationError + from huggingface_hub.errors import HfHubHTTPError - # Test with an invalid repo ID (too many slashes) - caught by HF validation - with pytest.raises(HFValidationError): - RobotProcessor.from_pretrained("/path/that/does/not/exist") + # Test with an invalid local path - should raise FileNotFoundError + with pytest.raises(FileNotFoundError): + DataProcessorPipeline.from_pretrained("/path/that/does/not/exist", config_filename="processor.json") - # Test with a non-existent but valid Hub repo format + # Test with a path that doesn't exist as a directory + with pytest.raises(FileNotFoundError): + DataProcessorPipeline.from_pretrained("user/repo/extra/path", config_filename="processor.json") + + # Test with a non-existent Hub repo with pytest.raises((FileNotFoundError, HfHubHTTPError)): - RobotProcessor.from_pretrained("nonexistent-user/nonexistent-repo") + DataProcessorPipeline.from_pretrained( + "nonexistent-user/nonexistent-repo", config_filename="processor.json" + ) # Test with a local directory that exists but has no config files with tempfile.TemporaryDirectory() as tmp_dir: - with pytest.raises(FileNotFoundError, match="No .json configuration files found"): - RobotProcessor.from_pretrained(tmp_dir) + # Since the directory exists but has no config, it will raise FileNotFoundError + with pytest.raises(FileNotFoundError): + DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="processor.json") def test_save_load_with_custom_converter_functions(): @@ -1733,13 +1828,15 @@ def test_save_load_with_custom_converter_functions(): } # Create processor with custom converters - pipeline = RobotProcessor([MockStep()], to_transition=custom_to_transition, to_output=custom_to_output) + pipeline = DataProcessorPipeline( + [MockStep()], to_transition=custom_to_transition, to_output=custom_to_output + ) with tempfile.TemporaryDirectory() as tmp_dir: pipeline.save_pretrained(tmp_dir) # Load - should use default converters - loaded = RobotProcessor.from_pretrained(tmp_dir) + loaded = DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="dataprocessorpipeline.json") # Verify it uses default converters by checking with standard batch format batch = { @@ -1753,35 +1850,39 @@ def test_save_load_with_custom_converter_functions(): # Should work with standard format (wouldn't work with custom converter) result = loaded(batch) - assert "observation.image" in result # Standard format preserved + # With new behavior, default to_output is _default_transition_to_batch, so result is batch dict + assert "observation.image" in result class NonCompliantStep: - """Intentionally non-compliant: missing feature_contract.""" + """Intentionally non-compliant: missing features.""" def __call__(self, transition: EnvTransition) -> EnvTransition: return transition -def test_construction_rejects_step_without_feature_contract(): - with pytest.raises(TypeError, match=r"must define feature_contract\(features\) -> dict\[str, Any\]"): - RobotProcessor([NonCompliantStep()]) - - -class NonCallableStep: +class NonCallableStep(ProcessorStep): """Intentionally non-compliant: missing __call__.""" - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: return features def test_construction_rejects_step_without_call(): - with pytest.raises(TypeError, match=r"must define __call__"): - RobotProcessor([NonCallableStep()]) + """Test that DataProcessorPipeline rejects steps that don't inherit from ProcessorStep.""" + with pytest.raises( + TypeError, match=r"Can't instantiate abstract class NonCallableStep with abstract method __call_" + ): + DataProcessorPipeline([NonCallableStep()]) + + with pytest.raises(TypeError, match=r"must inherit from ProcessorStep"): + DataProcessorPipeline([NonCompliantStep()]) @dataclass -class FeatureContractAddStep: +class FeatureContractAddStep(ProcessorStep): """Adds a PolicyFeature""" key: str = "a" @@ -1790,39 +1891,47 @@ class FeatureContractAddStep: def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - features[self.key] = self.value + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + features[PipelineFeatureType.OBSERVATION][self.key] = self.value return features @dataclass -class FeatureContractMutateStep: +class FeatureContractMutateStep(ProcessorStep): """Mutates a PolicyFeature""" key: str = "a" - fn: Callable[[PolicyFeature | None], PolicyFeature] = lambda x: x # noqa: E731 + fn: Callable[[PolicyFeature | None], PolicyFeature] = identity_transition # noqa: E731 def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - features[self.key] = self.fn(features.get(self.key)) + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + features[PipelineFeatureType.OBSERVATION][self.key] = self.fn( + features[PipelineFeatureType.OBSERVATION].get(self.key) + ) return features @dataclass -class FeatureContractBadReturnStep: +class FeatureContractBadReturnStep(ProcessorStep): """Returns a non-dict""" def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: return ["not-a-dict"] @dataclass -class FeatureContractRemoveStep: +class FeatureContractRemoveStep(ProcessorStep): """Removes a PolicyFeature""" key: str @@ -1830,32 +1939,39 @@ class FeatureContractRemoveStep: def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: - features.pop(self.key, None) + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + features[PipelineFeatureType.OBSERVATION].pop(self.key, None) return features -def test_feature_contract_orders_and_merges(policy_feature_factory): - p = RobotProcessor( +def test_features_orders_and_merges(policy_feature_factory): + p = DataProcessorPipeline( [ FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), FeatureContractMutateStep("a", lambda v: PolicyFeature(type=v.type, shape=(3,))), FeatureContractAddStep("b", policy_feature_factory(FeatureType.ENV, (2,))), ] ) - out = p.feature_contract({}) - - assert out["a"].type == FeatureType.STATE and out["a"].shape == (3,) - assert out["b"].type == FeatureType.ENV and out["b"].shape == (2,) + out = p.transform_features({PipelineFeatureType.OBSERVATION: {}}) + assert out[PipelineFeatureType.OBSERVATION]["a"].type == FeatureType.STATE and out[ + PipelineFeatureType.OBSERVATION + ]["a"].shape == (3,) + assert out[PipelineFeatureType.OBSERVATION]["b"].type == FeatureType.ENV and out[ + PipelineFeatureType.OBSERVATION + ]["b"].shape == (2,) assert_contract_is_typed(out) -def test_feature_contract_respects_initial_without_mutation(policy_feature_factory): +def test_features_respects_initial_without_mutation(policy_feature_factory): initial = { - "seed": policy_feature_factory(FeatureType.STATE, (7,)), - "nested": policy_feature_factory(FeatureType.ENV, (0,)), + PipelineFeatureType.OBSERVATION: { + "seed": policy_feature_factory(FeatureType.STATE, (7,)), + "nested": policy_feature_factory(FeatureType.ENV, (0,)), + } } - p = RobotProcessor( + p = DataProcessorPipeline( [ FeatureContractMutateStep("seed", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 1,))), FeatureContractMutateStep( @@ -1863,57 +1979,224 @@ def test_feature_contract_respects_initial_without_mutation(policy_feature_facto ), ] ) - out = p.feature_contract(initial_features=initial) + out = p.transform_features(initial_features=initial) - assert out["seed"].shape == (8,) - assert out["nested"].shape == (5,) + assert out[PipelineFeatureType.OBSERVATION]["seed"].shape == (8,) + assert out[PipelineFeatureType.OBSERVATION]["nested"].shape == (5,) # Initial dict must be preserved - assert initial["seed"].shape == (7,) - assert initial["nested"].shape == (0,) + assert initial[PipelineFeatureType.OBSERVATION]["seed"].shape == (7,) + assert initial[PipelineFeatureType.OBSERVATION]["nested"].shape == (0,) assert_contract_is_typed(out) -def test_feature_contract_type_error_on_bad_step(): - p = RobotProcessor([FeatureContractAddStep(), FeatureContractBadReturnStep()]) - with pytest.raises(TypeError, match=r"\w+\.feature_contract must return dict\[str, Any\]"): - _ = p.feature_contract({}) - - -def test_feature_contract_execution_order_tracking(): - class Track: +def test_features_execution_order_tracking(): + class Track(ProcessorStep): def __init__(self, label): self.label = label def __call__(self, transition: EnvTransition) -> EnvTransition: return transition - def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: code = {"A": 1, "B": 2, "C": 3}[self.label] - pf = features.get("order", PolicyFeature(type=FeatureType.ENV, shape=())) - features["order"] = PolicyFeature(type=pf.type, shape=pf.shape + (code,)) + pf = features[PipelineFeatureType.OBSERVATION].get( + "order", PolicyFeature(type=FeatureType.ENV, shape=()) + ) + features[PipelineFeatureType.OBSERVATION]["order"] = PolicyFeature( + type=pf.type, shape=pf.shape + (code,) + ) return features - out = RobotProcessor([Track("A"), Track("B"), Track("C")]).feature_contract({}) - assert out["order"].shape == (1, 2, 3) + out = DataProcessorPipeline([Track("A"), Track("B"), Track("C")]).transform_features( + initial_features={PipelineFeatureType.OBSERVATION: {}} + ) + assert out[PipelineFeatureType.OBSERVATION]["order"].shape == (1, 2, 3) -def test_feature_contract_remove_key(policy_feature_factory): - p = RobotProcessor( +def test_features_remove_key(policy_feature_factory): + p = DataProcessorPipeline( [ FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), FeatureContractRemoveStep("a"), ] ) - out = p.feature_contract({}) - assert "a" not in out + out = p.transform_features({PipelineFeatureType.OBSERVATION: {}}) + assert "a" not in out[PipelineFeatureType.OBSERVATION] -def test_feature_contract_remove_from_initial(policy_feature_factory): +def test_features_remove_from_initial(policy_feature_factory): initial = { - "keep": policy_feature_factory(FeatureType.STATE, (1,)), - "drop": policy_feature_factory(FeatureType.STATE, (1,)), + PipelineFeatureType.OBSERVATION: { + "keep": policy_feature_factory(FeatureType.STATE, (1,)), + "drop": policy_feature_factory(FeatureType.STATE, (1,)), + }, } - p = RobotProcessor([FeatureContractRemoveStep("drop")]) - out = p.feature_contract(initial_features=initial) - assert "drop" not in out and out["keep"] == initial["keep"] + p = DataProcessorPipeline([FeatureContractRemoveStep("drop")]) + out = p.transform_features(initial_features=initial) + assert ( + "drop" not in out[PipelineFeatureType.OBSERVATION] + and out[PipelineFeatureType.OBSERVATION]["keep"] == initial[PipelineFeatureType.OBSERVATION]["keep"] + ) + + +@dataclass +class AddActionEEAndJointFeatures(ProcessorStep): + """Adds both EE and JOINT action features.""" + + def __call__(self, tr): + return tr + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # EE features + features[PipelineFeatureType.ACTION]["action.ee.x"] = float + features[PipelineFeatureType.ACTION]["action.ee.y"] = float + # JOINT features + features[PipelineFeatureType.ACTION]["action.j1.pos"] = float + features[PipelineFeatureType.ACTION]["action.j2.pos"] = float + return features + + +@dataclass +class AddObservationStateFeatures(ProcessorStep): + """Adds state features (and optionally an image spec to test precedence).""" + + add_front_image: bool = False + front_image_shape: tuple = (240, 320, 3) + + def __call__(self, tr): + return tr + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + # State features (mix EE and a joint state) + features[PipelineFeatureType.OBSERVATION]["observation.state.ee.x"] = float + features[PipelineFeatureType.OBSERVATION]["observation.state.j1.pos"] = float + if self.add_front_image: + features[PipelineFeatureType.OBSERVATION]["observation.images.front"] = self.front_image_shape + return features + + +def test_aggregate_joint_action_only(): + rp = DataProcessorPipeline([AddActionEEAndJointFeatures()]) + initial = {PipelineFeatureType.OBSERVATION: {"front": (480, 640, 3)}, PipelineFeatureType.ACTION: {}} + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features=initial, + use_videos=True, + patterns=["action.j1.pos", "action.j2.pos"], + ) + + # Expect only "action" with joint names + assert "action" in out and "observation.state" not in out + assert out["action"]["dtype"] == "float32" + assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"} + assert out["action"]["shape"] == (len(out["action"]["names"]),) + + +def test_aggregate_ee_action_and_observation_with_videos(): + rp = DataProcessorPipeline([AddActionEEAndJointFeatures(), AddObservationStateFeatures()]) + initial = {"front": (480, 640, 3), "side": (720, 1280, 3)} + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}}, + use_videos=True, + patterns=["action.ee", "observation.state"], + ) + + # Action should pack only EE names + assert "action" in out + assert set(out["action"]["names"]) == {"ee.x", "ee.y"} + assert out["action"]["dtype"] == "float32" + + # Observation state should pack both ee.x and j1.pos as a vector + assert "observation.state" in out + assert set(out["observation.state"]["names"]) == {"ee.x", "j1.pos"} + assert out["observation.state"]["dtype"] == "float32" + + # Cameras from initial_features appear as videos + for cam in ("front", "side"): + key = f"observation.images.{cam}" + assert key in out + assert out[key]["dtype"] == "video" + assert out[key]["shape"] == initial[cam] + assert out[key]["names"] == ["height", "width", "channels"] + + +def test_aggregate_both_action_types(): + rp = DataProcessorPipeline([AddActionEEAndJointFeatures()]) + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: {}}, + use_videos=True, + patterns=["action.ee", "action.j1", "action.j2.pos"], + ) + + assert "action" in out + expected = {"ee.x", "ee.y", "j1.pos", "j2.pos"} + assert set(out["action"]["names"]) == expected + assert out["action"]["shape"] == (len(expected),) + + +def test_aggregate_images_when_use_videos_false(): + rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)]) + initial = {"back": (480, 640, 3)} + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial}, + use_videos=False, # expect "image" dtype + patterns=None, + ) + + key = "observation.images.back" + key_front = "observation.images.front" + assert key not in out + assert key_front not in out + + +def test_aggregate_images_when_use_videos_true(): + rp = DataProcessorPipeline([AddObservationStateFeatures(add_front_image=True)]) + initial = {"back": (480, 640, 3)} + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}}, + use_videos=True, + patterns=None, + ) + + key = "observation.images.front" + key_back = "observation.images.back" + assert key in out + assert key_back in out + assert out[key]["dtype"] == "video" + assert out[key_back]["dtype"] == "video" + assert out[key_back]["shape"] == initial["back"] + + +def test_initial_camera_not_overridden_by_step_image(): + # Step explicitly sets a different front image shape; initial has another shape. + # aggregate_pipeline_dataset_features should keep the step's value (setdefault behavior on initial cams). + rp = DataProcessorPipeline( + [AddObservationStateFeatures(add_front_image=True, front_image_shape=(240, 320, 3))] + ) + initial = {"front": (480, 640, 3)} # should NOT override the step-provided (240, 320, 3) + + out = aggregate_pipeline_dataset_features( + pipeline=rp, + initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial}, + use_videos=True, + patterns=["observation.images.front"], + ) + + key = "observation.images.front" + assert key in out + assert out[key]["shape"] == (240, 320, 3) # from the step, not from initial diff --git a/tests/processor/test_pipeline_from_pretrained_helpers.py b/tests/processor/test_pipeline_from_pretrained_helpers.py new file mode 100644 index 000000000..89d45cbad --- /dev/null +++ b/tests/processor/test_pipeline_from_pretrained_helpers.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for DataProcessorPipeline.from_pretrained helper methods. + +These tests focus on the individual private methods that were extracted from +the main from_pretrained method to improve modularity and testability. +""" + +import json +import tempfile +from pathlib import Path + +import pytest + +from lerobot.processor.pipeline import DataProcessorPipeline, ProcessorMigrationError + +# Simplified Config Loading Tests + + +def test_load_config_directory(): + """Test loading config from directory.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create a config file + config_file = tmp_path / "processor.json" + test_config = {"name": "TestProcessor", "steps": []} + config_file.write_text(json.dumps(test_config)) + + # Load from directory + loaded_config, base_path = DataProcessorPipeline._load_config(str(tmp_path), "processor.json", {}) + + assert loaded_config == test_config + assert base_path == tmp_path + + +def test_load_config_single_file(): + """Test loading config from a single file path.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create a config file + config_file = tmp_path / "processor.json" + test_config = {"name": "TestProcessor", "steps": []} + config_file.write_text(json.dumps(test_config)) + + # Load using file path directly + loaded_config, base_path = DataProcessorPipeline._load_config( + str(config_file), "any_filename_ignored", {} + ) + + assert loaded_config == test_config + assert base_path == tmp_path + + +def test_load_config_directory_file_not_found(): + """Test directory loading when config file doesn't exist.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Directory exists but no processor.json + with pytest.raises(FileNotFoundError, match="not found in directory"): + DataProcessorPipeline._load_config(str(tmp_path), "processor.json", {}) + + +def test_load_config_directory_with_migration_detection(): + """Test that missing config triggers migration detection.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create old-style config to trigger migration + (tmp_path / "config.json").write_text(json.dumps({"type": "act"})) + + # Try to load processor.json (doesn't exist), should trigger migration + with pytest.raises(ProcessorMigrationError): + DataProcessorPipeline._load_config(str(tmp_path), "processor.json", {}) + + +def test_load_config_nonexistent_path_tries_hub(): + """Test that nonexistent paths try Hub (simplified logic).""" + # This path doesn't exist locally, should try Hub + with pytest.raises(FileNotFoundError, match="on the HuggingFace Hub"): + DataProcessorPipeline._load_config("nonexistent/path", "processor.json", {}) + + +# Config Validation Tests + + +def test_validate_loaded_config_valid_config(): + """Test validation with valid processor config.""" + valid_config = {"name": "TestProcessor", "steps": []} + + # Should not raise any exception + DataProcessorPipeline._validate_loaded_config("any-path", valid_config, "processor.json") + + +def test_validate_loaded_config_invalid_config(): + """Test validation with invalid processor config.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create non-processor config to trigger migration + (tmp_path / "config.json").write_text(json.dumps({"type": "act"})) + + invalid_config = {"type": "act", "hidden_dim": 256} + + with pytest.raises(ProcessorMigrationError): + DataProcessorPipeline._validate_loaded_config(str(tmp_path), invalid_config, "config.json") + + +def test_validate_loaded_config_invalid_config_no_migration(): + """Test validation with invalid config when no migration is detected.""" + # Non-directory path (Hub repo) - no migration detection + invalid_config = {"type": "act", "hidden_dim": 256} + + with pytest.raises(ValueError, match="not a valid processor configuration"): + DataProcessorPipeline._validate_loaded_config("user/repo", invalid_config, "config.json") + + +# Step Class Resolution Tests + + +def test_resolve_step_class_registry_name(): + """Test resolution using registry name.""" + from lerobot.processor.pipeline import ProcessorStep, ProcessorStepRegistry + + # Register a test step + @ProcessorStepRegistry.register("test_step") + class TestStep(ProcessorStep): + def __call__(self, transition): + return transition + + def transform_features(self, features): + return features + + try: + step_entry = {"registry_name": "test_step"} + step_class, step_key = DataProcessorPipeline._resolve_step_class(step_entry) + + assert step_class is TestStep + assert step_key == "test_step" + finally: + ProcessorStepRegistry.unregister("test_step") + + +def test_resolve_step_class_registry_name_not_found(): + """Test resolution with non-existent registry name.""" + step_entry = {"registry_name": "nonexistent_step"} + + with pytest.raises(ImportError, match="Failed to load processor step from registry"): + DataProcessorPipeline._resolve_step_class(step_entry) + + +def test_resolve_step_class_import_path(): + """Test resolution using full import path.""" + # Use a valid existing class (this should work) + step_entry = {"class": "lerobot.processor.pipeline.ProcessorStep"} + + # This should succeed - ProcessorStep can be imported, just not instantiated + step_class, step_key = DataProcessorPipeline._resolve_step_class(step_entry) + + from lerobot.processor.pipeline import ProcessorStep + + assert step_class is ProcessorStep + assert step_key == "ProcessorStep" + + +def test_resolve_step_class_invalid_import_path(): + """Test resolution with invalid import path.""" + step_entry = {"class": "nonexistent.module.ClassName"} + + with pytest.raises(ImportError, match="Failed to load processor step"): + DataProcessorPipeline._resolve_step_class(step_entry) + + +# Override Validation Tests + + +def test_validate_overrides_used_all_used(): + """Test validation when all overrides are used.""" + # Empty set means all overrides were used + remaining_overrides = set() + config = {"steps": [{"class": "SomeStep"}]} + + # Should not raise + DataProcessorPipeline._validate_overrides_used(remaining_overrides, config) + + +def test_validate_overrides_used_some_unused(): + """Test validation when some overrides are unused.""" + remaining_overrides = {"NonExistentStep", "AnotherMissingStep"} + config = { + "steps": [ + {"registry_name": "normalize_step"}, + {"class": "some.module.TransformStep"}, + ] + } + + with pytest.raises(KeyError, match="Override keys.*do not match any step"): + DataProcessorPipeline._validate_overrides_used(remaining_overrides, config) + + +def test_validate_overrides_used_helpful_error_message(): + """Test that error message includes available step keys.""" + remaining_overrides = {"WrongStep"} + config = { + "steps": [ + {"registry_name": "correct_step"}, + {"class": "module.path.CorrectClass"}, + ] + } + + with pytest.raises(KeyError) as exc_info: + DataProcessorPipeline._validate_overrides_used(remaining_overrides, config) + + error_msg = str(exc_info.value) + assert "Available step keys" in error_msg + assert "correct_step" in error_msg + assert "CorrectClass" in error_msg + + +# Integration Tests for Simplified Logic + + +def test_simplified_three_way_loading(): + """Test that the simplified 3-way loading logic works correctly.""" + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Test 1: Directory loading + config_file = tmp_path / "processor.json" + test_config = {"name": "DirectoryTest", "steps": []} + config_file.write_text(json.dumps(test_config)) + + loaded_config, base_path = DataProcessorPipeline._load_config(str(tmp_path), "processor.json", {}) + assert loaded_config["name"] == "DirectoryTest" + assert base_path == tmp_path + + # Test 2: Single file loading + loaded_config, base_path = DataProcessorPipeline._load_config( + str(config_file), "ignored_filename", {} + ) + assert loaded_config["name"] == "DirectoryTest" + assert base_path == tmp_path diff --git a/tests/processor/test_policy_robot_bridge.py b/tests/processor/test_policy_robot_bridge.py new file mode 100644 index 000000000..f3bbd9a74 --- /dev/null +++ b/tests/processor/test_policy_robot_bridge.py @@ -0,0 +1,525 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path + +import pytest +import torch + +from lerobot.configs.types import FeatureType, PipelineFeatureType +from lerobot.processor import ( + DataProcessorPipeline, + PolicyActionToRobotActionProcessorStep, + ProcessorStepRegistry, + RobotActionToPolicyActionProcessorStep, +) +from lerobot.processor.converters import identity_transition +from tests.conftest import assert_contract_is_typed + + +def test_robot_to_policy_basic_action_conversion(): + """Test basic robot action to policy action conversion.""" + motor_names = ["joint1", "joint2", "joint3"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + robot_action = { + "joint1.pos": 1.0, + "joint2.pos": 2.0, + "joint3.pos": 3.0, + } + + policy_action = processor.action(robot_action) + + assert isinstance(policy_action, torch.Tensor) + assert policy_action.shape == (3,) + torch.testing.assert_close(policy_action, torch.tensor([1.0, 2.0, 3.0])) + + +def test_robot_to_policy_action_conversion_preserves_order(): + """Test that motor names order is preserved in conversion.""" + motor_names = ["gripper", "arm", "wrist"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + robot_action = { + "arm.pos": 10.0, + "gripper.pos": 5.0, + "wrist.pos": 15.0, + } + + policy_action = processor.action(robot_action) + + expected = torch.tensor([5.0, 10.0, 15.0]) + torch.testing.assert_close(policy_action, expected) + + +def test_robot_to_policy_action_conversion_with_floats_and_tensors(): + """Test conversion with mixed float and tensor values.""" + motor_names = ["joint1", "joint2"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + robot_action = { + "joint1.pos": torch.tensor(1.5), + "joint2.pos": 2.5, # Regular float + } + + policy_action = processor.action(robot_action) + + assert isinstance(policy_action, torch.Tensor) + torch.testing.assert_close(policy_action, torch.tensor([1.5, 2.5])) + + +def test_robot_to_policy_action_length_mismatch_error(): + """Test error when robot action length doesn't match motor names.""" + motor_names = ["joint1", "joint2", "joint3"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + # Too few actions + robot_action = {"joint1.pos": 1.0, "joint2.pos": 2.0} + + with pytest.raises(ValueError, match="Action must have 3 elements, got 2"): + processor.action(robot_action) + + robot_action = { + "joint1.pos": 1.0, + "joint2.pos": 2.0, + "joint3.pos": 3.0, + "extra.pos": 4.0, + } + + with pytest.raises(ValueError, match="Action must have 3 elements, got 4"): + processor.action(robot_action) + + +def test_robot_to_policy_missing_motor_key_error(): + """Test error when robot action is missing expected motor keys.""" + motor_names = ["joint1", "joint2"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + robot_action = { + "joint1.pos": 1.0, + "wrong_key.pos": 2.0, + } + + with pytest.raises(KeyError): + processor.action(robot_action) + + +def test_robot_to_policy_transform_features(): + """Test feature transformation for robot to policy action processor.""" + motor_names = ["joint1", "joint2", "joint3"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + features = { + PipelineFeatureType.ACTION: { + "joint1.pos": {"type": FeatureType.ACTION, "shape": (1,)}, + "joint2.pos": {"type": FeatureType.ACTION, "shape": (1,)}, + "joint3.pos": {"type": FeatureType.ACTION, "shape": (1,)}, + "other_data": {"type": FeatureType.ENV, "shape": (1,)}, + } + } + + transformed = processor.transform_features(features) + + assert "action" in transformed[PipelineFeatureType.ACTION] + action_feature = transformed[PipelineFeatureType.ACTION]["action"] + assert action_feature.type == FeatureType.ACTION + assert action_feature.shape == (3,) + + assert "joint1.pos" in transformed[PipelineFeatureType.ACTION] + assert "joint2.pos" in transformed[PipelineFeatureType.ACTION] + assert "joint3.pos" in transformed[PipelineFeatureType.ACTION] + + assert "other_data" in transformed[PipelineFeatureType.ACTION] + + +def test_robot_to_policy_get_config(): + """Test configuration serialization.""" + motor_names = ["motor1", "motor2"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + config = processor.get_config() + assert config == {"motor_names": motor_names} + + +def test_robot_to_policy_state_dict(): + """Test state dict operations.""" + processor = RobotActionToPolicyActionProcessorStep(motor_names=["joint1"]) + + state = processor.state_dict() + assert state == {} + + processor.load_state_dict({}) + + +def test_robot_to_policy_single_motor(): + """Test with single motor.""" + processor = RobotActionToPolicyActionProcessorStep(motor_names=["single_joint"]) + + robot_action = {"single_joint.pos": 42.0} + policy_action = processor.action(robot_action) + + assert policy_action.shape == (1,) + torch.testing.assert_close(policy_action, torch.tensor([42.0])) + + +def test_policy_to_robot_basic_action_conversion(): + """Test basic policy action to robot action conversion.""" + motor_names = ["joint1", "joint2", "joint3"] + processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + + policy_action = torch.tensor([1.0, 2.0, 3.0]) + robot_action = processor.action(policy_action) + + assert isinstance(robot_action, dict) + assert len(robot_action) == 3 + + expected = { + "joint1.pos": 1.0, + "joint2.pos": 2.0, + "joint3.pos": 3.0, + } + + for key, expected_value in expected.items(): + assert key in robot_action + actual_value = robot_action[key] + if isinstance(actual_value, torch.Tensor): + actual_value = actual_value.item() + assert actual_value == pytest.approx(expected_value) + + +def test_policy_to_robot_action_conversion_preserves_order(): + """Test that motor names order corresponds to tensor indices.""" + motor_names = ["gripper", "arm", "wrist"] + processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + + policy_action = torch.tensor([5.0, 10.0, 15.0]) + robot_action = processor.action(policy_action) + + assert robot_action["gripper.pos"] == pytest.approx(5.0) + assert robot_action["arm.pos"] == pytest.approx(10.0) + assert robot_action["wrist.pos"] == pytest.approx(15.0) + + +def test_policy_to_robot_action_conversion_with_numpy_input(): + """Test conversion with numpy array input.""" + import numpy as np + + motor_names = ["joint1", "joint2"] + processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + + policy_action = np.array([1.5, 2.5]) + robot_action = processor.action(policy_action) + + assert robot_action["joint1.pos"] == pytest.approx(1.5) + assert robot_action["joint2.pos"] == pytest.approx(2.5) + + +def test_policy_to_robot_action_length_mismatch_error(): + """Test error when policy action length doesn't match motor names.""" + motor_names = ["joint1", "joint2", "joint3"] + processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + + policy_action = torch.tensor([1.0, 2.0]) + + with pytest.raises(ValueError, match="Action must have 3 elements, got 2"): + processor.action(policy_action) + + policy_action = torch.tensor([1.0, 2.0, 3.0, 4.0]) + + with pytest.raises(ValueError, match="Action must have 3 elements, got 4"): + processor.action(policy_action) + + +def test_policy_to_robot_transform_features(): + """Test feature transformation for policy to robot action processor.""" + motor_names = ["joint1", "joint2"] + processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + + features = { + PipelineFeatureType.ACTION: { + "action": {"type": FeatureType.ACTION, "shape": (2,)}, + "other_data": {"type": FeatureType.ENV, "shape": (1,)}, + } + } + + transformed = processor.transform_features(features) + + assert "joint1.pos" in transformed[PipelineFeatureType.ACTION] + assert "joint2.pos" in transformed[PipelineFeatureType.ACTION] + + for motor in motor_names: + motor_feature = transformed[PipelineFeatureType.ACTION][f"{motor}.pos"] + assert motor_feature.type == FeatureType.ACTION + assert motor_feature.shape == (1,) + + assert "action" in transformed[PipelineFeatureType.ACTION] + + assert "other_data" in transformed[PipelineFeatureType.ACTION] + + +def test_policy_to_robot_get_config(): + """Test configuration serialization.""" + motor_names = ["motor1", "motor2"] + processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + + config = processor.get_config() + assert config == {"motor_names": motor_names} + + +def test_policy_to_robot_state_dict(): + """Test state dict operations.""" + processor = PolicyActionToRobotActionProcessorStep(motor_names=["joint1"]) + + state = processor.state_dict() + assert state == {} + + processor.load_state_dict({}) + + +def test_policy_to_robot_single_motor(): + """Test with single motor.""" + processor = PolicyActionToRobotActionProcessorStep(motor_names=["single_joint"]) + + policy_action = torch.tensor([42.0]) + robot_action = processor.action(policy_action) + + assert len(robot_action) == 1 + assert robot_action["single_joint.pos"] == pytest.approx(42.0) + + +def test_robot_to_policy_registry(): + """Test RobotActionToPolicyActionProcessorStep registry.""" + assert "robot_action_to_policy_action_processor" in ProcessorStepRegistry.list() + + retrieved_class = ProcessorStepRegistry.get("robot_action_to_policy_action_processor") + assert retrieved_class is RobotActionToPolicyActionProcessorStep + + instance = retrieved_class(motor_names=["test"]) + assert isinstance(instance, RobotActionToPolicyActionProcessorStep) + assert instance.motor_names == ["test"] + + +def test_policy_to_robot_registry(): + """Test PolicyActionToRobotActionProcessorStep registry.""" + assert "policy_action_to_robot_action_processor" in ProcessorStepRegistry.list() + + retrieved_class = ProcessorStepRegistry.get("policy_action_to_robot_action_processor") + assert retrieved_class is PolicyActionToRobotActionProcessorStep + + instance = retrieved_class(motor_names=["test"]) + assert isinstance(instance, PolicyActionToRobotActionProcessorStep) + assert instance.motor_names == ["test"] + + +def test_save_and_load_robot_to_policy(): + """Test saving and loading RobotActionToPolicyActionProcessorStep.""" + motor_names = ["joint1", "joint2", "joint3"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + pipeline = DataProcessorPipeline([processor], name="TestRobotToPolicy") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save pipeline + pipeline.save_pretrained(tmp_dir) + + # Check config file exists + config_path = Path(tmp_dir) / "testrobottopolicy.json" + assert config_path.exists() + + # Load pipeline + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + "testrobottopolicy.json", + to_transition=identity_transition, + to_output=identity_transition, + ) + + assert loaded_pipeline.name == "TestRobotToPolicy" + assert len(loaded_pipeline) == 1 + + # Check loaded processor + loaded_processor = loaded_pipeline.steps[0] + assert isinstance(loaded_processor, RobotActionToPolicyActionProcessorStep) + assert loaded_processor.motor_names == motor_names + + # Test functionality after loading + robot_action = {"joint1.pos": 1.0, "joint2.pos": 2.0, "joint3.pos": 3.0} + policy_action = loaded_processor.action(robot_action) + torch.testing.assert_close(policy_action, torch.tensor([1.0, 2.0, 3.0])) + + +def test_save_and_load_policy_to_robot(): + """Test saving and loading PolicyActionToRobotActionProcessorStep.""" + motor_names = ["motor_a", "motor_b"] + processor = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + pipeline = DataProcessorPipeline([processor], name="TestPolicyToRobot") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save pipeline + pipeline.save_pretrained(tmp_dir) + + # Load pipeline + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + "testpolicytorobot.json", + to_transition=identity_transition, + to_output=identity_transition, + ) + + loaded_processor = loaded_pipeline.steps[0] + assert isinstance(loaded_processor, PolicyActionToRobotActionProcessorStep) + assert loaded_processor.motor_names == motor_names + + policy_action = torch.tensor([10.0, 20.0]) + robot_action = loaded_processor.action(policy_action) + assert robot_action["motor_a.pos"] == pytest.approx(10.0) + assert robot_action["motor_b.pos"] == pytest.approx(20.0) + + +# Integration and chaining tests + + +def test_round_trip_conversion(): + """Test that robot->policy->robot conversion preserves values.""" + motor_names = ["joint1", "joint2", "joint3"] + robot_to_policy = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + policy_to_robot = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + + original_robot_action = { + "joint1.pos": 1.5, + "joint2.pos": -2.3, + "joint3.pos": 0.7, + } + + policy_action = robot_to_policy.action(original_robot_action) + final_robot_action = policy_to_robot.action(policy_action) + + for key in original_robot_action: + original_val = original_robot_action[key] + final_val = final_robot_action[key] + if isinstance(final_val, torch.Tensor): + final_val = final_val.item() + assert final_val == pytest.approx(original_val, abs=1e-6) + + +def test_chained_processors_in_pipeline(): + """Test both processors chained in a pipeline.""" + motor_names = ["joint1", "joint2"] + robot_to_policy = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + policy_to_robot = PolicyActionToRobotActionProcessorStep(motor_names=motor_names) + + pipeline = DataProcessorPipeline( + [robot_to_policy, policy_to_robot], + to_transition=identity_transition, + to_output=identity_transition, + ) + + assert len(pipeline.steps) == 2 + assert isinstance(pipeline.steps[0], RobotActionToPolicyActionProcessorStep) + assert isinstance(pipeline.steps[1], PolicyActionToRobotActionProcessorStep) + + +def test_robot_to_policy_features_contract(policy_feature_factory): + """Test feature transformation maintains proper typing contract.""" + processor = RobotActionToPolicyActionProcessorStep(motor_names=["j1", "j2"]) + features = { + PipelineFeatureType.ACTION: { + "j1.pos": policy_feature_factory(FeatureType.ACTION, (1,)), + "j2.pos": policy_feature_factory(FeatureType.ACTION, (1,)), + "other": policy_feature_factory(FeatureType.ENV, (3,)), + } + } + + out = processor.transform_features(features.copy()) + + assert_contract_is_typed(out) + + assert "action" in out[PipelineFeatureType.ACTION] + action_feature = out[PipelineFeatureType.ACTION]["action"] + assert action_feature.type == FeatureType.ACTION + assert action_feature.shape == (2,) + + +def test_policy_to_robot_features_contract(policy_feature_factory): + """Test feature transformation maintains proper typing contract.""" + processor = PolicyActionToRobotActionProcessorStep(motor_names=["m1", "m2", "m3"]) + features = { + PipelineFeatureType.ACTION: { + "action": policy_feature_factory(FeatureType.ACTION, (3,)), + "other": policy_feature_factory(FeatureType.ENV, (1,)), + } + } + + out = processor.transform_features(features.copy()) + + assert_contract_is_typed(out) + + for motor in ["m1", "m2", "m3"]: + key = f"{motor}.pos" + assert key in out[PipelineFeatureType.ACTION] + motor_feature = out[PipelineFeatureType.ACTION][key] + assert motor_feature.type == FeatureType.ACTION + assert motor_feature.shape == (1,) + + +def test_empty_motor_names_list(): + """Test behavior with empty motor names list.""" + processor = RobotActionToPolicyActionProcessorStep(motor_names=[]) + + robot_action = {} + policy_action = processor.action(robot_action) + + assert isinstance(policy_action, torch.Tensor) + assert policy_action.shape == (0,) + + +def test_empty_motor_names_list_policy_to_robot(): + """Test PolicyActionToRobotActionProcessorStep with empty motor names.""" + processor = PolicyActionToRobotActionProcessorStep(motor_names=[]) + + policy_action = torch.tensor([]) + robot_action = processor.action(policy_action) + + assert isinstance(robot_action, dict) + assert len(robot_action) == 0 + + +def test_very_long_motor_names(): + """Test with many motor names.""" + motor_names = [f"joint_{i}" for i in range(100)] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + robot_action = {f"joint_{i}.pos": float(i) for i in range(100)} + policy_action = processor.action(robot_action) + + assert policy_action.shape == (100,) + expected = torch.tensor([float(i) for i in range(100)]) + torch.testing.assert_close(policy_action, expected) + + +def test_special_characters_in_motor_names(): + """Test with special characters in motor names.""" + motor_names = ["motor-1", "motor_2", "motor.3"] + processor = RobotActionToPolicyActionProcessorStep(motor_names=motor_names) + + robot_action = { + "motor-1.pos": 1.0, + "motor_2.pos": 2.0, + "motor.3.pos": 3.0, + } + + policy_action = processor.action(robot_action) + torch.testing.assert_close(policy_action, torch.tensor([1.0, 2.0, 3.0])) diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py index 229d57f9f..5f2b48576 100644 --- a/tests/processor/test_rename_processor.py +++ b/tests/processor/test_rename_processor.py @@ -19,33 +19,25 @@ from pathlib import Path import numpy as np import torch -from lerobot.configs.types import FeatureType -from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey +from lerobot.configs.types import FeatureType, PipelineFeatureType +from lerobot.processor import ( + DataProcessorPipeline, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + TransitionKey, +) +from lerobot.processor.converters import create_transition, identity_transition +from lerobot.processor.rename_processor import rename_stats from tests.conftest import assert_contract_is_typed -def create_transition( - observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None -): - """Helper to create an EnvTransition dictionary.""" - return { - TransitionKey.OBSERVATION: observation, - TransitionKey.ACTION: action, - TransitionKey.REWARD: reward, - TransitionKey.DONE: done, - TransitionKey.TRUNCATED: truncated, - TransitionKey.INFO: info, - TransitionKey.COMPLEMENTARY_DATA: complementary_data, - } - - def test_basic_renaming(): """Test basic key renaming functionality.""" rename_map = { "old_key1": "new_key1", "old_key2": "new_key2", } - processor = RenameProcessor(rename_map=rename_map) + processor = RenameObservationsProcessorStep(rename_map=rename_map) observation = { "old_key1": torch.tensor([1.0, 2.0]), @@ -73,7 +65,7 @@ def test_basic_renaming(): def test_empty_rename_map(): """Test processor with empty rename map (should pass through unchanged).""" - processor = RenameProcessor(rename_map={}) + processor = RenameObservationsProcessorStep(rename_map={}) observation = { "key1": torch.tensor([1.0]), @@ -92,9 +84,9 @@ def test_empty_rename_map(): def test_none_observation(): """Test processor with None observation.""" - processor = RenameProcessor(rename_map={"old": "new"}) + processor = RenameObservationsProcessorStep(rename_map={"old": "new"}) - transition = create_transition() + transition = create_transition(observation={}) result = processor(transition) # Should return transition unchanged @@ -107,7 +99,7 @@ def test_overlapping_rename(): "a": "b", "b": "c", # This creates a potential conflict } - processor = RenameProcessor(rename_map=rename_map) + processor = RenameObservationsProcessorStep(rename_map=rename_map) observation = { "a": 1, @@ -132,7 +124,7 @@ def test_partial_rename(): "observation.state": "observation.proprio_state", "pixels": "observation.image", } - processor = RenameProcessor(rename_map=rename_map) + processor = RenameObservationsProcessorStep(rename_map=rename_map) observation = { "observation.state": torch.randn(10), @@ -162,15 +154,15 @@ def test_get_config(): "old1": "new1", "old2": "new2", } - processor = RenameProcessor(rename_map=rename_map) + processor = RenameObservationsProcessorStep(rename_map=rename_map) config = processor.get_config() assert config == {"rename_map": rename_map} def test_state_dict(): - """Test state dict (should be empty for RenameProcessor).""" - processor = RenameProcessor(rename_map={"old": "new"}) + """Test state dict (should be empty for RenameProcessorStep).""" + processor = RenameObservationsProcessorStep(rename_map={"old": "new"}) state = processor.state_dict() assert state == {} @@ -185,9 +177,11 @@ def test_integration_with_robot_processor(): "agent_pos": "observation.state", "pixels": "observation.image", } - rename_processor = RenameProcessor(rename_map=rename_map) + rename_processor = RenameObservationsProcessorStep(rename_map=rename_map) - pipeline = RobotProcessor([rename_processor]) + pipeline = DataProcessorPipeline( + [rename_processor], to_transition=identity_transition, to_output=identity_transition + ) observation = { "agent_pos": np.array([1.0, 2.0, 3.0]), @@ -219,30 +213,37 @@ def test_save_and_load_pretrained(): "old_state": "observation.state", "old_image": "observation.image", } - processor = RenameProcessor(rename_map=rename_map) - pipeline = RobotProcessor([processor], name="TestRenameProcessor") + processor = RenameObservationsProcessorStep(rename_map=rename_map) + pipeline = DataProcessorPipeline([processor], name="TestRenameProcessorStep") with tempfile.TemporaryDirectory() as tmp_dir: # Save pipeline pipeline.save_pretrained(tmp_dir) # Check files were created - config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor" + config_path = ( + Path(tmp_dir) / "testrenameprocessorstep.json" + ) # Based on name="TestRenameProcessorStep" assert config_path.exists() - # No state files should be created for RenameProcessor + # No state files should be created for RenameProcessorStep state_files = list(Path(tmp_dir).glob("*.safetensors")) assert len(state_files) == 0 # Load pipeline - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, + config_filename="testrenameprocessorstep.json", + to_transition=identity_transition, + to_output=identity_transition, + ) - assert loaded_pipeline.name == "TestRenameProcessor" + assert loaded_pipeline.name == "TestRenameProcessorStep" assert len(loaded_pipeline) == 1 # Check that loaded processor works correctly loaded_processor = loaded_pipeline.steps[0] - assert isinstance(loaded_processor, RenameProcessor) + assert isinstance(loaded_processor, RenameObservationsProcessorStep) assert loaded_processor.rename_map == rename_map # Test functionality after loading @@ -259,24 +260,26 @@ def test_save_and_load_pretrained(): def test_registry_functionality(): - """Test that RenameProcessor is properly registered.""" + """Test that RenameProcessorStep is properly registered.""" # Check that it's registered - assert "rename_processor" in ProcessorStepRegistry.list() + assert "rename_observations_processor" in ProcessorStepRegistry.list() # Get from registry - retrieved_class = ProcessorStepRegistry.get("rename_processor") - assert retrieved_class is RenameProcessor + retrieved_class = ProcessorStepRegistry.get("rename_observations_processor") + assert retrieved_class is RenameObservationsProcessorStep # Create instance from registry instance = retrieved_class(rename_map={"old": "new"}) - assert isinstance(instance, RenameProcessor) + assert isinstance(instance, RenameObservationsProcessorStep) assert instance.rename_map == {"old": "new"} def test_registry_based_save_load(): """Test save/load using registry name instead of module path.""" - processor = RenameProcessor(rename_map={"key1": "renamed_key1"}) - pipeline = RobotProcessor([processor]) + processor = RenameObservationsProcessorStep(rename_map={"key1": "renamed_key1"}) + pipeline = DataProcessorPipeline( + [processor], to_transition=identity_transition, to_output=identity_transition + ) with tempfile.TemporaryDirectory() as tmp_dir: # Save and load @@ -285,24 +288,26 @@ def test_registry_based_save_load(): # Verify config uses registry name import json - with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor" + with open(Path(tmp_dir) / "dataprocessorpipeline.json") as f: # Default name is "RobotProcessor" config = json.load(f) assert "registry_name" in config["steps"][0] - assert config["steps"][0]["registry_name"] == "rename_processor" + assert config["steps"][0]["registry_name"] == "rename_observations_processor" assert "class" not in config["steps"][0] # Should use registry, not module path # Load should work - loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_pipeline = DataProcessorPipeline.from_pretrained( + tmp_dir, config_filename="dataprocessorpipeline.json" + ) loaded_processor = loaded_pipeline.steps[0] - assert isinstance(loaded_processor, RenameProcessor) + assert isinstance(loaded_processor, RenameObservationsProcessorStep) assert loaded_processor.rename_map == {"key1": "renamed_key1"} def test_chained_rename_processors(): - """Test multiple RenameProcessors in a pipeline.""" + """Test multiple RenameProcessorSteps in a pipeline.""" # First processor: rename raw keys to intermediate format - processor1 = RenameProcessor( + processor1 = RenameObservationsProcessorStep( rename_map={ "pos": "agent_position", "img": "camera_image", @@ -310,14 +315,16 @@ def test_chained_rename_processors(): ) # Second processor: rename to final format - processor2 = RenameProcessor( + processor2 = RenameObservationsProcessorStep( rename_map={ "agent_position": "observation.state", "camera_image": "observation.image", } ) - pipeline = RobotProcessor([processor1, processor2]) + pipeline = DataProcessorPipeline( + [processor1, processor2], to_transition=identity_transition, to_output=identity_transition + ) observation = { "pos": np.array([1.0, 2.0]), @@ -353,7 +360,7 @@ def test_nested_observation_rename(): "observation.images.right": "observation.camera.right_view", "observation.proprio": "observation.proprioception", } - processor = RenameProcessor(rename_map=rename_map) + processor = RenameObservationsProcessorStep(rename_map=rename_map) observation = { "observation.images.left": torch.randn(3, 64, 64), @@ -383,7 +390,7 @@ def test_nested_observation_rename(): def test_value_types_preserved(): """Test that various value types are preserved during renaming.""" rename_map = {"old_tensor": "new_tensor", "old_array": "new_array", "old_scalar": "new_scalar"} - processor = RenameProcessor(rename_map=rename_map) + processor = RenameObservationsProcessorStep(rename_map=rename_map) tensor_value = torch.randn(3, 3) array_value = np.random.rand(2, 2) @@ -410,58 +417,87 @@ def test_value_types_preserved(): assert processed_obs["old_list"] == [1, 2, 3] -def test_feature_contract_basic_renaming(policy_feature_factory): - processor = RenameProcessor(rename_map={"a": "x", "b": "y"}) +def test_features_basic_renaming(policy_feature_factory): + processor = RenameObservationsProcessorStep(rename_map={"a": "x", "b": "y"}) features = { - "a": policy_feature_factory(FeatureType.STATE, (2,)), - "b": policy_feature_factory(FeatureType.ACTION, (3,)), - "c": policy_feature_factory(FeatureType.ENV, (1,)), + PipelineFeatureType.OBSERVATION: { + "a": policy_feature_factory(FeatureType.VISUAL, (2,)), + "b": policy_feature_factory(FeatureType.VISUAL, (3,)), + "c": policy_feature_factory(FeatureType.VISUAL, (1,)), + }, } - out = processor.feature_contract(features.copy()) + out = processor.transform_features(features.copy()) # Values preserved and typed - assert out["x"] == features["a"] - assert out["y"] == features["b"] - assert out["c"] == features["c"] + assert out[PipelineFeatureType.OBSERVATION]["x"] == features[PipelineFeatureType.OBSERVATION]["a"] + assert out[PipelineFeatureType.OBSERVATION]["y"] == features[PipelineFeatureType.OBSERVATION]["b"] + assert out[PipelineFeatureType.OBSERVATION]["c"] == features[PipelineFeatureType.OBSERVATION]["c"] assert_contract_is_typed(out) # Input not mutated - assert set(features) == {"a", "b", "c"} + assert set(features[PipelineFeatureType.OBSERVATION]) == {"a", "b", "c"} -def test_feature_contract_overlapping_keys(policy_feature_factory): +def test_features_overlapping_keys(policy_feature_factory): # Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c' - processor = RenameProcessor(rename_map={"a": "b", "b": "c"}) + processor = RenameObservationsProcessorStep(rename_map={"a": "b", "b": "c"}) features = { - "a": policy_feature_factory(FeatureType.STATE, (1,)), - "b": policy_feature_factory(FeatureType.STATE, (2,)), + PipelineFeatureType.OBSERVATION: { + "a": policy_feature_factory(FeatureType.VISUAL, (1,)), + "b": policy_feature_factory(FeatureType.VISUAL, (2,)), + }, } - out = processor.feature_contract(features) + out = processor.transform_features(features) - assert set(out) == {"b", "c"} - assert out["b"] == features["a"] # 'a' renamed to'b' - assert out["c"] == features["b"] # 'b' renamed to 'c' + assert set(out[PipelineFeatureType.OBSERVATION]) == {"b", "c"} + assert ( + out[PipelineFeatureType.OBSERVATION]["b"] == features[PipelineFeatureType.OBSERVATION]["a"] + ) # 'a' renamed to'b' + assert ( + out[PipelineFeatureType.OBSERVATION]["c"] == features[PipelineFeatureType.OBSERVATION]["b"] + ) # 'b' renamed to 'c' assert_contract_is_typed(out) -def test_feature_contract_chained_processors(policy_feature_factory): +def test_features_chained_processors(policy_feature_factory): # Chain two rename processors at the contract level - processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"}) - processor2 = RenameProcessor( + processor1 = RenameObservationsProcessorStep(rename_map={"pos": "agent_position", "img": "camera_image"}) + processor2 = RenameObservationsProcessorStep( rename_map={"agent_position": "observation.state", "camera_image": "observation.image"} ) - pipeline = RobotProcessor([processor1, processor2]) + pipeline = DataProcessorPipeline([processor1, processor2]) spec = { - "pos": policy_feature_factory(FeatureType.STATE, (7,)), - "img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), - "extra": policy_feature_factory(FeatureType.ENV, (1,)), + PipelineFeatureType.OBSERVATION: { + "pos": policy_feature_factory(FeatureType.VISUAL, (7,)), + "img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "extra": policy_feature_factory(FeatureType.VISUAL, (1,)), + }, } - out = pipeline.feature_contract(initial_features=spec) + out = pipeline.transform_features(initial_features=spec) - assert set(out) == {"observation.state", "observation.image", "extra"} - assert out["observation.state"] == spec["pos"] - assert out["observation.image"] == spec["img"] - assert out["extra"] == spec["extra"] + assert set(out[PipelineFeatureType.OBSERVATION]) == {"observation.state", "observation.image", "extra"} + assert ( + out[PipelineFeatureType.OBSERVATION]["observation.state"] + == spec[PipelineFeatureType.OBSERVATION]["pos"] + ) + assert ( + out[PipelineFeatureType.OBSERVATION]["observation.image"] + == spec[PipelineFeatureType.OBSERVATION]["img"] + ) + assert out[PipelineFeatureType.OBSERVATION]["extra"] == spec[PipelineFeatureType.OBSERVATION]["extra"] assert_contract_is_typed(out) + + +def test_rename_stats_basic(): + orig = { + "observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}, + "action": {"mean": np.array([0.0])}, + } + mapping = {"observation.state": "observation.robot_state"} + renamed = rename_stats(orig, mapping) + assert "observation.robot_state" in renamed and "observation.state" not in renamed + # Ensure deep copy: mutate original and verify renamed unaffected + orig["observation.state"]["mean"][0] = 42.0 + assert renamed["observation.robot_state"]["mean"][0] != 42.0 diff --git a/tests/processor/test_sac_processor.py b/tests/processor/test_sac_processor.py new file mode 100644 index 000000000..7cbcb1882 --- /dev/null +++ b/tests/processor/test_sac_processor.py @@ -0,0 +1,414 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for SAC policy processor.""" + +import tempfile + +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_STATE +from lerobot.policies.sac.configuration_sac import SACConfig +from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DataProcessorPipeline, + DeviceProcessorStep, + NormalizerProcessorStep, + RenameObservationsProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import create_transition, transition_to_batch + + +def create_default_config(): + """Create a default SAC configuration for testing.""" + config = SACConfig() + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(5,)), + } + config.normalization_mapping = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + config.device = "cpu" + return config + + +def create_default_stats(): + """Create default dataset statistics for testing.""" + return { + OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)}, + ACTION: {"min": torch.full((5,), -1.0), "max": torch.ones(5)}, + } + + +def test_make_sac_processor_basic(): + """Test basic creation of SAC processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + # Check processor names + assert preprocessor.name == "policy_preprocessor" + assert postprocessor.name == "policy_postprocessor" + + # Check steps in preprocessor + assert len(preprocessor.steps) == 4 + assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + assert isinstance(preprocessor.steps[3], NormalizerProcessorStep) + + # Check steps in postprocessor + assert len(postprocessor.steps) == 2 + assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) + assert isinstance(postprocessor.steps[1], DeviceProcessorStep) + + +def test_sac_processor_normalization_modes(): + """Test that SAC processor correctly handles different normalization modes.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + # Create test data + observation = {OBS_STATE: torch.randn(10) * 2} # Larger values to test normalization + action = torch.rand(5) * 2 - 1 # Range [-1, 1] + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data is normalized and batched + # State should be mean-std normalized + # Action should be min-max normalized to [-1, 1] + assert processed[OBS_STATE].shape == (1, 10) + assert processed[TransitionKey.ACTION.value].shape == (1, 5) + + # Process action through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that action is unnormalized (but still batched) + assert postprocessed.shape == (1, 5) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_sac_processor_cuda(): + """Test SAC processor with CUDA device.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + # Create CPU data + observation = {OBS_STATE: torch.randn(10)} + action = torch.randn(5) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data is on CUDA + assert processed[OBS_STATE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" + + # Process through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that action is back on CPU + assert postprocessed.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_sac_processor_accelerate_scenario(): + """Test SAC processor in simulated Accelerate scenario.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + # Simulate Accelerate: data already on GPU + device = torch.device("cuda:0") + observation = {OBS_STATE: torch.randn(10).to(device)} + action = torch.randn(5).to(device) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data stays on same GPU + assert processed[OBS_STATE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_sac_processor_multi_gpu(): + """Test SAC processor with multi-GPU setup.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + # Simulate data on different GPU + device = torch.device("cuda:1") + observation = {OBS_STATE: torch.randn(10).to(device)} + action = torch.randn(5).to(device) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data stays on cuda:1 + assert processed[OBS_STATE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +def test_sac_processor_without_stats(): + """Test SAC processor creation without dataset statistics.""" + config = create_default_config() + + preprocessor, postprocessor = make_sac_pre_post_processors(config, dataset_stats=None) + + # Should still create processors + assert preprocessor is not None + assert postprocessor is not None + + # Process should still work + observation = {OBS_STATE: torch.randn(10)} + action = torch.randn(5) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed is not None + + +def test_sac_processor_save_and_load(): + """Test saving and loading SAC processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + # Save preprocessor + preprocessor.save_pretrained(tmpdir) + + # Load preprocessor + loaded_preprocessor = DataProcessorPipeline.from_pretrained( + tmpdir, config_filename="policy_preprocessor.json" + ) + + # Test that loaded processor works + observation = {OBS_STATE: torch.randn(10)} + action = torch.randn(5) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 10) + assert processed[TransitionKey.ACTION.value].shape == (1, 5) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_sac_processor_mixed_precision(): + """Test SAC processor with mixed precision.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + # Create processor + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + # Replace DeviceProcessorStep with one that uses float16 + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) + elif isinstance(step, NormalizerProcessorStep): + # Update normalizer to use the same device as the device processor + norm_step = step # Now type checker knows this is NormalizerProcessorStep + modified_steps.append( + NormalizerProcessorStep( + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, + device=config.device, + dtype=torch.float16, # Match the float16 dtype + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Create test data + observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} + action = torch.randn(5, dtype=torch.float32) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that data is converted to float16 + assert processed[OBS_STATE].dtype == torch.float16 + assert processed[TransitionKey.ACTION.value].dtype == torch.float16 + + +def test_sac_processor_batch_data(): + """Test SAC processor with batched data.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + # Test with batched data + batch_size = 32 + observation = {OBS_STATE: torch.randn(batch_size, 10)} + action = torch.randn(batch_size, 5) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through preprocessor + processed = preprocessor(batch) + + # Check that batch dimension is preserved + assert processed[OBS_STATE].shape == (batch_size, 10) + assert processed[TransitionKey.ACTION.value].shape == (batch_size, 5) + + +def test_sac_processor_edge_cases(): + """Test SAC processor with edge cases.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_sac_pre_post_processors( + config, + stats, + ) + + # Test with observation that has no state key but still exists + observation = {"observation.dummy": torch.randn(1)} # Some dummy observation to pass validation + action = torch.randn(5) + batch = {TransitionKey.ACTION.value: action, **observation} + processed = preprocessor(batch) + # observation.state wasn't in original, so it won't be in processed + assert OBS_STATE not in processed + assert processed[TransitionKey.ACTION.value].shape == (1, 5) + + # Test with zero action (representing "null" action) + transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=torch.zeros(5)) + batch = transition_to_batch(transition) + processed = preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 10) + # Action should be present and batched, even if it's zeros + assert processed[TransitionKey.ACTION.value].shape == (1, 5) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_sac_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, _ = make_sac_pre_post_processors( + config, + stats, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + norm_step = step # Now type checker knows this is NormalizerProcessorStep + modified_steps.append( + NormalizerProcessorStep( + features=norm_step.features, + norm_map=norm_step.norm_map, + stats=norm_step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration + normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data + observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} # Start with float32 + action = torch.randn(5, dtype=torch.float32) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + # Process through full pipeline + processed = preprocessor(batch) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 diff --git a/tests/processor/test_smolvla_processor.py b/tests/processor/test_smolvla_processor.py new file mode 100644 index 000000000..ce162c10d --- /dev/null +++ b/tests/processor/test_smolvla_processor.py @@ -0,0 +1,459 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for SmolVLA policy processor.""" + +from unittest.mock import patch + +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature +from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE +from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig +from lerobot.policies.smolvla.processor_smolvla import ( + SmolVLANewLineProcessor, + make_smolvla_pre_post_processors, +) +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + EnvTransition, + NormalizerProcessorStep, + ProcessorStep, + RenameObservationsProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import create_transition, transition_to_batch + + +class MockTokenizerProcessorStep(ProcessorStep): + """Mock tokenizer processor step for testing.""" + + def __init__(self, *args, **kwargs): + # Accept any arguments to mimic the real TokenizerProcessorStep interface + pass + + def __call__(self, transition: EnvTransition) -> EnvTransition: + # Pass through transition unchanged + return transition + + def transform_features(self, features): + # Pass through features unchanged + return features + + +def create_default_config(): + """Create a default SmolVLA configuration for testing.""" + config = SmolVLAConfig() + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + config.normalization_mapping = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + config.device = "cpu" + config.vlm_model_name = "HuggingFaceTB/SmolVLM-Instruct" + config.pad_language_to = "max_length" + config.tokenizer_max_length = 100 + return config + + +def create_default_stats(): + """Create default dataset statistics for testing.""" + return { + OBS_STATE: {"mean": torch.zeros(8), "std": torch.ones(8)}, + OBS_IMAGE: {}, # No normalization for images + ACTION: {"min": torch.full((7,), -1.0), "max": torch.ones(7)}, + } + + +def test_make_smolvla_processor_basic(): + """Test basic creation of SmolVLA processor.""" + config = create_default_config() + stats = create_default_stats() + + with patch( + "lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep + ): + preprocessor, postprocessor = make_smolvla_pre_post_processors( + config, + stats, + ) + + # Check processor names + assert preprocessor.name == "policy_preprocessor" + assert postprocessor.name == "policy_postprocessor" + + # Check steps in preprocessor + assert len(preprocessor.steps) == 6 + assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], SmolVLANewLineProcessor) + # Step 3 would be TokenizerProcessorStep but it's mocked + assert isinstance(preprocessor.steps[4], DeviceProcessorStep) + assert isinstance(preprocessor.steps[5], NormalizerProcessorStep) + + # Check steps in postprocessor + assert len(postprocessor.steps) == 2 + assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) + assert isinstance(postprocessor.steps[1], DeviceProcessorStep) + + +def test_smolvla_newline_processor_single_task(): + """Test SmolVLANewLineProcessor with single task string.""" + processor = SmolVLANewLineProcessor() + + # Test with task that doesn't have newline + transition = create_transition(complementary_data={"task": "test task"}) + result = processor(transition) + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n" + + # Test with task that already has newline + transition = create_transition(complementary_data={"task": "test task\n"}) + result = processor(transition) + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n" + + +def test_smolvla_newline_processor_list_of_tasks(): + """Test SmolVLANewLineProcessor with list of task strings.""" + processor = SmolVLANewLineProcessor() + + # Test with list of tasks + tasks = ["task1", "task2\n", "task3"] + transition = create_transition(complementary_data={"task": tasks}) + result = processor(transition) + expected = ["task1\n", "task2\n", "task3\n"] + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == expected + + +def test_smolvla_newline_processor_empty_transition(): + """Test SmolVLANewLineProcessor with empty transition.""" + processor = SmolVLANewLineProcessor() + + # Test with no complementary_data + transition = create_transition() + result = processor(transition) + assert result == transition + + # Test with complementary_data but no task + transition = create_transition(complementary_data={"other": "data"}) + result = processor(transition) + assert result == transition + + # Test with None task + transition = create_transition(complementary_data={"task": None}) + result = processor(transition) + assert result == transition + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_smolvla_processor_cuda(): + """Test SmolVLA processor with CUDA device.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + # Mock the tokenizer processor to act as pass-through + class MockTokenizerProcessorStep(ProcessorStep): + def __init__(self, *args, **kwargs): + pass + + def __call__(self, transition): + return transition + + def state_dict(self): + return {} + + def load_state_dict(self, state): + pass + + def reset(self): + pass + + def get_config(self): + return {"tokenizer_name": "HuggingFaceTB/SmolVLM-Instruct"} + + def transform_features(self, features): + return features + + with patch( + "lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep + ): + preprocessor, postprocessor = make_smolvla_pre_post_processors( + config, + stats, + ) + + # Create CPU data + observation = { + OBS_STATE: torch.randn(8), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(7) + transition = create_transition(observation, action, complementary_data={"task": "test task"}) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is on CUDA + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_smolvla_processor_accelerate_scenario(): + """Test SmolVLA processor in simulated Accelerate scenario.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + # Mock the tokenizer processor to act as pass-through + class MockTokenizerProcessorStep(ProcessorStep): + def __init__(self, *args, **kwargs): + pass + + def __call__(self, transition): + return transition + + def state_dict(self): + return {} + + def load_state_dict(self, state): + pass + + def reset(self): + pass + + def get_config(self): + return {"tokenizer_name": "HuggingFaceTB/SmolVLM-Instruct"} + + def transform_features(self, features): + return features + + with patch( + "lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep + ): + preprocessor, postprocessor = make_smolvla_pre_post_processors( + config, + stats, + ) + + # Simulate Accelerate: data already on GPU and batched + device = torch.device("cuda:0") + observation = { + OBS_STATE: torch.randn(1, 8).to(device), + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), + } + action = torch.randn(1, 7).to(device) + transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on same GPU + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_smolvla_processor_multi_gpu(): + """Test SmolVLA processor with multi-GPU setup.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + # Mock the tokenizer processor to act as pass-through + class MockTokenizerProcessorStep(ProcessorStep): + def __init__(self, *args, **kwargs): + pass + + def __call__(self, transition): + return transition + + def state_dict(self): + return {} + + def load_state_dict(self, state): + pass + + def reset(self): + pass + + def get_config(self): + return {"tokenizer_name": "HuggingFaceTB/SmolVLM-Instruct"} + + def transform_features(self, features): + return features + + with patch( + "lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep + ): + preprocessor, postprocessor = make_smolvla_pre_post_processors( + config, + stats, + ) + + # Simulate data on different GPU + device = torch.device("cuda:1") + observation = { + OBS_STATE: torch.randn(1, 8).to(device), + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), + } + action = torch.randn(1, 7).to(device) + transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on cuda:1 + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +def test_smolvla_processor_without_stats(): + """Test SmolVLA processor creation without dataset statistics.""" + config = create_default_config() + + # Mock the tokenizer processor + with patch( + "lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep + ): + preprocessor, postprocessor = make_smolvla_pre_post_processors( + config, + dataset_stats=None, + ) + + # Should still create processors + assert preprocessor is not None + assert postprocessor is not None + + +def test_smolvla_newline_processor_state_dict(): + """Test SmolVLANewLineProcessor state dict methods.""" + processor = SmolVLANewLineProcessor() + + # Test state_dict (should be empty) + state = processor.state_dict() + assert state == {} + + # Test load_state_dict (should do nothing) + processor.load_state_dict({}) + + # Test reset (should do nothing) + processor.reset() + + # Test get_config + config = processor.get_config() + assert config == {} + + +def test_smolvla_newline_processor_transform_features(): + """Test SmolVLANewLineProcessor transform_features method.""" + processor = SmolVLANewLineProcessor() + + # Test transform_features + features = { + PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + } + result = processor.transform_features(features) + assert result == features # Should return unchanged + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_smolvla_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + with patch( + "lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep + ): + preprocessor, _ = make_smolvla_pre_post_processors( + config, + stats, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration (SmolVLA has NormalizerProcessorStep at index 5) + normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data with both state and visual observations + observation = { + OBS_STATE: torch.randn(8, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(7, dtype=torch.float32) + transition = create_transition( + observation, action, complementary_data={"task": "test bfloat16 adaptation"} + ) + + batch = transition_to_batch(transition) + + # Process through full pipeline + processed = preprocessor(batch) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + # Check state stats (has normalization) + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 + # OBS_IMAGE uses IDENTITY normalization, so no stats to check diff --git a/tests/processor/test_tdmpc_processor.py b/tests/processor/test_tdmpc_processor.py new file mode 100644 index 000000000..20979fd6d --- /dev/null +++ b/tests/processor/test_tdmpc_processor.py @@ -0,0 +1,467 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for TDMPC policy processor.""" + +import tempfile + +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE +from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig +from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DataProcessorPipeline, + DeviceProcessorStep, + NormalizerProcessorStep, + RenameObservationsProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import create_transition, transition_to_batch + + +def create_default_config(): + """Create a default TDMPC configuration for testing.""" + config = TDMPCConfig() + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(12,)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)), + } + config.normalization_mapping = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + config.device = "cpu" + return config + + +def create_default_stats(): + """Create default dataset statistics for testing.""" + return { + OBS_STATE: {"mean": torch.zeros(12), "std": torch.ones(12)}, + OBS_IMAGE: {}, # No normalization for images + ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)}, + } + + +def test_make_tdmpc_processor_basic(): + """Test basic creation of TDMPC processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Check processor names + assert preprocessor.name == "policy_preprocessor" + assert postprocessor.name == "policy_postprocessor" + + # Check steps in preprocessor + assert len(preprocessor.steps) == 4 + assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + assert isinstance(preprocessor.steps[3], NormalizerProcessorStep) + + # Check steps in postprocessor + assert len(postprocessor.steps) == 2 + assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) + assert isinstance(postprocessor.steps[1], DeviceProcessorStep) + + +def test_tdmpc_processor_normalization(): + """Test that TDMPC processor correctly normalizes and unnormalizes data.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Create test data + observation = { + OBS_STATE: torch.randn(12), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is processed and batched + assert processed[OBS_STATE].shape == (1, 12) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 6) + + # Process action through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that action is unnormalized (but still batched) + assert postprocessed.shape == (1, 6) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_tdmpc_processor_cuda(): + """Test TDMPC processor with CUDA device.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Create CPU data + observation = { + OBS_STATE: torch.randn(12), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is on CUDA + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" + + # Process through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that action is back on CPU + assert postprocessed.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_tdmpc_processor_accelerate_scenario(): + """Test TDMPC processor in simulated Accelerate scenario.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Simulate Accelerate: data already on GPU + device = torch.device("cuda:0") + observation = { + OBS_STATE: torch.randn(12).to(device), + OBS_IMAGE: torch.randn(3, 224, 224).to(device), + } + action = torch.randn(6).to(device) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on same GPU + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_tdmpc_processor_multi_gpu(): + """Test TDMPC processor with multi-GPU setup.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Simulate data on different GPU + device = torch.device("cuda:1") + observation = { + OBS_STATE: torch.randn(12).to(device), + OBS_IMAGE: torch.randn(3, 224, 224).to(device), + } + action = torch.randn(6).to(device) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on cuda:1 + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +def test_tdmpc_processor_without_stats(): + """Test TDMPC processor creation without dataset statistics.""" + config = create_default_config() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, dataset_stats=None) + + # Should still create processors + assert preprocessor is not None + assert postprocessor is not None + + # Process should still work + observation = { + OBS_STATE: torch.randn(12), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action) + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed is not None + + +def test_tdmpc_processor_save_and_load(): + """Test saving and loading TDMPC processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + # Save preprocessor + preprocessor.save_pretrained(tmpdir) + + # Load preprocessor + loaded_preprocessor = DataProcessorPipeline.from_pretrained( + tmpdir, config_filename="policy_preprocessor.json" + ) + + # Test that loaded processor works + observation = { + OBS_STATE: torch.randn(12), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 12) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 6) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_tdmpc_processor_mixed_precision(): + """Test TDMPC processor with mixed precision.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + # Create processor + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Replace DeviceProcessorStep with one that uses float16 + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) + elif isinstance(step, NormalizerProcessorStep): + # Update normalizer to use the same device as the device processor + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float16, # Match the float16 dtype + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Create test data + observation = { + OBS_STATE: torch.randn(12, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(6, dtype=torch.float32) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is converted to float16 + assert processed[OBS_STATE].dtype == torch.float16 + assert processed[OBS_IMAGE].dtype == torch.float16 + assert processed[TransitionKey.ACTION.value].dtype == torch.float16 + + +def test_tdmpc_processor_batch_data(): + """Test TDMPC processor with batched data.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Test with batched data + batch_size = 64 + observation = { + OBS_STATE: torch.randn(batch_size, 12), + OBS_IMAGE: torch.randn(batch_size, 3, 224, 224), + } + action = torch.randn(batch_size, 6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that batch dimension is preserved + assert processed[OBS_STATE].shape == (batch_size, 12) + assert processed[OBS_IMAGE].shape == (batch_size, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (batch_size, 6) + + +def test_tdmpc_processor_edge_cases(): + """Test TDMPC processor with edge cases.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Test with only state observation (no image) + observation = {OBS_STATE: torch.randn(12)} + action = torch.randn(6) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 12) + assert OBS_IMAGE not in processed + + # Test with only image observation (no state) + observation = {OBS_IMAGE: torch.randn(3, 224, 224)} + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert OBS_STATE not in processed + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_tdmpc_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, _ = make_tdmpc_pre_post_processors( + config, + stats, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration + normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data with both state and visual observations + observation = { + OBS_STATE: torch.randn(12, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(6, dtype=torch.float32) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through full pipeline + processed = preprocessor(batch) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + # Check state stats (has normalization) + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 + # OBS_IMAGE uses IDENTITY normalization, so no stats to check diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py new file mode 100644 index 000000000..b3b0c9bfc --- /dev/null +++ b/tests/processor/test_tokenizer_processor.py @@ -0,0 +1,1029 @@ +""" +Tests for the TokenizerProcessorStep class. +""" + +import tempfile +from unittest.mock import patch + +import pytest +import torch + +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature +from lerobot.constants import OBS_LANGUAGE +from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey +from lerobot.processor.converters import create_transition, identity_transition +from tests.utils import require_package + + +class MockTokenizer: + """Mock tokenizer for testing that mimics transformers tokenizer interface.""" + + def __init__(self, vocab_size: int = 1000): + self.vocab_size = vocab_size + + def __call__( + self, + text: str | list[str], + max_length: int = 512, + truncation: bool = True, + padding: str = "max_length", + padding_side: str = "right", + return_tensors: str = "pt", + **kwargs, + ) -> dict[str, torch.Tensor]: + """Mock tokenization that returns deterministic tokens based on text.""" + if isinstance(text, str): + texts = [text] + else: + texts = text + + batch_size = len(texts) + + # Create mock input_ids and attention_mask + input_ids = torch.zeros(batch_size, max_length, dtype=torch.long) + attention_mask = torch.zeros(batch_size, max_length, dtype=torch.long) + + for i, txt in enumerate(texts): + # Simple mock: use hash of text to generate deterministic tokens + text_hash = hash(txt) % self.vocab_size + seq_len = min(len(txt.split()), max_length) + + # Fill input_ids with simple pattern based on text + for j in range(seq_len): + input_ids[i, j] = (text_hash + j) % self.vocab_size + + # Set attention mask for non-padded positions + attention_mask[i, :seq_len] = 1 + + result = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + # Return single sequence for single input to match transformers behavior + if len(texts) == 1: + result = {k: v.squeeze(0) for k, v in result.items()} + + return result + + +@pytest.fixture +def mock_tokenizer(): + """Provide a mock tokenizer for testing.""" + return MockTokenizer(vocab_size=100) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_basic_tokenization(mock_auto_tokenizer): + """Test basic string tokenization functionality.""" + # Mock AutoTokenizer.from_pretrained to return our mock tokenizer + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "pick up the red cube"}, + ) + + result = processor(transition) + + # Check that original task is preserved + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "pick up the red cube" + + # Check that tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + # Check token structure + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"] + assert isinstance(tokens, torch.Tensor) + assert isinstance(attention_mask, torch.Tensor) + assert tokens.shape == (10,) + assert attention_mask.shape == (10,) + + +@require_package("transformers") +def test_basic_tokenization_with_tokenizer_object(): + """Test basic string tokenization functionality using tokenizer object directly.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "pick up the red cube"}, + ) + + result = processor(transition) + + # Check that original task is preserved + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "pick up the red cube" + + # Check that tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + # Check token structure + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"] + assert isinstance(tokens, torch.Tensor) + assert isinstance(attention_mask, torch.Tensor) + assert tokens.shape == (10,) + assert attention_mask.shape == (10,) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_list_of_strings_tokenization(mock_auto_tokenizer): + """Test tokenization of a list of strings.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": ["pick up cube", "place on table"]}, + ) + + result = processor(transition) + + # Check that original task is preserved + assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == ["pick up cube", "place on table"] + + # Check that tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"] + assert tokens.shape == (2, 8) # batch_size=2, seq_len=8 + assert attention_mask.shape == (2, 8) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_custom_keys(mock_auto_tokenizer): + """Test using custom task_key.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", task_key="instruction", max_length=5) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"instruction": "move forward"}, + ) + + result = processor(transition) + + # Check that tokens are stored in observation regardless of task_key + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + assert tokens.shape == (5,) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_none_complementary_data(mock_auto_tokenizer): + """Test handling of None complementary_data.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer") + + transition = create_transition(observation={}, complementary_data=None) + + # create_transition converts None complementary_data to empty dict, so task key is missing + with pytest.raises(KeyError, match="task"): + processor(transition) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_missing_task_key(mock_auto_tokenizer): + """Test handling when task key is missing.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer") + + transition = create_transition(observation={}, complementary_data={"other_field": "some value"}) + + with pytest.raises(KeyError, match="task"): + processor(transition) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_none_task_value(mock_auto_tokenizer): + """Test handling when task value is None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer") + + transition = create_transition(observation={}, complementary_data={"task": None}) + + with pytest.raises(ValueError, match="Task extracted from Complementary data is None"): + processor(transition) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_unsupported_task_type(mock_auto_tokenizer): + """Test handling of unsupported task types.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer") + + # Test with integer task - get_task returns None, observation raises ValueError + transition = create_transition(observation={}, complementary_data={"task": 123}) + + with pytest.raises(ValueError, match="Task cannot be None"): + processor(transition) + + # Test with mixed list - get_task returns None, observation raises ValueError + transition = create_transition(observation={}, complementary_data={"task": ["text", 123, "more text"]}) + + with pytest.raises(ValueError, match="Task cannot be None"): + processor(transition) + + +@require_package("transformers") +def test_no_tokenizer_error(): + """Test that ValueError is raised when neither tokenizer nor tokenizer_name is provided.""" + with pytest.raises(ValueError, match="Either 'tokenizer' or 'tokenizer_name' must be provided"): + TokenizerProcessorStep() + + +@require_package("transformers") +def test_invalid_tokenizer_name_error(): + """Test that error is raised when invalid tokenizer_name is provided.""" + with patch("lerobot.processor.tokenizer_processor.AutoTokenizer") as mock_auto_tokenizer: + # Mock import error + mock_auto_tokenizer.from_pretrained.side_effect = Exception("Model not found") + + with pytest.raises(Exception, match="Model not found"): + TokenizerProcessorStep(tokenizer_name="invalid-tokenizer") + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_get_config_with_tokenizer_name(mock_auto_tokenizer): + """Test configuration serialization when using tokenizer_name.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep( + tokenizer_name="test-tokenizer", + max_length=256, + task_key="instruction", + padding="longest", + truncation=False, + ) + + config = processor.get_config() + + expected = { + "tokenizer_name": "test-tokenizer", + "max_length": 256, + "task_key": "instruction", + "padding_side": "right", + "padding": "longest", + "truncation": False, + } + + assert config == expected + + +@require_package("transformers") +def test_get_config_with_tokenizer_object(): + """Test configuration serialization when using tokenizer object.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + + processor = TokenizerProcessorStep( + tokenizer=mock_tokenizer, + max_length=256, + task_key="instruction", + padding="longest", + truncation=False, + ) + + config = processor.get_config() + + # tokenizer_name should not be in config when tokenizer object is used + expected = { + "max_length": 256, + "task_key": "instruction", + "padding_side": "right", + "padding": "longest", + "truncation": False, + } + + assert config == expected + assert "tokenizer_name" not in config + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_state_dict_methods(mock_auto_tokenizer): + """Test state_dict and load_state_dict methods.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer") + + # Should return empty dict + state = processor.state_dict() + assert state == {} + + # load_state_dict should not raise error + processor.load_state_dict({}) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_reset_method(mock_auto_tokenizer): + """Test reset method.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer") + + # Should not raise error + processor.reset() + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_integration_with_robot_processor(mock_auto_tokenizer): + """Test integration with RobotProcessor.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6) + robot_processor = DataProcessorPipeline( + [tokenizer_processor], to_transition=identity_transition, to_output=identity_transition + ) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "test task"}, + ) + + result = robot_processor(transition) + + # Check that observation exists and tokenization was applied + assert TransitionKey.OBSERVATION in result + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"] + assert tokens.shape == (6,) + assert attention_mask.shape == (6,) + + # Check that other data is preserved + assert torch.equal( + result[TransitionKey.OBSERVATION]["state"], transition[TransitionKey.OBSERVATION]["state"] + ) + assert torch.equal(result[TransitionKey.ACTION], transition[TransitionKey.ACTION]) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_save_and_load_pretrained_with_tokenizer_name(mock_auto_tokenizer): + """Test saving and loading processor with tokenizer_name.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + original_processor = TokenizerProcessorStep( + tokenizer_name="test-tokenizer", max_length=32, task_key="instruction" + ) + + robot_processor = DataProcessorPipeline( + [original_processor], to_transition=identity_transition, to_output=identity_transition + ) + + with tempfile.TemporaryDirectory() as temp_dir: + # Save processor + robot_processor.save_pretrained(temp_dir) + + # Load processor - tokenizer will be recreated from saved config + loaded_processor = DataProcessorPipeline.from_pretrained( + temp_dir, + config_filename="dataprocessorpipeline.json", + to_transition=identity_transition, + to_output=identity_transition, + ) + + # Test that loaded processor works + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"instruction": "test instruction"}, + ) + + result = loaded_processor(transition) + assert TransitionKey.OBSERVATION in result + assert f"{OBS_LANGUAGE}.tokens" in result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.attention_mask" in result[TransitionKey.OBSERVATION] + + +@require_package("transformers") +def test_save_and_load_pretrained_with_tokenizer_object(): + """Test saving and loading processor with tokenizer object using overrides.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + + original_processor = TokenizerProcessorStep( + tokenizer=mock_tokenizer, max_length=32, task_key="instruction" + ) + + robot_processor = DataProcessorPipeline( + [original_processor], to_transition=identity_transition, to_output=identity_transition + ) + + with tempfile.TemporaryDirectory() as temp_dir: + # Save processor + robot_processor.save_pretrained(temp_dir) + + # Load processor with tokenizer override (since tokenizer object wasn't saved) + loaded_processor = DataProcessorPipeline.from_pretrained( + temp_dir, + config_filename="dataprocessorpipeline.json", + overrides={"tokenizer_processor": {"tokenizer": mock_tokenizer}}, + to_transition=identity_transition, + to_output=identity_transition, + ) + + # Test that loaded processor works + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"instruction": "test instruction"}, + ) + + result = loaded_processor(transition) + assert TransitionKey.OBSERVATION in result + assert f"{OBS_LANGUAGE}.tokens" in result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.attention_mask" in result[TransitionKey.OBSERVATION] + + +@require_package("transformers") +def test_registry_functionality(): + """Test that the processor is properly registered.""" + from lerobot.processor import ProcessorStepRegistry + + # Check that the processor is registered + assert "tokenizer_processor" in ProcessorStepRegistry.list() + + # Check that we can retrieve it + retrieved_class = ProcessorStepRegistry.get("tokenizer_processor") + assert retrieved_class is TokenizerProcessorStep + + +@require_package("transformers") +def test_features_basic(): + """Test basic feature contract functionality.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=128) + + input_features = { + PipelineFeatureType.OBSERVATION: { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)) + }, + PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, + } + + output_features = processor.transform_features(input_features) + + # Check that original features are preserved + assert "observation.state" in output_features[PipelineFeatureType.OBSERVATION] + assert "action" in output_features[PipelineFeatureType.ACTION] + + # Check that tokenized features are added + assert f"{OBS_LANGUAGE}.tokens" in output_features[PipelineFeatureType.OBSERVATION] + assert f"{OBS_LANGUAGE}.attention_mask" in output_features[PipelineFeatureType.OBSERVATION] + + # Check feature properties + tokens_feature = output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask_feature = output_features[PipelineFeatureType.OBSERVATION][ + f"{OBS_LANGUAGE}.attention_mask" + ] + + assert tokens_feature.type == FeatureType.LANGUAGE + assert tokens_feature.shape == (128,) + assert attention_mask_feature.type == FeatureType.LANGUAGE + assert attention_mask_feature.shape == (128,) + + +@require_package("transformers") +def test_features_with_custom_max_length(): + """Test feature contract with custom max_length.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=64) + + input_features = {PipelineFeatureType.OBSERVATION: {}} + output_features = processor.transform_features(input_features) + + # Check that features use correct max_length + assert f"{OBS_LANGUAGE}.tokens" in output_features[PipelineFeatureType.OBSERVATION] + assert f"{OBS_LANGUAGE}.attention_mask" in output_features[PipelineFeatureType.OBSERVATION] + + tokens_feature = output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask_feature = output_features[PipelineFeatureType.OBSERVATION][ + f"{OBS_LANGUAGE}.attention_mask" + ] + + assert tokens_feature.shape == (64,) + assert attention_mask_feature.shape == (64,) + + +@require_package("transformers") +def test_features_existing_features(): + """Test feature contract when tokenized features already exist.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=256) + + input_features = { + PipelineFeatureType.OBSERVATION: { + f"{OBS_LANGUAGE}.tokens": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)), + f"{OBS_LANGUAGE}.attention_mask": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)), + } + } + + output_features = processor.transform_features(input_features) + + # Should not overwrite existing features + assert output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.tokens"].shape == ( + 100, + ) # Original shape preserved + assert output_features[PipelineFeatureType.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"].shape == (100,) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_tokenization_parameters(mock_auto_tokenizer): + """Test that tokenization parameters are correctly passed to tokenizer.""" + + # Create a custom mock that tracks calls + class TrackingMockTokenizer: + def __init__(self): + self.last_call_args = None + self.last_call_kwargs = None + + def __call__(self, *args, **kwargs): + self.last_call_args = args + self.last_call_kwargs = kwargs + # Return minimal valid output + return { + "input_ids": torch.zeros(16, dtype=torch.long), + "attention_mask": torch.ones(16, dtype=torch.long), + } + + tracking_tokenizer = TrackingMockTokenizer() + mock_auto_tokenizer.from_pretrained.return_value = tracking_tokenizer + + processor = TokenizerProcessorStep( + tokenizer_name="test-tokenizer", + max_length=16, + padding="longest", + truncation=False, + padding_side="left", + ) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "test task"}, + ) + + processor(transition) + + # Check that parameters were passed correctly (task is converted to list) + assert tracking_tokenizer.last_call_args == (["test task"],) + assert tracking_tokenizer.last_call_kwargs["max_length"] == 16 + assert tracking_tokenizer.last_call_kwargs["padding"] == "longest" + assert tracking_tokenizer.last_call_kwargs["padding_side"] == "left" + assert tracking_tokenizer.last_call_kwargs["truncation"] is False + assert tracking_tokenizer.last_call_kwargs["return_tensors"] == "pt" + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_preserves_other_complementary_data(mock_auto_tokenizer): + """Test that other complementary data fields are preserved.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer") + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={ + "task": "test task", + "episode_id": 123, + "timestamp": 456.789, + "other_field": {"nested": "data"}, + }, + ) + + result = processor(transition) + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + + # Check that all original fields are preserved + assert comp_data["task"] == "test task" + assert comp_data["episode_id"] == 123 + assert comp_data["timestamp"] == 456.789 + assert comp_data["other_field"] == {"nested": "data"} + + # Check that tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_deterministic_tokenization(mock_auto_tokenizer): + """Test that tokenization is deterministic for the same input.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "consistent test"}, + ) + + result1 = processor(transition) + result2 = processor(transition) + + tokens1 = result1[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask1 = result1[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + tokens2 = result2[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask2 = result2[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + # Results should be identical + assert torch.equal(tokens1, tokens2) + assert torch.equal(attention_mask1, attention_mask2) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_empty_string_task(mock_auto_tokenizer): + """Test handling of empty string task.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": ""}, + ) + + result = processor(transition) + + # Should still tokenize (mock tokenizer handles empty strings) + observation = result[TransitionKey.OBSERVATION] + assert f"{OBS_LANGUAGE}.tokens" in observation + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + assert tokens.shape == (8,) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_very_long_task(mock_auto_tokenizer): + """Test handling of very long task strings.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=5, truncation=True) + + long_task = " ".join(["word"] * 100) # Very long task + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": long_task}, + ) + + result = processor(transition) + + # Should be truncated to max_length + observation = result[TransitionKey.OBSERVATION] + tokens = observation[f"{OBS_LANGUAGE}.tokens"] + attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"] + assert tokens.shape == (5,) + assert attention_mask.shape == (5,) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_custom_padding_side(mock_auto_tokenizer): + """Test using custom padding_side parameter.""" + + # Create a mock tokenizer that tracks padding_side calls + class PaddingSideTrackingTokenizer: + def __init__(self): + self.padding_side_calls = [] + + def __call__( + self, + text, + max_length=512, + truncation=True, + padding="max_length", + padding_side="right", + return_tensors="pt", + **kwargs, + ): + self.padding_side_calls.append(padding_side) + # Return minimal valid output + return { + "input_ids": torch.zeros(max_length, dtype=torch.long), + "attention_mask": torch.ones(max_length, dtype=torch.long), + } + + tracking_tokenizer = PaddingSideTrackingTokenizer() + mock_auto_tokenizer.from_pretrained.return_value = tracking_tokenizer + + # Test left padding + processor_left = TokenizerProcessorStep( + tokenizer_name="test-tokenizer", max_length=10, padding_side="left" + ) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "test task"}, + ) + processor_left(transition) + + assert tracking_tokenizer.padding_side_calls[-1] == "left" + + # Test right padding (default) + processor_right = TokenizerProcessorStep( + tokenizer_name="test-tokenizer", max_length=10, padding_side="right" + ) + + processor_right(transition) + + assert tracking_tokenizer.padding_side_calls[-1] == "right" + + +@require_package("transformers") +def test_device_detection_cpu(): + """Test that tokenized tensors stay on CPU when other tensors are on CPU.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with CPU tensors + observation = {"observation.state": torch.randn(10)} # CPU tensor + action = torch.randn(5) # CPU tensor + transition = create_transition( + observation=observation, action=action, complementary_data={"task": "test task"} + ) + + result = processor(transition) + + # Check that tokenized tensors are on CPU + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + assert tokens.device.type == "cpu" + assert attention_mask.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@require_package("transformers") +def test_device_detection_cuda(): + """Test that tokenized tensors are moved to CUDA when other tensors are on CUDA.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with CUDA tensors + observation = {"observation.state": torch.randn(10).cuda()} # CUDA tensor + action = torch.randn(5).cuda() # CUDA tensor + transition = create_transition( + observation=observation, action=action, complementary_data={"task": "test task"} + ) + + result = processor(transition) + + # Check that tokenized tensors are on CUDA + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + assert tokens.device.type == "cuda" + assert attention_mask.device.type == "cuda" + assert tokens.device.index == 0 # Should be on same device as input + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +@require_package("transformers") +def test_device_detection_multi_gpu(): + """Test that tokenized tensors match device in multi-GPU setup.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Test with tensors on cuda:1 + device = torch.device("cuda:1") + observation = {"observation.state": torch.randn(10).to(device)} + action = torch.randn(5).to(device) + transition = create_transition( + observation=observation, action=action, complementary_data={"task": "multi gpu test"} + ) + + result = processor(transition) + + # Check that tokenized tensors are on cuda:1 + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + assert tokens.device == device + assert attention_mask.device == device + + +@require_package("transformers") +def test_device_detection_no_tensors(): + """Test that tokenized tensors stay on CPU when no other tensors exist.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with no tensors + transition = create_transition( + observation={"metadata": {"key": "value"}}, # No tensors + complementary_data={"task": "no tensor test"}, + ) + + result = processor(transition) + + # Check that tokenized tensors are on CPU (default) + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + assert tokens.device.type == "cpu" + assert attention_mask.device.type == "cpu" + + +@require_package("transformers") +def test_device_detection_mixed_devices(): + """Test device detection when tensors are on different devices (uses first found).""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + if torch.cuda.is_available(): + # Create transition with mixed devices + observation = { + "observation.cpu": torch.randn(10), # CPU + "observation.cuda": torch.randn(10).cuda(), # CUDA + } + transition = create_transition( + observation=observation, complementary_data={"task": "mixed device test"} + ) + + result = processor(transition) + + # The device detection should use the first tensor found + # (iteration order depends on dict, but result should be consistent) + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + # Both should be on the same device + assert tokens.device == attention_mask.device + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@require_package("transformers") +def test_device_detection_from_action(): + """Test that device is detected from action tensor when no observation tensors exist.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with action on CUDA but no observation tensors + observation = {"metadata": {"key": "value"}} # No tensors in observation + action = torch.randn(5).cuda() + transition = create_transition( + observation=observation, action=action, complementary_data={"task": "action device test"} + ) + + result = processor(transition) + + # Check that tokenized tensors match action's device + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + assert tokens.device.type == "cuda" + assert attention_mask.device.type == "cuda" + + +@require_package("transformers") +def test_device_detection_preserves_dtype(): + """Test that device detection doesn't affect dtype of tokenized tensors.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with float tensor (to test dtype isn't affected) + observation = {"observation.state": torch.randn(10, dtype=torch.float16)} + transition = create_transition(observation=observation, complementary_data={"task": "dtype test"}) + + result = processor(transition) + + # Check that tokenized tensors have correct dtypes (not affected by input dtype) + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + assert tokens.dtype == torch.long # Should remain long + assert attention_mask.dtype == torch.bool # Should be bool (converted in processor) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_integration_with_device_processor(mock_auto_tokenizer): + """Test that TokenizerProcessorStep works correctly with DeviceProcessorStep in pipeline.""" + from lerobot.processor import DeviceProcessorStep + + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + # Create pipeline with TokenizerProcessorStep then DeviceProcessorStep + tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6) + device_processor = DeviceProcessorStep(device="cuda:0") + robot_processor = DataProcessorPipeline( + [tokenizer_processor, device_processor], + to_transition=identity_transition, + to_output=identity_transition, + ) + + # Start with CPU tensors + transition = create_transition( + observation={"observation.state": torch.randn(10)}, # CPU + action=torch.randn(5), # CPU + complementary_data={"task": "pipeline test"}, + ) + + result = robot_processor(transition) + + # All tensors should end up on CUDA (moved by DeviceProcessorStep) + assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.ACTION].device.type == "cuda" + + # Tokenized tensors should also be on CUDA + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + assert tokens.device.type == "cuda" + assert attention_mask.device.type == "cuda" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@require_package("transformers") +def test_simulated_accelerate_scenario(): + """Test scenario simulating Accelerate with data already on GPU.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Simulate Accelerate scenario: batch already on GPU + device = torch.device("cuda:0") + observation = { + "observation.state": torch.randn(1, 10).to(device), # Batched, on GPU + "observation.image": torch.randn(1, 3, 224, 224).to(device), # Batched, on GPU + } + action = torch.randn(1, 5).to(device) # Batched, on GPU + + transition = create_transition( + observation=observation, + action=action, + complementary_data={"task": ["accelerate test"]}, # List for batched task + ) + + result = processor(transition) + + # Tokenized tensors should match GPU placement + tokens = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.tokens"] + attention_mask = result[TransitionKey.OBSERVATION][f"{OBS_LANGUAGE}.attention_mask"] + + assert tokens.device == device + assert attention_mask.device == device + # MockTokenizer squeezes single-item batches, so shape is (max_length,) not (1, max_length) + assert tokens.shape == (10,) # MockTokenizer behavior for single string in list + assert attention_mask.shape == (10,) diff --git a/tests/processor/test_vqbet_processor.py b/tests/processor/test_vqbet_processor.py new file mode 100644 index 000000000..98e05eae8 --- /dev/null +++ b/tests/processor/test_vqbet_processor.py @@ -0,0 +1,462 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for VQBeT policy processor.""" + +import tempfile + +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE +from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig +from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DataProcessorPipeline, + DeviceProcessorStep, + NormalizerProcessorStep, + RenameObservationsProcessorStep, + TransitionKey, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import create_transition, transition_to_batch + + +def create_default_config(): + """Create a default VQBeT configuration for testing.""" + config = VQBeTConfig() + config.input_features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + config.normalization_mapping = { + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.VISUAL: NormalizationMode.IDENTITY, + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + config.device = "cpu" + return config + + +def create_default_stats(): + """Create default dataset statistics for testing.""" + return { + OBS_STATE: {"mean": torch.zeros(8), "std": torch.ones(8)}, + OBS_IMAGE: {}, # No normalization for images + ACTION: {"min": torch.full((7,), -1.0), "max": torch.ones(7)}, + } + + +def test_make_vqbet_processor_basic(): + """Test basic creation of VQBeT processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Check processor names + assert preprocessor.name == "policy_preprocessor" + assert postprocessor.name == "policy_postprocessor" + + # Check steps in preprocessor + assert len(preprocessor.steps) == 4 + assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + assert isinstance(preprocessor.steps[3], NormalizerProcessorStep) + + # Check steps in postprocessor + assert len(postprocessor.steps) == 2 + assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) + assert isinstance(postprocessor.steps[1], DeviceProcessorStep) + + +def test_vqbet_processor_with_images(): + """Test VQBeT processor with image and state observations.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Create test data with images and states + observation = { + OBS_STATE: torch.randn(8), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(7) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is batched + assert processed[OBS_STATE].shape == (1, 8) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 7) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_vqbet_processor_cuda(): + """Test VQBeT processor with CUDA device.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Create CPU data + observation = { + OBS_STATE: torch.randn(8), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(7) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is on CUDA + assert processed[OBS_STATE].device.type == "cuda" + assert processed[OBS_IMAGE].device.type == "cuda" + assert processed[TransitionKey.ACTION.value].device.type == "cuda" + + # Process through postprocessor + postprocessed = postprocessor(processed[TransitionKey.ACTION.value]) + + # Check that action is back on CPU + assert postprocessed.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_vqbet_processor_accelerate_scenario(): + """Test VQBeT processor in simulated Accelerate scenario.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Simulate Accelerate: data already on GPU and batched + device = torch.device("cuda:0") + observation = { + OBS_STATE: torch.randn(1, 8).to(device), + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), + } + action = torch.randn(1, 7).to(device) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on same GPU + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") +def test_vqbet_processor_multi_gpu(): + """Test VQBeT processor with multi-GPU setup.""" + config = create_default_config() + config.device = "cuda:0" + stats = create_default_stats() + + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Simulate data on different GPU + device = torch.device("cuda:1") + observation = { + OBS_STATE: torch.randn(1, 8).to(device), + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), + } + action = torch.randn(1, 7).to(device) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data stays on cuda:1 + assert processed[OBS_STATE].device == device + assert processed[OBS_IMAGE].device == device + assert processed[TransitionKey.ACTION.value].device == device + + +def test_vqbet_processor_without_stats(): + """Test VQBeT processor creation without dataset statistics.""" + config = create_default_config() + + preprocessor, postprocessor = make_vqbet_pre_post_processors(config, dataset_stats=None) + + # Should still create processors + assert preprocessor is not None + assert postprocessor is not None + + # Process should still work + observation = { + OBS_STATE: torch.randn(8), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(7) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + assert processed is not None + + +def test_vqbet_processor_save_and_load(): + """Test saving and loading VQBeT processor.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + # Save preprocessor + preprocessor.save_pretrained(tmpdir) + + # Load preprocessor + loaded_preprocessor = DataProcessorPipeline.from_pretrained( + tmpdir, config_filename="policy_preprocessor.json" + ) + + # Test that loaded processor works + observation = { + OBS_STATE: torch.randn(8), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(7) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + processed = loaded_preprocessor(batch) + assert processed[OBS_STATE].shape == (1, 8) + assert processed[OBS_IMAGE].shape == (1, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (1, 7) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_vqbet_processor_mixed_precision(): + """Test VQBeT processor with mixed precision.""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + # Create processor + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Replace DeviceProcessorStep with one that uses float16 + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) + elif isinstance(step, NormalizerProcessorStep): + # Update normalizer to use the same device as the device processor + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float16, # Match the float16 dtype + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Create test data + observation = { + OBS_STATE: torch.randn(8, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(7, dtype=torch.float32) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that data is converted to float16 + assert processed[OBS_STATE].dtype == torch.float16 + assert processed[OBS_IMAGE].dtype == torch.float16 + assert processed[TransitionKey.ACTION.value].dtype == torch.float16 + + +def test_vqbet_processor_large_batch(): + """Test VQBeT processor with large batch sizes.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Test with large batch + batch_size = 128 + observation = { + OBS_STATE: torch.randn(batch_size, 8), + OBS_IMAGE: torch.randn(batch_size, 3, 224, 224), + } + action = torch.randn(batch_size, 7) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through preprocessor + + processed = preprocessor(batch) + + # Check that batch dimension is preserved + assert processed[OBS_STATE].shape == (batch_size, 8) + assert processed[OBS_IMAGE].shape == (batch_size, 3, 224, 224) + assert processed[TransitionKey.ACTION.value].shape == (batch_size, 7) + + +def test_vqbet_processor_sequential_processing(): + """Test VQBeT processor with sequential data processing.""" + config = create_default_config() + stats = create_default_stats() + + preprocessor, postprocessor = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Process multiple samples sequentially + results = [] + for _ in range(5): + observation = { + OBS_STATE: torch.randn(8), + OBS_IMAGE: torch.randn(3, 224, 224), + } + action = torch.randn(7) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + processed = preprocessor(batch) + results.append(processed) + + # Check that all results are consistent + for result in results: + assert result[OBS_STATE].shape == (1, 8) + assert result[OBS_IMAGE].shape == (1, 3, 224, 224) + assert result[TransitionKey.ACTION.value].shape == (1, 7) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_vqbet_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, _ = make_vqbet_pre_post_processors( + config, + stats, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration + normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data with both state and visual observations + observation = { + OBS_STATE: torch.randn(8, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(7, dtype=torch.float32) + transition = create_transition(observation, action) + + batch = transition_to_batch(transition) + + # Process through full pipeline + processed = preprocessor(batch) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[OBS_STATE].dtype == torch.bfloat16 + assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + # Check state stats (has normalization) + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 + # OBS_IMAGE uses IDENTITY normalization, so no stats to check diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py new file mode 100644 index 000000000..29b7bf70a --- /dev/null +++ b/tests/utils/test_visualization_utils.py @@ -0,0 +1,209 @@ +import importlib +import sys +from types import SimpleNamespace + +import numpy as np +import pytest + +from lerobot.processor import TransitionKey + + +@pytest.fixture +def mock_rerun(monkeypatch): + """ + Provide a mock `rerun` module so tests don't depend on the real library. + Also reload the module-under-test so it binds to this mock `rr`. + """ + calls = [] + + class DummyScalar: + def __init__(self, value): + self.value = float(value) + + class DummyImage: + def __init__(self, arr): + self.arr = arr + + def dummy_log(key, obj, **kwargs): + calls.append((key, obj, kwargs)) + + dummy_rr = SimpleNamespace( + Scalar=DummyScalar, + Image=DummyImage, + log=dummy_log, + init=lambda *a, **k: None, + spawn=lambda *a, **k: None, + ) + + # Inject fake module into sys.modules + monkeypatch.setitem(sys.modules, "rerun", dummy_rr) + + # Now import and reload the module under test, to bind to our rerun mock + import lerobot.utils.visualization_utils as vu + + importlib.reload(vu) + + # Expose both the reloaded module and the call recorder + yield vu, calls + + +def _keys(calls): + """Helper to extract just the keys logged to rr.log""" + return [k for (k, _obj, _kw) in calls] + + +def _obj_for(calls, key): + """Find the first object logged under a given key.""" + for k, obj, _kw in calls: + if k == key: + return obj + raise KeyError(f"Key {key} not found in calls: {calls}") + + +def _kwargs_for(calls, key): + for k, _obj, kw in calls: + if k == key: + return kw + raise KeyError(f"Key {key} not found in calls: {calls}") + + +def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): + vu, calls = mock_rerun + + # Build EnvTransition dict + obs = { + "observation.state.temperature": np.float32(25.0), + # CHW image should be converted to HWC for rr.Image + "observation.camera": np.zeros((3, 10, 20), dtype=np.uint8), + } + act = { + "action.throttle": 0.7, + # 1D array should log individual Scalars with suffix _i + "action.vector": np.array([1.0, 2.0], dtype=np.float32), + } + transition = { + TransitionKey.OBSERVATION: obs, + TransitionKey.ACTION: act, + } + + # Extract observation and action data from transition like in the real call sites + obs_data = transition.get(TransitionKey.OBSERVATION, {}) + action_data = transition.get(TransitionKey.ACTION, {}) + vu.log_rerun_data(observation=obs_data, action=action_data) + + # We expect: + # - observation.state.temperature -> Scalar + # - observation.camera -> Image (HWC) with static=True + # - action.throttle -> Scalar + # - action.vector_0, action.vector_1 -> Scalars + expected_keys = { + "observation.state.temperature", + "observation.camera", + "action.throttle", + "action.vector_0", + "action.vector_1", + } + assert set(_keys(calls)) == expected_keys + + # Check scalar types and values + temp_obj = _obj_for(calls, "observation.state.temperature") + assert type(temp_obj).__name__ == "DummyScalar" + assert temp_obj.value == pytest.approx(25.0) + + throttle_obj = _obj_for(calls, "action.throttle") + assert type(throttle_obj).__name__ == "DummyScalar" + assert throttle_obj.value == pytest.approx(0.7) + + v0 = _obj_for(calls, "action.vector_0") + v1 = _obj_for(calls, "action.vector_1") + assert type(v0).__name__ == "DummyScalar" + assert type(v1).__name__ == "DummyScalar" + assert v0.value == pytest.approx(1.0) + assert v1.value == pytest.approx(2.0) + + # Check image handling: CHW -> HWC + img_obj = _obj_for(calls, "observation.camera") + assert type(img_obj).__name__ == "DummyImage" + assert img_obj.arr.shape == (10, 20, 3) # transposed + assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images + + +def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun): + vu, calls = mock_rerun + + # First dict without prefixes treated as observation + # Second dict without prefixes treated as action + obs_plain = { + "temp": 1.5, + # Already HWC image => should stay as-is + "img": np.zeros((5, 6, 3), dtype=np.uint8), + "none": None, # should be skipped + } + act_plain = { + "throttle": 0.3, + "vec": np.array([9, 8, 7], dtype=np.float32), + } + + # Extract observation and action data from list like the old function logic did + # First dict was treated as observation, second as action + vu.log_rerun_data(observation=obs_plain, action=act_plain) + + # Expected keys with auto-prefixes + expected = { + "observation.temp", + "observation.img", + "action.throttle", + "action.vec_0", + "action.vec_1", + "action.vec_2", + } + logged = set(_keys(calls)) + assert logged == expected + + # Scalars + t = _obj_for(calls, "observation.temp") + assert type(t).__name__ == "DummyScalar" + assert t.value == pytest.approx(1.5) + + throttle = _obj_for(calls, "action.throttle") + assert type(throttle).__name__ == "DummyScalar" + assert throttle.value == pytest.approx(0.3) + + # Image stays HWC + img = _obj_for(calls, "observation.img") + assert type(img).__name__ == "DummyImage" + assert img.arr.shape == (5, 6, 3) + assert _kwargs_for(calls, "observation.img").get("static", False) is True + + # Vectors + for i, val in enumerate([9, 8, 7]): + o = _obj_for(calls, f"action.vec_{i}") + assert type(o).__name__ == "DummyScalar" + assert o.value == pytest.approx(val) + + +def test_log_rerun_data_kwargs_only(mock_rerun): + vu, calls = mock_rerun + + vu.log_rerun_data( + observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)}, + action={"action.a": 1.0}, + ) + + keys = set(_keys(calls)) + assert "observation.temp" in keys + assert "observation.gray" in keys + assert "action.a" in keys + + temp = _obj_for(calls, "observation.temp") + assert type(temp).__name__ == "DummyScalar" + assert temp.value == pytest.approx(10.0) + + img = _obj_for(calls, "observation.gray") + assert type(img).__name__ == "DummyImage" + assert img.arr.shape == (8, 8, 1) # remains HWC + assert _kwargs_for(calls, "observation.gray").get("static", False) is True + + a = _obj_for(calls, "action.a") + assert type(a).__name__ == "DummyScalar" + assert a.value == pytest.approx(1.0) From 1bc38be719be180d0f8fbf1232c0fe38162f218b Mon Sep 17 00:00:00 2001 From: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Date: Thu, 18 Sep 2025 18:33:34 +0200 Subject: [PATCH 090/158] small tiny nit (#1975) * small tiny nit Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --- examples/5_train_with_streaming.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/5_train_with_streaming.py b/examples/5_train_with_streaming.py index 93d13535f..80fee5883 100644 --- a/examples/5_train_with_streaming.py +++ b/examples/5_train_with_streaming.py @@ -51,9 +51,7 @@ def main(): training_steps = 10 log_freq = 1 - dataset_id = ( - "aractingi/droid_1.0.1" # 26M frames! Would require 4TB of disk space if installed locally (: - ) + dataset_id = "lerobot/droid_1.0.1" # 26M frames! Would require 4TB of disk space if installed locally (: dataset_metadata = LeRobotDatasetMetadata(dataset_id) features = dataset_to_policy_features(dataset_metadata.features) output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} From 5d1837d87e07440424140c7330d58be87e462ba0 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Thu, 18 Sep 2025 21:31:34 +0200 Subject: [PATCH 091/158] fix (docs): image link for phone (#1977) --- docs/source/phone_teleop.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/phone_teleop.mdx b/docs/source/phone_teleop.mdx index 71d5457fb..bab0ac28e 100644 --- a/docs/source/phone_teleop.mdx +++ b/docs/source/phone_teleop.mdx @@ -36,7 +36,7 @@ Links: - iOS: Analog input `A3` controls the gripper as velocity input. - Android: Buttons `A` and `B` act like increment/decrement (A opens, B closes). You can tune velocity in the `GripperVelocityToJoint` step. -Phone teleop orientation +Phone teleop orientation ### Step 1: Choose the platform From cc135d3c4a7c31e60252aa214b7ef25a59a15159 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Fri, 19 Sep 2025 11:04:13 +0200 Subject: [PATCH 092/158] bump gym-hil version to be pipeline compatible (#1983) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 70755cf9b..cbf555e3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,7 +121,7 @@ phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"] # Policies pi0 = ["lerobot[transformers-dep]"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"] -hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] +hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] # Features async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"] From d65668ff3c66f0d14802037991185f9f2e855684 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Fri, 19 Sep 2025 15:19:49 +0200 Subject: [PATCH 093/158] Add docs for LeRobot Image transforms (#1972) * Remove unused scripts, add docs for image transforms and add example * fix(examples): move train_policy.py under examples, remove outdated readme parts * remove script thats copied to train folder * remove outdated links to examples and example tests --- README.md | 37 +-- docs/source/lerobot-dataset-v3.mdx | 112 +++++++ examples/2_evaluate_pretrained_policy.py | 139 -------- examples/4_train_policy_with_script.md | 311 ------------------ .../load_lerobot_dataset.py} | 2 +- .../dataset/use_dataset_image_transforms.py | 177 ++++++++++ .../train_policy.py} | 6 +- .../train_with_streaming.py} | 6 +- src/lerobot/robots/stretch3/README.md | 4 - src/lerobot/robots/viperx/README.md | 2 - tests/examples/test_examples.py | 147 --------- 11 files changed, 293 insertions(+), 650 deletions(-) delete mode 100644 examples/2_evaluate_pretrained_policy.py delete mode 100644 examples/4_train_policy_with_script.md rename examples/{1_load_lerobot_dataset.py => dataset/load_lerobot_dataset.py} (99%) create mode 100644 examples/dataset/use_dataset_image_transforms.py rename examples/{3_train_policy.py => training/train_policy.py} (97%) rename examples/{5_train_with_streaming.py => training/train_with_streaming.py} (96%) delete mode 100644 tests/examples/test_examples.py diff --git a/README.md b/README.md index 9fd45a7b7..47b0d4518 100644 --- a/README.md +++ b/README.md @@ -279,42 +279,6 @@ A `LeRobotDataset` is serialised using several widespread file formats for each Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location. -### Evaluate a pretrained policy - -Check out [example 2](https://github.com/huggingface/lerobot/blob/main/examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment. - -We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht): - -```bash -lerobot-eval \ - --policy.path=lerobot/diffusion_pusht \ - --env.type=pusht \ - --eval.batch_size=10 \ - --eval.n_episodes=10 \ - --policy.use_amp=false \ - --policy.device=cuda -```` - -Note: After training your own policy, you can re-evaluate the checkpoints with: - -```bash -lerobot-eval --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model -``` - -See `lerobot-eval --help` for more instructions. - -### Train your own policy - -Check out [example 3](https://github.com/huggingface/lerobot/blob/main/examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md) that shows how to use our training script from command line. - -To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding `--wandb.enable=true`. - -A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs. - -\WandB logs example - -Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `lerobot-eval --help` for more instructions. - #### Reproduce state-of-the-art (SOTA) We provide some pretrained policies on our [hub page](https://huggingface.co/lerobot) that can achieve state-of-the-art performances. @@ -373,3 +337,4 @@ If you want, you can cite this work with: ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=huggingface/lerobot&type=Timeline)](https://star-history.com/#huggingface/lerobot&Timeline) +```` diff --git a/docs/source/lerobot-dataset-v3.mdx b/docs/source/lerobot-dataset-v3.mdx index 4f33d9a25..09fb17fad 100644 --- a/docs/source/lerobot-dataset-v3.mdx +++ b/docs/source/lerobot-dataset-v3.mdx @@ -8,6 +8,7 @@ This docs will guide you to: - Record a dataset and push it to the Hub - Load datasets for training with `LeRobotDataset` - Stream datasets without downloading using `StreamingLeRobotDataset` +- Apply image transforms for data augmentation during training - Migrate existing `v2.1` datasets to `v3.0` ## What’s new in `v3` @@ -150,6 +151,117 @@ dataset = StreamingLeRobotDataset(repo_id) # streams directly from the Hub

+## Image transforms + +Image transforms are data augmentations applied to camera frames during training to improve model robustness and generalization. LeRobot supports various transforms including brightness, contrast, saturation, hue, and sharpness adjustments. + +### Using transforms during dataset creation/recording + +Currently, transforms are applied during **training time only**, not during recording. When you create or record a dataset, the raw images are stored without transforms. This allows you to experiment with different augmentations later without re-recording data. + +### Adding transforms to existing datasets (API) + +Use the `image_transforms` parameter when loading a dataset for training: + +```python +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.transforms import ImageTransforms, ImageTransformsConfig, ImageTransformConfig + +# Option 1: Use default transform configuration (disabled by default) +transforms_config = ImageTransformsConfig( + enable=True, # Enable transforms + max_num_transforms=3, # Apply up to 3 transforms per frame + random_order=False, # Apply in standard order +) +transforms = ImageTransforms(transforms_config) + +dataset = LeRobotDataset( + repo_id="your-username/your-dataset", + image_transforms=transforms +) + +# Option 2: Create custom transform configuration +custom_transforms_config = ImageTransformsConfig( + enable=True, + max_num_transforms=2, + random_order=True, + tfs={ + "brightness": ImageTransformConfig( + weight=1.0, + type="ColorJitter", + kwargs={"brightness": (0.7, 1.3)} # Adjust brightness range + ), + "contrast": ImageTransformConfig( + weight=2.0, # Higher weight = more likely to be selected + type="ColorJitter", + kwargs={"contrast": (0.8, 1.2)} + ), + "sharpness": ImageTransformConfig( + weight=0.5, # Lower weight = less likely to be selected + type="SharpnessJitter", + kwargs={"sharpness": (0.3, 2.0)} + ), + } +) + +dataset = LeRobotDataset( + repo_id="your-username/your-dataset", + image_transforms=ImageTransforms(custom_transforms_config) +) + +# Option 3: Use pure torchvision transforms +from torchvision.transforms import v2 + +torchvision_transforms = v2.Compose([ + v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), + v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), +]) + +dataset = LeRobotDataset( + repo_id="your-username/your-dataset", + image_transforms=torchvision_transforms +) +``` + +### Available transform types + +LeRobot provides several transform types: + +- **`ColorJitter`**: Adjusts brightness, contrast, saturation, and hue +- **`SharpnessJitter`**: Randomly adjusts image sharpness +- **`Identity`**: No transformation (useful for testing) + +You can also use any `torchvision.transforms.v2` transform by passing it directly to the `image_transforms` parameter. + +### Configuration options + +- **`enable`**: Enable/disable transforms (default: `False`) +- **`max_num_transforms`**: Maximum number of transforms applied per frame (default: `3`) +- **`random_order`**: Apply transforms in random order vs. standard order (default: `False`) +- **`weight`**: Sampling probability for each transform (higher = more likely, if sum of weights is not 1, they will be normalized) +- **`kwargs`**: Transform-specific parameters (e.g., brightness range) + +### Visualizing transforms + +Use the visualization script to preview how transforms affect your data: + +```bash +python -m lerobot.scripts.visualize_image_transforms \ + --repo-id=your-username/your-dataset \ + --output-dir=./transform_examples \ + --n-examples=5 +``` + +This saves example images showing the effect of each transform, helping you tune parameters. + +### Best practices + +- **Start conservative**: Begin with small ranges (e.g., brightness 0.9-1.1) and increase gradually +- **Test first**: Use the visualization script to ensure transforms look reasonable +- **Monitor training**: Strong augmentations can hurt performance if too aggressive +- **Match your domain**: If your robot operates in varying lighting, use brightness/contrast transforms +- **Combine wisely**: Using too many transforms simultaneously can make training unstable + ## Migrate `v2.1` → `v3.0` A converter aggregates per‑episode files into larger shards and writes episode offsets/metadata. Convert your dataset using the instructions below. diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py deleted file mode 100644 index c0c7845e8..000000000 --- a/examples/2_evaluate_pretrained_policy.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local -training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first. - -It requires the installation of the 'gym_pusht' simulation environment. Install it by running: -```bash -pip install -e ".[pusht]" -``` -""" - -from pathlib import Path - -import gym_pusht # noqa: F401 -import gymnasium as gym -import imageio -import numpy -import torch - -from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy - -# Create a directory to store the video of the evaluation -output_directory = Path("outputs/eval/example_pusht_diffusion") -output_directory.mkdir(parents=True, exist_ok=True) - -# Select your device -device = "cuda" - -# Provide the [hugging face repo id](https://huggingface.co/lerobot/diffusion_pusht): -pretrained_policy_path = "lerobot/diffusion_pusht" -# OR a path to a local outputs/train folder. -# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion") - -policy = DiffusionPolicy.from_pretrained(pretrained_policy_path) - -# Initialize evaluation environment to render two observation types: -# an image of the scene and state/position of the agent. The environment -# also automatically stops running after 300 interactions/steps. -env = gym.make( - "gym_pusht/PushT-v0", - obs_type="pixels_agent_pos", - max_episode_steps=300, -) - -# We can verify that the shapes of the features expected by the policy match the ones from the observations -# produced by the environment -print(policy.config.input_features) -print(env.observation_space) - -# Similarly, we can check that the actions produced by the policy will match the actions expected by the -# environment -print(policy.config.output_features) -print(env.action_space) - -# Reset the policy and environments to prepare for rollout -policy.reset() -numpy_observation, info = env.reset(seed=42) - -# Prepare to collect every rewards and all the frames of the episode, -# from initial state to final state. -rewards = [] -frames = [] - -# Render frame of the initial state -frames.append(env.render()) - -step = 0 -done = False -while not done: - # Prepare observation for the policy running in Pytorch - state = torch.from_numpy(numpy_observation["agent_pos"]) - image = torch.from_numpy(numpy_observation["pixels"]) - - # Convert to float32 with image from channel first in [0,255] - # to channel last in [0,1] - state = state.to(torch.float32) - image = image.to(torch.float32) / 255 - image = image.permute(2, 0, 1) - - # Send data tensors from CPU to GPU - state = state.to(device, non_blocking=True) - image = image.to(device, non_blocking=True) - - # Add extra (empty) batch dimension, required to forward the policy - state = state.unsqueeze(0) - image = image.unsqueeze(0) - - # Create the policy input dictionary - observation = { - "observation.state": state, - "observation.image": image, - } - - # Predict the next action with respect to the current observation - with torch.inference_mode(): - action = policy.select_action(observation) - - # Prepare the action for the environment - numpy_action = action.squeeze(0).to("cpu").numpy() - - # Step through the environment and receive a new observation - numpy_observation, reward, terminated, truncated, info = env.step(numpy_action) - print(f"{step=} {reward=} {terminated=}") - - # Keep track of all the rewards and frames - rewards.append(reward) - frames.append(env.render()) - - # The rollout is considered done when the success state is reached (i.e. terminated is True), - # or the maximum number of iterations is reached (i.e. truncated is True) - done = terminated | truncated | done - step += 1 - -if terminated: - print("Success!") -else: - print("Failure!") - -# Get the speed of environment (i.e. its number of frames per second). -fps = env.metadata["render_fps"] - -# Encode all frames into a mp4 video. -video_path = output_directory / "rollout.mp4" -imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps) - -print(f"Video of the evaluation is available in '{video_path}'.") diff --git a/examples/4_train_policy_with_script.md b/examples/4_train_policy_with_script.md deleted file mode 100644 index ffa7de66e..000000000 --- a/examples/4_train_policy_with_script.md +++ /dev/null @@ -1,311 +0,0 @@ -This tutorial will explain the training script, how to use it, and particularly how to configure everything needed for the training run. - -> **Note:** The following assumes you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--policy.device=cpu` (`--policy.device=mps` respectively). However, be advised that the code executes much slower on cpu. - -## The training script - -LeRobot offers a training script at [`lerobot/scripts/train.py`](../src/lerobot/scripts/train.py). At a high level it does the following: - -- Initialize/load a configuration for the following steps using. -- Instantiates a dataset. -- (Optional) Instantiates a simulation environment corresponding to that dataset. -- Instantiates a policy. -- Runs a standard training loop with forward pass, backward pass, optimization step, and occasional logging, evaluation (of the policy on the environment), and checkpointing. - -## Overview of the configuration system - -In the training script, the main function `train` expects a `TrainPipelineConfig` object: - - -```python -# train.py -@parser.wrap() -def train(cfg: TrainPipelineConfig): -``` - - -You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../src/lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option) - -When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated to this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.) - -Let's have a look at a simplified example. Amongst other attributes, the training config has the following attributes: - - -```python -@dataclass -class TrainPipelineConfig: - dataset: DatasetConfig - env: envs.EnvConfig | None = None - policy: PreTrainedConfig | None = None -``` - - -in which `DatasetConfig` for example is defined as such: - - -```python -@dataclass -class DatasetConfig: - repo_id: str - episodes: list[int] | None = None - video_backend: str = "pyav" -``` - - -This creates a hierarchical relationship where, for example assuming we have a `cfg` instance of `TrainPipelineConfig`, we can access the `repo_id` value with `cfg.dataset.repo_id`. -From the command line, we can specify this value by using a very similar syntax `--dataset.repo_id=repo/id`. - -By default, every field takes its default value specified in the dataclass. If a field doesn't have a default value, it needs to be specified either from the command line or from a config file – which path is also given in the command line (more in this below). In the example above, the `dataset` field doesn't have a default value which means it must be specified. - -## Specifying values from the CLI - -Let's say that we want to train [Diffusion Policy](../src/lerobot/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this: - -```bash -lerobot-train \ - --dataset.repo_id=lerobot/pusht \ - --policy.type=diffusion \ - --env.type=pusht -``` - -Let's break this down: - -- To specify the dataset, we just need to specify its `repo_id` on the hub which is the only required argument in the `DatasetConfig`. The rest of the fields have default values and in this case we are fine with those so we can just add the option `--dataset.repo_id=lerobot/pusht`. -- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/policies](../src/lerobot/policies) -- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/envs/configs.py`](../src/lerobot/envs/configs.py) - -Let's see another example. Let's say you've been training [ACT](../src/lerobot/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with: - -```bash -lerobot-train \ - --policy.type=act \ - --dataset.repo_id=lerobot/aloha_sim_insertion_human \ - --env.type=aloha \ - --output_dir=outputs/train/act_aloha_insertion -``` - -> Notice we added `--output_dir` to explicitly tell where to write outputs from this run (checkpoints, training state, configs etc.). This is not mandatory and if you don't specify it, a default directory will be created from the current date and time, env.type and policy.type. This will typically look like `outputs/train/2025-01-24/16-10-05_aloha_act`. - -We now want to train a different policy for aloha on another task. We'll change the dataset and use [lerobot/aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) instead. Of course, we also need to change the task of the environment as well to match this other task. -Looking at the [`AlohaEnv`](../src/lerobot/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using: - -```bash -lerobot-train \ - --policy.type=act \ - --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \ - --env.type=aloha \ - --env.task=AlohaTransferCube-v0 \ - --output_dir=outputs/train/act_aloha_transfer -``` - -## Loading from a config file - -Now, let's assume that we want to reproduce the run just above. That run has produced a `train_config.json` file in its checkpoints, which serializes the `TrainPipelineConfig` instance it used: - -```json -{ - "dataset": { - "repo_id": "lerobot/aloha_sim_transfer_cube_human", - "episodes": null, - ... - }, - "env": { - "type": "aloha", - "task": "AlohaTransferCube-v0", - "fps": 50, - ... - }, - "policy": { - "type": "act", - "n_obs_steps": 1, - ... - }, - ... -} -``` - -We can then simply load the config values from this file using: - -```bash -lerobot-train \ - --config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \ - --output_dir=outputs/train/act_aloha_transfer_2 -``` - -`--config_path` is also a special argument which allows to initialize the config from a local config file. It can point to a directory that contains `train_config.json` or to the config file itself directly. - -Similarly to Hydra, we can still override some parameters in the CLI if we want to, e.g.: - -```bash -lerobot-train \ - --config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \ - --output_dir=outputs/train/act_aloha_transfer_2 - --policy.n_action_steps=80 -``` - -> Note: While `--output_dir` is not required in general, in this case we need to specify it since it will otherwise take the value from the `train_config.json` (which is `outputs/train/act_aloha_transfer`). In order to prevent accidental deletion of previous run checkpoints, we raise an error if you're trying to write in an existing directory. This is not the case when resuming a run, which is what you'll learn next. - -`--config_path` can also accept the repo_id of a repo on the hub that contains a `train_config.json` file, e.g. running: - -```bash -lerobot-train --config_path=lerobot/diffusion_pusht -``` - -will start a training run with the same configuration used for training [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht) - -## Resume training - -Being able to resume a training run is important in case it crashed or aborted for any reason. We'll demonstrate how to do that here. - -Let's reuse the command from the previous run and add a few more options: - -```bash -lerobot-train \ - --policy.type=act \ - --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \ - --env.type=aloha \ - --env.task=AlohaTransferCube-v0 \ - --log_freq=25 \ - --save_freq=100 \ - --output_dir=outputs/train/run_resumption -``` - -Here we've taken care to set up the log frequency and checkpointing frequency to low numbers so we can showcase resumption. You should be able to see some logging and have a first checkpoint within 1 minute (depending on hardware). Wait for the first checkpoint to happen, you should see a line that looks like this in your terminal: - -``` -INFO 2025-01-24 16:10:56 ts/train.py:263 Checkpoint policy after step 100 -``` - -Now let's simulate a crash by killing the process (hit `ctrl`+`c`). We can then simply resume this run from the last checkpoint available with: - -```bash -lerobot-train \ - --config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \ - --resume=true -``` - -You should see from the logging that your training picks up from where it left off. - -Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--steps`, which is 100 000 by default. -You could double the number of steps of the previous run with: - -```bash -lerobot-train \ - --config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \ - --resume=true \ - --steps=200000 -``` - -## Outputs of a run - -In the output directory, there will be a folder called `checkpoints` with the following structure: - -```bash -outputs/train/run_resumption/checkpoints -├── 000100 # checkpoint_dir for training step 100 -│ ├── pretrained_model/ -│ │ ├── config.json # policy config -│ │ ├── model.safetensors # policy weights -│ │ └── train_config.json # train config -│ └── training_state/ -│ ├── optimizer_param_groups.json # optimizer param groups -│ ├── optimizer_state.safetensors # optimizer state -│ ├── rng_state.safetensors # rng states -│ ├── scheduler_state.json # scheduler state -│ └── training_step.json # training step -├── 000200 -└── last -> 000200 # symlink to the last available checkpoint -``` - -## Fine-tuning a pre-trained policy - -In addition to the features currently in Draccus, we've added a special `.path` argument for the policy, which allows to load a policy as you would with `PreTrainedPolicy.from_pretrained()`. In that case, `path` can be a local directory that contains a checkpoint or a repo_id pointing to a pretrained policy on the hub. - -For example, we could fine-tune a [policy pre-trained on the aloha transfer task](https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human) on the aloha insertion task. We can achieve this with: - -```bash -lerobot-train \ - --policy.path=lerobot/act_aloha_sim_transfer_cube_human \ - --dataset.repo_id=lerobot/aloha_sim_insertion_human \ - --env.type=aloha \ - --env.task=AlohaInsertion-v0 -``` - -When doing so, keep in mind that the features of the fine-tuning dataset would have to match the input/output features of the pretrained policy. - -## Typical logs and metrics - -When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you configured your run correctly. The final configuration will also be saved with the checkpoint. - -After that, you will see training log like this one: - -``` -INFO 2024-08-14 13:35:12 ts/train.py:192 step:0 smpl:64 ep:1 epch:0.00 loss:1.112 grdn:15.387 lr:2.0e-07 updt_s:1.738 data_s:4.774 -``` - -or evaluation log: - -``` -INFO 2024-08-14 13:38:45 ts/train.py:226 step:100 smpl:6K ep:52 epch:0.25 ∑rwrd:20.693 success:0.0% eval_s:120.266 -``` - -These logs will also be saved in wandb if `wandb.enable` is set to `true`. Here are the meaning of some abbreviations: - -- `smpl`: number of samples seen during training. -- `ep`: number of episodes seen during training. An episode contains multiple samples in a complete manipulation task. -- `epch`: number of time all unique samples are seen (epoch). -- `grdn`: gradient norm. -- `∑rwrd`: compute the sum of rewards in every evaluation episode and then take an average of them. -- `success`: average success rate of eval episodes. Reward and success are usually different except for the sparsing reward setting, where reward=1 only when the task is completed successfully. -- `eval_s`: time to evaluate the policy in the environment, in second. -- `updt_s`: time to update the network parameters, in second. -- `data_s`: time to load a batch of data, in second. - -Some metrics are useful for initial performance profiling. For example, if you find the current GPU utilization is low via the `nvidia-smi` command and `data_s` sometimes is too high, you may need to modify batch size or number of dataloading workers to accelerate dataloading. We also recommend [pytorch profiler](https://github.com/huggingface/lerobot?tab=readme-ov-file#improve-your-code-with-profiling) for detailed performance probing. - -## In short - -We'll summarize here the main use cases to remember from this tutorial. - -#### Train a policy from scratch – CLI - -```bash -lerobot-train \ - --policy.type=act \ # <- select 'act' policy - --env.type=pusht \ # <- select 'pusht' environment - --dataset.repo_id=lerobot/pusht # <- train on this dataset -``` - -#### Train a policy from scratch - config file + CLI - -```bash -lerobot-train \ - --config_path=path/to/pretrained_model \ # <- can also be a repo_id - --policy.n_action_steps=80 # <- you may still override values -``` - -#### Resume/continue a training run - -```bash -lerobot-train \ - --config_path=checkpoint/pretrained_model/ \ - --resume=true \ - --steps=200000 # <- you can change some training parameters -``` - -#### Fine-tuning - -```bash -lerobot-train \ - --policy.path=lerobot/act_aloha_sim_transfer_cube_human \ # <- can also be a local path to a checkpoint - --dataset.repo_id=lerobot/aloha_sim_insertion_human \ - --env.type=aloha \ - --env.task=AlohaInsertion-v0 -``` - ---- - -Now that you know the basics of how to train a policy, you might want to know how to apply this knowledge to actual robots, or how to record your own datasets and train policies on your specific task? -If that's the case, head over to the next tutorial [`7_get_started_with_real_robot.md`](./7_get_started_with_real_robot.md). - -Or in the meantime, happy training! 🤗 diff --git a/examples/1_load_lerobot_dataset.py b/examples/dataset/load_lerobot_dataset.py similarity index 99% rename from examples/1_load_lerobot_dataset.py rename to examples/dataset/load_lerobot_dataset.py index ac4a843c7..a96c170cf 100644 --- a/examples/1_load_lerobot_dataset.py +++ b/examples/dataset/load_lerobot_dataset.py @@ -136,7 +136,7 @@ print(f"{dataset[0]['action'].shape=}\n") # (64, c) # PyTorch datasets. dataloader = torch.utils.data.DataLoader( dataset, - num_workers=0, + num_workers=4, batch_size=32, shuffle=True, ) diff --git a/examples/dataset/use_dataset_image_transforms.py b/examples/dataset/use_dataset_image_transforms.py new file mode 100644 index 000000000..c28f2ef0c --- /dev/null +++ b/examples/dataset/use_dataset_image_transforms.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This example demonstrates how to use image transforms with LeRobot datasets for data augmentation during training. + +Image transforms are applied to camera frames to improve model robustness and generalization. They are applied +at training time only, not during dataset recording, allowing you to experiment with different augmentations +without re-recording data. +""" + +import torch +from torchvision.transforms import v2 +from torchvision.transforms.functional import to_pil_image + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.transforms import ImageTransformConfig, ImageTransforms, ImageTransformsConfig + + +def save_image(tensor, filename): + """Helper function to save a tensor as an image file.""" + if tensor.dim() == 3: # [C, H, W] + if tensor.max() > 1.0: + tensor = tensor / 255.0 + tensor = torch.clamp(tensor, 0.0, 1.0) + pil_image = to_pil_image(tensor) + pil_image.save(filename) + print(f"Saved: {filename}") + else: + print(f"Skipped {filename}: unexpected tensor shape {tensor.shape}") + + +def example_1_default_transforms(): + """Example 1: Use default transform configuration and save original vs transformed images""" + print("\n Example 1: Default Transform Configuration with Image Saving") + + repo_id = "pepijn223/record_main_0" # Example dataset + + try: + # Load dataset without transforms (original) + dataset_original = LeRobotDataset(repo_id=repo_id) + + # Load dataset with transforms enabled + transforms_config = ImageTransformsConfig( + enable=True, # Enable transforms (disabled by default) + max_num_transforms=2, # Apply up to 2 transforms per frame + random_order=False, # Apply in standard order + ) + dataset_with_transforms = LeRobotDataset( + repo_id=repo_id, image_transforms=ImageTransforms(transforms_config) + ) + + # Save original and transformed images for comparison + if len(dataset_original) > 0: + frame_idx = 0 # Use first frame + original_sample = dataset_original[frame_idx] + transformed_sample = dataset_with_transforms[frame_idx] + + print(f"Saving comparison images (frame {frame_idx}):") + + for cam_key in dataset_original.meta.camera_keys: + if cam_key in original_sample and cam_key in transformed_sample: + cam_name = cam_key.replace(".", "_").replace("/", "_") + + # Save original and transformed images + save_image(original_sample[cam_key], f"{cam_name}_original.png") + save_image(transformed_sample[cam_key], f"{cam_name}_transformed.png") + + except Exception as e: + print(f"Could not load dataset '{repo_id}': {e}") + + +def example_2_custom_transforms(): + """Example 2: Create custom transform configuration and save examples""" + print("\n Example 2: Custom Transform Configuration") + + repo_id = "pepijn223/record_main_0" # Example dataset + + try: + # Create custom transform configuration with strong effects + custom_transforms_config = ImageTransformsConfig( + enable=True, + max_num_transforms=2, # Apply up to 2 transforms per frame + random_order=True, # Apply transforms in random order + tfs={ + "brightness": ImageTransformConfig( + weight=1.0, + type="ColorJitter", + kwargs={"brightness": (0.5, 1.5)}, # Strong brightness range + ), + "contrast": ImageTransformConfig( + weight=1.0, # Higher weight = more likely to be selected + type="ColorJitter", + kwargs={"contrast": (0.6, 1.4)}, # Strong contrast + ), + "sharpness": ImageTransformConfig( + weight=0.5, # Lower weight = less likely to be selected + type="SharpnessJitter", + kwargs={"sharpness": (0.2, 2.0)}, # Strong sharpness variation + ), + }, + ) + + dataset_with_custom_transforms = LeRobotDataset( + repo_id=repo_id, image_transforms=ImageTransforms(custom_transforms_config) + ) + + # Save examples with strong transforms + if len(dataset_with_custom_transforms) > 0: + sample = dataset_with_custom_transforms[0] + print("Saving custom transform examples:") + + for cam_key in dataset_with_custom_transforms.meta.camera_keys: + if cam_key in sample: + cam_name = cam_key.replace(".", "_").replace("/", "_") + save_image(sample[cam_key], f"{cam_name}_custom_transforms.png") + + except Exception as e: + print(f"Could not load dataset '{repo_id}': {e}") + + +def example_3_torchvision_transforms(): + """Example 3: Use pure torchvision transforms and save examples""" + print("\n Example 3: Pure Torchvision Transforms") + + repo_id = "pepijn223/record_main_0" # Example dataset + + try: + # Create torchvision transform pipeline + torchvision_transforms = v2.Compose( + [ + v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), + v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), + v2.RandomRotation(degrees=10), # Small rotation + ] + ) + + dataset_with_torchvision = LeRobotDataset(repo_id=repo_id, image_transforms=torchvision_transforms) + + # Save examples with torchvision transforms + if len(dataset_with_torchvision) > 0: + sample = dataset_with_torchvision[0] + print("Saving torchvision transform examples:") + + for cam_key in dataset_with_torchvision.meta.camera_keys: + if cam_key in sample: + cam_name = cam_key.replace(".", "_").replace("/", "_") + save_image(sample[cam_key], f"{cam_name}_torchvision.png") + + except Exception as e: + print(f"Could not load dataset '{repo_id}': {e}") + + +def main(): + """Run all examples""" + print("LeRobot Dataset Image Transforms Examples") + + example_1_default_transforms() + example_2_custom_transforms() + example_3_torchvision_transforms() + + +if __name__ == "__main__": + main() diff --git a/examples/3_train_policy.py b/examples/training/train_policy.py similarity index 97% rename from examples/3_train_policy.py rename to examples/training/train_policy.py index 7f3fad36c..16f2a4d87 100644 --- a/examples/3_train_policy.py +++ b/examples/training/train_policy.py @@ -12,11 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This script demonstrates how to train Diffusion Policy on the PushT environment. - -Once you have trained a model with this script, you can try to evaluate it on -examples/2_evaluate_pretrained_policy.py -""" +"""This script demonstrates how to train Diffusion Policy on the PushT environment.""" from pathlib import Path diff --git a/examples/5_train_with_streaming.py b/examples/training/train_with_streaming.py similarity index 96% rename from examples/5_train_with_streaming.py rename to examples/training/train_with_streaming.py index 80fee5883..e7edc17f8 100644 --- a/examples/5_train_with_streaming.py +++ b/examples/training/train_with_streaming.py @@ -13,11 +13,7 @@ # limitations under the License. """This script demonstrates how to train a Diffusion Policy on the PushT environment, -using a dataset processed in streaming mode. - -Once you have trained a model with this script, you can try to evaluate it on -examples/2_evaluate_pretrained_policy.py -""" +using a dataset processed in streaming mode.""" from pathlib import Path diff --git a/src/lerobot/robots/stretch3/README.md b/src/lerobot/robots/stretch3/README.md index 724732286..027f12d65 100644 --- a/src/lerobot/robots/stretch3/README.md +++ b/src/lerobot/robots/stretch3/README.md @@ -170,8 +170,4 @@ python lerobot/scripts/control_robot.py \ --control.episode=0 ``` -Follow [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) to train a policy on your data and run inference on your robot. You will need to adapt the code for Stretch. - -> TODO(rcadene, aliberts): Add already setup environment and policy yaml configuration files - If you need help, please reach out on Discord in the channel `#stretch3-mobile-arm`. diff --git a/src/lerobot/robots/viperx/README.md b/src/lerobot/robots/viperx/README.md index 5b57d61f5..f6386215a 100644 --- a/src/lerobot/robots/viperx/README.md +++ b/src/lerobot/robots/viperx/README.md @@ -193,6 +193,4 @@ As you can see, it's almost the same command as previously used to record your t ## More -Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explanation. - If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`. diff --git a/tests/examples/test_examples.py b/tests/examples/test_examples.py deleted file mode 100644 index aabec69a6..000000000 --- a/tests/examples/test_examples.py +++ /dev/null @@ -1,147 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import io -import subprocess -import sys -from pathlib import Path - -import pytest - -from tests.fixtures.constants import DUMMY_REPO_ID -from tests.utils import require_package - - -def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str: - for f, r in finds_and_replaces: - assert f in text - text = text.replace(f, r) - return text - - -# TODO(aliberts): Remove usage of subprocess calls and patch code with fixtures -def _run_script(path): - subprocess.run([sys.executable, path], check=True) - - -def _read_file(path): - with open(path) as file: - return file.read() - - -@pytest.mark.skip("TODO Fix and remove subprocess / excec calls") -def test_example_1(tmp_path, lerobot_dataset_factory): - _ = lerobot_dataset_factory(root=tmp_path, repo_id=DUMMY_REPO_ID) - path = "examples/1_load_lerobot_dataset.py" - file_contents = _read_file(path) - file_contents = _find_and_replace( - file_contents, - [ - ('repo_id = "lerobot/pusht"', f'repo_id = "{DUMMY_REPO_ID}"'), - ( - "LeRobotDataset(repo_id", - f"LeRobotDataset(repo_id, root='{str(tmp_path)}'", - ), - ], - ) - exec(file_contents, {}) - assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists() - - -@pytest.mark.skip("TODO Fix and remove subprocess / excec calls") -@require_package("gym_pusht") -def test_examples_basic2_basic3_advanced1(): - """ - Train a model with example 3, check the outputs. - Evaluate the trained model with example 2, check the outputs. - Calculate the validation loss with advanced example 1, check the outputs. - """ - - ### Test example 3 - file_contents = _read_file("examples/3_train_policy.py") - - # Do fewer steps, use smaller batch, use CPU, and don't complicate things with dataloader workers. - file_contents = _find_and_replace( - file_contents, - [ - ("training_steps = 5000", "training_steps = 1"), - ("num_workers=4", "num_workers=0"), - ('device = torch.device("cuda")', 'device = torch.device("cpu")'), - ("batch_size=64", "batch_size=1"), - ], - ) - - # Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249. - exec(file_contents, {}) - - for file_name in ["model.safetensors", "config.json"]: - assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() - - ### Test example 2 - file_contents = _read_file("examples/2_evaluate_pretrained_policy.py") - - # Do fewer evals, use CPU, and use the local model. - file_contents = _find_and_replace( - file_contents, - [ - ( - 'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', - "", - ), - ( - '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', - 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', - ), - ('device = torch.device("cuda")', 'device = torch.device("cpu")'), - ("step += 1", "break"), - ], - ) - - exec(file_contents, {}) - - assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists() - - ## Test example 4 - file_contents = _read_file("examples/advanced/2_calculate_validation_loss.py") - - # Run on a single example from the last episode, use CPU, and use the local model. - file_contents = _find_and_replace( - file_contents, - [ - ( - 'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', - "", - ), - ( - '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', - 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', - ), - ("train_episodes = episodes[:num_train_episodes]", "train_episodes = [0]"), - ("val_episodes = episodes[num_train_episodes:]", "val_episodes = [1]"), - ("num_workers=4", "num_workers=0"), - ('device = torch.device("cuda")', 'device = torch.device("cpu")'), - ("batch_size=64", "batch_size=1"), - ], - ) - - # Capture the output of the script - output_buffer = io.StringIO() - sys.stdout = output_buffer - exec(file_contents, {}) - printed_output = output_buffer.getvalue() - # Restore stdout to its original state - sys.stdout = sys.__stdout__ - assert "Average loss on validation set" in printed_output From 62d6169d2f38cd5ad4f944ad093e5623a5ff5d88 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Fri, 19 Sep 2025 20:21:23 +0200 Subject: [PATCH 094/158] fix formatting readme (#1987) --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 47b0d4518..39a3e8bcb 100644 --- a/README.md +++ b/README.md @@ -233,7 +233,7 @@ Under the hood, the `LeRobotDataset` format makes use of several ways to seriali Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects: -```` +``` dataset attributes: ├ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example: │ ├ observation.images.cam_high (VideoFrame): @@ -269,7 +269,7 @@ dataset attributes: ├ root (Path): local directory where the dataset is stored ├ image_transforms (Callable): optional image transformations to apply to visual modalities └ delta_timestamps (dict): optional delta timestamps for temporal queries -decoding videos (e.g., 'pyav', 'torchcodec') +``` A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely: @@ -337,4 +337,7 @@ If you want, you can cite this work with: ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=huggingface/lerobot&type=Timeline)](https://star-history.com/#huggingface/lerobot&Timeline) -```` + +``` + +``` From ce3670a20e40797cc5fe0d346c322a2c73287142 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 22 Sep 2025 10:19:45 +0200 Subject: [PATCH 095/158] bump datasets to 4.0.0 (#1990) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cbf555e3b..c42ee5080 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici dependencies = [ # Hugging Face dependencies - "datasets>=2.19.0,<=3.6.0", # TODO: Bumb dependency + "datasets>=4.0.0", "diffusers>=0.27.2", "huggingface-hub[hf-transfer,cli]>=0.34.2", From f7283193ea9ae932423e3a1e27524a27fa5c0fe5 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 22 Sep 2025 11:26:30 +0200 Subject: [PATCH 096/158] fix(trainer): overrides device to the target device, for the device processor on the preprocessor (#1993) * fix(trainer): overiddes device to the target defice, for device processor on preprocessor * Update src/lerobot/scripts/train.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Adil Zouitine --------- Signed-off-by: Adil Zouitine Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/lerobot/scripts/train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 485fc9275..5594d2f9b 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -183,6 +183,9 @@ def train(cfg: TrainPipelineConfig): # Only provide dataset_stats when not resuming from saved processor state processor_kwargs["dataset_stats"] = dataset.meta.stats + if cfg.policy.pretrained_path is not None: + processor_kwargs["preprocessor_overrides"] = {"device_processor": {"device": device.type}} + preprocessor, postprocessor = make_pre_post_processors( policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs ) From 25384727812de60ff6e7a5e705cc016ec5def552 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Mon, 22 Sep 2025 15:36:20 +0200 Subject: [PATCH 097/158] feat(sim): Add Libero Env (#1984) --- docker/Dockerfile.internal | 1 + docker/Dockerfile.user | 1 + docs/source/_toctree.yml | 2 + docs/source/libero.mdx | 126 +++++++++++ pyproject.toml | 3 + src/lerobot/envs/configs.py | 54 +++++ src/lerobot/envs/factory.py | 44 +++- src/lerobot/envs/libero.py | 377 ++++++++++++++++++++++++++++++++ src/lerobot/envs/utils.py | 40 ++++ src/lerobot/scripts/eval.py | 251 ++++++++++++++++++++- src/lerobot/scripts/train.py | 26 ++- tests/envs/test_envs.py | 5 +- tests/policies/test_policies.py | 8 +- 13 files changed, 906 insertions(+), 32 deletions(-) create mode 100644 docs/source/libero.mdx create mode 100644 src/lerobot/envs/libero.py diff --git a/docker/Dockerfile.internal b/docker/Dockerfile.internal index 8c77fe497..52becb830 100644 --- a/docker/Dockerfile.internal +++ b/docker/Dockerfile.internal @@ -39,6 +39,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ software-properties-common build-essential git curl \ libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \ libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \ + cmake pkg-config ninja-build \ && add-apt-repository -y ppa:deadsnakes/ppa \ && apt-get update \ && apt-get install -y --no-install-recommends \ diff --git a/docker/Dockerfile.user b/docker/Dockerfile.user index bcd067637..59fd3e0b3 100644 --- a/docker/Dockerfile.user +++ b/docker/Dockerfile.user @@ -31,6 +31,7 @@ ENV DEBIAN_FRONTEND=noninteractive \ RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential git curl libglib2.0-0 libegl1-mesa-dev ffmpeg \ libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \ + cmake pkg-config ninja-build \ && curl -LsSf https://astral.sh/uv/install.sh | sh \ && mv /root/.local/bin/uv /usr/local/bin/uv \ && useradd --create-home --shell /bin/bash user_lerobot \ diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 7d6b69fba..7f4c07944 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -29,6 +29,8 @@ - sections: - local: smolvla title: Finetune SmolVLA + - local: libero + title: Using Libero title: "Policies" - sections: diff --git a/docs/source/libero.mdx b/docs/source/libero.mdx new file mode 100644 index 000000000..488c02ce0 --- /dev/null +++ b/docs/source/libero.mdx @@ -0,0 +1,126 @@ +# LIBERO + +**LIBERO** is a benchmark designed to study **lifelong robot learning**. The idea is that robots won’t just be pretrained once in a factory, they’ll need to keep learning and adapting with their human users over time. This ongoing adaptation is called **lifelong learning in decision making (LLDM)**, and it’s a key step toward building robots that become truly personalized helpers. + +- 📄 [LIBERO paper](https://arxiv.org/abs/2306.03310) +- 💻 [Original LIBERO repo](https://github.com/Lifelong-Robot-Learning/LIBERO) + +To make progress on this challenge, LIBERO provides a set of standardized tasks that focus on **knowledge transfer**: how well a robot can apply what it has already learned to new situations. By evaluating on LIBERO, different algorithms can be compared fairly and researchers can build on each other’s work. + +LIBERO includes **five task suites**: + +- **LIBERO-Spatial (`libero_spatial`)** – tasks that require reasoning about spatial relations. +- **LIBERO-Object (`libero_object`)** – tasks centered on manipulating different objects. +- **LIBERO-Goal (`libero_goal`)** – goal-conditioned tasks where the robot must adapt to changing targets. +- **LIBERO-90 (`libero_90`)** – 90 short-horizon tasks from the LIBERO-100 collection. +- **LIBERO-Long (`libero_10`)** – 10 long-horizon tasks from the LIBERO-100 collection. + +Together, these suites cover **130 tasks**, ranging from simple object manipulations to complex multi-step scenarios. LIBERO is meant to grow over time, and to serve as a shared benchmark where the community can test and improve lifelong learning algorithms. + +![An overview of the LIBERO benchmark](https://libero-project.github.io/assets/img/libero/fig1.png) + +## Evaluating with LIBERO + +At **LeRobot**, we ported [LIBERO](https://github.com/Lifelong-Robot-Learning/LIBERO) into our framework and used it mainly to **evaluate [SmolVLA](https://huggingface.co/docs/lerobot/en/smolvla)**, our lightweight Vision-Language-Action model. + +LIBERO is now part of our **multi-eval supported simulation**, meaning you can benchmark your policies either on a **single suite of tasks** or across **multiple suites at once** with just a flag. + +To Install LIBERO, after following LeRobot official instructions, just do: +`pip install -e ".[libero]"` + +### Single-suite evaluation + +Evaluate a policy on one LIBERO suite: + +```bash +python src/lerobot/scripts/eval.py \ + --policy.path="your-policy-id" \ + --env.type=libero \ + --env.task=libero_object \ + --eval.batch_size=2 \ + --eval.n_episodes=3 +``` + +- `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.). +- `--eval.batch_size` controls how many environments run in parallel. +- `--eval.n_episodes` sets how many episodes to run in total. + +--- + +### Multi-suite evaluation + +Benchmark a policy across multiple suites at once: + +```bash +python src/lerobot/scripts/eval.py \ + --policy.path="your-policy-id" \ + --env.type=libero \ + --env.task=libero_object,libero_spatial \ + --eval.batch_size=1 \ + --eval.n_episodes=2 +``` + +- Pass a comma-separated list to `--env.task` for multi-suite evaluation. + +### Policy inputs and outputs + +When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**: + +- **Observations** + - `observation.state` – proprioceptive features (agent state). + - `observation.images.image` – main camera view (`agentview_image`). + - `observation.images.image2` – wrist camera view (`robot0_eye_in_hand_image`). + + ⚠️ **Note:** LeRobot enforces the `.images.*` prefix for any multi-modal visual features. Always ensure that your policy config `input_features` use the same naming keys, and that your dataset metadata keys follow this convention during evaluation. + If your data contains different keys, you must rename the observations to match what the policy expects, since naming keys are encoded inside the normalization statistics layer. + This will be fixed with the upcoming Pipeline PR. + +- **Actions** + - Continuous control values in a `Box(-1, 1, shape=(7,))` space. + +We also provide a notebook for quick testing: +Training with LIBERO + +## Training with LIBERO + +When training on LIBERO tasks, make sure your dataset parquet and metadata keys follow the LeRobot convention. + +The environment expects: + +- `observation.state` → 8-dim agent state +- `observation.images.image` → main camera (`agentview_image`) +- `observation.images.image2` → wrist camera (`robot0_eye_in_hand_image`) + +⚠️ Cleaning the dataset upfront is **cleaner and more efficient** than remapping keys inside the code. +To avoid potential mismatches and key errors, we provide a **preprocessed LIBERO dataset** that is fully compatible with the current LeRobot codebase and requires no additional manipulation: +👉 [HuggingFaceVLA/libero](https://huggingface.co/datasets/HuggingFaceVLA/libero) + +For reference, here is the **original dataset** published by Physical Intelligence: +👉 [physical-intelligence/libero](https://huggingface.co/datasets/physical-intelligence/libero) + +--- + +### Example training command + +```bash +python src/lerobot/scripts/train.py \ + --policy.type=smolvla \ + --policy.repo_id=${HF_USER}/libero-test \ + --dataset.repo_id=jadechoghari/smol-libero3 \ + --env.type=libero \ + --env.task=libero_10 \ + --output_dir=./outputs/ \ + --steps=100000 \ + --batch_size=4 \ + --eval.batch_size=1 \ + --eval.n_episodes=1 \ + --eval_freq=1000 \ +``` + +--- + +### Note on rendering + +LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation: + +- `export MUJOCO_GL=egl` → for headless servers (e.g. HPC, cloud) diff --git a/pyproject.toml b/pyproject.toml index c42ee5080..6db5e1307 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,6 +135,8 @@ video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] aloha = ["gym-aloha>=0.1.1"] pusht = ["gym-pusht>=0.1.5", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead xarm = ["gym-xarm>=0.1.1"] +libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"] + # All all = [ @@ -156,6 +158,7 @@ all = [ "lerobot[pusht]", "lerobot[xarm]", "lerobot[phone]", + "lerobot[libero]", ] [project.scripts] diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index f71aca70d..8c66b278e 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -30,6 +30,8 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC): fps: int = 30 features: dict[str, PolicyFeature] = field(default_factory=dict) features_map: dict[str, str] = field(default_factory=dict) + max_parallel_tasks: int = 1 + disable_env_checker: bool = True @property def type(self) -> str: @@ -242,3 +244,55 @@ class HILSerlRobotEnvConfig(EnvConfig): @property def gym_kwargs(self) -> dict: return {} + + +@EnvConfig.register_subclass("libero") +@dataclass +class LiberoEnv(EnvConfig): + task: str = "libero_10" # can also choose libero_spatial, libero_object, etc. + fps: int = 30 + episode_length: int = 520 + obs_type: str = "pixels_agent_pos" + render_mode: str = "rgb_array" + camera_name: str = "agentview_image,robot0_eye_in_hand_image" + init_states: bool = True + camera_name_mapping: dict[str, str] | None = (None,) + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "agent_pos": OBS_STATE, + "pixels/agentview_image": f"{OBS_IMAGES}.image", + "pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2", + } + ) + + def __post_init__(self): + if self.obs_type == "pixels": + self.features["pixels/agentview_image"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(360, 360, 3) + ) + self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(360, 360, 3) + ) + elif self.obs_type == "pixels_agent_pos": + self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,)) + self.features["pixels/agentview_image"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(360, 360, 3) + ) + self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature( + type=FeatureType.VISUAL, shape=(360, 360, 3) + ) + else: + raise ValueError(f"Unsupported obs_type: {self.obs_type}") + + @property + def gym_kwargs(self) -> dict: + return { + "obs_type": self.obs_type, + "render_mode": self.render_mode, + } diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index af8f5eaf5..9b172854c 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -17,7 +17,7 @@ import importlib import gymnasium as gym -from lerobot.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv +from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv, XarmEnv def make_env_config(env_type: str, **kwargs) -> EnvConfig: @@ -27,11 +27,15 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: return PushtEnv(**kwargs) elif env_type == "xarm": return XarmEnv(**kwargs) + elif env_type == "libero": + return LiberoEnv(**kwargs) else: raise ValueError(f"Policy type '{env_type}' is not available.") -def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None: +def make_env( + cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False +) -> dict[str, dict[int, gym.vector.VectorEnv]]: """Makes a gym vector environment according to the config. Args: @@ -45,13 +49,30 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g ModuleNotFoundError: If the requested env package is not installed Returns: - gym.vector.VectorEnv: The parallelized gym.env instance. + dict[str, dict[int, gym.vector.VectorEnv]]: + A mapping from suite name to indexed vectorized environments. + - For multi-task benchmarks (e.g., LIBERO): one entry per suite, and one vec env per task_id. + - For single-task environments: a single suite entry (cfg.type) with task_id=0. + """ if n_envs < 1: - raise ValueError("`n_envs must be at least 1") + raise ValueError("`n_envs` must be at least 1") + + env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv + + if "libero" in cfg.type: + from lerobot.envs.libero import create_libero_envs + + return create_libero_envs( + task=cfg.task, + n_envs=n_envs, + camera_name=cfg.camera_name, + init_states=cfg.init_states, + gym_kwargs=cfg.gym_kwargs, + env_cls=env_cls, + ) package_name = f"gym_{cfg.type}" - try: importlib.import_module(package_name) except ModuleNotFoundError as e: @@ -60,10 +81,11 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g gym_handle = f"{package_name}/{cfg.task}" - # batched version of the env that returns an observation of shape (b, c) - env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv - env = env_cls( - [lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)] - ) + def _make_one(): + return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {})) - return env + vec = env_cls([_make_one for _ in range(n_envs)]) + + # normalize to {suite: {task_id: vec_env}} for consistency + suite_name = cfg.type # e.g., "pusht", "aloha" + return {suite_name: {0: vec}} diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py new file mode 100644 index 000000000..466796975 --- /dev/null +++ b/src/lerobot/envs/libero.py @@ -0,0 +1,377 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import os +from collections import defaultdict +from collections.abc import Callable, Iterable, Mapping, Sequence +from functools import partial +from pathlib import Path +from typing import Any + +import gymnasium as gym +import numpy as np +import torch +from gymnasium import spaces +from libero.libero import benchmark, get_libero_path +from libero.libero.envs import OffScreenRenderEnv +from robosuite.utils.transform_utils import quat2axisangle + + +def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]: + """Normalize camera_name into a non-empty list of strings.""" + if isinstance(camera_name, str): + cams = [c.strip() for c in camera_name.split(",") if c.strip()] + elif isinstance(camera_name, (list, tuple)): + cams = [str(c).strip() for c in camera_name if str(c).strip()] + else: + raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}") + if not cams: + raise ValueError("camera_name resolved to an empty list.") + return cams + + +def _get_suite(name: str) -> benchmark.Benchmark: + """Instantiate a LIBERO suite by name with clear validation.""" + bench = benchmark.get_benchmark_dict() + if name not in bench: + raise ValueError(f"Unknown LIBERO suite '{name}'. Available: {', '.join(sorted(bench.keys()))}") + suite = bench[name]() + if not getattr(suite, "tasks", None): + raise ValueError(f"Suite '{name}' has no tasks.") + return suite + + +def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[int]: + """Validate/normalize task ids. If None → all tasks.""" + if task_ids is None: + return list(range(total_tasks)) + ids = sorted({int(t) for t in task_ids}) + for t in ids: + if t < 0 or t >= total_tasks: + raise ValueError(f"task_id {t} out of range [0, {total_tasks - 1}].") + return ids + + +def get_task_init_states(task_suite: Any, i: int) -> np.ndarray: + init_states_path = ( + Path(get_libero_path("init_states")) + / task_suite.tasks[i].problem_folder + / task_suite.tasks[i].init_states_file + ) + init_states = torch.load(init_states_path, weights_only=False) # nosec B614 + return init_states + + +def get_libero_dummy_action(): + """Get dummy/no-op action, used to roll out the simulation while the robot does nothing.""" + return [0, 0, 0, 0, 0, 0, -1] + + +OBS_STATE_DIM = 8 +ACTION_DIM = 7 +AGENT_POS_LOW = -1000.0 +AGENT_POS_HIGH = 1000.0 +ACTION_LOW = -1.0 +ACTION_HIGH = 1.0 +TASK_SUITE_MAX_STEPS: dict[str, int] = { + "libero_spatial": 280, # longest training demo has 193 steps + "libero_object": 280, # longest training demo has 254 steps + "libero_goal": 300, # longest training demo has 270 steps + "libero_10": 520, # longest training demo has 505 steps + "libero_90": 400, # longest training demo has 373 steps +} + + +class LiberoEnv(gym.Env): + metadata = {"render_modes": ["rgb_array"], "render_fps": 80} + + def __init__( + self, + task_suite: Any, + task_id: int, + task_suite_name: str, + camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image", + obs_type: str = "pixels", + render_mode: str = "rgb_array", + observation_width: int = 256, + observation_height: int = 256, + visualization_width: int = 640, + visualization_height: int = 480, + init_states: bool = True, + episode_index: int = 0, + camera_name_mapping: dict[str, str] | None = None, + num_steps_wait: int = 10, + ): + super().__init__() + self.task_id = task_id + self.obs_type = obs_type + self.render_mode = render_mode + self.observation_width = observation_width + self.observation_height = observation_height + self.visualization_width = visualization_width + self.visualization_height = visualization_height + self.init_states = init_states + self.camera_name = _parse_camera_names( + camera_name + ) # agentview_image (main) or robot0_eye_in_hand_image (wrist) + + # Map raw camera names to "image1" and "image2". + # The preprocessing step `preprocess_observation` will then prefix these with `.images.*`, + # following the LeRobot convention (e.g., `observation.images.image`, `observation.images.image2`). + # This ensures the policy consistently receives observations in the + # expected format regardless of the original camera naming. + if camera_name_mapping is None: + camera_name_mapping = { + "agentview_image": "image", + "robot0_eye_in_hand_image": "image2", + } + self.camera_name_mapping = camera_name_mapping + self.num_steps_wait = num_steps_wait + self.episode_index = episode_index + # Load once and keep + self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None + self._init_state_id = self.episode_index # tie each sub-env to a fixed init state + + self._env = self._make_envs_task(task_suite, self.task_id) + default_steps = 500 + self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps) + + images = {} + for cam in self.camera_name: + images[self.camera_name_mapping[cam]] = spaces.Box( + low=0, + high=255, + shape=(self.observation_height, self.observation_width, 3), + dtype=np.uint8, + ) + + if self.obs_type == "state": + raise NotImplementedError( + "The 'state' observation type is not supported in LiberoEnv. " + "Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')." + ) + + elif self.obs_type == "pixels": + self.observation_space = spaces.Dict( + { + "pixels": spaces.Dict(images), + } + ) + elif self.obs_type == "pixels_agent_pos": + self.observation_space = spaces.Dict( + { + "pixels": spaces.Dict(images), + "agent_pos": spaces.Box( + low=AGENT_POS_LOW, + high=AGENT_POS_HIGH, + shape=(OBS_STATE_DIM,), + dtype=np.float64, + ), + } + ) + + self.action_space = spaces.Box( + low=ACTION_LOW, high=ACTION_HIGH, shape=(ACTION_DIM,), dtype=np.float32 + ) + + def render(self): + raw_obs = self._env.env._get_observations() + image = self._format_raw_obs(raw_obs)["pixels"]["image"] + return image + + def _make_envs_task(self, task_suite: Any, task_id: int = 0): + task = task_suite.get_task(task_id) + self.task = task.name + self.task_description = task.language + task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file) + + env_args = { + "bddl_file_name": task_bddl_file, + "camera_heights": self.observation_height, + "camera_widths": self.observation_width, + } + env = OffScreenRenderEnv(**env_args) + env.reset() + return env + + def _format_raw_obs(self, raw_obs: dict[str, Any]) -> dict[str, Any]: + images = {} + for camera_name in self.camera_name: + image = raw_obs[camera_name] + image = image[::-1, ::-1] # rotate 180 degrees + images[self.camera_name_mapping[camera_name]] = image + state = np.concatenate( + ( + raw_obs["robot0_eef_pos"], + quat2axisangle(raw_obs["robot0_eef_quat"]), + raw_obs["robot0_gripper_qpos"], + ) + ) + agent_pos = state + if self.obs_type == "pixels": + return {"pixels": images.copy()} + if self.obs_type == "pixels_agent_pos": + return { + "pixels": images.copy(), + "agent_pos": agent_pos, + } + raise NotImplementedError( + f"The observation type '{self.obs_type}' is not supported in LiberoEnv. " + "Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')." + ) + + def reset(self, seed=None, **kwargs): + super().reset(seed=seed) + self._env.seed(seed) + if self.init_states and self._init_states is not None: + self._env.set_init_state(self._init_states[self._init_state_id]) + raw_obs = self._env.reset() + + # After reset, objects may be unstable (slightly floating, intersecting, etc.). + # Step the simulator with a no-op action for a few frames so everything settles. + # Increasing this value can improve determinism and reproducibility across resets. + for _ in range(self.num_steps_wait): + raw_obs, _, _, _ = self._env.step(get_libero_dummy_action()) + observation = self._format_raw_obs(raw_obs) + info = {"is_success": False} + return observation, info + + def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + if action.ndim != 1: + raise ValueError( + f"Expected action to be 1-D (shape (action_dim,)), " + f"but got shape {action.shape} with ndim={action.ndim}" + ) + raw_obs, reward, done, info = self._env.step(action) + + is_success = self._env.check_success() + terminated = done or is_success + info["is_success"] = is_success + + observation = self._format_raw_obs(raw_obs) + if done: + self.reset() + info.update( + { + "task": self.task, + "task_id": self.task_id, + "done": done, + "is_success": is_success, + } + ) + truncated = False + return observation, reward, terminated, truncated, info + + def close(self): + self._env.close() + + +def _make_env_fns( + *, + suite, + suite_name: str, + task_id: int, + n_envs: int, + camera_names: list[str], + init_states: bool, + gym_kwargs: Mapping[str, Any], +) -> list[Callable[[], LiberoEnv]]: + """Build n_envs factory callables for a single (suite, task_id).""" + + def _make_env(episode_index: int, **kwargs) -> LiberoEnv: + local_kwargs = dict(kwargs) + return LiberoEnv( + task_suite=suite, + task_id=task_id, + task_suite_name=suite_name, + camera_name=camera_names, + init_states=init_states, + episode_index=episode_index, + **local_kwargs, + ) + + fns: list[Callable[[], LiberoEnv]] = [] + for episode_index in range(n_envs): + fns.append(partial(_make_env, episode_index, **gym_kwargs)) + return fns + + +# ---- Main API ---------------------------------------------------------------- + + +def create_libero_envs( + task: str, + n_envs: int, + gym_kwargs: dict[str, Any] | None = None, + camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image", + init_states: bool = True, + env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None, +) -> dict[str, dict[int, Any]]: + """ + Create vectorized LIBERO environments with a consistent return shape. + + Returns: + dict[suite_name][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories) + Notes: + - n_envs is the number of rollouts *per task* (episode_index = 0..n_envs-1). + - `task` can be a single suite or a comma-separated list of suites. + - You may pass `task_ids` (list[int]) inside `gym_kwargs` to restrict tasks per suite. + """ + if env_cls is None or not callable(env_cls): + raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.") + if not isinstance(n_envs, int) or n_envs <= 0: + raise ValueError(f"n_envs must be a positive int; got {n_envs}.") + + gym_kwargs = dict(gym_kwargs or {}) + task_ids_filter = gym_kwargs.pop("task_ids", None) # optional: limit to specific tasks + + camera_names = _parse_camera_names(camera_name) + suite_names = [s.strip() for s in str(task).split(",") if s.strip()] + if not suite_names: + raise ValueError("`task` must contain at least one LIBERO suite name.") + + print( + f"Creating LIBERO envs | suites={suite_names} | n_envs(per task)={n_envs} | init_states={init_states}" + ) + if task_ids_filter is not None: + print(f"Restricting to task_ids={task_ids_filter}") + + out: dict[str, dict[int, Any]] = defaultdict(dict) + + for suite_name in suite_names: + suite = _get_suite(suite_name) + total = len(suite.tasks) + selected = _select_task_ids(total, task_ids_filter) + + if not selected: + raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).") + + for tid in selected: + fns = _make_env_fns( + suite=suite, + suite_name=suite_name, + task_id=tid, + n_envs=n_envs, + camera_names=camera_names, + init_states=init_states, + gym_kwargs=gym_kwargs, + ) + out[suite_name][tid] = env_cls(fns) + print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}") + + # return plain dicts for predictability + return {suite: dict(task_map) for suite, task_map in out.items()} diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index b4f65ee9c..f0aa0b5c6 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import warnings +from collections.abc import Mapping, Sequence +from functools import singledispatch from typing import Any import einops @@ -154,3 +156,41 @@ def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dic num_envs = observation[list(observation.keys())[0]].shape[0] observation["task"] = ["" for _ in range(num_envs)] return observation + + +def _close_single_env(env: Any) -> None: + try: + env.close() + except Exception as exc: + print(f"Exception while closing env {env}: {exc}") + + +@singledispatch +def close_envs(obj: Any) -> None: + """Default: raise if the type is not recognized.""" + raise NotImplementedError(f"close_envs not implemented for type {type(obj).__name__}") + + +@close_envs.register +def _(env: Mapping) -> None: + for v in env.values(): + if isinstance(v, Mapping): + close_envs(v) + elif hasattr(v, "close"): + _close_single_env(v) + + +@close_envs.register +def _(envs: Sequence) -> None: + if isinstance(envs, (str, bytes)): + return + for v in envs: + if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str, bytes)): + close_envs(v) + elif hasattr(v, "close"): + _close_single_env(v) + + +@close_envs.register +def _(env: gym.Env) -> None: + _close_single_env(env) diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index bf398a0a9..ca900f8df 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -46,17 +46,20 @@ Note that in both examples, the repo/folder should contain at least `config.json You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py """ +import concurrent.futures as cf import json import logging import threading import time +from collections import defaultdict from collections.abc import Callable from contextlib import nullcontext from copy import deepcopy from dataclasses import asdict +from functools import partial from pathlib import Path from pprint import pformat -from typing import Any +from typing import Any, TypedDict import einops import gymnasium as gym @@ -69,7 +72,12 @@ from tqdm import trange from lerobot.configs import parser from lerobot.configs.eval import EvalPipelineConfig from lerobot.envs.factory import make_env -from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation +from lerobot.envs.utils import ( + add_envs_task, + check_env_attributes_and_types, + close_envs, + preprocess_observation, +) from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.processor import PolicyAction, PolicyProcessorPipeline @@ -147,7 +155,7 @@ def rollout( leave=False, ) check_env_attributes_and_types(env) - while not np.all(done): + while not np.all(done) and step < max_steps: # Numpy array to tensor and changing dictionary keys to LeRobot policy format. observation = preprocess_observation(observation) if return_observations: @@ -178,7 +186,12 @@ def rollout( successes = [False] * env.num_envs # Keep track of which environments are done so far. + # Mark the episode as done if we reach the maximum step limit. + # This ensures that the rollout always terminates cleanly at `max_steps`, + # and allows logging/saving (e.g., videos) to be triggered consistently. done = terminated | truncated | done + if step + 1 == max_steps: + done = np.ones_like(done, dtype=bool) all_actions.append(torch.from_numpy(action_numpy)) all_rewards.append(torch.from_numpy(reward)) @@ -474,7 +487,7 @@ def eval_main(cfg: EvalPipelineConfig): logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") logging.info("Making environment.") - env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) + envs = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) logging.info("Making policy.") @@ -490,10 +503,9 @@ def eval_main(cfg: EvalPipelineConfig): # The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility. preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}}, ) - with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): - info = eval_policy( - env=env, + info = eval_policy_all( + envs=envs, policy=policy, preprocessor=preprocessor, postprocessor=postprocessor, @@ -501,18 +513,237 @@ def eval_main(cfg: EvalPipelineConfig): max_episodes_rendered=10, videos_dir=Path(cfg.output_dir) / "videos", start_seed=cfg.seed, + max_parallel_tasks=cfg.env.max_parallel_tasks, ) - print(info["aggregated"]) + print("Overall Aggregated Metrics:") + print(info["overall"]) + + # Print per-suite stats + for task_group, task_group_info in info.items(): + print(f"\nAggregated Metrics for {task_group}:") + print(task_group_info) + # Close all vec envs + close_envs(envs) # Save info with open(Path(cfg.output_dir) / "eval_info.json", "w") as f: json.dump(info, f, indent=2) - env.close() - logging.info("End of eval") +# ---- typed payload returned by one task eval ---- +class TaskMetrics(TypedDict): + sum_rewards: list[float] + max_rewards: list[float] + successes: list[bool] + video_paths: list[str] + + +ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths") + + +def eval_one( + env: gym.vector.VectorEnv, + *, + policy: PreTrainedPolicy, + preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], + n_episodes: int, + max_episodes_rendered: int, + videos_dir: Path | None, + return_episode_data: bool, + start_seed: int | None, +) -> TaskMetrics: + """Evaluates one task_id of one suite using the provided vec env.""" + + task_videos_dir = videos_dir + + task_result = eval_policy( + env=env, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + n_episodes=n_episodes, + max_episodes_rendered=max_episodes_rendered, + videos_dir=task_videos_dir, + return_episode_data=return_episode_data, + start_seed=start_seed, + ) + + per_episode = task_result["per_episode"] + return TaskMetrics( + sum_rewards=[ep["sum_reward"] for ep in per_episode], + max_rewards=[ep["max_reward"] for ep in per_episode], + successes=[ep["success"] for ep in per_episode], + video_paths=task_result.get("video_paths", []), + ) + + +def run_one( + task_group: str, + task_id: int, + env, + *, + policy, + preprocessor, + postprocessor, + n_episodes: int, + max_episodes_rendered: int, + videos_dir: Path | None, + return_episode_data: bool, + start_seed: int | None, +): + """ + Run eval_one for a single (task_group, task_id, env). + Returns (task_group, task_id, task_metrics_dict). + This function is intentionally module-level to make it easy to test. + """ + task_videos_dir = None + if videos_dir is not None: + task_videos_dir = videos_dir / f"{task_group}_{task_id}" + task_videos_dir.mkdir(parents=True, exist_ok=True) + + # Call the existing eval_one (assumed to return TaskMetrics-like dict) + metrics = eval_one( + env, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + n_episodes=n_episodes, + max_episodes_rendered=max_episodes_rendered, + videos_dir=task_videos_dir, + return_episode_data=return_episode_data, + start_seed=start_seed, + ) + # ensure we always provide video_paths key to simplify accumulation + if max_episodes_rendered > 0: + metrics.setdefault("video_paths", []) + return task_group, task_id, metrics + + +def eval_policy_all( + envs: dict[str, dict[int, gym.vector.VectorEnv]], + policy, + preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction], + n_episodes: int, + *, + max_episodes_rendered: int = 0, + videos_dir: Path | None = None, + return_episode_data: bool = False, + start_seed: int | None = None, + max_parallel_tasks: int = 1, +) -> dict: + """ + Evaluate a nested `envs` dict: {task_group: {task_id: vec_env}}. + This implementation flattens tasks, runs them sequentially or via ThreadPoolExecutor, + accumulates per-group and overall statistics, and returns the same aggregate metrics + schema as the single-env evaluator (avg_sum_reward / avg_max_reward / pc_success / timings) + plus per-task infos. + """ + start_t = time.time() + + # Flatten envs into list of (task_group, task_id, env) + tasks = [(tg, tid, vec) for tg, group in envs.items() for tid, vec in group.items()] + + # accumulators: track metrics at both per-group level and across all groups + group_acc: dict[str, dict[str, list]] = defaultdict(lambda: {k: [] for k in ACC_KEYS}) + overall: dict[str, list] = {k: [] for k in ACC_KEYS} + per_task_infos: list[dict] = [] + + # small inline helper to accumulate one task's metrics into accumulators + def _accumulate_to(group: str, metrics: dict): + # metrics expected to contain 'sum_rewards', 'max_rewards', 'successes', optionally 'video_paths' + # but eval_one may store per-episode lists; we assume metrics uses scalars averaged per task as before. + # To be robust, accept scalars or lists. + def _append(key, value): + if value is None: + return + if isinstance(value, list): + group_acc[group][key].extend(value) + overall[key].extend(value) + else: + group_acc[group][key].append(value) + overall[key].append(value) + + _append("sum_rewards", metrics.get("sum_rewards")) + _append("max_rewards", metrics.get("max_rewards")) + _append("successes", metrics.get("successes")) + # video_paths is list-like + paths = metrics.get("video_paths", []) + if paths: + group_acc[group]["video_paths"].extend(paths) + overall["video_paths"].extend(paths) + + # Choose runner (sequential vs threaded) + task_runner = partial( + run_one, + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + n_episodes=n_episodes, + max_episodes_rendered=max_episodes_rendered, + videos_dir=videos_dir, + return_episode_data=return_episode_data, + start_seed=start_seed, + ) + + if max_parallel_tasks <= 1: + # sequential path (single accumulator path on the main thread) + # NOTE: keeping a single-threaded accumulator avoids concurrent list appends or locks + for task_group, task_id, env in tasks: + tg, tid, metrics = task_runner(task_group, task_id, env) + _accumulate_to(tg, metrics) + per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics}) + else: + # threaded path: submit all tasks, consume completions on main thread and accumulate there + with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor: + fut2meta = {} + for task_group, task_id, env in tasks: + fut = executor.submit(task_runner, task_group, task_id, env) + fut2meta[fut] = (task_group, task_id) + for fut in cf.as_completed(fut2meta): + tg, tid, metrics = fut.result() + _accumulate_to(tg, metrics) + per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics}) + + # compute aggregated metrics helper (robust to lists/scalars) + def _agg_from_list(xs): + if not xs: + return float("nan") + arr = np.array(xs, dtype=float) + return float(np.nanmean(arr)) + + # compute per-group aggregates + groups_aggregated = {} + for group, acc in group_acc.items(): + groups_aggregated[group] = { + "avg_sum_reward": _agg_from_list(acc["sum_rewards"]), + "avg_max_reward": _agg_from_list(acc["max_rewards"]), + "pc_success": _agg_from_list(acc["successes"]) * 100 if acc["successes"] else float("nan"), + "n_episodes": len(acc["sum_rewards"]), + "video_paths": list(acc["video_paths"]), + } + + # overall aggregates + overall_agg = { + "avg_sum_reward": _agg_from_list(overall["sum_rewards"]), + "avg_max_reward": _agg_from_list(overall["max_rewards"]), + "pc_success": _agg_from_list(overall["successes"]) * 100 if overall["successes"] else float("nan"), + "n_episodes": len(overall["sum_rewards"]), + "eval_s": time.time() - start_t, + "eval_ep_s": (time.time() - start_t) / max(1, len(overall["sum_rewards"])), + "video_paths": list(overall["video_paths"]), + } + + return { + "per_task": per_task_infos, + "per_group": groups_aggregated, + "overall": overall_agg, + } + + def main(): init_logging() eval_main() diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 5594d2f9b..21da62bbb 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -30,11 +30,12 @@ from lerobot.datasets.factory import make_dataset from lerobot.datasets.sampler import EpisodeAwareSampler from lerobot.datasets.utils import cycle from lerobot.envs.factory import make_env +from lerobot.envs.utils import close_envs from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters -from lerobot.scripts.eval import eval_policy +from lerobot.scripts.eval import eval_policy_all from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.random_utils import set_seed from lerobot.utils.train_utils import ( @@ -302,8 +303,8 @@ def train(cfg: TrainPipelineConfig): torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), ): - eval_info = eval_policy( - env=eval_env, + eval_info = eval_policy_all( + envs=eval_env, # dict[suite][task_id] -> vec_env policy=policy, preprocessor=preprocessor, postprocessor=postprocessor, @@ -311,8 +312,16 @@ def train(cfg: TrainPipelineConfig): videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", max_episodes_rendered=4, start_seed=cfg.seed, + max_parallel_tasks=cfg.env.max_parallel_tasks, ) + # overall metrics (suite-agnostic) + aggregated = eval_info["overall"] + # optional: per-suite logging + for suite, suite_info in eval_info.items(): + logging.info("Suite %s aggregated: %s", suite, suite_info) + + # meters/tracker eval_metrics = { "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), "pc_success": AverageMeter("success", ":.1f"), @@ -321,17 +330,16 @@ def train(cfg: TrainPipelineConfig): eval_tracker = MetricsTracker( cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step ) - eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s") - eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward") - eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success") - logging.info(eval_tracker) + eval_tracker.eval_s = aggregated.pop("eval_s") + eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward") + eval_tracker.pc_success = aggregated.pop("pc_success") if wandb_logger: wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} wandb_logger.log_dict(wandb_log_dict, step, mode="eval") - wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval") + wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval") if eval_env: - eval_env.close() + close_envs(eval_env) logging.info("End of training") if cfg.policy.push_to_hub: diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 140e9dfb9..51ea564e5 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -46,7 +46,10 @@ def test_env(env_name, env_task, obs_type): @require_env def test_factory(env_name): cfg = make_env_config(env_name) - env = make_env(cfg, n_envs=1) + envs = make_env(cfg, n_envs=1) + suite_name = next(iter(envs)) + task_id = next(iter(envs[suite_name])) + env = envs[suite_name][task_id] obs, _ = env.reset() obs = preprocess_observation(obs) diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index ef09bcd22..28c395bfc 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -159,7 +159,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): assert isinstance(policy, PreTrainedPolicy) # Check that we run select_actions and get the appropriate output. - env = make_env(train_cfg.env, n_envs=2) + envs = make_env(train_cfg.env, n_envs=2) dataloader = torch.utils.data.DataLoader( dataset, @@ -188,6 +188,12 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): # reset the policy and environment policy.reset() + # For testing purposes, we only need a single environment instance. + # So here we unwrap the first suite_name and first task_id to grab + # the actual env object (SyncVectorEnv) that exposes `.reset()`. + suite_name = next(iter(envs)) + task_id = next(iter(envs[suite_name])) + env = envs[suite_name][task_id] observation, _ = env.reset(seed=train_cfg.seed) # apply transform to normalize the observations From 4bad09cd25a2a4adab3bf2889d5950f522ead257 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 22 Sep 2025 16:06:16 +0200 Subject: [PATCH 098/158] feat(ci): add stale GH action bot for stalled issues & PRs (#1996) --- .github/workflows/stale.yml | 68 +++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 .github/workflows/stale.yml diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 000000000..6aa75d5b8 --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,68 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This workflow handles closing stale issues and PRs. +name: Stale +on: + # Allows running this workflow manually from the Actions tab + workflow_dispatch: + + # Runs at 02:00 + schedule: + - cron: "0 2 * * *" + +env: + CLOSE_ISSUE_MESSAGE: > + This issue was closed because it has been stalled for 14 days with no activity. + Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions. + CLOSE_PR_MESSAGE: > + This PR was closed because it has been stalled for 14 days with no activity. + Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions. + WARN_ISSUE_MESSAGE: > + This issue has been automatically marked as stale because it has not had + recent activity (1 year). It will be closed if no further activity occurs. + Thank you for your contributions. + WARN_PR_MESSAGE: > + This PR has been automatically marked as stale because it has not had + recent activity (1 year). It will be closed if no further activity occurs. + Thank you for your contributions. + +jobs: + # This job runs the actions/stale action to close stale issues and PRs. + stale: + name: Close Stale Issues and PRs + runs-on: ubuntu-latest + permissions: + actions: write + contents: write # only for delete-branch option + issues: write + pull-requests: write + steps: + - uses: actions/stale@v10 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + stale-issue-label: stale + stale-pr-label: stale + exempt-issue-labels: never-stale + exempt-pr-labels: never-stale + days-before-issue-stale: 365 # TODO(Steven): Will modify this to 180 after initial cleanup + days-before-issue-close: 14 + days-before-pr-stale: 365 + days-before-pr-close: 14 + delete-branch: true + close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }} + close-pr-message: ${{ env.CLOSE_PR_MESSAGE }} + stale-issue-message: ${{ env.WARN_ISSUE_MESSAGE }} + stale-pr-message: ${{ env.WARN_PR_MESSAGE }} + operations-per-run: 500 From a665a9df83982c42ba5b4908d2af99445a5a57f7 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 22 Sep 2025 16:40:31 +0200 Subject: [PATCH 099/158] chore(ci): update time for stale issue/pr (#1997) * chore(ci): update time for stale issue/pr * chore(ci): update comment --- .github/workflows/stale.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 6aa75d5b8..acd1ae53a 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -56,9 +56,9 @@ jobs: stale-pr-label: stale exempt-issue-labels: never-stale exempt-pr-labels: never-stale - days-before-issue-stale: 365 # TODO(Steven): Will modify this to 180 after initial cleanup + days-before-issue-stale: 180 # TODO(Steven): Will modify this to 90 after initial cleanup days-before-issue-close: 14 - days-before-pr-stale: 365 + days-before-pr-stale: 180 days-before-pr-close: 14 delete-branch: true close-issue-message: ${{ env.CLOSE_ISSUE_MESSAGE }} From 664c00b59405504f519946a1e39116fe23b71472 Mon Sep 17 00:00:00 2001 From: Mohit <97352487+complete-dope@users.noreply.github.com> Date: Mon, 22 Sep 2025 20:21:43 +0530 Subject: [PATCH 100/158] Update README.md (#1989) Signed-off-by: Mohit <97352487+complete-dope@users.noreply.github.com> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 39a3e8bcb..a3f28f552 100644 --- a/README.md +++ b/README.md @@ -227,7 +227,7 @@ Our script can also visualize datasets stored on a distant server. See `python - A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and PyTorch dataset. For instance `dataset[0]` will retrieve a single temporal frame from the dataset containing observation(s) and an action as PyTorch tensors ready to be fed to a model. -A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](https://github.com/huggingface/lerobot/blob/main/examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`. +A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) for more details on `delta_timestamps`. Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states but easily extended to other types of sensory inputs as long as they can be represented by a tensor. From a68424c3c9d8807b255df975f987aa1d054ba091 Mon Sep 17 00:00:00 2001 From: "Jivin.L" <45867423+JivinDotL@users.noreply.github.com> Date: Tue, 23 Sep 2025 19:38:22 +0800 Subject: [PATCH 101/158] Fix: Resolve PermissionError and UnicodeDecodeError in Python scripts (#1980) * Fix: Resolve PermissionError and UnicodeDecodeError in Python scripts Problem: 1. PermissionError when running eval.py 2. UnicodeDecodeError: 'gbk' when running migrate_policy_normalization.py * To explicitly specify the file encoding and resolve linter warnings. Signed-off-by: Jivin.L <45867423+JivinDotL@users.noreply.github.com> --------- Signed-off-by: Jivin.L <45867423+JivinDotL@users.noreply.github.com> Co-authored-by: Steven Palma --- src/lerobot/configs/policies.py | 9 ++++----- src/lerobot/policies/pretrained.py | 4 +++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 7532f0612..9a2bb911a 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -196,11 +196,10 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): config = json.load(f) config.pop("type") - with tempfile.NamedTemporaryFile("w+") as f: + with tempfile.NamedTemporaryFile("w+", delete=False, suffix=".json") as f: json.dump(config, f) config_file = f.name - f.flush() - cli_overrides = policy_kwargs.pop("cli_overrides", []) - with draccus.config_type("json"): - return draccus.parse(orig_config.__class__, config_file, args=cli_overrides) + cli_overrides = policy_kwargs.pop("cli_overrides", []) + with draccus.config_type("json"): + return draccus.parse(orig_config.__class__, config_file, args=cli_overrides) diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index 2f69309c1..b770c980b 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -246,7 +246,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): base_model=base_model, ) - template_card = files("lerobot.templates").joinpath("lerobot_modelcard_template.md").read_text() + template_card = ( + files("lerobot.templates").joinpath("lerobot_modelcard_template.md").read_text(encoding="utf-8") + ) card = ModelCard.from_template(card_data, template_str=template_card) card.validate() return card From 9d0cf64da611c816a229a1b8e6e301c7dda262f5 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 23 Sep 2025 15:51:19 +0200 Subject: [PATCH 102/158] fix(dataset): cast fps to int instead of float (#2001) --- src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py index 96bdc1897..1327bd820 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -404,7 +404,7 @@ def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb): info["video_files_size_in_mb"] = video_file_size_in_mb info["data_path"] = DEFAULT_DATA_PATH info["video_path"] = DEFAULT_VIDEO_PATH - info["fps"] = float(info["fps"]) + info["fps"] = int(info["fps"]) for key in info["features"]: if info["features"][key]["dtype"] == "video": # already has fps in video_info From d6a32e9742571a2fc96a02e143cfdb1a1be8940d Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 23 Sep 2025 16:32:34 +0200 Subject: [PATCH 103/158] chore(rl): move rl related code to its directory at top level (#2002) * chore(rl): move rl related code to its directory at top level * chore(style): apply pre-commit to renamed headers * test(rl): fix rl imports * docs(rl): update rl headers doc --- docs/source/hilserl.mdx | 16 ++++++++-------- docs/source/hilserl_sim.mdx | 8 ++++---- docs/source/il_sim.mdx | 8 ++++---- src/lerobot/{scripts => }/rl/actor.py | 15 ++++++++------- src/lerobot/{scripts => }/rl/crop_dataset_roi.py | 0 src/lerobot/{scripts => }/rl/eval_policy.py | 3 ++- src/lerobot/{scripts => }/rl/gym_manipulator.py | 0 src/lerobot/{scripts => }/rl/learner.py | 11 ++++++----- src/lerobot/{scripts => }/rl/learner_service.py | 0 tests/rl/test_actor.py | 10 +++++----- tests/rl/test_actor_learner.py | 12 ++++++------ tests/rl/test_learner_service.py | 2 +- 12 files changed, 44 insertions(+), 41 deletions(-) rename src/lerobot/{scripts => }/rl/actor.py (99%) rename src/lerobot/{scripts => }/rl/crop_dataset_roi.py (100%) rename src/lerobot/{scripts => }/rl/eval_policy.py (97%) rename src/lerobot/{scripts => }/rl/gym_manipulator.py (100%) rename src/lerobot/{scripts => }/rl/learner.py (99%) rename src/lerobot/{scripts => }/rl/learner_service.py (100%) diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx index f6bac1ffa..08301556f 100644 --- a/docs/source/hilserl.mdx +++ b/docs/source/hilserl.mdx @@ -518,7 +518,7 @@ During the online training, press `space` to take over the policy and `space` ag Start the recording process, an example of the config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json): ```bash -python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/env_config_so100.json +python -m lerobot.rl.gym_manipulator --config_path src/lerobot/configs/env_config_so100.json ``` During recording: @@ -549,7 +549,7 @@ Note: If you already know the crop parameters, you can skip this step and just s Use the `crop_dataset_roi.py` script to interactively select regions of interest in your camera images: ```bash -python -m lerobot.scripts.rl.crop_dataset_roi --repo-id username/pick_lift_cube +python -m lerobot.rl.crop_dataset_roi --repo-id username/pick_lift_cube ``` 1. For each camera view, the script will display the first frame @@ -618,7 +618,7 @@ Before training, you need to collect a dataset with labeled examples. The `recor To collect a dataset, you need to modify some parameters in the environment configuration based on HILSerlRobotEnvConfig. ```bash -python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/reward_classifier_train_config.json +python -m lerobot.rl.gym_manipulator --config_path src/lerobot/configs/reward_classifier_train_config.json ``` **Key Parameters for Data Collection** @@ -764,7 +764,7 @@ or set the argument in the json config file. Run `gym_manipulator.py` to test the model. ```bash -python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/env_config.json +python -m lerobot.rl.gym_manipulator --config_path path/to/env_config.json ``` The reward classifier will automatically provide rewards based on the visual input from the robot's cameras. @@ -777,7 +777,7 @@ The reward classifier will automatically provide rewards based on the visual inp 2. **Collect a dataset**: ```bash - python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/env_config.json + python -m lerobot.rl.gym_manipulator --config_path src/lerobot/configs/env_config.json ``` 3. **Train the classifier**: @@ -788,7 +788,7 @@ The reward classifier will automatically provide rewards based on the visual inp 4. **Test the classifier**: ```bash - python -m lerobot.scripts.rl.gym_manipulator --config_path src/lerobot/configs/env_config.json + python -m lerobot.rl.gym_manipulator --config_path src/lerobot/configs/env_config.json ``` ### Training with Actor-Learner @@ -810,7 +810,7 @@ Create a training configuration file (example available [here](https://huggingfa First, start the learner server process: ```bash -python -m lerobot.scripts.rl.learner --config_path src/lerobot/configs/train_config_hilserl_so100.json +python -m lerobot.rl.learner --config_path src/lerobot/configs/train_config_hilserl_so100.json ``` The learner: @@ -825,7 +825,7 @@ The learner: In a separate terminal, start the actor process with the same configuration: ```bash -python -m lerobot.scripts.rl.actor --config_path src/lerobot/configs/train_config_hilserl_so100.json +python -m lerobot.rl.actor --config_path src/lerobot/configs/train_config_hilserl_so100.json ``` The actor: diff --git a/docs/source/hilserl_sim.mdx b/docs/source/hilserl_sim.mdx index 77191fde3..e2dddd9ed 100644 --- a/docs/source/hilserl_sim.mdx +++ b/docs/source/hilserl_sim.mdx @@ -91,7 +91,7 @@ Important parameters: To run the environment, set mode to null: ```bash -python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json +python -m lerobot.rl.gym_manipulator --config_path path/to/gym_hil_env.json ``` ### Recording a Dataset @@ -118,7 +118,7 @@ To collect a dataset, set the mode to `record` whilst defining the repo_id and n ``` ```bash -python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.json +python -m lerobot.rl.gym_manipulator --config_path path/to/gym_hil_env.json ``` ### Training a Policy @@ -126,13 +126,13 @@ python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/gym_hil_env.j To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/rl/gym_hil/train_config.json) and run the actor and learner servers: ```bash -python -m lerobot.scripts.rl.actor --config_path path/to/train_gym_hil_env.json +python -m lerobot.rl.actor --config_path path/to/train_gym_hil_env.json ``` In a different terminal, run the learner server: ```bash -python -m lerobot.scripts.rl.learner --config_path path/to/train_gym_hil_env.json +python -m lerobot.rl.learner --config_path path/to/train_gym_hil_env.json ``` The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots. diff --git a/docs/source/il_sim.mdx b/docs/source/il_sim.mdx index 6a615620b..9b7d7c111 100644 --- a/docs/source/il_sim.mdx +++ b/docs/source/il_sim.mdx @@ -61,14 +61,14 @@ Then we can run this command to start: ```bash -python -m lerobot.scripts.rl.gym_manipulator --config_path path/to/env_config_gym_hil_il.json +python -m lerobot.rl.gym_manipulator --config_path path/to/env_config_gym_hil_il.json ``` ```bash -mjpython -m lerobot.scripts.rl.gym_manipulator --config_path path/to/env_config_gym_hil_il.json +mjpython -m lerobot.rl.gym_manipulator --config_path path/to/env_config_gym_hil_il.json ``` @@ -198,14 +198,14 @@ Then you can run this command to visualize your trained policy ```bash -python -m lerobot.scripts.rl.eval_policy --config_path=path/to/eval_config_gym_hil.json +python -m lerobot.rl.eval_policy --config_path=path/to/eval_config_gym_hil.json ``` ```bash -mjpython -m lerobot.scripts.rl.eval_policy --config_path=path/to/eval_config_gym_hil.json +mjpython -m lerobot.rl.eval_policy --config_path=path/to/eval_config_gym_hil.json ``` diff --git a/src/lerobot/scripts/rl/actor.py b/src/lerobot/rl/actor.py similarity index 99% rename from src/lerobot/scripts/rl/actor.py rename to src/lerobot/rl/actor.py index baa284c4a..d1e709253 100644 --- a/src/lerobot/scripts/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -24,7 +24,7 @@ Examples of usage: - Start an actor server for real robot training with human-in-the-loop intervention: ```bash -python -m lerobot.scripts.rl.actor --config_path src/lerobot/configs/train_config_hilserl_so100.json +python -m lerobot.rl.actor --config_path src/lerobot/configs/train_config_hilserl_so100.json ``` **NOTE**: The actor server requires a running learner server to connect to. Ensure the learner @@ -64,12 +64,6 @@ from lerobot.policies.factory import make_policy from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.processor import TransitionKey from lerobot.robots import so100_follower # noqa: F401 -from lerobot.scripts.rl.gym_manipulator import ( - create_transition, - make_processors, - make_robot_env, - step_env_and_process_transition, -) from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents from lerobot.transport import services_pb2, services_pb2_grpc @@ -96,6 +90,13 @@ from lerobot.utils.utils import ( init_logging, ) +from .gym_manipulator import ( + create_transition, + make_processors, + make_robot_env, + step_env_and_process_transition, +) + ACTOR_SHUTDOWN_TIMEOUT = 30 # Main entry point diff --git a/src/lerobot/scripts/rl/crop_dataset_roi.py b/src/lerobot/rl/crop_dataset_roi.py similarity index 100% rename from src/lerobot/scripts/rl/crop_dataset_roi.py rename to src/lerobot/rl/crop_dataset_roi.py diff --git a/src/lerobot/scripts/rl/eval_policy.py b/src/lerobot/rl/eval_policy.py similarity index 97% rename from src/lerobot/scripts/rl/eval_policy.py rename to src/lerobot/rl/eval_policy.py index aa97483b6..7cec66800 100644 --- a/src/lerobot/scripts/rl/eval_policy.py +++ b/src/lerobot/rl/eval_policy.py @@ -25,12 +25,13 @@ from lerobot.robots import ( # noqa: F401 make_robot_from_config, so100_follower, ) -from lerobot.scripts.rl.gym_manipulator import make_robot_env from lerobot.teleoperators import ( gamepad, # noqa: F401 so101_leader, # noqa: F401 ) +from .gym_manipulator import make_robot_env + logging.basicConfig(level=logging.INFO) diff --git a/src/lerobot/scripts/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py similarity index 100% rename from src/lerobot/scripts/rl/gym_manipulator.py rename to src/lerobot/rl/gym_manipulator.py diff --git a/src/lerobot/scripts/rl/learner.py b/src/lerobot/rl/learner.py similarity index 99% rename from src/lerobot/scripts/rl/learner.py rename to src/lerobot/rl/learner.py index 5d9953827..6441ba55f 100644 --- a/src/lerobot/scripts/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -25,7 +25,7 @@ Examples of usage: - Start a learner server for training: ```bash -python -m lerobot.scripts.rl.learner --config_path src/lerobot/configs/train_config_hilserl_so100.json +python -m lerobot.rl.learner --config_path src/lerobot/configs/train_config_hilserl_so100.json ``` **NOTE**: Start the learner server before launching the actor server. The learner opens a gRPC server @@ -73,7 +73,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.factory import make_policy from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.robots import so100_follower # noqa: F401 -from lerobot.scripts.rl import learner_service from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents from lerobot.transport import services_pb2_grpc @@ -100,6 +99,8 @@ from lerobot.utils.utils import ( ) from lerobot.utils.wandb_utils import WandBLogger +from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService + LOG_PREFIX = "[LEARNER]" @@ -639,7 +640,7 @@ def start_learner( # TODO: Check if its useful _ = ProcessSignalHandler(False, display_pid=True) - service = learner_service.LearnerService( + service = LearnerService( shutdown_event=shutdown_event, parameters_queue=parameters_queue, seconds_between_pushes=cfg.policy.actor_learner_config.policy_parameters_push_frequency, @@ -649,7 +650,7 @@ def start_learner( ) server = grpc.server( - ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS), + ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ("grpc.max_receive_message_length", MAX_MESSAGE_SIZE), ("grpc.max_send_message_length", MAX_MESSAGE_SIZE), @@ -670,7 +671,7 @@ def start_learner( shutdown_event.wait() logging.info("[LEARNER] Stopping gRPC server...") - server.stop(learner_service.SHUTDOWN_TIMEOUT) + server.stop(SHUTDOWN_TIMEOUT) logging.info("[LEARNER] gRPC server stopped") diff --git a/src/lerobot/scripts/rl/learner_service.py b/src/lerobot/rl/learner_service.py similarity index 100% rename from src/lerobot/scripts/rl/learner_service.py rename to src/lerobot/rl/learner_service.py diff --git a/tests/rl/test_actor.py b/tests/rl/test_actor.py index f078b4602..aa9913bb2 100644 --- a/tests/rl/test_actor.py +++ b/tests/rl/test_actor.py @@ -65,7 +65,7 @@ def close_service_stub(channel, server): @require_package("grpc") def test_establish_learner_connection_success(): - from lerobot.scripts.rl.actor import establish_learner_connection + from lerobot.rl.actor import establish_learner_connection """Test successful connection establishment.""" stub, _servicer, channel, server = create_learner_service_stub() @@ -82,7 +82,7 @@ def test_establish_learner_connection_success(): @require_package("grpc") def test_establish_learner_connection_failure(): - from lerobot.scripts.rl.actor import establish_learner_connection + from lerobot.rl.actor import establish_learner_connection """Test connection failure.""" stub, servicer, channel, server = create_learner_service_stub() @@ -101,7 +101,7 @@ def test_establish_learner_connection_failure(): @require_package("grpc") def test_push_transitions_to_transport_queue(): - from lerobot.scripts.rl.actor import push_transitions_to_transport_queue + from lerobot.rl.actor import push_transitions_to_transport_queue from lerobot.transport.utils import bytes_to_transitions from tests.transport.test_transport_utils import assert_transitions_equal @@ -137,7 +137,7 @@ def test_push_transitions_to_transport_queue(): @require_package("grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_transitions_stream(): - from lerobot.scripts.rl.actor import transitions_stream + from lerobot.rl.actor import transitions_stream """Test transitions stream functionality.""" shutdown_event = Event() @@ -169,7 +169,7 @@ def test_transitions_stream(): @require_package("grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_interactions_stream(): - from lerobot.scripts.rl.actor import interactions_stream + from lerobot.rl.actor import interactions_stream from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes """Test interactions stream functionality.""" diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py index b2a7a5d5f..43a6b0957 100644 --- a/tests/rl/test_actor_learner.py +++ b/tests/rl/test_actor_learner.py @@ -90,13 +90,13 @@ def cfg(): @require_package("grpc") @pytest.mark.timeout(10) # force cross-platform watchdog def test_end_to_end_transitions_flow(cfg): - from lerobot.scripts.rl.actor import ( + from lerobot.rl.actor import ( establish_learner_connection, learner_service_client, push_transitions_to_transport_queue, send_transitions, ) - from lerobot.scripts.rl.learner import start_learner + from lerobot.rl.learner import start_learner from lerobot.transport.utils import bytes_to_transitions from tests.transport.test_transport_utils import assert_transitions_equal @@ -152,12 +152,12 @@ def test_end_to_end_transitions_flow(cfg): @require_package("grpc") @pytest.mark.timeout(10) def test_end_to_end_interactions_flow(cfg): - from lerobot.scripts.rl.actor import ( + from lerobot.rl.actor import ( establish_learner_connection, learner_service_client, send_interactions, ) - from lerobot.scripts.rl.learner import start_learner + from lerobot.rl.learner import start_learner from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes """Test complete interactions flow from actor to learner.""" @@ -226,8 +226,8 @@ def test_end_to_end_interactions_flow(cfg): @pytest.mark.parametrize("data_size", ["small", "large"]) @pytest.mark.timeout(10) def test_end_to_end_parameters_flow(cfg, data_size): - from lerobot.scripts.rl.actor import establish_learner_connection, learner_service_client, receive_policy - from lerobot.scripts.rl.learner import start_learner + from lerobot.rl.actor import establish_learner_connection, learner_service_client, receive_policy + from lerobot.rl.learner import start_learner from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes """Test complete parameter flow from learner to actor, with small and large data.""" diff --git a/tests/rl/test_learner_service.py b/tests/rl/test_learner_service.py index f5e1e8d48..b0e61165a 100644 --- a/tests/rl/test_learner_service.py +++ b/tests/rl/test_learner_service.py @@ -50,7 +50,7 @@ def create_learner_service_stub( ): import grpc - from lerobot.scripts.rl.learner_service import LearnerService + from lerobot.rl.learner_service import LearnerService from lerobot.transport import services_pb2_grpc # generated from .proto """Fixture to start a LearnerService gRPC server and provide a connected stub.""" From 3068ce3569e122d910e66c0da0b1abc5282d65aa Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 23 Sep 2025 17:43:55 +0200 Subject: [PATCH 104/158] docs(rl): fix path (#2004) --- docs/source/hilserl.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx index 08301556f..07f92b824 100644 --- a/docs/source/hilserl.mdx +++ b/docs/source/hilserl.mdx @@ -62,7 +62,7 @@ pip install -e ".[hilserl]" ### Understanding Configuration -The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/scripts/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs: +The training process begins with proper configuration for the HILSerl environment. The main configuration class is `GymManipulatorConfig` in `lerobot/rl/gym_manipulator.py`, which contains nested `HILSerlRobotEnvConfig` and `DatasetConfig`. The configuration is organized into focused, nested sub-configs: ```python From 1666097fd3588c7f9fa3f975af3a268a10085d0b Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 23 Sep 2025 17:55:53 +0200 Subject: [PATCH 105/158] refactor(scripts): update system info script (#2005) * refactor(scripts): update system info script * chore(scripts): rename info script * feat(scripts): add entrypoint for info * chore(ci): update issue report template --- .github/ISSUE_TEMPLATE/bug-report.yml | 2 +- pyproject.toml | 1 + src/lerobot/scripts/display_sys_info.py | 90 ----------------------- src/lerobot/scripts/lerobot_info.py | 96 +++++++++++++++++++++++++ 4 files changed, 98 insertions(+), 91 deletions(-) delete mode 100644 src/lerobot/scripts/display_sys_info.py create mode 100644 src/lerobot/scripts/lerobot_info.py diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 2fb23051c..7423495de 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -25,7 +25,7 @@ body: id: system-info attributes: label: System Info - description: If needed, you can share your lerobot configuration with us by running `python -m lerobot.scripts.display_sys_info` and copy-pasting its outputs below + description: Please share your LeRobot configuration by running `lerobot-info` (if installed) or `python -m lerobot.scripts.display_sys_info` (if not installed) and pasting the output below. render: Shell placeholder: lerobot version, OS, python version, numpy version, torch version, and lerobot's configuration validations: diff --git a/pyproject.toml b/pyproject.toml index 6db5e1307..9ee3c962f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -171,6 +171,7 @@ lerobot-setup-motors="lerobot.setup_motors:main" lerobot-teleoperate="lerobot.teleoperate:main" lerobot-eval="lerobot.scripts.eval:main" lerobot-train="lerobot.scripts.train:main" +lerobot-info="lerobot.scripts.lerobot_info:main" # ---------------- Tool Configurations ---------------- [tool.setuptools.packages.find] diff --git a/src/lerobot/scripts/display_sys_info.py b/src/lerobot/scripts/display_sys_info.py deleted file mode 100644 index 4d3cc291f..000000000 --- a/src/lerobot/scripts/display_sys_info.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Use this script to get a quick summary of your system config. -It should be able to run without any of LeRobot's dependencies or LeRobot itself installed. -""" - -import platform - -HAS_HF_HUB = True -HAS_HF_DATASETS = True -HAS_NP = True -HAS_TORCH = True -HAS_LEROBOT = True - -try: - import huggingface_hub -except ImportError: - HAS_HF_HUB = False - -try: - import datasets -except ImportError: - HAS_HF_DATASETS = False - -try: - import numpy as np -except ImportError: - HAS_NP = False - -try: - import torch -except ImportError: - HAS_TORCH = False - -try: - import lerobot -except ImportError: - HAS_LEROBOT = False - - -lerobot_version = lerobot.__version__ if HAS_LEROBOT else "N/A" -hf_hub_version = huggingface_hub.__version__ if HAS_HF_HUB else "N/A" -hf_datasets_version = datasets.__version__ if HAS_HF_DATASETS else "N/A" -np_version = np.__version__ if HAS_NP else "N/A" - -torch_version = torch.__version__ if HAS_TORCH else "N/A" -torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A" -cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A" - - -# TODO(aliberts): refactor into an actual command `lerobot env` -def display_sys_info() -> dict: - """Run this to get basic system info to help for tracking issues & bugs.""" - info = { - "`lerobot` version": lerobot_version, - "Platform": platform.platform(), - "Python version": platform.python_version(), - "Huggingface_hub version": hf_hub_version, - "Dataset version": hf_datasets_version, - "Numpy version": np_version, - "PyTorch version (GPU?)": f"{torch_version} ({torch_cuda_available})", - "Cuda version": cuda_version, - "Using GPU in script?": "", - # "Using distributed or parallel set-up in script?": "", - } - print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n") - print(format_dict(info)) - return info - - -def format_dict(d: dict) -> str: - return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n" - - -if __name__ == "__main__": - display_sys_info() diff --git a/src/lerobot/scripts/lerobot_info.py b/src/lerobot/scripts/lerobot_info.py new file mode 100644 index 000000000..9b49cad18 --- /dev/null +++ b/src/lerobot/scripts/lerobot_info.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Use this script to get a quick summary of your system config. +It should be able to run without any of LeRobot's dependencies or LeRobot itself installed. + +Example: + +```shell +lerobot-info +``` +""" + +import importlib +import platform + + +def get_package_version(package_name: str) -> str: + """Get the version of a package if it exists, otherwise return 'N/A'.""" + try: + module = importlib.import_module(package_name) + return getattr(module, "__version__", "Installed (version not found)") + except ImportError: + return "N/A" + + +def get_sys_info() -> dict: + """Run this to get basic system info to help for tracking issues & bugs.""" + # General package versions + info = { + "lerobot version": get_package_version("lerobot"), + "Platform": platform.platform(), + "Python version": platform.python_version(), + "Huggingface Hub version": get_package_version("huggingface_hub"), + "Datasets version": get_package_version("datasets"), + "Numpy version": get_package_version("numpy"), + } + + # PyTorch and GPU specific information + torch_version = "N/A" + torch_cuda_available = "N/A" + cuda_version = "N/A" + gpu_model = "N/A" + try: + import torch + + torch_version = torch.__version__ + torch_cuda_available = torch.cuda.is_available() + if torch_cuda_available: + cuda_version = torch.version.cuda + # Gets the name of the first available GPU + gpu_model = torch.cuda.get_device_name(0) + except ImportError: + # If torch is not installed, the default "N/A" values will be used. + pass + + info.update( + { + "PyTorch version": torch_version, + "Is PyTorch built with CUDA support?": torch_cuda_available, + "Cuda version": cuda_version, + "GPU model": gpu_model, + "Using GPU in script?": "", + } + ) + + return info + + +def format_dict_for_markdown(d: dict) -> str: + """Formats a dictionary into a markdown-friendly bulleted list.""" + return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + + +def main(): + system_info = get_sys_info() + print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n") + print(format_dict_for_markdown(system_info)) + + +if __name__ == "__main__": + main() From c435d3cebc6395538158bee6d2919e6ee1f930ac Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 23 Sep 2025 18:46:27 +0200 Subject: [PATCH 106/158] feat(script): add entry point for dataset viz (#2006) * chore(scripts): rename script dataset viz * feat(scripts): add entry point for dataset-viz --------- Signed-off-by: Steven Palma --- README.md | 6 +++--- pyproject.toml | 1 + src/lerobot/robots/viperx/README.md | 2 +- .../{visualize_dataset.py => lerobot_dataset_viz.py} | 6 +++--- tests/datasets/test_visualize_dataset.py | 2 +- 5 files changed, 9 insertions(+), 8 deletions(-) rename src/lerobot/scripts/{visualize_dataset.py => lerobot_dataset_viz.py} (98%) diff --git a/README.md b/README.md index a3f28f552..a59f96deb 100644 --- a/README.md +++ b/README.md @@ -202,7 +202,7 @@ Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/ You can also locally visualize episodes from a dataset on the hub by executing our script from the command line: ```bash -python -m lerobot.scripts.visualize_dataset \ +lerobot-dataset-viz \ --repo-id lerobot/pusht \ --episode-index 0 ``` @@ -210,7 +210,7 @@ python -m lerobot.scripts.visualize_dataset \ or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`) ```bash -python -m lerobot.scripts.visualize_dataset \ +lerobot-dataset-viz \ --repo-id lerobot/pusht \ --root ./my_local_data_dir \ --local-files-only 1 \ @@ -221,7 +221,7 @@ It will open `rerun.io` and display the camera streams, robot states and actions https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144 -Our script can also visualize datasets stored on a distant server. See `python -m lerobot.scripts.visualize_dataset --help` for more instructions. +Our script can also visualize datasets stored on a distant server. See `lerobot-dataset-viz --help` for more instructions. ### The `LeRobotDataset` format diff --git a/pyproject.toml b/pyproject.toml index 9ee3c962f..6fa054cde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -171,6 +171,7 @@ lerobot-setup-motors="lerobot.setup_motors:main" lerobot-teleoperate="lerobot.teleoperate:main" lerobot-eval="lerobot.scripts.eval:main" lerobot-train="lerobot.scripts.train:main" +lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main" lerobot-info="lerobot.scripts.lerobot_info:main" # ---------------- Tool Configurations ---------------- diff --git a/src/lerobot/robots/viperx/README.md b/src/lerobot/robots/viperx/README.md index f6386215a..2e8fc7289 100644 --- a/src/lerobot/robots/viperx/README.md +++ b/src/lerobot/robots/viperx/README.md @@ -118,7 +118,7 @@ echo ${HF_USER}/aloha_test If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with [Rerun](https://github.com/rerun-io/rerun): ```bash -python -m lerobot.scripts.visualize_dataset \ +lerobot-dataset-viz \ --repo-id ${HF_USER}/aloha_test --episode 0 ``` diff --git a/src/lerobot/scripts/visualize_dataset.py b/src/lerobot/scripts/lerobot_dataset_viz.py similarity index 98% rename from src/lerobot/scripts/visualize_dataset.py rename to src/lerobot/scripts/lerobot_dataset_viz.py index dda12594a..2033b36ba 100644 --- a/src/lerobot/scripts/visualize_dataset.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -29,14 +29,14 @@ Examples: - Visualize data stored on a local machine: ``` -local$ python -m lerobot.scripts.visualize_dataset \ +local$ lerobot-dataset-viz \ --repo-id lerobot/pusht \ --episode-index 0 ``` - Visualize data stored on a distant machine with a local viewer: ``` -distant$ python -m lerobot.scripts.visualize_dataset \ +distant$ lerobot-dataset-viz \ --repo-id lerobot/pusht \ --episode-index 0 \ --save 1 \ @@ -50,7 +50,7 @@ local$ rerun lerobot_pusht_episode_0.rrd (You need to forward the websocket port to the distant machine, with `ssh -L 9087:localhost:9087 username@remote-host`) ``` -distant$ python -m lerobot.scripts.visualize_dataset \ +distant$ lerobot-dataset-viz \ --repo-id lerobot/pusht \ --episode-index 0 \ --mode distant \ diff --git a/tests/datasets/test_visualize_dataset.py b/tests/datasets/test_visualize_dataset.py index 303342e3c..8e92ec82e 100644 --- a/tests/datasets/test_visualize_dataset.py +++ b/tests/datasets/test_visualize_dataset.py @@ -15,7 +15,7 @@ # limitations under the License. import pytest -from lerobot.scripts.visualize_dataset import visualize_dataset +from lerobot.scripts.lerobot_dataset_viz import visualize_dataset @pytest.mark.skip("TODO: add dummy videos") From c9787bd98aecf9ff00387b085c4839a975d4f5a1 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 23 Sep 2025 18:47:36 +0200 Subject: [PATCH 107/158] feat(script): add entry point for image transform viz (#2007) * feat(Scripts): add entry point for img transform viz * chore(style): pre-commit style --- docs/source/lerobot-dataset-v3.mdx | 2 +- pyproject.toml | 1 + ...e_transforms.py => lerobot_imgtransform_viz.py} | 14 +++++++++----- tests/datasets/test_image_transforms.py | 2 +- 4 files changed, 12 insertions(+), 7 deletions(-) rename src/lerobot/scripts/{visualize_image_transforms.py => lerobot_imgtransform_viz.py} (97%) diff --git a/docs/source/lerobot-dataset-v3.mdx b/docs/source/lerobot-dataset-v3.mdx index 09fb17fad..cf1942fdc 100644 --- a/docs/source/lerobot-dataset-v3.mdx +++ b/docs/source/lerobot-dataset-v3.mdx @@ -246,7 +246,7 @@ You can also use any `torchvision.transforms.v2` transform by passing it directl Use the visualization script to preview how transforms affect your data: ```bash -python -m lerobot.scripts.visualize_image_transforms \ +lerobot-imgtransform-viz \ --repo-id=your-username/your-dataset \ --output-dir=./transform_examples \ --n-examples=5 diff --git a/pyproject.toml b/pyproject.toml index 6fa054cde..9ed3da006 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -173,6 +173,7 @@ lerobot-eval="lerobot.scripts.eval:main" lerobot-train="lerobot.scripts.train:main" lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main" lerobot-info="lerobot.scripts.lerobot_info:main" +lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" # ---------------- Tool Configurations ---------------- [tool.setuptools.packages.find] diff --git a/src/lerobot/scripts/visualize_image_transforms.py b/src/lerobot/scripts/lerobot_imgtransform_viz.py similarity index 97% rename from src/lerobot/scripts/visualize_image_transforms.py rename to src/lerobot/scripts/lerobot_imgtransform_viz.py index 14caf89df..bc13f0508 100644 --- a/src/lerobot/scripts/visualize_image_transforms.py +++ b/src/lerobot/scripts/lerobot_imgtransform_viz.py @@ -20,10 +20,10 @@ Additionally, each individual transform can be visualized separately as well as Example: ```bash -python -m lerobot.scripts.visualize_image_transforms \ - --repo_id=lerobot/pusht \ - --episodes='[0]' \ - --image_transforms.enable=True +lerobot-imgtransform-viz \ + --repo_id=lerobot/pusht \ + --episodes='[0]' \ + --image_transforms.enable=True ``` """ @@ -126,5 +126,9 @@ def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR save_each_transform(cfg.image_transforms, original_frame, output_dir, n_examples) -if __name__ == "__main__": +def main(): visualize_image_transforms() + + +if __name__ == "__main__": + main() diff --git a/tests/datasets/test_image_transforms.py b/tests/datasets/test_image_transforms.py index 3ab93cb2c..98f957076 100644 --- a/tests/datasets/test_image_transforms.py +++ b/tests/datasets/test_image_transforms.py @@ -29,7 +29,7 @@ from lerobot.datasets.transforms import ( SharpnessJitter, make_transform_from_config, ) -from lerobot.scripts.visualize_image_transforms import ( +from lerobot.scripts.lerobot_imgtransform_viz import ( save_all_transforms, save_each_transform, ) From 7cf04a5ec38536a184a0a70c475c91b74a127083 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 11:11:53 +0200 Subject: [PATCH 108/158] chore: move constants to utils (#2016) --- examples/training/train_with_streaming.py | 2 +- src/lerobot/configs/policies.py | 2 +- src/lerobot/datasets/lerobot_dataset.py | 2 +- src/lerobot/datasets/pipeline_features.py | 2 +- src/lerobot/datasets/streaming_dataset.py | 2 +- .../datasets/v30/convert_dataset_v21_to_v30.py | 2 +- src/lerobot/envs/configs.py | 2 +- src/lerobot/optim/optimizers.py | 4 ++-- src/lerobot/optim/schedulers.py | 2 +- src/lerobot/policies/act/modeling_act.py | 2 +- src/lerobot/policies/act/processor_act.py | 2 +- src/lerobot/policies/diffusion/modeling_diffusion.py | 2 +- .../policies/diffusion/processor_diffusion.py | 2 +- src/lerobot/policies/factory.py | 2 +- src/lerobot/policies/pi0/modeling_pi0.py | 2 +- src/lerobot/policies/pi0/processor_pi0.py | 2 +- src/lerobot/policies/pi0fast/modeling_pi0fast.py | 2 +- src/lerobot/policies/pi0fast/processor_pi0fast.py | 2 +- src/lerobot/policies/sac/configuration_sac.py | 2 +- src/lerobot/policies/sac/processor_sac.py | 2 +- .../policies/sac/reward_model/modeling_classifier.py | 2 +- src/lerobot/policies/smolvla/modeling_smolvla.py | 2 +- src/lerobot/policies/smolvla/processor_smolvla.py | 2 +- src/lerobot/policies/tdmpc/modeling_tdmpc.py | 2 +- src/lerobot/policies/tdmpc/processor_tdmpc.py | 2 +- src/lerobot/policies/vqbet/modeling_vqbet.py | 2 +- src/lerobot/policies/vqbet/processor_vqbet.py | 2 +- src/lerobot/processor/batch_processor.py | 2 +- .../processor/joint_observations_processor.py | 2 +- src/lerobot/processor/observation_processor.py | 2 +- src/lerobot/processor/pipeline.py | 2 +- src/lerobot/processor/tokenizer_processor.py | 2 +- src/lerobot/rl/learner.py | 12 ++++++------ src/lerobot/robots/robot.py | 2 +- src/lerobot/robots/stretch3/robot_stretch3.py | 2 +- src/lerobot/robots/viperx/viperx.py | 2 +- src/lerobot/scripts/server/helpers.py | 2 +- src/lerobot/teleoperators/teleoperator.py | 2 +- src/lerobot/{ => utils}/constants.py | 0 src/lerobot/utils/random_utils.py | 2 +- src/lerobot/utils/train_utils.py | 12 ++++++------ src/lerobot/utils/wandb_utils.py | 2 +- tests/datasets/test_datasets.py | 2 +- tests/fixtures/constants.py | 2 +- tests/optim/test_optimizers.py | 8 ++++---- tests/optim/test_schedulers.py | 2 +- tests/policies/test_policies.py | 2 +- tests/processor/test_act_processor.py | 2 +- tests/processor/test_batch_processor.py | 2 +- tests/processor/test_classifier_processor.py | 2 +- tests/processor/test_device_processor.py | 4 ++-- tests/processor/test_diffusion_processor.py | 2 +- tests/processor/test_observation_processor.py | 2 +- tests/processor/test_pi0_processor.py | 2 +- tests/processor/test_sac_processor.py | 2 +- tests/processor/test_smolvla_processor.py | 2 +- tests/processor/test_tdmpc_processor.py | 2 +- tests/processor/test_tokenizer_processor.py | 2 +- tests/processor/test_vqbet_processor.py | 2 +- tests/utils/test_train_utils.py | 2 +- 60 files changed, 74 insertions(+), 74 deletions(-) rename src/lerobot/{ => utils}/constants.py (100%) diff --git a/examples/training/train_with_streaming.py b/examples/training/train_with_streaming.py index e7edc17f8..185be5b13 100644 --- a/examples/training/train_with_streaming.py +++ b/examples/training/train_with_streaming.py @@ -20,13 +20,13 @@ from pathlib import Path import torch from lerobot.configs.types import FeatureType -from lerobot.constants import ACTION from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.utils import dataset_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors +from lerobot.utils.constants import ACTION def main(): diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 9a2bb911a..06c220cb8 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -27,9 +27,9 @@ from huggingface_hub.constants import CONFIG_NAME from huggingface_hub.errors import HfHubHTTPError from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.constants import ACTION, OBS_STATE from lerobot.optim.optimizers import OptimizerConfig from lerobot.optim.schedulers import LRSchedulerConfig +from lerobot.utils.constants import ACTION, OBS_STATE from lerobot.utils.hub import HubMixin from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 4ac7a841c..9eebcea4b 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -31,7 +31,6 @@ import torch.utils from huggingface_hub import HfApi, snapshot_download from huggingface_hub.errors import RevisionNotFoundError -from lerobot.constants import HF_LEROBOT_HOME from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.datasets.image_writer import AsyncImageWriter, write_image from lerobot.datasets.utils import ( @@ -79,6 +78,7 @@ from lerobot.datasets.video_utils import ( get_video_duration_in_s, get_video_info, ) +from lerobot.utils.constants import HF_LEROBOT_HOME CODEBASE_VERSION = "v3.0" diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index b55ccf8a9..cdf0b7448 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -17,9 +17,9 @@ from collections.abc import Sequence from typing import Any from lerobot.configs.types import PipelineFeatureType -from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.datasets.utils import hw_to_dataset_features from lerobot.processor import DataProcessorPipeline +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE def create_initial_features( diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index e354c4060..c3c48d90d 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -21,7 +21,6 @@ import numpy as np import torch from datasets import load_dataset -from lerobot.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata from lerobot.datasets.utils import ( Backtrackable, @@ -38,6 +37,7 @@ from lerobot.datasets.video_utils import ( VideoDecoderCache, decode_video_frames_torchcodec, ) +from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE class StreamingLeRobotDataset(torch.utils.data.IterableDataset): diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py index 1327bd820..e5a6e3c9a 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -46,7 +46,6 @@ from datasets import Dataset, Features, Image from huggingface_hub import HfApi, snapshot_download from requests import HTTPError -from lerobot.constants import HF_LEROBOT_HOME from lerobot.datasets.compute_stats import aggregate_stats from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.datasets.utils import ( @@ -71,6 +70,7 @@ from lerobot.datasets.utils import ( write_tasks, ) from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s +from lerobot.utils.constants import HF_LEROBOT_HOME V21 = "v2.1" diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 8c66b278e..4456c51a5 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -19,9 +19,9 @@ from typing import Any import draccus from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.robots import RobotConfig from lerobot.teleoperators.config import TeleoperatorConfig +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE @dataclass diff --git a/src/lerobot/optim/optimizers.py b/src/lerobot/optim/optimizers.py index ece4dc157..f2bd0df42 100644 --- a/src/lerobot/optim/optimizers.py +++ b/src/lerobot/optim/optimizers.py @@ -22,11 +22,11 @@ import draccus import torch from safetensors.torch import load_file, save_file -from lerobot.constants import ( +from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json +from lerobot.utils.constants import ( OPTIMIZER_PARAM_GROUPS, OPTIMIZER_STATE, ) -from lerobot.datasets.utils import flatten_dict, unflatten_dict, write_json from lerobot.utils.io_utils import deserialize_json_into_object diff --git a/src/lerobot/optim/schedulers.py b/src/lerobot/optim/schedulers.py index d08018175..55ee62e40 100644 --- a/src/lerobot/optim/schedulers.py +++ b/src/lerobot/optim/schedulers.py @@ -22,8 +22,8 @@ import draccus from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler -from lerobot.constants import SCHEDULER_STATE from lerobot.datasets.utils import write_json +from lerobot.utils.constants import SCHEDULER_STATE from lerobot.utils.io_utils import deserialize_json_into_object diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index e0f3462cc..e4ebec199 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -33,9 +33,9 @@ from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d -from lerobot.constants import ACTION, OBS_IMAGES from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.utils.constants import ACTION, OBS_IMAGES class ACTPolicy(PreTrainedPolicy): diff --git a/src/lerobot/policies/act/processor_act.py b/src/lerobot/policies/act/processor_act.py index b0d2067e9..727b18cef 100644 --- a/src/lerobot/policies/act/processor_act.py +++ b/src/lerobot/policies/act/processor_act.py @@ -17,7 +17,6 @@ from typing import Any import torch -from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.act.configuration_act import ACTConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -29,6 +28,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME def make_act_pre_post_processors( diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 747ead334..0bd2e282b 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -33,7 +33,6 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from torch import Tensor, nn -from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import ( @@ -42,6 +41,7 @@ from lerobot.policies.utils import ( get_output_shape, populate_queues, ) +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE class DiffusionPolicy(PreTrainedPolicy): diff --git a/src/lerobot/policies/diffusion/processor_diffusion.py b/src/lerobot/policies/diffusion/processor_diffusion.py index 4383ec950..a7799be64 100644 --- a/src/lerobot/policies/diffusion/processor_diffusion.py +++ b/src/lerobot/policies/diffusion/processor_diffusion.py @@ -18,7 +18,6 @@ from typing import Any import torch -from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -30,6 +29,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME def make_diffusion_pre_post_processors( diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 06c0c4ba5..60c05240e 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -24,7 +24,6 @@ from typing_extensions import Unpack from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType -from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.datasets.utils import dataset_to_policy_features from lerobot.envs.configs import EnvConfig @@ -46,6 +45,7 @@ from lerobot.processor.converters import ( transition_to_batch, transition_to_policy_action, ) +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME def get_policy_class(name: str) -> type[PreTrainedPolicy]: diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 66bd81e61..4d3f4ffa1 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -57,13 +57,13 @@ import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi0.paligemma_with_expert import ( PaliGemmaWithExpertConfig, PaliGemmaWithExpertModel, ) from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE from lerobot.utils.utils import get_safe_dtype diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py index cd9712201..50f5dec83 100644 --- a/src/lerobot/policies/pi0/processor_pi0.py +++ b/src/lerobot/policies/pi0/processor_pi0.py @@ -19,7 +19,6 @@ from typing import Any import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -35,6 +34,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME @ProcessorStepRegistry.register(name="pi0_new_line_processor") diff --git a/src/lerobot/policies/pi0fast/modeling_pi0fast.py b/src/lerobot/policies/pi0fast/modeling_pi0fast.py index 682a372f4..102cfb8fa 100644 --- a/src/lerobot/policies/pi0fast/modeling_pi0fast.py +++ b/src/lerobot/policies/pi0fast/modeling_pi0fast.py @@ -57,9 +57,9 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe from transformers.cache_utils import HybridCache, StaticCache from transformers.models.auto import CONFIG_MAPPING -from lerobot.constants import ACTION, OBS_STATE from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.utils.constants import ACTION, OBS_STATE PRECISION = { "float16": torch.float16, diff --git a/src/lerobot/policies/pi0fast/processor_pi0fast.py b/src/lerobot/policies/pi0fast/processor_pi0fast.py index 81314aa37..95b5e541b 100644 --- a/src/lerobot/policies/pi0fast/processor_pi0fast.py +++ b/src/lerobot/policies/pi0fast/processor_pi0fast.py @@ -18,7 +18,6 @@ from typing import Any import torch -from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -30,6 +29,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME def make_pi0fast_pre_post_processors( diff --git a/src/lerobot/policies/sac/configuration_sac.py b/src/lerobot/policies/sac/configuration_sac.py index c57eeeb72..a42758b85 100644 --- a/src/lerobot/policies/sac/configuration_sac.py +++ b/src/lerobot/policies/sac/configuration_sac.py @@ -19,8 +19,8 @@ from dataclasses import dataclass, field from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode -from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE from lerobot.optim.optimizers import MultiAdamConfig +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE def is_image_feature(key: str) -> bool: diff --git a/src/lerobot/policies/sac/processor_sac.py b/src/lerobot/policies/sac/processor_sac.py index 9e8013d31..cf90e3cb4 100644 --- a/src/lerobot/policies/sac/processor_sac.py +++ b/src/lerobot/policies/sac/processor_sac.py @@ -19,7 +19,6 @@ from typing import Any import torch -from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -31,6 +30,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME def make_sac_pre_post_processors( diff --git a/src/lerobot/policies/sac/reward_model/modeling_classifier.py b/src/lerobot/policies/sac/reward_model/modeling_classifier.py index ca501c3a7..dba6a174b 100644 --- a/src/lerobot/policies/sac/reward_model/modeling_classifier.py +++ b/src/lerobot/policies/sac/reward_model/modeling_classifier.py @@ -19,9 +19,9 @@ import logging import torch from torch import Tensor, nn -from lerobot.constants import OBS_IMAGE, REWARD from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig +from lerobot.utils.constants import OBS_IMAGE, REWARD class ClassifierOutput: diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 48d4b2315..23fc3ca4f 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -59,13 +59,13 @@ import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn -from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel from lerobot.policies.utils import ( populate_queues, ) +from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE from lerobot.utils.utils import get_safe_dtype diff --git a/src/lerobot/policies/smolvla/processor_smolvla.py b/src/lerobot/policies/smolvla/processor_smolvla.py index ac3cd4626..3fc130aa1 100644 --- a/src/lerobot/policies/smolvla/processor_smolvla.py +++ b/src/lerobot/policies/smolvla/processor_smolvla.py @@ -19,7 +19,6 @@ from typing import Any import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -34,6 +33,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME def make_smolvla_pre_post_processors( diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index e160310b3..f83048862 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -35,10 +35,10 @@ import torch.nn as nn import torch.nn.functional as F # noqa: N812 from torch import Tensor -from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD class TDMPCPolicy(PreTrainedPolicy): diff --git a/src/lerobot/policies/tdmpc/processor_tdmpc.py b/src/lerobot/policies/tdmpc/processor_tdmpc.py index 75a7d4f7e..9b6f97e50 100644 --- a/src/lerobot/policies/tdmpc/processor_tdmpc.py +++ b/src/lerobot/policies/tdmpc/processor_tdmpc.py @@ -18,7 +18,6 @@ from typing import Any import torch -from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -30,6 +29,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME def make_tdmpc_pre_post_processors( diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index bb6040e90..34e5b1c0d 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -27,11 +27,11 @@ import torch.nn.functional as F # noqa: N812 import torchvision from torch import Tensor, nn -from lerobot.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.policies.vqbet.vqbet_utils import GPT, ResidualVQ +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE # ruff: noqa: N806 diff --git a/src/lerobot/policies/vqbet/processor_vqbet.py b/src/lerobot/policies/vqbet/processor_vqbet.py index 1c741cd33..1e19ff779 100644 --- a/src/lerobot/policies/vqbet/processor_vqbet.py +++ b/src/lerobot/policies/vqbet/processor_vqbet.py @@ -19,7 +19,6 @@ from typing import Any import torch -from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -31,6 +30,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME def make_vqbet_pre_post_processors( diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index a563599cd..e1a90421f 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -25,7 +25,7 @@ from dataclasses import dataclass, field from torch import Tensor from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from .core import EnvTransition, PolicyAction from .pipeline import ( diff --git a/src/lerobot/processor/joint_observations_processor.py b/src/lerobot/processor/joint_observations_processor.py index ab3c6ecc1..2fbcc7c46 100644 --- a/src/lerobot/processor/joint_observations_processor.py +++ b/src/lerobot/processor/joint_observations_processor.py @@ -20,12 +20,12 @@ from typing import Any import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.constants import OBS_STATE from lerobot.processor.pipeline import ( ObservationProcessorStep, ProcessorStepRegistry, ) from lerobot.robots import Robot +from lerobot.utils.constants import OBS_STATE @dataclass diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index 71fdbbf0d..2b9402bee 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -21,7 +21,7 @@ import torch from torch import Tensor from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from .pipeline import ObservationProcessorStep, ProcessorStepRegistry diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 1c88cd741..e14d8b0b9 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -422,7 +422,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]): """ if save_directory is None: # Use default directory in HF_LEROBOT_HOME - from lerobot.constants import HF_LEROBOT_HOME + from lerobot.utils.constants import HF_LEROBOT_HOME sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) save_directory = HF_LEROBOT_HOME / "processors" / sanitized_name diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 23db7b5e3..2ef89c107 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS +from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS from lerobot.utils.import_utils import _transformers_available from .core import EnvTransition, TransitionKey diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index 6441ba55f..8d6831286 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -62,12 +62,6 @@ from torch.optim.optimizer import Optimizer from lerobot.cameras import opencv # noqa: F401 from lerobot.configs import parser from lerobot.configs.train import TrainRLServerPipelineConfig -from lerobot.constants import ( - CHECKPOINTS_DIR, - LAST_CHECKPOINT_LINK, - PRETRAINED_MODEL_DIR, - TRAINING_STATE_DIR, -) from lerobot.datasets.factory import make_dataset from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.factory import make_policy @@ -83,6 +77,12 @@ from lerobot.transport.utils import ( state_to_bytes, ) from lerobot.utils.buffer import ReplayBuffer, concatenate_batch_transitions +from lerobot.utils.constants import ( + CHECKPOINTS_DIR, + LAST_CHECKPOINT_LINK, + PRETRAINED_MODEL_DIR, + TRAINING_STATE_DIR, +) from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.random_utils import set_seed from lerobot.utils.train_utils import ( diff --git a/src/lerobot/robots/robot.py b/src/lerobot/robots/robot.py index 2a9004380..5e88b915b 100644 --- a/src/lerobot/robots/robot.py +++ b/src/lerobot/robots/robot.py @@ -19,8 +19,8 @@ from typing import Any import draccus -from lerobot.constants import HF_LEROBOT_CALIBRATION, ROBOTS from lerobot.motors import MotorCalibration +from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, ROBOTS from .config import RobotConfig diff --git a/src/lerobot/robots/stretch3/robot_stretch3.py b/src/lerobot/robots/stretch3/robot_stretch3.py index b907d6a3f..8a0ff5c6a 100644 --- a/src/lerobot/robots/stretch3/robot_stretch3.py +++ b/src/lerobot/robots/stretch3/robot_stretch3.py @@ -22,8 +22,8 @@ from stretch_body.robot import Robot as StretchAPI from stretch_body.robot_params import RobotParams from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.constants import OBS_IMAGES, OBS_STATE from lerobot.datasets.utils import get_nested_item +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE from ..robot import Robot from .configuration_stretch3 import Stretch3RobotConfig diff --git a/src/lerobot/robots/viperx/viperx.py b/src/lerobot/robots/viperx/viperx.py index 881640cd5..006c780e3 100644 --- a/src/lerobot/robots/viperx/viperx.py +++ b/src/lerobot/robots/viperx/viperx.py @@ -18,13 +18,13 @@ from functools import cached_property from typing import Any from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.constants import OBS_STATE from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) +from lerobot.utils.constants import OBS_STATE from ..robot import Robot from ..utils import ensure_safe_goal_position diff --git a/src/lerobot/scripts/server/helpers.py b/src/lerobot/scripts/server/helpers.py index d8051b76e..175cecf6d 100644 --- a/src/lerobot/scripts/server/helpers.py +++ b/src/lerobot/scripts/server/helpers.py @@ -22,12 +22,12 @@ from pathlib import Path import torch from lerobot.configs.types import PolicyFeature -from lerobot.constants import OBS_IMAGES, OBS_STATE from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features # NOTE: Configs need to be loaded for the client to be able to instantiate the policy config from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401 from lerobot.robots.robot import Robot +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE from lerobot.utils.utils import init_logging Action = torch.Tensor diff --git a/src/lerobot/teleoperators/teleoperator.py b/src/lerobot/teleoperators/teleoperator.py index c360ee7bb..95020a962 100644 --- a/src/lerobot/teleoperators/teleoperator.py +++ b/src/lerobot/teleoperators/teleoperator.py @@ -19,8 +19,8 @@ from typing import Any import draccus -from lerobot.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS from lerobot.motors.motors_bus import MotorCalibration +from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS from .config import TeleoperatorConfig diff --git a/src/lerobot/constants.py b/src/lerobot/utils/constants.py similarity index 100% rename from src/lerobot/constants.py rename to src/lerobot/utils/constants.py diff --git a/src/lerobot/utils/random_utils.py b/src/lerobot/utils/random_utils.py index da3ecf37f..1bb1f0631 100644 --- a/src/lerobot/utils/random_utils.py +++ b/src/lerobot/utils/random_utils.py @@ -23,8 +23,8 @@ import numpy as np import torch from safetensors.torch import load_file, save_file -from lerobot.constants import RNG_STATE from lerobot.datasets.utils import flatten_dict, unflatten_dict +from lerobot.utils.constants import RNG_STATE def serialize_python_rng_state() -> dict[str, torch.Tensor]: diff --git a/src/lerobot/utils/train_utils.py b/src/lerobot/utils/train_utils.py index be2eb8146..08d1bcc9d 100644 --- a/src/lerobot/utils/train_utils.py +++ b/src/lerobot/utils/train_utils.py @@ -21,18 +21,18 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from lerobot.configs.train import TrainPipelineConfig -from lerobot.constants import ( +from lerobot.datasets.utils import load_json, write_json +from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state +from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.processor import PolicyProcessorPipeline +from lerobot.utils.constants import ( CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK, PRETRAINED_MODEL_DIR, TRAINING_STATE_DIR, TRAINING_STEP, ) -from lerobot.datasets.utils import load_json, write_json -from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state -from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.processor import PolicyProcessorPipeline from lerobot.utils.random_utils import load_rng_state, save_rng_state diff --git a/src/lerobot/utils/wandb_utils.py b/src/lerobot/utils/wandb_utils.py index 91b4ec95c..b13254421 100644 --- a/src/lerobot/utils/wandb_utils.py +++ b/src/lerobot/utils/wandb_utils.py @@ -23,7 +23,7 @@ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from termcolor import colored from lerobot.configs.train import TrainPipelineConfig -from lerobot.constants import PRETRAINED_MODEL_DIR +from lerobot.utils.constants import PRETRAINED_MODEL_DIR def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str: diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 2eca82346..d1d6dbdb2 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -112,7 +112,7 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory): # and test the small resulting function that validates the features def test_dataset_feature_with_forward_slash_raises_error(): # make sure dir does not exist - from lerobot.constants import HF_LEROBOT_HOME + from lerobot.utils.constants import HF_LEROBOT_HOME dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash" # make sure does not exist diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index 0af499364..973c5b050 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from lerobot.constants import HF_LEROBOT_HOME +from lerobot.utils.constants import HF_LEROBOT_HOME LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing" DUMMY_REPO_ID = "dummy/repo" diff --git a/tests/optim/test_optimizers.py b/tests/optim/test_optimizers.py index 4152c7f8d..d18565562 100644 --- a/tests/optim/test_optimizers.py +++ b/tests/optim/test_optimizers.py @@ -14,10 +14,6 @@ import pytest import torch -from lerobot.constants import ( - OPTIMIZER_PARAM_GROUPS, - OPTIMIZER_STATE, -) from lerobot.optim.optimizers import ( AdamConfig, AdamWConfig, @@ -26,6 +22,10 @@ from lerobot.optim.optimizers import ( load_optimizer_state, save_optimizer_state, ) +from lerobot.utils.constants import ( + OPTIMIZER_PARAM_GROUPS, + OPTIMIZER_STATE, +) @pytest.mark.parametrize( diff --git a/tests/optim/test_schedulers.py b/tests/optim/test_schedulers.py index 43851c458..1e566a6ba 100644 --- a/tests/optim/test_schedulers.py +++ b/tests/optim/test_schedulers.py @@ -13,7 +13,6 @@ # limitations under the License. from torch.optim.lr_scheduler import LambdaLR -from lerobot.constants import SCHEDULER_STATE from lerobot.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, DiffuserSchedulerConfig, @@ -21,6 +20,7 @@ from lerobot.optim.schedulers import ( load_scheduler_state, save_scheduler_state, ) +from lerobot.utils.constants import SCHEDULER_STATE def test_diffuser_scheduler(optimizer): diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 28c395bfc..b577e5763 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -27,7 +27,6 @@ from lerobot import available_policies from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.constants import ACTION, OBS_STATE from lerobot.datasets.factory import make_dataset from lerobot.datasets.utils import cycle, dataset_to_policy_features from lerobot.envs.factory import make_env, make_env_config @@ -42,6 +41,7 @@ from lerobot.policies.factory import ( make_pre_post_processors, ) from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.utils.constants import ACTION, OBS_STATE from lerobot.utils.random_utils import seeded_context from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel diff --git a/tests/processor/test_act_processor.py b/tests/processor/test_act_processor.py index f96f871aa..00a4dbb96 100644 --- a/tests/processor/test_act_processor.py +++ b/tests/processor/test_act_processor.py @@ -21,7 +21,6 @@ import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.constants import ACTION, OBS_STATE from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.act.processor_act import make_act_pre_post_processors from lerobot.processor import ( @@ -34,6 +33,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import create_transition, transition_to_batch +from lerobot.utils.constants import ACTION, OBS_STATE def create_default_config(): diff --git a/tests/processor/test_batch_processor.py b/tests/processor/test_batch_processor.py index f7cbafd27..5c94b0657 100644 --- a/tests/processor/test_batch_processor.py +++ b/tests/processor/test_batch_processor.py @@ -21,7 +21,6 @@ import numpy as np import pytest import torch -from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.processor import ( AddBatchDimensionProcessorStep, DataProcessorPipeline, @@ -29,6 +28,7 @@ from lerobot.processor import ( TransitionKey, ) from lerobot.processor.converters import create_transition, identity_transition +from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE def test_state_1d_to_2d(): diff --git a/tests/processor/test_classifier_processor.py b/tests/processor/test_classifier_processor.py index 139e99bd7..e1567bf29 100644 --- a/tests/processor/test_classifier_processor.py +++ b/tests/processor/test_classifier_processor.py @@ -21,7 +21,6 @@ import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.constants import OBS_IMAGE, OBS_STATE from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor from lerobot.processor import ( @@ -32,6 +31,7 @@ from lerobot.processor import ( TransitionKey, ) from lerobot.processor.converters import create_transition, transition_to_batch +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE def create_default_config(): diff --git a/tests/processor/test_device_processor.py b/tests/processor/test_device_processor.py index ba00bde4d..10ee313d7 100644 --- a/tests/processor/test_device_processor.py +++ b/tests/processor/test_device_processor.py @@ -284,8 +284,8 @@ def test_features(): def test_integration_with_robot_processor(): """Test integration with RobotProcessor.""" - from lerobot.constants import OBS_STATE from lerobot.processor import AddBatchDimensionProcessorStep + from lerobot.utils.constants import OBS_STATE # Create a pipeline with DeviceProcessorStep device_processor = DeviceProcessorStep(device="cpu") @@ -948,12 +948,12 @@ def test_simulated_accelerate_scenario(): def test_policy_processor_integration(): """Test integration with policy processors - input on GPU, output on CPU.""" from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature - from lerobot.constants import ACTION, OBS_STATE from lerobot.processor import ( AddBatchDimensionProcessorStep, NormalizerProcessorStep, UnnormalizerProcessorStep, ) + from lerobot.utils.constants import ACTION, OBS_STATE # Create features and stats features = { diff --git a/tests/processor/test_diffusion_processor.py b/tests/processor/test_diffusion_processor.py index 5d280f9cc..67981c70d 100644 --- a/tests/processor/test_diffusion_processor.py +++ b/tests/processor/test_diffusion_processor.py @@ -21,7 +21,6 @@ import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors from lerobot.processor import ( @@ -34,6 +33,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import create_transition, transition_to_batch +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE def create_default_config(): diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index 57f32482d..6abc9edef 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -19,9 +19,9 @@ import pytest import torch from lerobot.configs.types import FeatureType, PipelineFeatureType -from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.processor import TransitionKey, VanillaObservationProcessorStep from lerobot.processor.converters import create_transition +from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from tests.conftest import assert_contract_is_typed diff --git a/tests/processor/test_pi0_processor.py b/tests/processor/test_pi0_processor.py index c481cb18f..24afc648f 100644 --- a/tests/processor/test_pi0_processor.py +++ b/tests/processor/test_pi0_processor.py @@ -21,7 +21,6 @@ import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre_post_processors from lerobot.processor import ( @@ -35,6 +34,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import create_transition, transition_to_batch +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE class MockTokenizerProcessorStep(ProcessorStep): diff --git a/tests/processor/test_sac_processor.py b/tests/processor/test_sac_processor.py index 7cbcb1882..a1a4b285d 100644 --- a/tests/processor/test_sac_processor.py +++ b/tests/processor/test_sac_processor.py @@ -21,7 +21,6 @@ import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.constants import ACTION, OBS_STATE from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors from lerobot.processor import ( @@ -34,6 +33,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import create_transition, transition_to_batch +from lerobot.utils.constants import ACTION, OBS_STATE def create_default_config(): diff --git a/tests/processor/test_smolvla_processor.py b/tests/processor/test_smolvla_processor.py index ce162c10d..227b1dc35 100644 --- a/tests/processor/test_smolvla_processor.py +++ b/tests/processor/test_smolvla_processor.py @@ -21,7 +21,6 @@ import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature -from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.smolvla.processor_smolvla import ( SmolVLANewLineProcessor, @@ -38,6 +37,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import create_transition, transition_to_batch +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE class MockTokenizerProcessorStep(ProcessorStep): diff --git a/tests/processor/test_tdmpc_processor.py b/tests/processor/test_tdmpc_processor.py index 20979fd6d..edbc25ae3 100644 --- a/tests/processor/test_tdmpc_processor.py +++ b/tests/processor/test_tdmpc_processor.py @@ -21,7 +21,6 @@ import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors from lerobot.processor import ( @@ -34,6 +33,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import create_transition, transition_to_batch +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE def create_default_config(): diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index b3b0c9bfc..9e6c8de2f 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -9,9 +9,9 @@ import pytest import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.constants import OBS_LANGUAGE from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey from lerobot.processor.converters import create_transition, identity_transition +from lerobot.utils.constants import OBS_LANGUAGE from tests.utils import require_package diff --git a/tests/processor/test_vqbet_processor.py b/tests/processor/test_vqbet_processor.py index 98e05eae8..47e41dff4 100644 --- a/tests/processor/test_vqbet_processor.py +++ b/tests/processor/test_vqbet_processor.py @@ -21,7 +21,6 @@ import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors from lerobot.processor import ( @@ -34,6 +33,7 @@ from lerobot.processor import ( UnnormalizerProcessorStep, ) from lerobot.processor.converters import create_transition, transition_to_batch +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE def create_default_config(): diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index 2d963d7ae..0eeaf907c 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -14,7 +14,7 @@ from pathlib import Path from unittest.mock import Mock, patch -from lerobot.constants import ( +from lerobot.utils.constants import ( CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK, OPTIMIZER_PARAM_GROUPS, From 1033680a57e1b1a9a6045606b6de8d892b09a666 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 11:14:23 +0200 Subject: [PATCH 109/158] chore: move errors to utils (#2017) Signed-off-by: Steven Palma --- src/lerobot/cameras/opencv/camera_opencv.py | 2 +- src/lerobot/cameras/reachy2_camera/reachy2_camera.py | 2 +- src/lerobot/cameras/realsense/camera_realsense.py | 2 +- src/lerobot/motors/motors_bus.py | 2 +- src/lerobot/robots/hope_jr/hope_jr_arm.py | 2 +- src/lerobot/robots/hope_jr/hope_jr_hand.py | 2 +- src/lerobot/robots/koch_follower/koch_follower.py | 2 +- src/lerobot/robots/lekiwi/lekiwi.py | 2 +- src/lerobot/robots/lekiwi/lekiwi_client.py | 2 +- src/lerobot/robots/so100_follower/so100_follower.py | 2 +- src/lerobot/robots/so101_follower/so101_follower.py | 2 +- src/lerobot/robots/viperx/viperx.py | 2 +- src/lerobot/teleoperators/homunculus/homunculus_arm.py | 2 +- src/lerobot/teleoperators/homunculus/homunculus_glove.py | 2 +- src/lerobot/teleoperators/keyboard/teleop_keyboard.py | 2 +- src/lerobot/teleoperators/koch_leader/koch_leader.py | 2 +- src/lerobot/teleoperators/phone/teleop_phone.py | 2 +- src/lerobot/teleoperators/so100_leader/so100_leader.py | 2 +- src/lerobot/teleoperators/so101_leader/so101_leader.py | 2 +- src/lerobot/teleoperators/stretch3_gamepad/stretch3_gamepad.py | 2 +- src/lerobot/teleoperators/widowx/widowx.py | 2 +- src/lerobot/{ => utils}/errors.py | 0 tests/cameras/test_opencv.py | 2 +- tests/cameras/test_reachy2_camera.py | 2 +- tests/cameras/test_realsense.py | 2 +- tests/mocks/mock_robot.py | 2 +- tests/mocks/mock_teleop.py | 2 +- 27 files changed, 26 insertions(+), 26 deletions(-) rename src/lerobot/{ => utils}/errors.py (100%) diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index 3665a909f..50e55f0c2 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -31,7 +31,7 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" import cv2 import numpy as np -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..camera import Camera from ..utils import get_cv2_backend, get_cv2_rotation diff --git a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py index 0daeb6bbb..c96789f96 100644 --- a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py @@ -31,7 +31,7 @@ import numpy as np from reachy2_sdk.media.camera import CameraView from reachy2_sdk.media.camera_manager import CameraManager -from lerobot.errors import DeviceNotConnectedError +from lerobot.utils.errors import DeviceNotConnectedError from ..camera import Camera from .configuration_reachy2_camera import ColorMode, Reachy2CameraConfig diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index 12ce89c91..cc816e552 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -29,7 +29,7 @@ try: except Exception as e: logging.info(f"Could not import realsense: {e}") -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..camera import Camera from ..configs import ColorMode diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index 97830fc35..dca7650e0 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -32,7 +32,7 @@ import serial from deepdiff import DeepDiff from tqdm import tqdm -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.utils.utils import enter_pressed, move_cursor_up NameOrID: TypeAlias = str | int diff --git a/src/lerobot/robots/hope_jr/hope_jr_arm.py b/src/lerobot/robots/hope_jr/hope_jr_arm.py index 0e3a615a9..baa36b560 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_arm.py +++ b/src/lerobot/robots/hope_jr/hope_jr_arm.py @@ -20,12 +20,12 @@ from functools import cached_property from typing import Any from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.motors import Motor, MotorNormMode from lerobot.motors.calibration_gui import RangeFinderGUI from lerobot.motors.feetech import ( FeetechMotorsBus, ) +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot from ..utils import ensure_safe_goal_position diff --git a/src/lerobot/robots/hope_jr/hope_jr_hand.py b/src/lerobot/robots/hope_jr/hope_jr_hand.py index 8dc100e06..9e960642b 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_hand.py +++ b/src/lerobot/robots/hope_jr/hope_jr_hand.py @@ -20,12 +20,12 @@ from functools import cached_property from typing import Any from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.motors import Motor, MotorNormMode from lerobot.motors.calibration_gui import RangeFinderGUI from lerobot.motors.feetech import ( FeetechMotorsBus, ) +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot from .config_hope_jr import HopeJrHandConfig diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py index 563325b88..41a57828b 100644 --- a/src/lerobot/robots/koch_follower/koch_follower.py +++ b/src/lerobot/robots/koch_follower/koch_follower.py @@ -20,12 +20,12 @@ from functools import cached_property from typing import Any from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot from ..utils import ensure_safe_goal_position diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py index 7004cc0fe..357109cb0 100644 --- a/src/lerobot/robots/lekiwi/lekiwi.py +++ b/src/lerobot/robots/lekiwi/lekiwi.py @@ -23,12 +23,12 @@ from typing import Any import numpy as np from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot from ..utils import ensure_safe_goal_position diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py index 9a8001401..9f6367152 100644 --- a/src/lerobot/robots/lekiwi/lekiwi_client.py +++ b/src/lerobot/robots/lekiwi/lekiwi_client.py @@ -23,7 +23,7 @@ from typing import Any import cv2 import numpy as np -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot from .config_lekiwi import LeKiwiClientConfig diff --git a/src/lerobot/robots/so100_follower/so100_follower.py b/src/lerobot/robots/so100_follower/so100_follower.py index 1e117e80b..d660ebed4 100644 --- a/src/lerobot/robots/so100_follower/so100_follower.py +++ b/src/lerobot/robots/so100_follower/so100_follower.py @@ -20,12 +20,12 @@ from functools import cached_property from typing import Any from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot from ..utils import ensure_safe_goal_position diff --git a/src/lerobot/robots/so101_follower/so101_follower.py b/src/lerobot/robots/so101_follower/so101_follower.py index 31b06c2fd..acfd4bd11 100644 --- a/src/lerobot/robots/so101_follower/so101_follower.py +++ b/src/lerobot/robots/so101_follower/so101_follower.py @@ -20,12 +20,12 @@ from functools import cached_property from typing import Any from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot from ..utils import ensure_safe_goal_position diff --git a/src/lerobot/robots/viperx/viperx.py b/src/lerobot/robots/viperx/viperx.py index 006c780e3..31e99ffdb 100644 --- a/src/lerobot/robots/viperx/viperx.py +++ b/src/lerobot/robots/viperx/viperx.py @@ -18,13 +18,13 @@ from functools import cached_property from typing import Any from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) from lerobot.utils.constants import OBS_STATE +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot from ..utils import ensure_safe_goal_position diff --git a/src/lerobot/teleoperators/homunculus/homunculus_arm.py b/src/lerobot/teleoperators/homunculus/homunculus_arm.py index 6f5137af9..4eca4b9e2 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_arm.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_arm.py @@ -22,8 +22,8 @@ from typing import Deque import serial -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.motors.motors_bus import MotorCalibration, MotorNormMode +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.utils.utils import enter_pressed, move_cursor_up from ..teleoperator import Teleoperator diff --git a/src/lerobot/teleoperators/homunculus/homunculus_glove.py b/src/lerobot/teleoperators/homunculus/homunculus_glove.py index 7b0ced9f6..52fd19def 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_glove.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_glove.py @@ -22,10 +22,10 @@ from typing import Deque import serial -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.motors import MotorCalibration from lerobot.motors.motors_bus import MotorNormMode from lerobot.teleoperators.homunculus.joints_translation import homunculus_glove_to_hope_jr_hand +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.utils.utils import enter_pressed, move_cursor_up from ..teleoperator import Teleoperator diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py index 7f489b25a..6f53a17c7 100644 --- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py +++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py @@ -21,7 +21,7 @@ import time from queue import Queue from typing import Any -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..teleoperator import Teleoperator from ..utils import TeleopEvents diff --git a/src/lerobot/teleoperators/koch_leader/koch_leader.py b/src/lerobot/teleoperators/koch_leader/koch_leader.py index f703d5b6e..0409f2e57 100644 --- a/src/lerobot/teleoperators/koch_leader/koch_leader.py +++ b/src/lerobot/teleoperators/koch_leader/koch_leader.py @@ -17,13 +17,13 @@ import logging import time -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.dynamixel import ( DriveMode, DynamixelMotorsBus, OperatingMode, ) +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..teleoperator import Teleoperator from .config_koch_leader import KochLeaderConfig diff --git a/src/lerobot/teleoperators/phone/teleop_phone.py b/src/lerobot/teleoperators/phone/teleop_phone.py index c90729efa..91e613190 100644 --- a/src/lerobot/teleoperators/phone/teleop_phone.py +++ b/src/lerobot/teleoperators/phone/teleop_phone.py @@ -26,9 +26,9 @@ import hebi import numpy as np from teleop import Teleop -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS from lerobot.teleoperators.teleoperator import Teleoperator +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.utils.rotation import Rotation logger = logging.getLogger(__name__) diff --git a/src/lerobot/teleoperators/so100_leader/so100_leader.py b/src/lerobot/teleoperators/so100_leader/so100_leader.py index a8f6d29b5..edcfe53e6 100644 --- a/src/lerobot/teleoperators/so100_leader/so100_leader.py +++ b/src/lerobot/teleoperators/so100_leader/so100_leader.py @@ -17,12 +17,12 @@ import logging import time -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..teleoperator import Teleoperator from .config_so100_leader import SO100LeaderConfig diff --git a/src/lerobot/teleoperators/so101_leader/so101_leader.py b/src/lerobot/teleoperators/so101_leader/so101_leader.py index 15a363e37..be804bf70 100644 --- a/src/lerobot/teleoperators/so101_leader/so101_leader.py +++ b/src/lerobot/teleoperators/so101_leader/so101_leader.py @@ -17,12 +17,12 @@ import logging import time -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..teleoperator import Teleoperator from .config_so101_leader import SO101LeaderConfig diff --git a/src/lerobot/teleoperators/stretch3_gamepad/stretch3_gamepad.py b/src/lerobot/teleoperators/stretch3_gamepad/stretch3_gamepad.py index bdcb57d40..94e1ca7cc 100644 --- a/src/lerobot/teleoperators/stretch3_gamepad/stretch3_gamepad.py +++ b/src/lerobot/teleoperators/stretch3_gamepad/stretch3_gamepad.py @@ -20,7 +20,7 @@ import numpy as np from stretch_body.gamepad_teleop import GamePadTeleop from stretch_body.robot_params import RobotParams -from lerobot.errors import DeviceAlreadyConnectedError +from lerobot.utils.errors import DeviceAlreadyConnectedError from ..teleoperator import Teleoperator from .configuration_stretch3 import Stretch3GamePadConfig diff --git a/src/lerobot/teleoperators/widowx/widowx.py b/src/lerobot/teleoperators/widowx/widowx.py index 6becd767f..1a00bd4d2 100644 --- a/src/lerobot/teleoperators/widowx/widowx.py +++ b/src/lerobot/teleoperators/widowx/widowx.py @@ -17,13 +17,13 @@ import logging import time -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.dynamixel import ( DriveMode, DynamixelMotorsBus, OperatingMode, ) +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..teleoperator import Teleoperator from .config_widowx import WidowXConfig diff --git a/src/lerobot/errors.py b/src/lerobot/utils/errors.py similarity index 100% rename from src/lerobot/errors.py rename to src/lerobot/utils/errors.py diff --git a/tests/cameras/test_opencv.py b/tests/cameras/test_opencv.py index a9c060c4f..a3d98a679 100644 --- a/tests/cameras/test_opencv.py +++ b/tests/cameras/test_opencv.py @@ -26,7 +26,7 @@ import pytest from lerobot.cameras.configs import Cv2Rotation from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError # NOTE(Steven): more tests + assertions? TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras" diff --git a/tests/cameras/test_reachy2_camera.py b/tests/cameras/test_reachy2_camera.py index 66c7675a6..0b38e8b0b 100644 --- a/tests/cameras/test_reachy2_camera.py +++ b/tests/cameras/test_reachy2_camera.py @@ -21,7 +21,7 @@ import numpy as np import pytest from lerobot.cameras.reachy2_camera import Reachy2Camera, Reachy2CameraConfig -from lerobot.errors import DeviceNotConnectedError +from lerobot.utils.errors import DeviceNotConnectedError PARAMS = [ ("teleop", "left"), diff --git a/tests/cameras/test_realsense.py b/tests/cameras/test_realsense.py index 4b3fbae82..fb9912257 100644 --- a/tests/cameras/test_realsense.py +++ b/tests/cameras/test_realsense.py @@ -26,7 +26,7 @@ import numpy as np import pytest from lerobot.cameras.configs import Cv2Rotation -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError pytest.importorskip("pyrealsense2") diff --git a/tests/mocks/mock_robot.py b/tests/mocks/mock_robot.py index 8108c7c25..027ee45ed 100644 --- a/tests/mocks/mock_robot.py +++ b/tests/mocks/mock_robot.py @@ -20,8 +20,8 @@ from functools import cached_property from typing import Any from lerobot.cameras import CameraConfig, make_cameras_from_configs -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.robots import Robot, RobotConfig +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError @RobotConfig.register_subclass("mock_robot") diff --git a/tests/mocks/mock_teleop.py b/tests/mocks/mock_teleop.py index e37d4a2c5..71b49947c 100644 --- a/tests/mocks/mock_teleop.py +++ b/tests/mocks/mock_teleop.py @@ -19,8 +19,8 @@ from dataclasses import dataclass from functools import cached_property from typing import Any -from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.teleoperators import Teleoperator, TeleoperatorConfig +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError @TeleoperatorConfig.register_subclass("mock_teleop") From bd09b2153f0bcd2cce8d97142f6361a964d06690 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 11:14:48 +0200 Subject: [PATCH 110/158] chore(scripts): move find_cameras to scripts (#2018) --- pyproject.toml | 2 +- .../{find_cameras.py => scripts/lerobot_find_cameras.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/lerobot/{find_cameras.py => scripts/lerobot_find_cameras.py} (100%) diff --git a/pyproject.toml b/pyproject.toml index 9ed3da006..3f5ef3f87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -163,7 +163,7 @@ all = [ [project.scripts] lerobot-calibrate="lerobot.calibrate:main" -lerobot-find-cameras="lerobot.find_cameras:main" +lerobot-find-cameras="lerobot.scripts.lerobot_find_cameras:main" lerobot-find-port="lerobot.find_port:main" lerobot-record="lerobot.record:main" lerobot-replay="lerobot.replay:main" diff --git a/src/lerobot/find_cameras.py b/src/lerobot/scripts/lerobot_find_cameras.py similarity index 100% rename from src/lerobot/find_cameras.py rename to src/lerobot/scripts/lerobot_find_cameras.py From a4178f385b05434b7c60c5190316678d4e2e31a4 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 11:28:56 +0200 Subject: [PATCH 111/158] feat(script): add entry point for find joints limits (#2010) Signed-off-by: Steven Palma --- docs/source/hilserl.mdx | 16 +++++++-------- pyproject.toml | 1 + src/lerobot/rl/actor.py | 2 +- ...limits.py => lerobot_find_joint_limits.py} | 20 +++++++++++-------- 4 files changed, 22 insertions(+), 17 deletions(-) rename src/lerobot/scripts/{find_joint_limits.py => lerobot_find_joint_limits.py} (93%) diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx index 07f92b824..bc38408e6 100644 --- a/docs/source/hilserl.mdx +++ b/docs/source/hilserl.mdx @@ -304,19 +304,19 @@ Before collecting demonstrations, you need to determine the appropriate operatio This helps simplify the problem of learning on the real robot in two ways: 1) by limiting the robot's operational space to a specific region that solves the task and avoids unnecessary or unsafe exploration, and 2) by allowing training in end-effector space rather than joint space. Empirically, learning in joint space for reinforcement learning in manipulation is often a harder problem - some tasks are nearly impossible to learn in joint space but become learnable when the action space is transformed to end-effector coordinates. -**Using find_joint_limits.py** +**Using lerobot-find-joint-limits** This script helps you find the safe operational bounds for your robot's end-effector. Given that you have a follower and leader arm, you can use the script to find the bounds for the follower arm that will be applied during training. Bounding the action space will reduce the redundant exploration of the agent and guarantees safety. ```bash -python -m lerobot.scripts.find_joint_limits \ - --robot.type=so100_follower \ - --robot.port=/dev/tty.usbmodem58760431541 \ - --robot.id=black \ - --teleop.type=so100_leader \ - --teleop.port=/dev/tty.usbmodem58760431551 \ - --teleop.id=blue +lerobot-find-joint-limits \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.id=black \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ + --teleop.id=blue ``` **Workflow** diff --git a/pyproject.toml b/pyproject.toml index 3f5ef3f87..acd1e8a0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -173,6 +173,7 @@ lerobot-eval="lerobot.scripts.eval:main" lerobot-train="lerobot.scripts.train:main" lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main" lerobot-info="lerobot.scripts.lerobot_info:main" +lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main" lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" # ---------------- Tool Configurations ---------------- diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index d1e709253..b38858ca6 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -35,7 +35,7 @@ gamepad to take control of the robot during training. Initially intervene freque reduce interventions as the policy improves. **WORKFLOW**: -1. Determine robot workspace bounds using `find_joint_limits.py` +1. Determine robot workspace bounds using `lerobot-find-joint-limits` 2. Record demonstrations with `gym_manipulator.py` in record mode 3. Process the dataset and determine camera crops with `crop_dataset_roi.py` 4. Start the learner server with the training configuration diff --git a/src/lerobot/scripts/find_joint_limits.py b/src/lerobot/scripts/lerobot_find_joint_limits.py similarity index 93% rename from src/lerobot/scripts/find_joint_limits.py rename to src/lerobot/scripts/lerobot_find_joint_limits.py index f7e07514f..07d57a760 100644 --- a/src/lerobot/scripts/find_joint_limits.py +++ b/src/lerobot/scripts/lerobot_find_joint_limits.py @@ -20,13 +20,13 @@ Simple script to control a robot from teleoperation. Example: ```shell -python -m lerobot.scripts.server.find_joint_limits \ - --robot.type=so100_follower \ - --robot.port=/dev/tty.usbmodem58760431541 \ - --robot.id=black \ - --teleop.type=so100_leader \ - --teleop.port=/dev/tty.usbmodem58760431551 \ - --teleop.id=blue +lerobot-find-joint-limits \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.id=black \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ + --teleop.id=blue ``` """ @@ -117,5 +117,9 @@ def find_joint_and_ee_bounds(cfg: FindJointLimitsConfig): busy_wait(0.01) -if __name__ == "__main__": +def main(): find_joint_and_ee_bounds() + + +if __name__ == "__main__": + main() From 98bcda2d8bb2ecc272a71c8a89948549898f5a84 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 11:38:04 +0200 Subject: [PATCH 112/158] chore(scripts): move find_port to scripts (#2019) --- pyproject.toml | 2 +- src/lerobot/scripts/lerobot_find_cameras.py | 2 +- src/lerobot/{find_port.py => scripts/lerobot_find_port.py} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename src/lerobot/{find_port.py => scripts/lerobot_find_port.py} (100%) diff --git a/pyproject.toml b/pyproject.toml index acd1e8a0c..9785481ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,7 +164,7 @@ all = [ [project.scripts] lerobot-calibrate="lerobot.calibrate:main" lerobot-find-cameras="lerobot.scripts.lerobot_find_cameras:main" -lerobot-find-port="lerobot.find_port:main" +lerobot-find-port="lerobot.scripts.lerobot_find_port:main" lerobot-record="lerobot.record:main" lerobot-replay="lerobot.replay:main" lerobot-setup-motors="lerobot.setup_motors:main" diff --git a/src/lerobot/scripts/lerobot_find_cameras.py b/src/lerobot/scripts/lerobot_find_cameras.py index ec8f5ff30..e17dca805 100644 --- a/src/lerobot/scripts/lerobot_find_cameras.py +++ b/src/lerobot/scripts/lerobot_find_cameras.py @@ -24,7 +24,7 @@ lerobot-find-cameras ``` """ -# NOTE(Steven): RealSense can also be identified/opened as OpenCV cameras. If you know the camera is a RealSense, use the `lerobot.find_cameras realsense` flag to avoid confusion. +# NOTE(Steven): RealSense can also be identified/opened as OpenCV cameras. If you know the camera is a RealSense, use the `lerobot-find-cameras realsense` flag to avoid confusion. # NOTE(Steven): macOS cameras sometimes report different FPS at init time, not an issue here as we don't specify FPS when opening the cameras, but the information displayed might not be truthful. import argparse diff --git a/src/lerobot/find_port.py b/src/lerobot/scripts/lerobot_find_port.py similarity index 100% rename from src/lerobot/find_port.py rename to src/lerobot/scripts/lerobot_find_port.py From 42e4b3d09e10b84806c470a47bf38ac5eb834e9a Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 12:01:21 +0200 Subject: [PATCH 113/158] chore(scripts): move teleop to scripts (#2023) --- pyproject.toml | 2 +- src/lerobot/{teleoperate.py => scripts/lerobot_teleoperate.py} | 0 tests/test_control_robot.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename src/lerobot/{teleoperate.py => scripts/lerobot_teleoperate.py} (100%) diff --git a/pyproject.toml b/pyproject.toml index 9785481ee..0a4096a40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,7 +168,7 @@ lerobot-find-port="lerobot.scripts.lerobot_find_port:main" lerobot-record="lerobot.record:main" lerobot-replay="lerobot.replay:main" lerobot-setup-motors="lerobot.setup_motors:main" -lerobot-teleoperate="lerobot.teleoperate:main" +lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main" lerobot-eval="lerobot.scripts.eval:main" lerobot-train="lerobot.scripts.train:main" lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main" diff --git a/src/lerobot/teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py similarity index 100% rename from src/lerobot/teleoperate.py rename to src/lerobot/scripts/lerobot_teleoperate.py diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 374f98129..8df71e040 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -19,7 +19,7 @@ from unittest.mock import patch from lerobot.calibrate import CalibrateConfig, calibrate from lerobot.record import DatasetRecordConfig, RecordConfig, record from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay -from lerobot.teleoperate import TeleoperateConfig, teleoperate +from lerobot.scripts.lerobot_teleoperate import TeleoperateConfig, teleoperate from tests.fixtures.constants import DUMMY_REPO_ID from tests.mocks.mock_robot import MockRobotConfig from tests.mocks.mock_teleop import MockTeleopConfig From 2b59850f15fea55b24b4a6ce3eab932d3100cf4b Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 13:38:12 +0200 Subject: [PATCH 114/158] chore(scripts): move record to scripts (#2022) Signed-off-by: Steven Palma --- examples/lekiwi/evaluate.py | 2 +- examples/lekiwi/record.py | 2 +- examples/phone_to_so100/evaluate.py | 2 +- examples/phone_to_so100/record.py | 2 +- examples/so100_to_so100_EE/evaluate.py | 2 +- examples/so100_to_so100_EE/record.py | 2 +- pyproject.toml | 2 +- src/lerobot/{record.py => scripts/lerobot_record.py} | 0 tests/test_control_robot.py | 2 +- 9 files changed, 8 insertions(+), 8 deletions(-) rename src/lerobot/{record.py => scripts/lerobot_record.py} (100%) diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 3dbb10f56..8993a5e14 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -19,8 +19,8 @@ from lerobot.datasets.utils import hw_to_dataset_features from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors from lerobot.processor import make_default_processors -from lerobot.record import record_loop from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig +from lerobot.scripts.lerobot_record import record_loop from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import _init_rerun diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index f5d109d5d..f59093b26 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -17,9 +17,9 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import hw_to_dataset_features from lerobot.processor import make_default_processors -from lerobot.record import record_loop from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient +from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig from lerobot.utils.control_utils import init_keyboard_listener diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index e76b11350..c7d6eb240 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -34,13 +34,13 @@ from lerobot.processor.converters import ( transition_to_observation, transition_to_robot_action, ) -from lerobot.record import record_loop from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.robot_kinematic_processor import ( ForwardKinematicsJointsToEE, InverseKinematicsEEToJoints, ) from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.scripts.lerobot_record import record_loop from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import _init_rerun diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index 768041d63..6681017a0 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -26,7 +26,6 @@ from lerobot.processor.converters import ( transition_to_observation, transition_to_robot_action, ) -from lerobot.record import record_loop from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.robot_kinematic_processor import ( EEBoundsAndSafety, @@ -36,6 +35,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction from lerobot.teleoperators.phone.teleop_phone import Phone diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py index fd10bf865..f47a216d6 100644 --- a/examples/so100_to_so100_EE/evaluate.py +++ b/examples/so100_to_so100_EE/evaluate.py @@ -34,13 +34,13 @@ from lerobot.processor.converters import ( transition_to_observation, transition_to_robot_action, ) -from lerobot.record import record_loop from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.robot_kinematic_processor import ( ForwardKinematicsJointsToEE, InverseKinematicsEEToJoints, ) from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.scripts.lerobot_record import record_loop from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import _init_rerun diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index abb8fb99d..60c96835f 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -27,7 +27,6 @@ from lerobot.processor.converters import ( transition_to_observation, transition_to_robot_action, ) -from lerobot.record import record_loop from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig from lerobot.robots.so100_follower.robot_kinematic_processor import ( EEBoundsAndSafety, @@ -35,6 +34,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader from lerobot.utils.control_utils import init_keyboard_listener diff --git a/pyproject.toml b/pyproject.toml index 0a4096a40..69c0fa2b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -165,7 +165,7 @@ all = [ lerobot-calibrate="lerobot.calibrate:main" lerobot-find-cameras="lerobot.scripts.lerobot_find_cameras:main" lerobot-find-port="lerobot.scripts.lerobot_find_port:main" -lerobot-record="lerobot.record:main" +lerobot-record="lerobot.scripts.lerobot_record:main" lerobot-replay="lerobot.replay:main" lerobot-setup-motors="lerobot.setup_motors:main" lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main" diff --git a/src/lerobot/record.py b/src/lerobot/scripts/lerobot_record.py similarity index 100% rename from src/lerobot/record.py rename to src/lerobot/scripts/lerobot_record.py diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 8df71e040..239f6a0e3 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -17,8 +17,8 @@ from unittest.mock import patch from lerobot.calibrate import CalibrateConfig, calibrate -from lerobot.record import DatasetRecordConfig, RecordConfig, record from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay +from lerobot.scripts.lerobot_record import DatasetRecordConfig, RecordConfig, record from lerobot.scripts.lerobot_teleoperate import TeleoperateConfig, teleoperate from tests.fixtures.constants import DUMMY_REPO_ID from tests.mocks.mock_robot import MockRobotConfig From acbc14f60a56d5138f57721c15d71cee45baa4e7 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 14:06:48 +0200 Subject: [PATCH 115/158] chore(scripts): move calibrate to scripts (#2024) Signed-off-by: Steven Palma --- pyproject.toml | 2 +- src/lerobot/{calibrate.py => scripts/lerobot_calibrate.py} | 0 tests/test_control_robot.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename src/lerobot/{calibrate.py => scripts/lerobot_calibrate.py} (100%) diff --git a/pyproject.toml b/pyproject.toml index 69c0fa2b8..893d637f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -162,7 +162,7 @@ all = [ ] [project.scripts] -lerobot-calibrate="lerobot.calibrate:main" +lerobot-calibrate="lerobot.scripts.lerobot_calibrate:main" lerobot-find-cameras="lerobot.scripts.lerobot_find_cameras:main" lerobot-find-port="lerobot.scripts.lerobot_find_port:main" lerobot-record="lerobot.scripts.lerobot_record:main" diff --git a/src/lerobot/calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py similarity index 100% rename from src/lerobot/calibrate.py rename to src/lerobot/scripts/lerobot_calibrate.py diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 239f6a0e3..a1dd33286 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -16,8 +16,8 @@ from unittest.mock import patch -from lerobot.calibrate import CalibrateConfig, calibrate from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay +from lerobot.scripts.lerobot_calibrate import CalibrateConfig, calibrate from lerobot.scripts.lerobot_record import DatasetRecordConfig, RecordConfig, record from lerobot.scripts.lerobot_teleoperate import TeleoperateConfig, teleoperate from tests.fixtures.constants import DUMMY_REPO_ID From 13010647bc5579dcf30753ec26b92db2cd91207f Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 14:06:58 +0200 Subject: [PATCH 116/158] chore(scripts): move setup_motors to scripts (#2020) Signed-off-by: Steven Palma --- pyproject.toml | 2 +- .../{setup_motors.py => scripts/lerobot_setup_motors.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/lerobot/{setup_motors.py => scripts/lerobot_setup_motors.py} (100%) diff --git a/pyproject.toml b/pyproject.toml index 893d637f6..f2be357e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -167,7 +167,7 @@ lerobot-find-cameras="lerobot.scripts.lerobot_find_cameras:main" lerobot-find-port="lerobot.scripts.lerobot_find_port:main" lerobot-record="lerobot.scripts.lerobot_record:main" lerobot-replay="lerobot.replay:main" -lerobot-setup-motors="lerobot.setup_motors:main" +lerobot-setup-motors="lerobot.scripts.lerobot_setup_motors:main" lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main" lerobot-eval="lerobot.scripts.eval:main" lerobot-train="lerobot.scripts.train:main" diff --git a/src/lerobot/setup_motors.py b/src/lerobot/scripts/lerobot_setup_motors.py similarity index 100% rename from src/lerobot/setup_motors.py rename to src/lerobot/scripts/lerobot_setup_motors.py From 7359e18eb65e33649951c5f7c959cc8fc8cf2ab2 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 14:48:23 +0200 Subject: [PATCH 117/158] chore(scripts): move replay to scripts (#2021) Signed-off-by: Steven Palma --- pyproject.toml | 2 +- src/lerobot/{replay.py => scripts/lerobot_replay.py} | 0 tests/test_control_robot.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename src/lerobot/{replay.py => scripts/lerobot_replay.py} (100%) diff --git a/pyproject.toml b/pyproject.toml index f2be357e9..dbc25805d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -166,7 +166,7 @@ lerobot-calibrate="lerobot.scripts.lerobot_calibrate:main" lerobot-find-cameras="lerobot.scripts.lerobot_find_cameras:main" lerobot-find-port="lerobot.scripts.lerobot_find_port:main" lerobot-record="lerobot.scripts.lerobot_record:main" -lerobot-replay="lerobot.replay:main" +lerobot-replay="lerobot.scripts.lerobot_replay:main" lerobot-setup-motors="lerobot.scripts.lerobot_setup_motors:main" lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main" lerobot-eval="lerobot.scripts.eval:main" diff --git a/src/lerobot/replay.py b/src/lerobot/scripts/lerobot_replay.py similarity index 100% rename from src/lerobot/replay.py rename to src/lerobot/scripts/lerobot_replay.py diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index a1dd33286..ace0aea49 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -16,9 +16,9 @@ from unittest.mock import patch -from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay from lerobot.scripts.lerobot_calibrate import CalibrateConfig, calibrate from lerobot.scripts.lerobot_record import DatasetRecordConfig, RecordConfig, record +from lerobot.scripts.lerobot_replay import DatasetReplayConfig, ReplayConfig, replay from lerobot.scripts.lerobot_teleoperate import TeleoperateConfig, teleoperate from tests.fixtures.constants import DUMMY_REPO_ID from tests.mocks.mock_robot import MockRobotConfig From 1cba47da20482f81016e162925063180da8dbbf6 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 14:49:37 +0200 Subject: [PATCH 118/158] chore(async): move async related code to its directory at top level (#2003) * chore(async): move async related code to its directory at top level * chore(style): apply pre-commit to renamed headers * test(async): fix async imports * docs(async): update async headers doc --- docs/source/async.mdx | 16 ++++++++-------- .../server => async_inference}/configs.py | 3 ++- .../server => async_inference}/constants.py | 0 .../server => async_inference}/helpers.py | 0 .../policy_server.py | 19 ++++++++++--------- .../robot_client.py | 19 ++++++++++--------- tests/async_inference/test_e2e.py | 8 ++++---- tests/async_inference/test_helpers.py | 4 ++-- tests/async_inference/test_policy_server.py | 8 ++++---- tests/async_inference/test_robot_client.py | 8 ++++---- 10 files changed, 44 insertions(+), 41 deletions(-) rename src/lerobot/{scripts/server => async_inference}/configs.py (99%) rename src/lerobot/{scripts/server => async_inference}/constants.py (100%) rename src/lerobot/{scripts/server => async_inference}/helpers.py (100%) rename src/lerobot/{scripts/server => async_inference}/policy_server.py (98%) rename src/lerobot/{scripts/server => async_inference}/robot_client.py (98%) diff --git a/docs/source/async.mdx b/docs/source/async.mdx index 397c513cf..c66cdb143 100644 --- a/docs/source/async.mdx +++ b/docs/source/async.mdx @@ -31,7 +31,7 @@ Then, spin up a policy server (in one terminal, or in a separate machine) specif You can spin up a policy server running: ```shell -python src/lerobot/scripts/server/policy_server.py \ +python src/lerobot/async_inference/policy_server.py \ --host=127.0.0.1 \ --port=8080 \ ``` @@ -39,7 +39,7 @@ python src/lerobot/scripts/server/policy_server.py \ This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with: ```shell -python src/lerobot/scripts/server/robot_client.py \ +python src/lerobot/async_inference/robot_client.py \ --server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server --robot.type=so100_follower \ # ROBOT: your robot type --robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port @@ -122,8 +122,8 @@ python -m lerobot.scripts.server.policy_server \ ```python -from lerobot.scripts.server.configs import PolicyServerConfig -from lerobot.scripts.server.policy_server import serve +from lerobot.async_inference.configs import PolicyServerConfig +from lerobot.async_inference.policy_server import serve config = PolicyServerConfig( host="localhost", @@ -148,7 +148,7 @@ The `RobotClient` streams observations to the `PolicyServer`, and receives actio ```bash -python src/lerobot/scripts/server/robot_client.py \ +python src/lerobot/async_inference/robot_client.py \ --server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server --robot.type=so100_follower \ # ROBOT: your robot type --robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port @@ -171,9 +171,9 @@ python src/lerobot/scripts/server/robot_client.py \ import threading from lerobot.robots.so100_follower import SO100FollowerConfig from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.scripts.server.configs import RobotClientConfig -from lerobot.scripts.server.robot_client import RobotClient -from lerobot.scripts.server.helpers import visualize_action_queue_size +from lerobot.async_inference.configs import RobotClientConfig +from lerobot.async_inference.robot_client import RobotClient +from lerobot.async_inference.helpers import visualize_action_queue_size # 1. Create the robot instance """Check out the cameras available in your setup by running `python lerobot/find_cameras.py`""" diff --git a/src/lerobot/scripts/server/configs.py b/src/lerobot/async_inference/configs.py similarity index 99% rename from src/lerobot/scripts/server/configs.py rename to src/lerobot/async_inference/configs.py index 5be46485e..24f889df1 100644 --- a/src/lerobot/scripts/server/configs.py +++ b/src/lerobot/async_inference/configs.py @@ -18,7 +18,8 @@ from dataclasses import dataclass, field import torch from lerobot.robots.config import RobotConfig -from lerobot.scripts.server.constants import ( + +from .constants import ( DEFAULT_FPS, DEFAULT_INFERENCE_LATENCY, DEFAULT_OBS_QUEUE_TIMEOUT, diff --git a/src/lerobot/scripts/server/constants.py b/src/lerobot/async_inference/constants.py similarity index 100% rename from src/lerobot/scripts/server/constants.py rename to src/lerobot/async_inference/constants.py diff --git a/src/lerobot/scripts/server/helpers.py b/src/lerobot/async_inference/helpers.py similarity index 100% rename from src/lerobot/scripts/server/helpers.py rename to src/lerobot/async_inference/helpers.py diff --git a/src/lerobot/scripts/server/policy_server.py b/src/lerobot/async_inference/policy_server.py similarity index 98% rename from src/lerobot/scripts/server/policy_server.py rename to src/lerobot/async_inference/policy_server.py index 0ed446d3a..125727060 100644 --- a/src/lerobot/scripts/server/policy_server.py +++ b/src/lerobot/async_inference/policy_server.py @@ -15,7 +15,7 @@ """ Example: ```shell -python src/lerobot/scripts/server/policy_server.py \ +python src/lerobot/async_inference/policy_server.py \ --host=127.0.0.1 \ --port=8080 \ --fps=30 \ @@ -38,9 +38,15 @@ import grpc import torch from lerobot.policies.factory import get_policy_class -from lerobot.scripts.server.configs import PolicyServerConfig -from lerobot.scripts.server.constants import SUPPORTED_POLICIES -from lerobot.scripts.server.helpers import ( +from lerobot.transport import ( + services_pb2, # type: ignore + services_pb2_grpc, # type: ignore +) +from lerobot.transport.utils import receive_bytes_in_chunks + +from .configs import PolicyServerConfig +from .constants import SUPPORTED_POLICIES +from .helpers import ( FPSTracker, Observation, RemotePolicyConfig, @@ -50,11 +56,6 @@ from lerobot.scripts.server.helpers import ( observations_similar, raw_observation_to_observation, ) -from lerobot.transport import ( - services_pb2, # type: ignore - services_pb2_grpc, # type: ignore -) -from lerobot.transport.utils import receive_bytes_in_chunks class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): diff --git a/src/lerobot/scripts/server/robot_client.py b/src/lerobot/async_inference/robot_client.py similarity index 98% rename from src/lerobot/scripts/server/robot_client.py rename to src/lerobot/async_inference/robot_client.py index 939d5cea8..c969bc605 100644 --- a/src/lerobot/scripts/server/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -15,7 +15,7 @@ """ Example command: ```shell -python src/lerobot/scripts/server/robot_client.py \ +python src/lerobot/async_inference/robot_client.py \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ @@ -57,9 +57,15 @@ from lerobot.robots import ( # noqa: F401 so100_follower, so101_follower, ) -from lerobot.scripts.server.configs import RobotClientConfig -from lerobot.scripts.server.constants import SUPPORTED_ROBOTS -from lerobot.scripts.server.helpers import ( +from lerobot.transport import ( + services_pb2, # type: ignore + services_pb2_grpc, # type: ignore +) +from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks + +from .configs import RobotClientConfig +from .constants import SUPPORTED_ROBOTS +from .helpers import ( Action, FPSTracker, Observation, @@ -72,11 +78,6 @@ from lerobot.scripts.server.helpers import ( validate_robot_cameras_for_policy, visualize_action_queue_size, ) -from lerobot.transport import ( - services_pb2, # type: ignore - services_pb2_grpc, # type: ignore -) -from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks class RobotClient: diff --git a/tests/async_inference/test_e2e.py b/tests/async_inference/test_e2e.py index 1c0400e66..2689f0618 100644 --- a/tests/async_inference/test_e2e.py +++ b/tests/async_inference/test_e2e.py @@ -48,11 +48,11 @@ def test_async_inference_e2e(monkeypatch): # Import grpc-dependent modules inside the test function import grpc + from lerobot.async_inference.configs import PolicyServerConfig, RobotClientConfig + from lerobot.async_inference.helpers import map_robot_keys_to_lerobot_features + from lerobot.async_inference.policy_server import PolicyServer + from lerobot.async_inference.robot_client import RobotClient from lerobot.robots.utils import make_robot_from_config - from lerobot.scripts.server.configs import PolicyServerConfig, RobotClientConfig - from lerobot.scripts.server.helpers import map_robot_keys_to_lerobot_features - from lerobot.scripts.server.policy_server import PolicyServer - from lerobot.scripts.server.robot_client import RobotClient from lerobot.transport import ( services_pb2, # type: ignore services_pb2_grpc, # type: ignore diff --git a/tests/async_inference/test_helpers.py b/tests/async_inference/test_helpers.py index e0b797371..f1c7636e2 100644 --- a/tests/async_inference/test_helpers.py +++ b/tests/async_inference/test_helpers.py @@ -19,8 +19,7 @@ import time import numpy as np import torch -from lerobot.configs.types import FeatureType, PolicyFeature -from lerobot.scripts.server.helpers import ( +from lerobot.async_inference.helpers import ( FPSTracker, TimedAction, TimedObservation, @@ -30,6 +29,7 @@ from lerobot.scripts.server.helpers import ( raw_observation_to_observation, resize_robot_observation_image, ) +from lerobot.configs.types import FeatureType, PolicyFeature # --------------------------------------------------------------------- # FPSTracker diff --git a/tests/async_inference/test_policy_server.py b/tests/async_inference/test_policy_server.py index 5c795e7ec..c5c52460f 100644 --- a/tests/async_inference/test_policy_server.py +++ b/tests/async_inference/test_policy_server.py @@ -65,8 +65,8 @@ class MockPolicy: def policy_server(): """Fresh `PolicyServer` instance with a stubbed-out policy model.""" # Import only when the test actually runs (after decorator check) - from lerobot.scripts.server.configs import PolicyServerConfig - from lerobot.scripts.server.policy_server import PolicyServer + from lerobot.async_inference.configs import PolicyServerConfig + from lerobot.async_inference.policy_server import PolicyServer test_config = PolicyServerConfig(host="localhost", port=9999) server = PolicyServer(test_config) @@ -95,7 +95,7 @@ def policy_server(): def _make_obs(state: torch.Tensor, timestep: int = 0, must_go: bool = False): """Create a TimedObservation with a given state vector.""" # Import only when needed - from lerobot.scripts.server.helpers import TimedObservation + from lerobot.async_inference.helpers import TimedObservation return TimedObservation( observation={ @@ -191,7 +191,7 @@ def test_obs_sanity_checks(policy_server): def test_predict_action_chunk(monkeypatch, policy_server): """End-to-end test of `_predict_action_chunk` with a stubbed _get_action_chunk.""" # Import only when needed - from lerobot.scripts.server.policy_server import PolicyServer + from lerobot.async_inference.policy_server import PolicyServer # Force server to act-style policy; patch method to return deterministic tensor policy_server.policy_type = "act" diff --git a/tests/async_inference/test_robot_client.py b/tests/async_inference/test_robot_client.py index 51db2c3a7..dfdb8ce42 100644 --- a/tests/async_inference/test_robot_client.py +++ b/tests/async_inference/test_robot_client.py @@ -38,8 +38,8 @@ def robot_client(): """Fresh `RobotClient` instance for each test case (no threads started). Uses DummyRobot.""" # Import only when the test actually runs (after decorator check) - from lerobot.scripts.server.configs import RobotClientConfig - from lerobot.scripts.server.robot_client import RobotClient + from lerobot.async_inference.configs import RobotClientConfig + from lerobot.async_inference.robot_client import RobotClient from tests.mocks.mock_robot import MockRobotConfig test_config = MockRobotConfig() @@ -73,7 +73,7 @@ def robot_client(): def _make_actions(start_ts: float, start_t: int, count: int): """Generate `count` consecutive TimedAction objects starting at timestep `start_t`.""" - from lerobot.scripts.server.helpers import TimedAction + from lerobot.async_inference.helpers import TimedAction fps = 30 # emulates most common frame-rate actions = [] @@ -124,7 +124,7 @@ def test_aggregate_action_queues_combines_actions_in_overlap( ): """`_aggregate_action_queues` must combine actions on overlapping timesteps according to the provided aggregate_fn, here tested with multiple coefficients.""" - from lerobot.scripts.server.helpers import TimedAction + from lerobot.async_inference.helpers import TimedAction robot_client.chunks_received = 0 From cdd2bf1c4e060e9ef3d43701a2346e36949041ce Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 15:46:44 +0200 Subject: [PATCH 119/158] chore(ci): update stale message (#2027) --- .github/workflows/stale.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index acd1ae53a..af91c9f58 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -31,11 +31,11 @@ env: Feel free to reopen if is still relevant, or to ping a collaborator if you have any questions. WARN_ISSUE_MESSAGE: > This issue has been automatically marked as stale because it has not had - recent activity (1 year). It will be closed if no further activity occurs. + recent activity (6 months). It will be closed if no further activity occurs. Thank you for your contributions. WARN_PR_MESSAGE: > This PR has been automatically marked as stale because it has not had - recent activity (1 year). It will be closed if no further activity occurs. + recent activity (6 months). It will be closed if no further activity occurs. Thank you for your contributions. jobs: From 163df97c0cd15fe4e15bae05fa637664b71e1209 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 16:17:39 +0200 Subject: [PATCH 120/158] fix(docs): update outdated links (#2026) --- docs/source/koch.mdx | 2 +- docs/source/lekiwi.mdx | 2 +- docs/source/smolvla.mdx | 4 ++-- docs/source/so100.mdx | 2 +- docs/source/so101.mdx | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/koch.mdx b/docs/source/koch.mdx index 3e94899a8..813b9bd67 100644 --- a/docs/source/koch.mdx +++ b/docs/source/koch.mdx @@ -277,7 +277,7 @@ leader.disconnect() -Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) +Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./il_robots) > [!TIP] > If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/docs/source/lekiwi.mdx b/docs/source/lekiwi.mdx index 14c06e444..875394d71 100644 --- a/docs/source/lekiwi.mdx +++ b/docs/source/lekiwi.mdx @@ -323,7 +323,7 @@ To replay an episode run the API example below, make sure to change `remote_ip`, python examples/lekiwi/replay.py ``` -Congrats 🎉, your robot is all set to learn a task on its own. Start training it by the training part of this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) +Congrats 🎉, your robot is all set to learn a task on its own. Start training it by the training part of this tutorial: [Getting started with real-world robots](./il_robots) ## Evaluate your policy diff --git a/docs/source/smolvla.mdx b/docs/source/smolvla.mdx index 89c475a90..a28e7cb44 100644 --- a/docs/source/smolvla.mdx +++ b/docs/source/smolvla.mdx @@ -29,7 +29,7 @@ SmolVLA is Hugging Face’s lightweight foundation model for robotics. Designed ## Collect a dataset SmolVLA is a base model, so fine-tuning on your own data is required for optimal performance in your setup. -We recommend recording ~50 episodes of your task as a starting point. Follow our guide to get started: [Recording a Dataset](https://huggingface.co/docs/lerobot/getting_started_real_world_robot#record-a-dataset) +We recommend recording ~50 episodes of your task as a starting point. Follow our guide to get started: [Recording a Dataset](./il_robots) @@ -93,7 +93,7 @@ lerobot-train --help ## Evaluate the finetuned model and run it in real-time -Similarly for when recording an episode, it is recommended that you are logged in to the HuggingFace Hub. You can follow the corresponding steps: [Record a dataset](./getting_started_real_world_robot#record-a-dataset). +Similarly for when recording an episode, it is recommended that you are logged in to the HuggingFace Hub. You can follow the corresponding steps: [Record a dataset](./il_robots). Once you are logged in, you can run inference in your setup by doing: ```bash diff --git a/docs/source/so100.mdx b/docs/source/so100.mdx index 8578e1e8d..3c73ae801 100644 --- a/docs/source/so100.mdx +++ b/docs/source/so100.mdx @@ -634,7 +634,7 @@ leader.disconnect()
-Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) +Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./il_robots) > [!TIP] > If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/docs/source/so101.mdx b/docs/source/so101.mdx index b9fb9cab4..00ec3eb74 100644 --- a/docs/source/so101.mdx +++ b/docs/source/so101.mdx @@ -430,7 +430,7 @@ leader.disconnect() -Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) +Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./il_robots) > [!TIP] > If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). From af1760f1757aa1436716422264acaeac11e7e320 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 16:46:38 +0200 Subject: [PATCH 121/158] chore(utils): move benchmark and buffer to their respective modules (#2028) --- {src/lerobot/utils => benchmarks/video}/benchmark.py | 0 benchmarks/video/run_video_benchmark.py | 2 +- src/lerobot/{utils => rl}/buffer.py | 0 src/lerobot/rl/learner.py | 2 +- tests/utils/test_replay_buffer.py | 2 +- 5 files changed, 3 insertions(+), 3 deletions(-) rename {src/lerobot/utils => benchmarks/video}/benchmark.py (100%) rename src/lerobot/{utils => rl}/buffer.py (100%) diff --git a/src/lerobot/utils/benchmark.py b/benchmarks/video/benchmark.py similarity index 100% rename from src/lerobot/utils/benchmark.py rename to benchmarks/video/benchmark.py diff --git a/benchmarks/video/run_video_benchmark.py b/benchmarks/video/run_video_benchmark.py index 5472551f5..f041a9066 100644 --- a/benchmarks/video/run_video_benchmark.py +++ b/benchmarks/video/run_video_benchmark.py @@ -35,12 +35,12 @@ import torch from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity from tqdm import tqdm +from benchmarks.video.benchmark import TimeBenchmark from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.video_utils import ( decode_video_frames_torchvision, encode_video_frames, ) -from lerobot.utils.benchmark import TimeBenchmark BASE_ENCODING = OrderedDict( [ diff --git a/src/lerobot/utils/buffer.py b/src/lerobot/rl/buffer.py similarity index 100% rename from src/lerobot/utils/buffer.py rename to src/lerobot/rl/buffer.py diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index 8d6831286..6fd9fb86e 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -66,6 +66,7 @@ from lerobot.datasets.factory import make_dataset from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.factory import make_policy from lerobot.policies.sac.modeling_sac import SACPolicy +from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions from lerobot.robots import so100_follower # noqa: F401 from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents @@ -76,7 +77,6 @@ from lerobot.transport.utils import ( bytes_to_transitions, state_to_bytes, ) -from lerobot.utils.buffer import ReplayBuffer, concatenate_batch_transitions from lerobot.utils.constants import ( CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK, diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index 8781c5c0d..b5254f393 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -21,7 +21,7 @@ import pytest import torch from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.utils.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized +from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized from tests.fixtures.constants import DUMMY_REPO_ID From ec63225dc150a713cfd576cdfbcac4bd069114a4 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 16:47:37 +0200 Subject: [PATCH 122/158] chore(utils): move encoding utils and process to their respective modules (#2029) Signed-off-by: Steven Palma --- src/lerobot/motors/dynamixel/dynamixel.py | 2 +- src/lerobot/{utils => motors}/encoding_utils.py | 0 src/lerobot/motors/feetech/feetech.py | 2 +- src/lerobot/rl/actor.py | 2 +- src/lerobot/rl/learner.py | 2 +- src/lerobot/{utils => rl}/process.py | 0 tests/motors/test_dynamixel.py | 2 +- tests/motors/test_feetech.py | 2 +- tests/utils/test_encoding_utils.py | 2 +- tests/utils/test_process.py | 2 +- 10 files changed, 8 insertions(+), 8 deletions(-) rename src/lerobot/{utils => motors}/encoding_utils.py (100%) rename src/lerobot/{utils => rl}/process.py (100%) diff --git a/src/lerobot/motors/dynamixel/dynamixel.py b/src/lerobot/motors/dynamixel/dynamixel.py index 1113ec0f7..e1d4e0963 100644 --- a/src/lerobot/motors/dynamixel/dynamixel.py +++ b/src/lerobot/motors/dynamixel/dynamixel.py @@ -22,7 +22,7 @@ import logging from copy import deepcopy from enum import Enum -from lerobot.utils.encoding_utils import decode_twos_complement, encode_twos_complement +from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address from .tables import ( diff --git a/src/lerobot/utils/encoding_utils.py b/src/lerobot/motors/encoding_utils.py similarity index 100% rename from src/lerobot/utils/encoding_utils.py rename to src/lerobot/motors/encoding_utils.py diff --git a/src/lerobot/motors/feetech/feetech.py b/src/lerobot/motors/feetech/feetech.py index 88d45ba39..2ea57af12 100644 --- a/src/lerobot/motors/feetech/feetech.py +++ b/src/lerobot/motors/feetech/feetech.py @@ -17,7 +17,7 @@ from copy import deepcopy from enum import Enum from pprint import pformat -from lerobot.utils.encoding_utils import decode_sign_magnitude, encode_sign_magnitude +from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address from .tables import ( diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index b38858ca6..2606481d3 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -63,6 +63,7 @@ from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.policies.factory import make_policy from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.processor import TransitionKey +from lerobot.rl.process import ProcessSignalHandler from lerobot.robots import so100_follower # noqa: F401 from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents @@ -75,7 +76,6 @@ from lerobot.transport.utils import ( send_bytes_in_chunks, transitions_to_bytes, ) -from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.queue import get_last_item_from_queue from lerobot.utils.random_utils import set_seed from lerobot.utils.robot_utils import busy_wait diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index 6fd9fb86e..1ff343760 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -67,6 +67,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.policies.factory import make_policy from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions +from lerobot.rl.process import ProcessSignalHandler from lerobot.robots import so100_follower # noqa: F401 from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents @@ -83,7 +84,6 @@ from lerobot.utils.constants import ( PRETRAINED_MODEL_DIR, TRAINING_STATE_DIR, ) -from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.random_utils import set_seed from lerobot.utils.train_utils import ( get_step_checkpoint_dir, diff --git a/src/lerobot/utils/process.py b/src/lerobot/rl/process.py similarity index 100% rename from src/lerobot/utils/process.py rename to src/lerobot/rl/process.py diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py index e0dbe713a..8b02d4330 100644 --- a/tests/motors/test_dynamixel.py +++ b/tests/motors/test_dynamixel.py @@ -24,7 +24,7 @@ import pytest from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.dynamixel import MODEL_NUMBER_TABLE, DynamixelMotorsBus from lerobot.motors.dynamixel.tables import X_SERIES_CONTROL_TABLE -from lerobot.utils.encoding_utils import encode_twos_complement +from lerobot.motors.encoding_utils import encode_twos_complement try: import dynamixel_sdk as dxl diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py index 31e4a9018..673276e05 100644 --- a/tests/motors/test_feetech.py +++ b/tests/motors/test_feetech.py @@ -22,9 +22,9 @@ from unittest.mock import MagicMock, patch import pytest from lerobot.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.motors.encoding_utils import encode_sign_magnitude from lerobot.motors.feetech import MODEL_NUMBER, MODEL_NUMBER_TABLE, FeetechMotorsBus from lerobot.motors.feetech.tables import STS_SMS_SERIES_CONTROL_TABLE -from lerobot.utils.encoding_utils import encode_sign_magnitude try: import scservo_sdk as scs diff --git a/tests/utils/test_encoding_utils.py b/tests/utils/test_encoding_utils.py index 813942862..8a0231221 100644 --- a/tests/utils/test_encoding_utils.py +++ b/tests/utils/test_encoding_utils.py @@ -16,7 +16,7 @@ import pytest -from lerobot.utils.encoding_utils import ( +from lerobot.motors.encoding_utils import ( decode_sign_magnitude, decode_twos_complement, encode_sign_magnitude, diff --git a/tests/utils/test_process.py b/tests/utils/test_process.py index 61e6e2c73..e2b00cae9 100644 --- a/tests/utils/test_process.py +++ b/tests/utils/test_process.py @@ -22,7 +22,7 @@ from unittest.mock import patch import pytest -from lerobot.utils.process import ProcessSignalHandler +from lerobot.rl.process import ProcessSignalHandler # Fixture to reset shutdown_event_counter and original signal handlers before and after each test From 853cc70194381794d5f810848f64f28643034daf Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 17:10:27 +0200 Subject: [PATCH 123/158] chore(utils): remove unused utils legacy functions + rename init_rerun (#2031) --- docs/source/il_robots.mdx | 8 ++--- examples/lekiwi/evaluate.py | 4 +-- examples/lekiwi/record.py | 4 +-- examples/lekiwi/teleoperate.py | 4 +-- examples/phone_to_so100/evaluate.py | 4 +-- examples/phone_to_so100/record.py | 4 +-- examples/phone_to_so100/teleoperate.py | 4 +-- examples/so100_to_so100_EE/evaluate.py | 4 +-- examples/so100_to_so100_EE/record.py | 4 +-- examples/so100_to_so100_EE/teleoperate.py | 4 +-- src/lerobot/scripts/lerobot_record.py | 4 +-- src/lerobot/scripts/lerobot_teleoperate.py | 4 +-- src/lerobot/utils/robot_utils.py | 14 -------- src/lerobot/utils/train_utils.py | 6 ---- src/lerobot/utils/utils.py | 39 +--------------------- src/lerobot/utils/visualization_utils.py | 2 +- 16 files changed, 28 insertions(+), 85 deletions(-) diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 19b62167e..91df14028 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -200,7 +200,7 @@ from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderCo from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say -from lerobot.utils.visualization_utils import _init_rerun +from lerobot.utils.visualization_utils import init_rerun from lerobot.record import record_loop NUM_EPISODES = 5 @@ -237,7 +237,7 @@ dataset = LeRobotDataset.create( # Initialize the keyboard listener and rerun visualization _, events = init_keyboard_listener() -_init_rerun(session_name="recording") +init_rerun(session_name="recording") # Connect the robot and teleoperator robot.connect() @@ -517,7 +517,7 @@ from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerCon from lerobot.robots.so100_follower.so100_follower import SO100Follower from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say -from lerobot.utils.visualization_utils import _init_rerun +from lerobot.utils.visualization_utils import init_rerun from lerobot.record import record_loop from lerobot.policies.factory import make_processor @@ -557,7 +557,7 @@ dataset = LeRobotDataset.create( # Initialize the keyboard listener and rerun visualization _, events = init_keyboard_listener() -_init_rerun(session_name="recording") +init_rerun(session_name="recording") # Connect the robot robot.connect() diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 8993a5e14..32a5e0a2b 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -23,7 +23,7 @@ from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.scripts.lerobot_record import record_loop from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say -from lerobot.utils.visualization_utils import _init_rerun +from lerobot.utils.visualization_utils import init_rerun NUM_EPISODES = 2 FPS = 30 @@ -73,7 +73,7 @@ teleop_action_processor, robot_action_processor, robot_observation_processor = m # Initialize the keyboard listener and rerun visualization listener, events = init_keyboard_listener() -_init_rerun(session_name="lekiwi_evaluate") +init_rerun(session_name="lekiwi_evaluate") if not robot.is_connected: raise ValueError("Robot is not connected!") diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index f59093b26..30f34e718 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -24,7 +24,7 @@ from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say -from lerobot.utils.visualization_utils import _init_rerun +from lerobot.utils.visualization_utils import init_rerun NUM_EPISODES = 2 FPS = 30 @@ -69,7 +69,7 @@ keyboard.connect() # Initialize the keyboard listener and rerun visualization listener, events = init_keyboard_listener() -_init_rerun(session_name="lekiwi_record") +init_rerun(session_name="lekiwi_record") if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected: raise ValueError("Robot or teleop is not connected!") diff --git a/examples/lekiwi/teleoperate.py b/examples/lekiwi/teleoperate.py index cde4000df..6b430df48 100644 --- a/examples/lekiwi/teleoperate.py +++ b/examples/lekiwi/teleoperate.py @@ -20,7 +20,7 @@ from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop, KeyboardTeleopConfig from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig from lerobot.utils.robot_utils import busy_wait -from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data +from lerobot.utils.visualization_utils import init_rerun, log_rerun_data FPS = 30 @@ -41,7 +41,7 @@ leader_arm.connect() keyboard.connect() # Init rerun viewer -_init_rerun(session_name="lekiwi_teleop") +init_rerun(session_name="lekiwi_teleop") if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected: raise ValueError("Robot or teleop is not connected!") diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index c7d6eb240..0d53f1177 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -43,7 +43,7 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower from lerobot.scripts.lerobot_record import record_loop from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say -from lerobot.utils.visualization_utils import _init_rerun +from lerobot.utils.visualization_utils import init_rerun NUM_EPISODES = 5 FPS = 30 @@ -137,7 +137,7 @@ robot.connect() # Initialize the keyboard listener and rerun visualization listener, events = init_keyboard_listener() -_init_rerun(session_name="phone_so100_evaluate") +init_rerun(session_name="phone_so100_evaluate") if not robot.is_connected: raise ValueError("Robot is not connected!") diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index 6681017a0..bb2e2f5f7 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -41,7 +41,7 @@ from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAct from lerobot.teleoperators.phone.teleop_phone import Phone from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say -from lerobot.utils.visualization_utils import _init_rerun +from lerobot.utils.visualization_utils import init_rerun NUM_EPISODES = 2 FPS = 30 @@ -143,7 +143,7 @@ phone.connect() # Initialize the keyboard listener and rerun visualization listener, events = init_keyboard_listener() -_init_rerun(session_name="phone_so100_record") +init_rerun(session_name="phone_so100_record") if not robot.is_connected or not phone.is_connected: raise ValueError("Robot or teleop is not connected!") diff --git a/examples/phone_to_so100/teleoperate.py b/examples/phone_to_so100/teleoperate.py index eb5ed3526..6c49a8453 100644 --- a/examples/phone_to_so100/teleoperate.py +++ b/examples/phone_to_so100/teleoperate.py @@ -33,7 +33,7 @@ from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction from lerobot.teleoperators.phone.teleop_phone import Phone from lerobot.utils.robot_utils import busy_wait -from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data +from lerobot.utils.visualization_utils import init_rerun, log_rerun_data FPS = 30 @@ -87,7 +87,7 @@ robot.connect() teleop_device.connect() # Init rerun viewer -_init_rerun(session_name="phone_so100_teleop") +init_rerun(session_name="phone_so100_teleop") if not robot.is_connected or not teleop_device.is_connected: raise ValueError("Robot or teleop is not connected!") diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py index f47a216d6..53a385442 100644 --- a/examples/so100_to_so100_EE/evaluate.py +++ b/examples/so100_to_so100_EE/evaluate.py @@ -43,7 +43,7 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower from lerobot.scripts.lerobot_record import record_loop from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say -from lerobot.utils.visualization_utils import _init_rerun +from lerobot.utils.visualization_utils import init_rerun NUM_EPISODES = 5 FPS = 30 @@ -138,7 +138,7 @@ robot.connect() # Initialize the keyboard listener and rerun visualization listener, events = init_keyboard_listener() -_init_rerun(session_name="so100_so100_evaluate") +init_rerun(session_name="so100_so100_evaluate") if not robot.is_connected: raise ValueError("Robot is not connected!") diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index 60c96835f..6c38553e2 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -39,7 +39,7 @@ from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderCo from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say -from lerobot.utils.visualization_utils import _init_rerun +from lerobot.utils.visualization_utils import init_rerun NUM_EPISODES = 2 FPS = 30 @@ -143,7 +143,7 @@ follower.connect() # Initialize the keyboard listener and rerun visualization listener, events = init_keyboard_listener() -_init_rerun(session_name="recording_phone") +init_rerun(session_name="recording_phone") if not leader.is_connected or not follower.is_connected: raise ValueError("Robot or teleop is not connected!") diff --git a/examples/so100_to_so100_EE/teleoperate.py b/examples/so100_to_so100_EE/teleoperate.py index ab54e7236..aa9755788 100644 --- a/examples/so100_to_so100_EE/teleoperate.py +++ b/examples/so100_to_so100_EE/teleoperate.py @@ -33,7 +33,7 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader from lerobot.utils.robot_utils import busy_wait -from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data +from lerobot.utils.visualization_utils import init_rerun, log_rerun_data FPS = 30 @@ -95,7 +95,7 @@ follower.connect() leader.connect() # Init rerun viewer -_init_rerun(session_name="so100_so100_EE_teleop") +init_rerun(session_name="so100_so100_EE_teleop") print("Starting teleop loop...") while True: diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index d09b017e4..dd4984fab 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -122,7 +122,7 @@ from lerobot.utils.utils import ( init_logging, log_say, ) -from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data +from lerobot.utils.visualization_utils import init_rerun, log_rerun_data @dataclass @@ -378,7 +378,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: init_logging() logging.info(pformat(asdict(cfg))) if cfg.display_data: - _init_rerun(session_name="recording") + init_rerun(session_name="recording") robot = make_robot_from_config(cfg.robot) teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index 62c243e95..ab9a6361d 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -90,7 +90,7 @@ from lerobot.teleoperators import ( # noqa: F401 ) from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import init_logging, move_cursor_up -from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data +from lerobot.utils.visualization_utils import init_rerun, log_rerun_data @dataclass @@ -185,7 +185,7 @@ def teleoperate(cfg: TeleoperateConfig): init_logging() logging.info(pformat(asdict(cfg))) if cfg.display_data: - _init_rerun(session_name="teleoperation") + init_rerun(session_name="teleoperation") teleop = make_teleoperator_from_config(cfg.teleop) robot = make_robot_from_config(cfg.robot) diff --git a/src/lerobot/utils/robot_utils.py b/src/lerobot/utils/robot_utils.py index 8069b3662..42abcdda4 100644 --- a/src/lerobot/utils/robot_utils.py +++ b/src/lerobot/utils/robot_utils.py @@ -27,17 +27,3 @@ def busy_wait(seconds): # On Linux time.sleep is accurate if seconds > 0: time.sleep(seconds) - - -def safe_disconnect(func): - # TODO(aliberts): Allow to pass custom exceptions - # (e.g. ThreadServiceExit, KeyboardInterrupt, SystemExit, UnpluggedError, DynamixelCommError) - def wrapper(robot, *args, **kwargs): - try: - return func(robot, *args, **kwargs) - except Exception as e: - if robot.is_connected: - robot.disconnect() - raise e - - return wrapper diff --git a/src/lerobot/utils/train_utils.py b/src/lerobot/utils/train_utils.py index 08d1bcc9d..3ebe31971 100644 --- a/src/lerobot/utils/train_utils.py +++ b/src/lerobot/utils/train_utils.py @@ -13,10 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging from pathlib import Path -from termcolor import colored from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -36,10 +34,6 @@ from lerobot.utils.constants import ( from lerobot.utils.random_utils import load_rng_state, save_rng_state -def log_output_dir(out_dir): - logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") - - def get_step_identifier(step: int, total_steps: int) -> str: num_digits = max(6, len(str(total_steps))) return f"{step:0{num_digits}d}" diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 107606fda..523a5e4d2 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -15,14 +15,13 @@ # limitations under the License. import logging import os -import os.path as osp import platform import select import subprocess import sys import time from copy import copy, deepcopy -from datetime import datetime, timezone +from datetime import datetime from pathlib import Path from statistics import mean @@ -30,12 +29,6 @@ import numpy as np import torch -def none_or_int(value): - if value == "None": - return None - return int(value) - - def inside_slurm(): """Check whether the python process was launched through slurm""" # TODO(rcadene): return False for interactive mode `--pty bash` @@ -165,36 +158,6 @@ def format_big_number(num, precision=0): return num -def _relative_path_between(path1: Path, path2: Path) -> Path: - """Returns path1 relative to path2.""" - path1 = path1.absolute() - path2 = path2.absolute() - try: - return path1.relative_to(path2) - except ValueError: # most likely because path1 is not a subpath of path2 - common_parts = Path(osp.commonpath([path1, path2])).parts - return Path( - "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) - ) - - -def print_cuda_memory_usage(): - """Use this function to locate and debug memory leak.""" - import gc - - gc.collect() - # Also clear the cache if you want to fully release the memory - torch.cuda.empty_cache() - print(f"Current GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB") - print(f"Maximum GPU Memory Allocated: {torch.cuda.max_memory_allocated(0) / 1024**2:.2f} MB") - print(f"Current GPU Memory Reserved: {torch.cuda.memory_reserved(0) / 1024**2:.2f} MB") - print(f"Maximum GPU Memory Reserved: {torch.cuda.max_memory_reserved(0) / 1024**2:.2f} MB") - - -def capture_timestamp_utc(): - return datetime.now(timezone.utc) - - def say(text: str, blocking: bool = False): system = platform.system() diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index e6acc87de..7fc881f26 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -20,7 +20,7 @@ import numpy as np import rerun as rr -def _init_rerun(session_name: str = "lerobot_control_loop") -> None: +def init_rerun(session_name: str = "lerobot_control_loop") -> None: """Initializes the Rerun SDK for visualizing the control loop.""" batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000") os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size From 170c09e7f63660934d8e784df202d04bea2e77bc Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 17:10:52 +0200 Subject: [PATCH 124/158] chore(utils): move queue utils and wandb_utils to their respective modules (#2030) * chore(utils): move queue utils and wandb_utils to their respective modules * fix(rl): remove double imports --------- Signed-off-by: Steven Palma --- src/lerobot/rl/actor.py | 2 +- src/lerobot/rl/learner.py | 4 ++-- src/lerobot/rl/learner_service.py | 2 +- src/lerobot/{utils => rl}/queue.py | 0 src/lerobot/{utils => rl}/wandb_utils.py | 0 src/lerobot/scripts/train.py | 2 +- tests/{utils => rl}/test_queue.py | 2 +- 7 files changed, 6 insertions(+), 6 deletions(-) rename src/lerobot/{utils => rl}/queue.py (100%) rename src/lerobot/{utils => rl}/wandb_utils.py (100%) rename tests/{utils => rl}/test_queue.py (98%) diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 2606481d3..3c025a05d 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -64,6 +64,7 @@ from lerobot.policies.factory import make_policy from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.processor import TransitionKey from lerobot.rl.process import ProcessSignalHandler +from lerobot.rl.queue import get_last_item_from_queue from lerobot.robots import so100_follower # noqa: F401 from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents @@ -76,7 +77,6 @@ from lerobot.transport.utils import ( send_bytes_in_chunks, transitions_to_bytes, ) -from lerobot.utils.queue import get_last_item_from_queue from lerobot.utils.random_utils import set_seed from lerobot.utils.robot_utils import busy_wait from lerobot.utils.transition import ( diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index 1ff343760..0faa460ef 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -68,6 +68,7 @@ from lerobot.policies.factory import make_policy from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions from lerobot.rl.process import ProcessSignalHandler +from lerobot.rl.wandb_utils import WandBLogger from lerobot.robots import so100_follower # noqa: F401 from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents @@ -97,7 +98,6 @@ from lerobot.utils.utils import ( get_safe_torch_device, init_logging, ) -from lerobot.utils.wandb_utils import WandBLogger from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService @@ -153,7 +153,7 @@ def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None): # Setup WandB logging if enabled if cfg.wandb.enable and cfg.wandb.project: - from lerobot.utils.wandb_utils import WandBLogger + from lerobot.rl.wandb_utils import WandBLogger wandb_logger = WandBLogger(cfg) else: diff --git a/src/lerobot/rl/learner_service.py b/src/lerobot/rl/learner_service.py index b07c296e6..7ef38119b 100644 --- a/src/lerobot/rl/learner_service.py +++ b/src/lerobot/rl/learner_service.py @@ -19,9 +19,9 @@ import logging import time from multiprocessing import Event, Queue +from lerobot.rl.queue import get_last_item_from_queue from lerobot.transport import services_pb2, services_pb2_grpc from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks -from lerobot.utils.queue import get_last_item_from_queue MAX_WORKERS = 3 # Stream parameters, send transitions and interactions SHUTDOWN_TIMEOUT = 10 diff --git a/src/lerobot/utils/queue.py b/src/lerobot/rl/queue.py similarity index 100% rename from src/lerobot/utils/queue.py rename to src/lerobot/rl/queue.py diff --git a/src/lerobot/utils/wandb_utils.py b/src/lerobot/rl/wandb_utils.py similarity index 100% rename from src/lerobot/utils/wandb_utils.py rename to src/lerobot/rl/wandb_utils.py diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 21da62bbb..df33f1dbe 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -35,6 +35,7 @@ from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters +from lerobot.rl.wandb_utils import WandBLogger from lerobot.scripts.eval import eval_policy_all from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.random_utils import set_seed @@ -51,7 +52,6 @@ from lerobot.utils.utils import ( has_method, init_logging, ) -from lerobot.utils.wandb_utils import WandBLogger def update_policy( diff --git a/tests/utils/test_queue.py b/tests/rl/test_queue.py similarity index 98% rename from tests/utils/test_queue.py rename to tests/rl/test_queue.py index 6e42acdb7..b6716fbd6 100644 --- a/tests/utils/test_queue.py +++ b/tests/rl/test_queue.py @@ -20,7 +20,7 @@ from queue import Queue from torch.multiprocessing import Queue as TorchMPQueue -from lerobot.utils.queue import get_last_item_from_queue +from lerobot.rl.queue import get_last_item_from_queue def test_get_last_item_single_item(): From a87d4c9a749da72d782fe4f37fdb9d498979e1d1 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Wed, 24 Sep 2025 17:30:32 +0200 Subject: [PATCH 125/158] (docs): small change in dataset name (#2032) * small change Signed-off-by: Jade Choghari * update Signed-off-by: Jade Choghari --------- Signed-off-by: Jade Choghari --- docs/source/libero.mdx | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/libero.mdx b/docs/source/libero.mdx index 488c02ce0..17e12d45e 100644 --- a/docs/source/libero.mdx +++ b/docs/source/libero.mdx @@ -33,7 +33,7 @@ To Install LIBERO, after following LeRobot official instructions, just do: Evaluate a policy on one LIBERO suite: ```bash -python src/lerobot/scripts/eval.py \ +lerobot-eval \ --policy.path="your-policy-id" \ --env.type=libero \ --env.task=libero_object \ @@ -52,7 +52,7 @@ python src/lerobot/scripts/eval.py \ Benchmark a policy across multiple suites at once: ```bash -python src/lerobot/scripts/eval.py \ +lerobot-eval \ --policy.path="your-policy-id" \ --env.type=libero \ --env.task=libero_object,libero_spatial \ @@ -103,10 +103,10 @@ For reference, here is the **original dataset** published by Physical Intelligen ### Example training command ```bash -python src/lerobot/scripts/train.py \ +lerobot-train \ --policy.type=smolvla \ --policy.repo_id=${HF_USER}/libero-test \ - --dataset.repo_id=jadechoghari/smol-libero3 \ + --dataset.repo_id=HuggingFaceVLA/libero \ --env.type=libero \ --env.task=libero_10 \ --output_dir=./outputs/ \ From ddba994d73e6315e78c76173cd4fa90d471fc662 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Sep 2025 18:29:58 +0200 Subject: [PATCH 126/158] chore(scripts): rename eval and train scripts (#2033) --- pyproject.toml | 4 ++-- src/lerobot/scripts/{eval.py => lerobot_eval.py} | 0 src/lerobot/scripts/{train.py => lerobot_train.py} | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) rename src/lerobot/scripts/{eval.py => lerobot_eval.py} (100%) rename src/lerobot/scripts/{train.py => lerobot_train.py} (99%) diff --git a/pyproject.toml b/pyproject.toml index dbc25805d..d2f1e502a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -169,8 +169,8 @@ lerobot-record="lerobot.scripts.lerobot_record:main" lerobot-replay="lerobot.scripts.lerobot_replay:main" lerobot-setup-motors="lerobot.scripts.lerobot_setup_motors:main" lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main" -lerobot-eval="lerobot.scripts.eval:main" -lerobot-train="lerobot.scripts.train:main" +lerobot-eval="lerobot.scripts.lerobot_eval:main" +lerobot-train="lerobot.scripts.lerobot_train:main" lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main" lerobot-info="lerobot.scripts.lerobot_info:main" lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main" diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/lerobot_eval.py similarity index 100% rename from src/lerobot/scripts/eval.py rename to src/lerobot/scripts/lerobot_eval.py diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/lerobot_train.py similarity index 99% rename from src/lerobot/scripts/train.py rename to src/lerobot/scripts/lerobot_train.py index df33f1dbe..5ef8c7263 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -36,7 +36,7 @@ from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.utils import get_device_from_parameters from lerobot.rl.wandb_utils import WandBLogger -from lerobot.scripts.eval import eval_policy_all +from lerobot.scripts.lerobot_eval import eval_policy_all from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.random_utils import set_seed from lerobot.utils.train_utils import ( From 43d878a102a5dd5f4f75d220505a68671d3e0c84 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 25 Sep 2025 15:36:47 +0200 Subject: [PATCH 127/158] chore: replace hard-coded obs values with constants throughout all the source code (#2037) * chore: replace hard-coded OBS values with constants throughout all the source code * chore(tests): replace hard-coded OBS values with constants throughout all the test code --- benchmarks/video/run_video_benchmark.py | 3 +- examples/lekiwi/evaluate.py | 3 +- examples/lekiwi/record.py | 3 +- src/lerobot/async_inference/helpers.py | 6 +- src/lerobot/datasets/factory.py | 3 +- src/lerobot/datasets/pipeline_features.py | 18 +- src/lerobot/datasets/utils.py | 7 +- src/lerobot/envs/utils.py | 9 +- src/lerobot/policies/act/modeling_act.py | 22 +- .../policies/diffusion/modeling_diffusion.py | 18 +- src/lerobot/policies/pi0/configuration_pi0.py | 3 +- .../conversion_scripts/compare_with_jax.py | 15 +- .../policies/pi0fast/configuration_pi0fast.py | 3 +- src/lerobot/policies/sac/modeling_sac.py | 13 +- .../reward_model/configuration_classifier.py | 3 +- .../policies/smolvla/configuration_smolvla.py | 3 +- src/lerobot/policies/tdmpc/modeling_tdmpc.py | 24 +- src/lerobot/policies/vqbet/modeling_vqbet.py | 14 +- src/lerobot/processor/converters.py | 4 +- .../processor/observation_processor.py | 6 +- src/lerobot/rl/buffer.py | 3 +- src/lerobot/rl/gym_manipulator.py | 7 +- src/lerobot/robots/lekiwi/lekiwi_client.py | 3 +- src/lerobot/scripts/lerobot_dataset_viz.py | 5 +- src/lerobot/scripts/lerobot_eval.py | 7 +- src/lerobot/scripts/lerobot_record.py | 3 +- src/lerobot/utils/constants.py | 18 +- src/lerobot/utils/visualization_utils.py | 4 +- .../policies/save_policy_to_safetensors.py | 3 +- tests/async_inference/test_helpers.py | 55 +-- tests/async_inference/test_policy_server.py | 5 +- tests/datasets/test_compute_stats.py | 33 +- tests/datasets/test_dataset_utils.py | 9 +- tests/datasets/test_datasets.py | 9 +- .../hilserl/test_modeling_classifier.py | 9 +- tests/policies/test_policies.py | 14 +- tests/policies/test_sac_config.py | 13 +- tests/policies/test_sac_policy.py | 21 +- tests/processor/test_act_processor.py | 2 +- tests/processor/test_batch_conversion.py | 79 ++-- tests/processor/test_converters.py | 17 +- tests/processor/test_device_processor.py | 99 ++-- tests/processor/test_migration_detection.py | 3 +- tests/processor/test_normalize_processor.py | 437 +++++++++--------- tests/processor/test_observation_processor.py | 46 +- tests/processor/test_pipeline.py | 37 +- tests/processor/test_rename_processor.py | 73 ++- tests/processor/test_tokenizer_processor.py | 24 +- tests/rl/test_actor.py | 5 +- tests/rl/test_actor_learner.py | 5 +- tests/utils/test_replay_buffer.py | 73 ++- tests/utils/test_visualization_utils.py | 7 +- 52 files changed, 659 insertions(+), 649 deletions(-) diff --git a/benchmarks/video/run_video_benchmark.py b/benchmarks/video/run_video_benchmark.py index f041a9066..9f34b2273 100644 --- a/benchmarks/video/run_video_benchmark.py +++ b/benchmarks/video/run_video_benchmark.py @@ -41,6 +41,7 @@ from lerobot.datasets.video_utils import ( decode_video_frames_torchvision, encode_video_frames, ) +from lerobot.utils.constants import OBS_IMAGE BASE_ENCODING = OrderedDict( [ @@ -117,7 +118,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None: hf_dataset = dataset.hf_dataset.with_format(None) # We only save images from the first camera - img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")] + img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)] imgs_dataset = hf_dataset.select_columns(img_keys[0]) for i, item in enumerate( diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 32a5e0a2b..174486eb8 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -21,6 +21,7 @@ from lerobot.policies.factory import make_pre_post_processors from lerobot.processor import make_default_processors from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.scripts.lerobot_record import record_loop +from lerobot.utils.constants import OBS_STR from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun @@ -42,7 +43,7 @@ policy = ACTPolicy.from_pretrained(HF_MODEL_ID) # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, "action") -obs_features = hw_to_dataset_features(robot.observation_features, "observation") +obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR) dataset_features = {**action_features, **obs_features} # Create the dataset diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 30f34e718..471cb3668 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -22,6 +22,7 @@ from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig +from lerobot.utils.constants import OBS_STR from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun @@ -48,7 +49,7 @@ teleop_action_processor, robot_action_processor, robot_observation_processor = m # Configure the dataset features action_features = hw_to_dataset_features(robot.action_features, "action") -obs_features = hw_to_dataset_features(robot.observation_features, "observation") +obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR) dataset_features = {**action_features, **obs_features} # Create the dataset diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 175cecf6d..75d81a0f3 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -27,7 +27,7 @@ from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features # NOTE: Configs need to be loaded for the client to be able to instantiate the policy config from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401 from lerobot.robots.robot import Robot -from lerobot.utils.constants import OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR from lerobot.utils.utils import init_logging Action = torch.Tensor @@ -66,7 +66,7 @@ def validate_robot_cameras_for_policy( def map_robot_keys_to_lerobot_features(robot: Robot) -> dict[str, dict]: - return hw_to_dataset_features(robot.observation_features, "observation", use_video=False) + return hw_to_dataset_features(robot.observation_features, OBS_STR, use_video=False) def is_image_key(k: str) -> bool: @@ -141,7 +141,7 @@ def make_lerobot_observation( lerobot_features: dict[str, dict], ) -> LeRobotObservation: """Make a lerobot observation from a raw observation.""" - return build_dataset_frame(lerobot_features, robot_obs, prefix="observation") + return build_dataset_frame(lerobot_features, robot_obs, prefix=OBS_STR) def prepare_raw_observation( diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index a71e978bc..2bac84aed 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -27,6 +27,7 @@ from lerobot.datasets.lerobot_dataset import ( ) from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.transforms import ImageTransforms +from lerobot.utils.constants import OBS_PREFIX IMAGENET_STATS = { "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) @@ -58,7 +59,7 @@ def resolve_delta_timestamps( delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices] if key == "action" and cfg.action_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] - if key.startswith("observation.") and cfg.observation_delta_indices is not None: + if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices] if len(delta_timestamps) == 0: diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index cdf0b7448..13555dd31 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -19,7 +19,7 @@ from typing import Any from lerobot.configs.types import PipelineFeatureType from lerobot.datasets.utils import hw_to_dataset_features from lerobot.processor import DataProcessorPipeline -from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR def create_initial_features( @@ -92,8 +92,8 @@ def aggregate_pipeline_dataset_features( # Intermediate storage for categorized and filtered features. processed_features: dict[str, dict[str, Any]] = { - "action": {}, - "observation": {}, + ACTION: {}, + OBS_STR: {}, } images_token = OBS_IMAGES.split(".")[-1] @@ -125,17 +125,15 @@ def aggregate_pipeline_dataset_features( # 3. Add the feature to the appropriate group with a clean name. name = strip_prefix(key, PREFIXES_TO_STRIP) if is_action: - processed_features["action"][name] = value + processed_features[ACTION][name] = value else: - processed_features["observation"][name] = value + processed_features[OBS_STR][name] = value # Convert the processed features into the final dataset format. dataset_features = {} - if processed_features["action"]: + if processed_features[ACTION]: dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos)) - if processed_features["observation"]: - dataset_features.update( - hw_to_dataset_features(processed_features["observation"], "observation", use_videos) - ) + if processed_features[OBS_STR]: + dataset_features.update(hw_to_dataset_features(processed_features[OBS_STR], OBS_STR, use_videos)) return dataset_features diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 922fc4e3f..96ae2eca6 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -43,6 +43,7 @@ from lerobot.datasets.backward_compatibility import ( BackwardCompatibilityError, ForwardCompatibilityError, ) +from lerobot.utils.constants import OBS_ENV_STATE, OBS_STR from lerobot.utils.utils import is_valid_numpy_dtype_string DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk @@ -652,7 +653,7 @@ def hw_to_dataset_features( "names": list(joint_fts), } - if joint_fts and prefix == "observation": + if joint_fts and prefix == OBS_STR: features[f"{prefix}.state"] = { "dtype": "float32", "shape": (len(joint_fts),), @@ -728,9 +729,9 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) shape = (shape[2], shape[0], shape[1]) - elif key == "observation.environment_state": + elif key == OBS_ENV_STATE: type = FeatureType.ENV - elif key.startswith("observation"): + elif key.startswith(OBS_STR): type = FeatureType.STATE elif key.startswith("action"): type = FeatureType.ACTION diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index f0aa0b5c6..023ceea67 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -26,6 +26,7 @@ from torch import Tensor from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.envs.configs import EnvConfig +from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE from lerobot.utils.utils import get_channel_first_image_shape @@ -41,9 +42,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten return_observations = {} if "pixels" in observations: if isinstance(observations["pixels"], dict): - imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} + imgs = {f"{OBS_IMAGES}.{key}": img for key, img in observations["pixels"].items()} else: - imgs = {"observation.image": observations["pixels"]} + imgs = {OBS_IMAGE: observations["pixels"]} for imgkey, img in imgs.items(): # TODO(aliberts, rcadene): use transforms.ToTensor()? @@ -72,13 +73,13 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten if env_state.dim() == 1: env_state = env_state.unsqueeze(0) - return_observations["observation.environment_state"] = env_state + return_observations[OBS_ENV_STATE] = env_state # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing agent_pos = torch.from_numpy(observations["agent_pos"]).float() if agent_pos.dim() == 1: agent_pos = agent_pos.unsqueeze(0) - return_observations["observation.state"] = agent_pos + return_observations[OBS_STATE] = agent_pos return return_observations diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index e4ebec199..f8261bb7f 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -35,7 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.constants import ACTION, OBS_IMAGES +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE class ACTPolicy(PreTrainedPolicy): @@ -398,10 +398,10 @@ class ACT(nn.Module): "actions must be provided when using the variational objective in training mode." ) - if "observation.images" in batch: - batch_size = batch["observation.images"][0].shape[0] + if OBS_IMAGES in batch: + batch_size = batch[OBS_IMAGES][0].shape[0] else: - batch_size = batch["observation.environment_state"].shape[0] + batch_size = batch[OBS_ENV_STATE].shape[0] # Prepare the latent for input to the transformer encoder. if self.config.use_vae and "action" in batch and self.training: @@ -410,7 +410,7 @@ class ACT(nn.Module): self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size ) # (B, 1, D) if self.config.robot_state_feature: - robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) + robot_state_embed = self.vae_encoder_robot_state_input_proj(batch[OBS_STATE]) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) @@ -430,7 +430,7 @@ class ACT(nn.Module): cls_joint_is_pad = torch.full( (batch_size, 2 if self.config.robot_state_feature else 1), False, - device=batch["observation.state"].device, + device=batch[OBS_STATE].device, ) key_padding_mask = torch.cat( [cls_joint_is_pad, batch["action_is_pad"]], axis=1 @@ -454,7 +454,7 @@ class ACT(nn.Module): mu = log_sigma_x2 = None # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( - batch["observation.state"].device + batch[OBS_STATE].device ) # Prepare transformer encoder inputs. @@ -462,18 +462,16 @@ class ACT(nn.Module): encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)) # Robot state token. if self.config.robot_state_feature: - encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"])) + encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch[OBS_STATE])) # Environment state token. if self.config.env_state_feature: - encoder_in_tokens.append( - self.encoder_env_state_input_proj(batch["observation.environment_state"]) - ) + encoder_in_tokens.append(self.encoder_env_state_input_proj(batch[OBS_ENV_STATE])) if self.config.image_features: # For a list of images, the H and W may vary but H*W is constant. # NOTE: If modifying this section, verify on MPS devices that # gradients remain stable (no explosions or NaNs). - for img in batch["observation.images"]: + for img in batch[OBS_IMAGES]: cam_features = self.backbone(img)["feature_map"] cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) cam_features = self.encoder_img_feat_input_proj(cam_features) diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 0bd2e282b..af1327ba2 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -81,13 +81,13 @@ class DiffusionPolicy(PreTrainedPolicy): def reset(self): """Clear observation and action queues. Should be called on `env.reset()`""" self._queues = { - "observation.state": deque(maxlen=self.config.n_obs_steps), + OBS_STATE: deque(maxlen=self.config.n_obs_steps), "action": deque(maxlen=self.config.n_action_steps), } if self.config.image_features: - self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps) + self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps) if self.config.env_state_feature: - self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) + self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps) @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: @@ -234,7 +234,7 @@ class DiffusionModel(nn.Module): if self.config.image_features: if self.config.use_separate_rgb_encoder_per_camera: # Combine batch and sequence dims while rearranging to make the camera index dimension first. - images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...") + images_per_camera = einops.rearrange(batch[OBS_IMAGES], "b s n ... -> n (b s) ...") img_features_list = torch.cat( [ encoder(images) @@ -249,7 +249,7 @@ class DiffusionModel(nn.Module): else: # Combine batch, sequence, and "which camera" dims before passing to shared encoder. img_features = self.rgb_encoder( - einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") + einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ...") ) # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the # feature dim (effectively concatenating the camera features). @@ -275,7 +275,7 @@ class DiffusionModel(nn.Module): "observation.environment_state": (B, n_obs_steps, environment_dim) } """ - batch_size, n_obs_steps = batch["observation.state"].shape[:2] + batch_size, n_obs_steps = batch[OBS_STATE].shape[:2] assert n_obs_steps == self.config.n_obs_steps # Encode image features and concatenate them all together along with the state vector. @@ -306,9 +306,9 @@ class DiffusionModel(nn.Module): } """ # Input validation. - assert set(batch).issuperset({"observation.state", "action", "action_is_pad"}) - assert "observation.images" in batch or "observation.environment_state" in batch - n_obs_steps = batch["observation.state"].shape[1] + assert set(batch).issuperset({OBS_STATE, "action", "action_is_pad"}) + assert OBS_IMAGES in batch or OBS_ENV_STATE in batch + n_obs_steps = batch[OBS_STATE].shape[1] horizon = batch["action"].shape[1] assert horizon == self.config.horizon assert n_obs_steps == self.config.n_obs_steps diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py index c9728e418..bd5bbf7ee 100644 --- a/src/lerobot/policies/pi0/configuration_pi0.py +++ b/src/lerobot/policies/pi0/configuration_pi0.py @@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) +from lerobot.utils.constants import OBS_IMAGES @PreTrainedConfig.register_subclass("pi0") @@ -113,7 +114,7 @@ class PI0Config(PreTrainedConfig): # raise ValueError("You must provide at least one image or the environment state among the inputs.") for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" + key = f"{OBS_IMAGES}.empty_camera_{i}" empty_camera = PolicyFeature( type=FeatureType.VISUAL, shape=(3, 480, 640), diff --git a/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py b/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py index c0c2e4816..fe9865697 100644 --- a/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py +++ b/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py @@ -21,6 +21,7 @@ import torch from lerobot.configs.policies import PreTrainedConfig from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.policies.factory import make_policy +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE def display(tensor: torch.Tensor): @@ -60,26 +61,26 @@ def main(): # Override stats dataset_meta = LeRobotDatasetMetadata(dataset_repo_id) - dataset_meta.stats["observation.state"]["mean"] = torch.tensor( + dataset_meta.stats[OBS_STATE]["mean"] = torch.tensor( norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32 ) - dataset_meta.stats["observation.state"]["std"] = torch.tensor( + dataset_meta.stats[OBS_STATE]["std"] = torch.tensor( norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32 ) # Create LeRobot batch from Jax batch = {} for cam_key, uint_chw_array in example["images"].items(): - batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0 - batch["observation.state"] = torch.from_numpy(example["state"]) + batch[f"{OBS_IMAGES}.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0 + batch[OBS_STATE] = torch.from_numpy(example["state"]) batch["action"] = torch.from_numpy(outputs["actions"]) batch["task"] = example["prompt"] if model_name == "pi0_aloha_towel": - del batch["observation.images.cam_low"] + del batch[f"{OBS_IMAGES}.cam_low"] elif model_name == "pi0_aloha_sim": - batch["observation.images.top"] = batch["observation.images.cam_high"] - del batch["observation.images.cam_high"] + batch[f"{OBS_IMAGES}.top"] = batch[f"{OBS_IMAGES}.cam_high"] + del batch[f"{OBS_IMAGES}.cam_high"] # Batchify for key in batch: diff --git a/src/lerobot/policies/pi0fast/configuration_pi0fast.py b/src/lerobot/policies/pi0fast/configuration_pi0fast.py index b72bcd735..705b61ea8 100644 --- a/src/lerobot/policies/pi0fast/configuration_pi0fast.py +++ b/src/lerobot/policies/pi0fast/configuration_pi0fast.py @@ -6,6 +6,7 @@ from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) +from lerobot.utils.constants import OBS_IMAGES @PreTrainedConfig.register_subclass("pi0fast") @@ -99,7 +100,7 @@ class PI0FASTConfig(PreTrainedConfig): def validate_features(self) -> None: for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" + key = f"{OBS_IMAGES}.empty_camera_{i}" empty_camera = PolicyFeature( type=FeatureType.VISUAL, shape=(3, 480, 640), diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index fcaf02a4b..a6ed79d4e 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -31,6 +31,7 @@ from torch.distributions import MultivariateNormal, TanhTransform, Transform, Tr from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature from lerobot.policies.utils import get_device_from_parameters +from lerobot.utils.constants import OBS_ENV_STATE, OBS_STATE DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension @@ -513,17 +514,17 @@ class SACObservationEncoder(nn.Module): ) def _init_state_layers(self) -> None: - self.has_env = "observation.environment_state" in self.config.input_features - self.has_state = "observation.state" in self.config.input_features + self.has_env = OBS_ENV_STATE in self.config.input_features + self.has_state = OBS_STATE in self.config.input_features if self.has_env: - dim = self.config.input_features["observation.environment_state"].shape[0] + dim = self.config.input_features[OBS_ENV_STATE].shape[0] self.env_encoder = nn.Sequential( nn.Linear(dim, self.config.latent_dim), nn.LayerNorm(self.config.latent_dim), nn.Tanh(), ) if self.has_state: - dim = self.config.input_features["observation.state"].shape[0] + dim = self.config.input_features[OBS_STATE].shape[0] self.state_encoder = nn.Sequential( nn.Linear(dim, self.config.latent_dim), nn.LayerNorm(self.config.latent_dim), @@ -549,9 +550,9 @@ class SACObservationEncoder(nn.Module): cache = self.get_cached_image_features(obs) parts.append(self._encode_images(cache, detach)) if self.has_env: - parts.append(self.env_encoder(obs["observation.environment_state"])) + parts.append(self.env_encoder(obs[OBS_ENV_STATE])) if self.has_state: - parts.append(self.state_encoder(obs["observation.state"])) + parts.append(self.state_encoder(obs[OBS_STATE])) if parts: return torch.cat(parts, dim=-1) diff --git a/src/lerobot/policies/sac/reward_model/configuration_classifier.py b/src/lerobot/policies/sac/reward_model/configuration_classifier.py index fc53283b3..9b76b8037 100644 --- a/src/lerobot/policies/sac/reward_model/configuration_classifier.py +++ b/src/lerobot/policies/sac/reward_model/configuration_classifier.py @@ -19,6 +19,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import NormalizationMode from lerobot.optim.optimizers import AdamWConfig, OptimizerConfig from lerobot.optim.schedulers import LRSchedulerConfig +from lerobot.utils.constants import OBS_IMAGE @PreTrainedConfig.register_subclass(name="reward_classifier") @@ -69,7 +70,7 @@ class RewardClassifierConfig(PreTrainedConfig): def validate_features(self) -> None: """Validate feature configurations.""" - has_image = any(key.startswith("observation.image") for key in self.input_features) + has_image = any(key.startswith(OBS_IMAGE) for key in self.input_features) if not has_image: raise ValueError( "You must provide an image observation (key starting with 'observation.image') in the input features" diff --git a/src/lerobot/policies/smolvla/configuration_smolvla.py b/src/lerobot/policies/smolvla/configuration_smolvla.py index 571900c4a..eedf477a5 100644 --- a/src/lerobot/policies/smolvla/configuration_smolvla.py +++ b/src/lerobot/policies/smolvla/configuration_smolvla.py @@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) +from lerobot.utils.constants import OBS_IMAGES @PreTrainedConfig.register_subclass("smolvla") @@ -117,7 +118,7 @@ class SmolVLAConfig(PreTrainedConfig): def validate_features(self) -> None: for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" + key = f"{OBS_IMAGES}.empty_camera_{i}" empty_camera = PolicyFeature( type=FeatureType.VISUAL, shape=(3, 480, 640), diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index f83048862..4b5e8b7bd 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -38,7 +38,7 @@ from torch import Tensor from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues -from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, OBS_STR, REWARD class TDMPCPolicy(PreTrainedPolicy): @@ -91,13 +91,13 @@ class TDMPCPolicy(PreTrainedPolicy): called on `env.reset()` """ self._queues = { - "observation.state": deque(maxlen=1), + OBS_STATE: deque(maxlen=1), "action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), } if self.config.image_features: - self._queues["observation.image"] = deque(maxlen=1) + self._queues[OBS_IMAGE] = deque(maxlen=1) if self.config.env_state_feature: - self._queues["observation.environment_state"] = deque(maxlen=1) + self._queues[OBS_ENV_STATE] = deque(maxlen=1) # Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start # CEM for the next step. self._prev_mean: torch.Tensor | None = None @@ -325,7 +325,7 @@ class TDMPCPolicy(PreTrainedPolicy): action = batch[ACTION] # (t, b, action_dim) reward = batch[REWARD] # (t, b) - observations = {k: v for k, v in batch.items() if k.startswith("observation.")} + observations = {k: v for k, v in batch.items() if k.startswith(OBS_PREFIX)} # Apply random image augmentations. if self.config.image_features and self.config.max_random_shift_ratio > 0: @@ -387,10 +387,10 @@ class TDMPCPolicy(PreTrainedPolicy): temporal_loss_coeffs * F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1) # `z_preds` depends on the current observation and the actions. - * ~batch["observation.state_is_pad"][0] + * ~batch[f"{OBS_STR}.state_is_pad"][0] * ~batch["action_is_pad"] # `z_targets` depends on the next observation. - * ~batch["observation.state_is_pad"][1:] + * ~batch[f"{OBS_STR}.state_is_pad"][1:] ) .sum(0) .mean() @@ -403,7 +403,7 @@ class TDMPCPolicy(PreTrainedPolicy): * F.mse_loss(reward_preds, reward, reduction="none") * ~batch["next.reward_is_pad"] # `reward_preds` depends on the current observation and the actions. - * ~batch["observation.state_is_pad"][0] + * ~batch[f"{OBS_STR}.state_is_pad"][0] * ~batch["action_is_pad"] ) .sum(0) @@ -419,11 +419,11 @@ class TDMPCPolicy(PreTrainedPolicy): reduction="none", ).sum(0) # sum over ensemble # `q_preds_ensemble` depends on the first observation and the actions. - * ~batch["observation.state_is_pad"][0] + * ~batch[f"{OBS_STR}.state_is_pad"][0] * ~batch["action_is_pad"] # q_targets depends on the reward and the next observations. * ~batch["next.reward_is_pad"] - * ~batch["observation.state_is_pad"][1:] + * ~batch[f"{OBS_STR}.state_is_pad"][1:] ) .sum(0) .mean() @@ -441,7 +441,7 @@ class TDMPCPolicy(PreTrainedPolicy): temporal_loss_coeffs * raw_v_value_loss # `v_targets` depends on the first observation and the actions, as does `v_preds`. - * ~batch["observation.state_is_pad"][0] + * ~batch[f"{OBS_STR}.state_is_pad"][0] * ~batch["action_is_pad"] ) .sum(0) @@ -477,7 +477,7 @@ class TDMPCPolicy(PreTrainedPolicy): * mse * temporal_loss_coeffs # `action_preds` depends on the first observation and the actions. - * ~batch["observation.state_is_pad"][0] + * ~batch[f"{OBS_STR}.state_is_pad"][0] * ~batch["action_is_pad"] ).mean() diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index 34e5b1c0d..91d609701 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -133,7 +133,7 @@ class VQBeTPolicy(PreTrainedPolicy): batch.pop(ACTION) batch = dict(batch) # shallow copy so that adding a key doesn't modify the original # NOTE: It's important that this happens after stacking the images into a single key. - batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out if ACTION in batch: batch.pop(ACTION) @@ -340,14 +340,12 @@ class VQBeTModel(nn.Module): def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]: # Input validation. - assert set(batch).issuperset({"observation.state", "observation.images"}) - batch_size, n_obs_steps = batch["observation.state"].shape[:2] + assert set(batch).issuperset({OBS_STATE, OBS_IMAGES}) + batch_size, n_obs_steps = batch[OBS_STATE].shape[:2] assert n_obs_steps == self.config.n_obs_steps # Extract image feature (first combine batch and sequence dims). - img_features = self.rgb_encoder( - einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") - ) + img_features = self.rgb_encoder(einops.rearrange(batch[OBS_IMAGES], "b s n ... -> (b s n) ...")) # Separate batch and sequence dims. img_features = einops.rearrange( img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images @@ -359,9 +357,7 @@ class VQBeTModel(nn.Module): img_features ) # (batch, obs_step, number of different cameras, projection dims) input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))] - input_tokens.append( - self.state_projector(batch["observation.state"]) - ) # (batch, obs_step, projection dims) + input_tokens.append(self.state_projector(batch[OBS_STATE])) # (batch, obs_step, projection dims) input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps)) # Interleave tokens by stacking and rearranging. input_tokens = torch.stack(input_tokens, dim=2) diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 440f8b1db..2e80cf4bb 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -23,6 +23,8 @@ from typing import Any import numpy as np import torch +from lerobot.utils.constants import OBS_PREFIX + from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey @@ -347,7 +349,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition: raise ValueError(f"Action should be a PolicyAction type got {type(action)}") # Extract observation and complementary data keys. - observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} + observation_keys = {k: v for k, v in batch.items() if k.startswith(OBS_PREFIX)} complementary_data = _extract_complementary_data(batch) return create_transition( diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index 2b9402bee..486218157 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -21,7 +21,7 @@ import torch from torch import Tensor from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @@ -171,7 +171,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep): # Prefix-based rules (e.g. pixels.cam1 -> OBS_IMAGES.cam1) for old_prefix, new_prefix in prefix_pairs.items(): - prefixed_old = f"observation.{old_prefix}" + prefixed_old = f"{OBS_STR}.{old_prefix}" if key.startswith(prefixed_old): suffix = key[len(prefixed_old) :] new_key = f"{new_prefix}{suffix}" @@ -191,7 +191,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep): # Exact-name rules (pixels, environment_state, agent_pos) for old, new in exact_pairs.items(): - if key == old or key == f"observation.{old}": + if key == old or key == f"{OBS_STR}.{old}": new_key = new new_features[src_ft][new_key] = feat handled = True diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index c65801896..fbf36de36 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -24,6 +24,7 @@ import torch.nn.functional as F # noqa: N812 from tqdm import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import OBS_IMAGE from lerobot.utils.transition import Transition @@ -240,7 +241,7 @@ class ReplayBuffer: idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device) # Identify image keys that need augmentation - image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else [] + image_keys = [k for k in self.states if k.startswith(OBS_IMAGE)] if self.use_drq else [] # Create batched state and next_state batch_state = {} diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index f91d077f4..393135708 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -73,6 +73,7 @@ from lerobot.teleoperators import ( ) from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import log_say @@ -180,7 +181,7 @@ class RobotEnv(gym.Env): # Define observation spaces for images and other states. if current_observation is not None and "pixels" in current_observation: - prefix = "observation.images" + prefix = OBS_IMAGES observation_spaces = { f"{prefix}.{key}": gym.spaces.Box( low=0, high=255, shape=current_observation["pixels"][key].shape, dtype=np.uint8 @@ -190,7 +191,7 @@ class RobotEnv(gym.Env): if current_observation is not None: agent_pos = current_observation["agent_pos"] - observation_spaces["observation.state"] = gym.spaces.Box( + observation_spaces[OBS_STATE] = gym.spaces.Box( low=0, high=10, shape=agent_pos.shape, @@ -612,7 +613,7 @@ def control_loop( } for key, value in transition[TransitionKey.OBSERVATION].items(): - if key == "observation.state": + if key == OBS_STATE: features[key] = { "dtype": "float32", "shape": value.squeeze(0).shape, diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py index 9f6367152..392d6d575 100644 --- a/src/lerobot/robots/lekiwi/lekiwi_client.py +++ b/src/lerobot/robots/lekiwi/lekiwi_client.py @@ -23,6 +23,7 @@ from typing import Any import cv2 import numpy as np +from lerobot.utils.constants import OBS_STATE from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -203,7 +204,7 @@ class LeKiwiClient(Robot): state_vec = np.array([flat_state[key] for key in self._state_order], dtype=np.float32) - obs_dict: dict[str, Any] = {**flat_state, "observation.state": state_vec} + obs_dict: dict[str, Any] = {**flat_state, OBS_STATE: state_vec} # Decode images current_frames: dict[str, np.ndarray] = {} diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 2033b36ba..5c0d31f73 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -75,6 +75,7 @@ import torch.utils.data import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import OBS_STATE class EpisodeSampler(torch.utils.data.Sampler): @@ -161,8 +162,8 @@ def visualize_dataset( rr.log(f"action/{dim_idx}", rr.Scalar(val.item())) # display each dimension of observed state space (e.g. agent position in joint space) - if "observation.state" in batch: - for dim_idx, val in enumerate(batch["observation.state"][i]): + if OBS_STATE in batch: + for dim_idx, val in enumerate(batch[OBS_STATE][i]): rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) if "next.done" in batch: diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index ca900f8df..310f771a9 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -81,6 +81,7 @@ from lerobot.envs.utils import ( from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.processor import PolicyAction, PolicyProcessorPipeline +from lerobot.utils.constants import OBS_STR from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( @@ -221,7 +222,7 @@ def rollout( stacked_observations = {} for key in all_observations[0]: stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1) - ret["observation"] = stacked_observations + ret[OBS_STR] = stacked_observations if hasattr(policy, "use_original_modules"): policy.use_original_modules() @@ -459,8 +460,8 @@ def _compile_episode_data( for k in ep_dict: ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]]) - for key in rollout_data["observation"]: - ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames] + for key in rollout_data[OBS_STR]: + ep_dict[key] = rollout_data[OBS_STR][key][ep_ix, :num_frames] ep_dicts.append(ep_dict) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index dd4984fab..f1d026a39 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -109,6 +109,7 @@ from lerobot.teleoperators import ( # noqa: F401 so101_leader, ) from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop +from lerobot.utils.constants import OBS_STR from lerobot.utils.control_utils import ( init_keyboard_listener, is_headless, @@ -303,7 +304,7 @@ def record_loop( obs_processed = robot_observation_processor(obs) if policy is not None or dataset is not None: - observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix="observation") + observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR) # Get action from either policy or teleop if policy is not None and preprocessor is not None and postprocessor is not None: diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index 464969c72..337817908 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -17,19 +17,21 @@ from pathlib import Path from huggingface_hub.constants import HF_HOME -OBS_ENV_STATE = "observation.environment_state" -OBS_STATE = "observation.state" -OBS_IMAGE = "observation.image" -OBS_IMAGES = "observation.images" -OBS_LANGUAGE = "observation.language" +OBS_STR = "observation" +OBS_PREFIX = OBS_STR + "." +OBS_ENV_STATE = OBS_STR + ".environment_state" +OBS_STATE = OBS_STR + ".state" +OBS_IMAGE = OBS_STR + ".image" +OBS_IMAGES = OBS_IMAGE + "s" +OBS_LANGUAGE = OBS_STR + ".language" +OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" +OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" + ACTION = "action" REWARD = "next.reward" TRUNCATED = "next.truncated" DONE = "next.done" -OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" -OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" - ROBOTS = "robots" ROBOT_TYPE = "robot_type" TELEOPERATORS = "teleoperators" diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index 7fc881f26..ae070b7c4 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -19,6 +19,8 @@ from typing import Any import numpy as np import rerun as rr +from .constants import OBS_PREFIX, OBS_STR + def init_rerun(session_name: str = "lerobot_control_loop") -> None: """Initializes the Rerun SDK for visualizing the control loop.""" @@ -63,7 +65,7 @@ def log_rerun_data( for k, v in observation.items(): if v is None: continue - key = k if str(k).startswith("observation.") else f"observation.{k}" + key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}" if _is_scalar(v): rr.log(key, rr.Scalar(float(v))) diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index b0ffa9a31..e130ae144 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -24,6 +24,7 @@ from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors +from lerobot.utils.constants import OBS_STR from lerobot.utils.random_utils import set_seed @@ -92,7 +93,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): # for backward compatibility if k == "task": continue - if k.startswith("observation"): + if k.startswith(OBS_STR): obs[k] = batch[k] if hasattr(train_cfg.policy, "n_action_steps"): diff --git a/tests/async_inference/test_helpers.py b/tests/async_inference/test_helpers.py index f1c7636e2..acf5870d5 100644 --- a/tests/async_inference/test_helpers.py +++ b/tests/async_inference/test_helpers.py @@ -30,6 +30,7 @@ from lerobot.async_inference.helpers import ( resize_robot_observation_image, ) from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # --------------------------------------------------------------------- # FPSTracker @@ -115,7 +116,7 @@ def test_timed_action_getters(): def test_timed_observation_getters(): """TimedObservation stores & returns timestamp, dict and timestep.""" ts = time.time() - obs_dict = {"observation.state": torch.ones(6)} + obs_dict = {OBS_STATE: torch.ones(6)} to = TimedObservation(timestamp=ts, observation=obs_dict, timestep=0) assert math.isclose(to.get_timestamp(), ts, rel_tol=0, abs_tol=1e-6) @@ -151,7 +152,7 @@ def test_timed_data_deserialization_data_getters(): # ------------------------------------------------------------------ # TimedObservation # ------------------------------------------------------------------ - obs_dict = {"observation.state": torch.arange(4).float()} + obs_dict = {OBS_STATE: torch.arange(4).float()} to_in = TimedObservation(timestamp=ts, observation=obs_dict, timestep=7, must_go=True) to_bytes = pickle.dumps(to_in) # nosec @@ -161,7 +162,7 @@ def test_timed_data_deserialization_data_getters(): assert to_out.get_timestep() == 7 assert to_out.must_go is True assert to_out.get_observation().keys() == obs_dict.keys() - torch.testing.assert_close(to_out.get_observation()["observation.state"], obs_dict["observation.state"]) + torch.testing.assert_close(to_out.get_observation()[OBS_STATE], obs_dict[OBS_STATE]) # --------------------------------------------------------------------- @@ -187,7 +188,7 @@ def test_observations_similar_true(): """Distance below atol → observations considered similar.""" # Create mock lerobot features for the similarity check lerobot_features = { - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": [4], "names": ["shoulder", "elbow", "wrist", "gripper"], @@ -222,17 +223,17 @@ def _create_mock_robot_observation(): def _create_mock_lerobot_features(): """Create mock lerobot features mapping similar to what hw_to_dataset_features returns.""" return { - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": [4], "names": ["shoulder", "elbow", "wrist", "gripper"], }, - "observation.images.laptop": { + f"{OBS_IMAGES}.laptop": { "dtype": "image", "shape": [480, 640, 3], "names": ["height", "width", "channels"], }, - "observation.images.phone": { + f"{OBS_IMAGES}.phone": { "dtype": "image", "shape": [480, 640, 3], "names": ["height", "width", "channels"], @@ -243,11 +244,11 @@ def _create_mock_lerobot_features(): def _create_mock_policy_image_features(): """Create mock policy image features with different resolutions.""" return { - "observation.images.laptop": PolicyFeature( + f"{OBS_IMAGES}.laptop": PolicyFeature( type=FeatureType.VISUAL, shape=(3, 224, 224), # Policy expects smaller resolution ), - "observation.images.phone": PolicyFeature( + f"{OBS_IMAGES}.phone": PolicyFeature( type=FeatureType.VISUAL, shape=(3, 160, 160), # Different resolution for second camera ), @@ -306,21 +307,21 @@ def test_prepare_raw_observation(): prepared = prepare_raw_observation(robot_obs, lerobot_features, policy_image_features) # Check that state is properly extracted and batched - assert "observation.state" in prepared - state = prepared["observation.state"] + assert OBS_STATE in prepared + state = prepared[OBS_STATE] assert isinstance(state, torch.Tensor) assert state.shape == (1, 4) # Batched state # Check that images are processed and resized - assert "observation.images.laptop" in prepared - assert "observation.images.phone" in prepared + assert f"{OBS_IMAGES}.laptop" in prepared + assert f"{OBS_IMAGES}.phone" in prepared - laptop_img = prepared["observation.images.laptop"] - phone_img = prepared["observation.images.phone"] + laptop_img = prepared[f"{OBS_IMAGES}.laptop"] + phone_img = prepared[f"{OBS_IMAGES}.phone"] # Check image shapes match policy requirements - assert laptop_img.shape == policy_image_features["observation.images.laptop"].shape - assert phone_img.shape == policy_image_features["observation.images.phone"].shape + assert laptop_img.shape == policy_image_features[f"{OBS_IMAGES}.laptop"].shape + assert phone_img.shape == policy_image_features[f"{OBS_IMAGES}.phone"].shape # Check that images are tensors assert isinstance(laptop_img, torch.Tensor) @@ -337,19 +338,19 @@ def test_raw_observation_to_observation_basic(): observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) # Check that all expected keys are present - assert "observation.state" in observation - assert "observation.images.laptop" in observation - assert "observation.images.phone" in observation + assert OBS_STATE in observation + assert f"{OBS_IMAGES}.laptop" in observation + assert f"{OBS_IMAGES}.phone" in observation # Check state processing - state = observation["observation.state"] + state = observation[OBS_STATE] assert isinstance(state, torch.Tensor) assert state.device.type == device assert state.shape == (1, 4) # Batched # Check image processing - laptop_img = observation["observation.images.laptop"] - phone_img = observation["observation.images.phone"] + laptop_img = observation[f"{OBS_IMAGES}.laptop"] + phone_img = observation[f"{OBS_IMAGES}.phone"] # Images should have batch dimension: (B, C, H, W) assert laptop_img.shape == (1, 3, 224, 224) @@ -429,19 +430,19 @@ def test_image_processing_pipeline_preserves_content(): robot_obs = {"shoulder": 1.0, "elbow": 1.0, "wrist": 1.0, "gripper": 1.0, "laptop": original_img} lerobot_features = { - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": [4], "names": ["shoulder", "elbow", "wrist", "gripper"], }, - "observation.images.laptop": { + f"{OBS_IMAGES}.laptop": { "dtype": "image", "shape": [100, 100, 3], "names": ["height", "width", "channels"], }, } policy_image_features = { - "observation.images.laptop": PolicyFeature( + f"{OBS_IMAGES}.laptop": PolicyFeature( type=FeatureType.VISUAL, shape=(3, 50, 50), # Downsamples from 100x100 ) @@ -449,7 +450,7 @@ def test_image_processing_pipeline_preserves_content(): observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu") - processed_img = observation["observation.images.laptop"].squeeze(0) # Remove batch dim + processed_img = observation[f"{OBS_IMAGES}.laptop"].squeeze(0) # Remove batch dim # Check that the center region has higher values than corners # Due to bilinear interpolation, exact values will change but pattern should remain diff --git a/tests/async_inference/test_policy_server.py b/tests/async_inference/test_policy_server.py index c5c52460f..de441ff09 100644 --- a/tests/async_inference/test_policy_server.py +++ b/tests/async_inference/test_policy_server.py @@ -23,6 +23,7 @@ import pytest import torch from lerobot.configs.types import PolicyFeature +from lerobot.utils.constants import OBS_STATE from tests.utils import require_package # ----------------------------------------------------------------------------- @@ -44,7 +45,7 @@ class MockPolicy: def predict_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor: """Return a chunk of 20 dummy actions.""" - batch_size = len(observation["observation.state"]) + batch_size = len(observation[OBS_STATE]) return torch.zeros(batch_size, 20, 6) def __init__(self): @@ -77,7 +78,7 @@ def policy_server(): # Add mock lerobot_features that the observation similarity functions need server.lerobot_features = { - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": [6], "names": ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"], diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index 8f8179c29..982f35c3f 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -28,6 +28,7 @@ from lerobot.datasets.compute_stats import ( sample_images, sample_indices, ) +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE def mock_load_image_as_numpy(path, dtype, channel_first): @@ -136,21 +137,21 @@ def test_get_feature_stats_single_value(): def test_compute_episode_stats(): episode_data = { - "observation.image": [f"image_{i}.jpg" for i in range(100)], - "observation.state": np.random.rand(100, 10), + OBS_IMAGE: [f"image_{i}.jpg" for i in range(100)], + OBS_STATE: np.random.rand(100, 10), } features = { - "observation.image": {"dtype": "image"}, - "observation.state": {"dtype": "numeric"}, + OBS_IMAGE: {"dtype": "image"}, + OBS_STATE: {"dtype": "numeric"}, } with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy): stats = compute_episode_stats(episode_data, features) - assert "observation.image" in stats and "observation.state" in stats - assert stats["observation.image"]["count"].item() == 100 - assert stats["observation.state"]["count"].item() == 100 - assert stats["observation.image"]["mean"].shape == (3, 1, 1) + assert OBS_IMAGE in stats and OBS_STATE in stats + assert stats[OBS_IMAGE]["count"].item() == 100 + assert stats[OBS_STATE]["count"].item() == 100 + assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1) def test_assert_type_and_shape_valid(): @@ -224,38 +225,38 @@ def test_aggregate_feature_stats(): def test_aggregate_stats(): all_stats = [ { - "observation.image": { + OBS_IMAGE: { "min": [1, 2, 3], "max": [10, 20, 30], "mean": [5.5, 10.5, 15.5], "std": [2.87, 5.87, 8.87], "count": 10, }, - "observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10}, + OBS_STATE: {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10}, "extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6}, }, { - "observation.image": { + OBS_IMAGE: { "min": [2, 1, 0], "max": [15, 10, 5], "mean": [8.5, 5.5, 2.5], "std": [3.42, 2.42, 1.42], "count": 15, }, - "observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15}, + OBS_STATE: {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15}, "extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5}, }, ] expected_agg_stats = { - "observation.image": { + OBS_IMAGE: { "min": [1, 1, 0], "max": [15, 20, 30], "mean": [7.3, 7.5, 7.7], "std": [3.5317, 4.8267, 8.5581], "count": 25, }, - "observation.state": { + OBS_STATE: { "min": 1, "max": 15, "mean": 7.3, @@ -283,7 +284,7 @@ def test_aggregate_stats(): for fkey, stats in ep_stats.items(): for k in stats: stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) - if fkey == "observation.image" and k != "count": + if fkey == OBS_IMAGE and k != "count": stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels else: stats[k] = stats[k].reshape(1) @@ -292,7 +293,7 @@ def test_aggregate_stats(): for fkey, stats in expected_agg_stats.items(): for k in stats: stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) - if fkey == "observation.image" and k != "count": + if fkey == OBS_IMAGE and k != "count": stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels else: stats[k] = stats[k].reshape(1) diff --git a/tests/datasets/test_dataset_utils.py b/tests/datasets/test_dataset_utils.py index f1ffd800a..c0b07ca65 100644 --- a/tests/datasets/test_dataset_utils.py +++ b/tests/datasets/test_dataset_utils.py @@ -21,6 +21,7 @@ from huggingface_hub import DatasetCard from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch +from lerobot.utils.constants import OBS_IMAGES def test_default_parameters(): @@ -96,14 +97,14 @@ def test_merge_multiple_groups_order_and_dedup(): def test_non_vector_last_wins_for_images(): # Non-vector (images) with same name should be overwritten by the last image specified g1 = { - "observation.images.front": { + f"{OBS_IMAGES}.front": { "dtype": "image", "shape": (3, 480, 640), "names": ["channels", "height", "width"], } } g2 = { - "observation.images.front": { + f"{OBS_IMAGES}.front": { "dtype": "image", "shape": (3, 720, 1280), "names": ["channels", "height", "width"], @@ -111,8 +112,8 @@ def test_non_vector_last_wins_for_images(): } out = combine_feature_dicts(g1, g2) - assert out["observation.images.front"]["shape"] == (3, 720, 1280) - assert out["observation.images.front"]["dtype"] == "image" + assert out[f"{OBS_IMAGES}.front"]["shape"] == (3, 720, 1280) + assert out[f"{OBS_IMAGES}.front"]["dtype"] == "image" def test_dtype_mismatch_raises(): diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index d1d6dbdb2..1d461c8ba 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -46,6 +46,7 @@ from lerobot.datasets.utils import ( from lerobot.envs.factory import make_env_config from lerobot.policies.factory import make_policy_config from lerobot.robots import make_robot_from_config +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.mocks.mock_robot import MockRobotConfig from tests.utils import require_x86_64_kernel @@ -75,7 +76,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): # Instantiate both ways robot = make_robot_from_config(MockRobotConfig()) action_features = hw_to_dataset_features(robot.action_features, "action", True) - obs_features = hw_to_dataset_features(robot.observation_features, "observation", True) + obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR, True) dataset_features = {**action_features, **obs_features} root_create = tmp_path / "create" dataset_create = LeRobotDataset.create( @@ -397,7 +398,7 @@ def test_factory(env_name, repo_id, policy_name): ("frame_index", 0, True), ("timestamp", 0, True), # TODO(rcadene): should we rename it agent_pos? - ("observation.state", 1, True), + (OBS_STATE, 1, True), ("next.reward", 0, False), ("next.done", 0, False), ] @@ -662,7 +663,7 @@ def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory): def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory): """Test the update_chunk_settings functionality for both LeRobotDataset and LeRobotDatasetMetadata.""" features = { - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"], @@ -769,7 +770,7 @@ def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory): def test_update_chunk_settings_video_dataset(tmp_path): """Test update_chunk_settings with a video dataset to ensure video-specific logic works.""" features = { - "observation.images.cam": { + f"{OBS_IMAGES}.cam": { "dtype": "video", "shape": (480, 640, 3), "names": ["height", "width", "channels"], diff --git a/tests/policies/hilserl/test_modeling_classifier.py b/tests/policies/hilserl/test_modeling_classifier.py index 0be1b9c7c..7a8782230 100644 --- a/tests/policies/hilserl/test_modeling_classifier.py +++ b/tests/policies/hilserl/test_modeling_classifier.py @@ -19,6 +19,7 @@ import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput +from lerobot.utils.constants import OBS_IMAGE from tests.utils import require_package @@ -41,7 +42,7 @@ def test_binary_classifier_with_default_params(): config = RewardClassifierConfig() config.input_features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), } config.output_features = { "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)), @@ -56,7 +57,7 @@ def test_binary_classifier_with_default_params(): batch_size = 10 input = { - "observation.image": torch.rand((batch_size, 3, 128, 128)), + OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)), "next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(), } @@ -83,7 +84,7 @@ def test_multiclass_classifier(): num_classes = 5 config = RewardClassifierConfig() config.input_features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), } config.output_features = { "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)), @@ -95,7 +96,7 @@ def test_multiclass_classifier(): batch_size = 10 input = { - "observation.image": torch.rand((batch_size, 3, 128, 128)), + OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)), "next.reward": torch.rand((batch_size, num_classes)), } diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index b577e5763..7752ad63f 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -41,7 +41,7 @@ from lerobot.policies.factory import ( make_pre_post_processors, ) from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.constants import ACTION, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.random_utils import seeded_context from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel @@ -52,7 +52,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p # Create only one camera input which is squared to fit all current policy constraints # e.g. vqbet and tdmpc works with one camera only, and tdmpc requires it to be squared camera_features = { - "observation.images.laptop": { + f"{OBS_IMAGES}.laptop": { "shape": (84, 84, 3), "names": ["height", "width", "channels"], "info": None, @@ -64,7 +64,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], }, - "observation.state": { + OBS_STATE: { "dtype": "float32", "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], @@ -281,7 +281,7 @@ def test_multikey_construction(multikey: bool): preventing erroneous creation of the policy object. """ input_features = { - "observation.state": PolicyFeature( + OBS_STATE: PolicyFeature( type=FeatureType.STATE, shape=(10,), ), @@ -297,9 +297,9 @@ def test_multikey_construction(multikey: bool): """Simulates the complete state/action is constructed from more granular multiple keys, of the same type as the overall state/action""" input_features = {} - input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) - input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) - input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,)) + input_features[f"{OBS_STATE}.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) + input_features[f"{OBS_STATE}.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) + input_features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(10,)) output_features = {} output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,)) diff --git a/tests/policies/test_sac_config.py b/tests/policies/test_sac_config.py index a67815eed..59ed4af65 100644 --- a/tests/policies/test_sac_config.py +++ b/tests/policies/test_sac_config.py @@ -25,6 +25,7 @@ from lerobot.policies.sac.configuration_sac import ( PolicyConfig, SACConfig, ) +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE def test_sac_config_default_initialization(): @@ -37,11 +38,11 @@ def test_sac_config_default_initialization(): "ACTION": NormalizationMode.MIN_MAX, } assert config.dataset_stats == { - "observation.image": { + OBS_IMAGE: { "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225], }, - "observation.state": { + OBS_STATE: { "min": [0.0, 0.0], "max": [1.0, 1.0], }, @@ -90,11 +91,11 @@ def test_sac_config_default_initialization(): # Dataset stats defaults expected_dataset_stats = { - "observation.image": { + OBS_IMAGE: { "mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225], }, - "observation.state": { + OBS_STATE: { "min": [0.0, 0.0], "max": [1.0, 1.0], }, @@ -191,7 +192,7 @@ def test_sac_config_custom_initialization(): def test_validate_features(): config = SACConfig( - input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, ) config.validate_features() @@ -210,7 +211,7 @@ def test_validate_features_missing_observation(): def test_validate_features_missing_action(): config = SACConfig( - input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, ) with pytest.raises(ValueError, match="You must provide 'action' in the output features"): diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py index 7891c2e52..71e45e055 100644 --- a/tests/policies/test_sac_policy.py +++ b/tests/policies/test_sac_policy.py @@ -23,6 +23,7 @@ from torch import Tensor, nn from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.modeling_sac import MLP, SACPolicy +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE from lerobot.utils.random_utils import seeded_context, set_seed try: @@ -85,14 +86,14 @@ def test_sac_policy_with_default_args(): def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor: return { - "observation.state": torch.randn(batch_size, state_dim), + OBS_STATE: torch.randn(batch_size, state_dim), } def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor: return { - "observation.image": torch.randn(batch_size, 3, 84, 84), - "observation.state": torch.randn(batch_size, state_dim), + OBS_IMAGE: torch.randn(batch_size, 3, 84, 84), + OBS_STATE: torch.randn(batch_size, state_dim), } @@ -126,14 +127,14 @@ def create_train_batch_with_visual_input( def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: return { - "observation.state": torch.randn(batch_size, state_dim), + OBS_STATE: torch.randn(batch_size, state_dim), } def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: return { - "observation.state": torch.randn(batch_size, state_dim), - "observation.image": torch.randn(batch_size, 3, 84, 84), + OBS_STATE: torch.randn(batch_size, state_dim), + OBS_IMAGE: torch.randn(batch_size, 3, 84, 84), } @@ -180,10 +181,10 @@ def create_default_config( action_dim += 1 config = SACConfig( - input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))}, dataset_stats={ - "observation.state": { + OBS_STATE: { "min": [0.0] * state_dim, "max": [1.0] * state_dim, }, @@ -205,8 +206,8 @@ def create_config_with_visual_input( continuous_action_dim=continuous_action_dim, has_discrete_action=has_discrete_action, ) - config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) - config.dataset_stats["observation.image"] = { + config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) + config.dataset_stats[OBS_IMAGE] = { "mean": torch.randn(3, 1, 1), "std": torch.randn(3, 1, 1), } diff --git a/tests/processor/test_act_processor.py b/tests/processor/test_act_processor.py index 00a4dbb96..134cff684 100644 --- a/tests/processor/test_act_processor.py +++ b/tests/processor/test_act_processor.py @@ -342,7 +342,7 @@ def test_act_processor_batch_consistency(): batch = transition_to_batch(transition) processed = preprocessor(batch) - assert processed["observation.state"].shape[0] == 1 # Batched + assert processed[OBS_STATE].shape[0] == 1 # Batched # Test already batched data observation_batched = {OBS_STATE: torch.randn(8, 7)} # Batch of 8 diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 631ad7899..8bf24db02 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -2,14 +2,15 @@ import torch from lerobot.processor import DataProcessorPipeline, TransitionKey from lerobot.processor.converters import batch_to_transition, transition_to_batch +from lerobot.utils.constants import OBS_IMAGE, OBS_PREFIX, OBS_STATE def _dummy_batch(): """Create a dummy batch using the new format with observation.* and next.* keys.""" return { - "observation.image.left": torch.randn(1, 3, 128, 128), - "observation.image.right": torch.randn(1, 3, 128, 128), - "observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]), + f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128), + f"{OBS_IMAGE}.right": torch.randn(1, 3, 128, 128), + OBS_STATE: torch.tensor([[0.1, 0.2, 0.3, 0.4]]), "action": torch.tensor([[0.5]]), "next.reward": 1.0, "next.done": False, @@ -25,15 +26,15 @@ def test_observation_grouping_roundtrip(): batch_out = proc(batch_in) # Check that all observation.* keys are preserved - original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")} - reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")} + original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith(OBS_PREFIX)} + reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith(OBS_PREFIX)} assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys()) # Check tensor values - assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"]) - assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"]) - assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"]) + assert torch.allclose(batch_out[f"{OBS_IMAGE}.left"], batch_in[f"{OBS_IMAGE}.left"]) + assert torch.allclose(batch_out[f"{OBS_IMAGE}.right"], batch_in[f"{OBS_IMAGE}.right"]) + assert torch.allclose(batch_out[OBS_STATE], batch_in[OBS_STATE]) # Check other fields assert torch.allclose(batch_out["action"], batch_in["action"]) @@ -46,9 +47,9 @@ def test_observation_grouping_roundtrip(): def test_batch_to_transition_observation_grouping(): """Test that batch_to_transition correctly groups observation.* keys.""" batch = { - "observation.image.top": torch.randn(1, 3, 128, 128), - "observation.image.left": torch.randn(1, 3, 128, 128), - "observation.state": [1, 2, 3, 4], + f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128), + f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128), + OBS_STATE: [1, 2, 3, 4], "action": torch.tensor([0.1, 0.2, 0.3, 0.4]), "next.reward": 1.5, "next.done": True, @@ -60,18 +61,18 @@ def test_batch_to_transition_observation_grouping(): # Check observation is a dict with all observation.* keys assert isinstance(transition[TransitionKey.OBSERVATION], dict) - assert "observation.image.top" in transition[TransitionKey.OBSERVATION] - assert "observation.image.left" in transition[TransitionKey.OBSERVATION] - assert "observation.state" in transition[TransitionKey.OBSERVATION] + assert f"{OBS_IMAGE}.top" in transition[TransitionKey.OBSERVATION] + assert f"{OBS_IMAGE}.left" in transition[TransitionKey.OBSERVATION] + assert OBS_STATE in transition[TransitionKey.OBSERVATION] # Check values are preserved assert torch.allclose( - transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"] + transition[TransitionKey.OBSERVATION][f"{OBS_IMAGE}.top"], batch[f"{OBS_IMAGE}.top"] ) assert torch.allclose( - transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"] + transition[TransitionKey.OBSERVATION][f"{OBS_IMAGE}.left"], batch[f"{OBS_IMAGE}.left"] ) - assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4] + assert transition[TransitionKey.OBSERVATION][OBS_STATE] == [1, 2, 3, 4] # Check other fields assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.1, 0.2, 0.3, 0.4])) @@ -85,9 +86,9 @@ def test_batch_to_transition_observation_grouping(): def test_transition_to_batch_observation_flattening(): """Test that transition_to_batch correctly flattens observation dict.""" observation_dict = { - "observation.image.top": torch.randn(1, 3, 128, 128), - "observation.image.left": torch.randn(1, 3, 128, 128), - "observation.state": [1, 2, 3, 4], + f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128), + f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128), + OBS_STATE: [1, 2, 3, 4], } transition = { @@ -103,14 +104,14 @@ def test_transition_to_batch_observation_flattening(): batch = transition_to_batch(transition) # Check that observation.* keys are flattened back to batch - assert "observation.image.top" in batch - assert "observation.image.left" in batch - assert "observation.state" in batch + assert f"{OBS_IMAGE}.top" in batch + assert f"{OBS_IMAGE}.left" in batch + assert OBS_STATE in batch # Check values are preserved - assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"]) - assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"]) - assert batch["observation.state"] == [1, 2, 3, 4] + assert torch.allclose(batch[f"{OBS_IMAGE}.top"], observation_dict[f"{OBS_IMAGE}.top"]) + assert torch.allclose(batch[f"{OBS_IMAGE}.left"], observation_dict[f"{OBS_IMAGE}.left"]) + assert batch[OBS_STATE] == [1, 2, 3, 4] # Check other fields are mapped to next.* format assert batch["action"] == "action_data" @@ -153,12 +154,12 @@ def test_no_observation_keys(): def test_minimal_batch(): """Test with minimal batch containing only observation.* and action.""" - batch = {"observation.state": "minimal_state", "action": torch.tensor([0.5])} + batch = {OBS_STATE: "minimal_state", "action": torch.tensor([0.5])} transition = batch_to_transition(batch) # Check observation - assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"} + assert transition[TransitionKey.OBSERVATION] == {OBS_STATE: "minimal_state"} assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.5])) # Check defaults @@ -170,7 +171,7 @@ def test_minimal_batch(): # Round trip reconstructed_batch = transition_to_batch(transition) - assert reconstructed_batch["observation.state"] == "minimal_state" + assert reconstructed_batch[OBS_STATE] == "minimal_state" assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5])) assert reconstructed_batch["next.reward"] == 0.0 assert not reconstructed_batch["next.done"] @@ -205,9 +206,9 @@ def test_empty_batch(): def test_complex_nested_observation(): """Test with complex nested observation data.""" batch = { - "observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890}, - "observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, - "observation.state": torch.randn(7), + f"{OBS_IMAGE}.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890}, + f"{OBS_IMAGE}.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, + OBS_STATE: torch.randn(7), "action": torch.randn(8), "next.reward": 3.14, "next.done": False, @@ -219,20 +220,20 @@ def test_complex_nested_observation(): reconstructed_batch = transition_to_batch(transition) # Check that all observation keys are preserved - original_obs_keys = {k for k in batch if k.startswith("observation.")} - reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")} + original_obs_keys = {k for k in batch if k.startswith(OBS_PREFIX)} + reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith(OBS_PREFIX)} assert original_obs_keys == reconstructed_obs_keys # Check tensor values - assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"]) + assert torch.allclose(batch[OBS_STATE], reconstructed_batch[OBS_STATE]) # Check nested dict with tensors assert torch.allclose( - batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"] + batch[f"{OBS_IMAGE}.top"]["image"], reconstructed_batch[f"{OBS_IMAGE}.top"]["image"] ) assert torch.allclose( - batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"] + batch[f"{OBS_IMAGE}.left"]["image"], reconstructed_batch[f"{OBS_IMAGE}.left"]["image"] ) # Check action tensor @@ -264,7 +265,7 @@ def test_custom_converter(): processor = DataProcessorPipeline(steps=[], to_transition=to_tr, to_output=to_batch) batch = { - "observation.state": torch.randn(1, 4), + OBS_STATE: torch.randn(1, 4), "action": torch.randn(1, 2), "next.reward": 1.0, "next.done": False, @@ -274,5 +275,5 @@ def test_custom_converter(): # Check the reward was doubled by our custom converter assert result["next.reward"] == 2.0 - assert torch.allclose(result["observation.state"], batch["observation.state"]) + assert torch.allclose(result[OBS_STATE], batch[OBS_STATE]) assert torch.allclose(result["action"], batch["action"]) diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py index fc91951de..b03d49214 100644 --- a/tests/processor/test_converters.py +++ b/tests/processor/test_converters.py @@ -9,6 +9,7 @@ from lerobot.processor.converters import ( to_tensor, transition_to_batch, ) +from lerobot.utils.constants import OBS_STATE, OBS_STR # Tests for the unified to_tensor function @@ -118,16 +119,16 @@ def test_to_tensor_dictionaries(): # Nested dictionary nested = { "action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]}, - "observation": {"mean": np.array([0.5, 0.6]), "count": 10}, + OBS_STR: {"mean": np.array([0.5, 0.6]), "count": 10}, } result = to_tensor(nested) assert isinstance(result, dict) assert isinstance(result["action"], dict) - assert isinstance(result["observation"], dict) + assert isinstance(result[OBS_STR], dict) assert isinstance(result["action"]["mean"], torch.Tensor) - assert isinstance(result["observation"]["mean"], torch.Tensor) + assert isinstance(result[OBS_STR]["mean"], torch.Tensor) assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2])) - assert torch.allclose(result["observation"]["mean"], torch.tensor([0.5, 0.6])) + assert torch.allclose(result[OBS_STR]["mean"], torch.tensor([0.5, 0.6])) def test_to_tensor_none_filtering(): @@ -198,7 +199,7 @@ def test_batch_to_transition_with_index_fields(): # Create batch with index and task_index fields batch = { - "observation.state": torch.randn(1, 7), + OBS_STATE: torch.randn(1, 7), "action": torch.randn(1, 4), "next.reward": 1.5, "next.done": False, @@ -231,7 +232,7 @@ def testtransition_to_batch_with_index_fields(): # Create transition with index and task_index in complementary_data transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, + observation={OBS_STATE: torch.randn(1, 7)}, action=torch.randn(1, 4), reward=1.5, done=False, @@ -260,7 +261,7 @@ def test_batch_to_transition_without_index_fields(): # Batch without index/task_index batch = { - "observation.state": torch.randn(1, 7), + OBS_STATE: torch.randn(1, 7), "action": torch.randn(1, 4), "task": ["pick_cube"], } @@ -279,7 +280,7 @@ def test_transition_to_batch_without_index_fields(): # Transition without index/task_index transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, + observation={OBS_STATE: torch.randn(1, 7)}, action=torch.randn(1, 4), complementary_data={"task": ["navigate"]}, ) diff --git a/tests/processor/test_device_processor.py b/tests/processor/test_device_processor.py index 10ee313d7..36081e021 100644 --- a/tests/processor/test_device_processor.py +++ b/tests/processor/test_device_processor.py @@ -21,6 +21,7 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey from lerobot.processor.converters import create_transition, identity_transition +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE def test_basic_functionality(): @@ -28,7 +29,7 @@ def test_basic_functionality(): processor = DeviceProcessorStep(device="cpu") # Create a transition with CPU tensors - observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)} + observation = {OBS_STATE: torch.randn(10), OBS_IMAGE: torch.randn(3, 224, 224)} action = torch.randn(5) reward = torch.tensor(1.0) done = torch.tensor(False) @@ -41,8 +42,8 @@ def test_basic_functionality(): result = processor(transition) # Check that all tensors are on CPU - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" - assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cpu" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu" + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cpu" assert result[TransitionKey.ACTION].device.type == "cpu" assert result[TransitionKey.REWARD].device.type == "cpu" assert result[TransitionKey.DONE].device.type == "cpu" @@ -55,7 +56,7 @@ def test_cuda_functionality(): processor = DeviceProcessorStep(device="cuda") # Create a transition with CPU tensors - observation = {"observation.state": torch.randn(10), "observation.image": torch.randn(3, 224, 224)} + observation = {OBS_STATE: torch.randn(10), OBS_IMAGE: torch.randn(3, 224, 224)} action = torch.randn(5) reward = torch.tensor(1.0) done = torch.tensor(False) @@ -68,8 +69,8 @@ def test_cuda_functionality(): result = processor(transition) # Check that all tensors are on CUDA - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" - assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda" assert result[TransitionKey.REWARD].device.type == "cuda" assert result[TransitionKey.DONE].device.type == "cuda" @@ -81,14 +82,14 @@ def test_specific_cuda_device(): """Test device processor with specific CUDA device.""" processor = DeviceProcessorStep(device="cuda:0") - observation = {"observation.state": torch.randn(10)} + observation = {OBS_STATE: torch.randn(10)} action = torch.randn(5) transition = create_transition(observation=observation, action=action) result = processor(transition) - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" - assert result[TransitionKey.OBSERVATION]["observation.state"].device.index == 0 + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.index == 0 assert result[TransitionKey.ACTION].device.type == "cuda" assert result[TransitionKey.ACTION].device.index == 0 @@ -98,7 +99,7 @@ def test_non_tensor_values(): processor = DeviceProcessorStep(device="cpu") observation = { - "observation.state": torch.randn(10), + OBS_STATE: torch.randn(10), "observation.metadata": {"key": "value"}, # Non-tensor data "observation.list": [1, 2, 3], # Non-tensor data } @@ -110,7 +111,7 @@ def test_non_tensor_values(): result = processor(transition) # Check tensors are processed - assert isinstance(result[TransitionKey.OBSERVATION]["observation.state"], torch.Tensor) + assert isinstance(result[TransitionKey.OBSERVATION][OBS_STATE], torch.Tensor) assert isinstance(result[TransitionKey.ACTION], torch.Tensor) # Check non-tensor values are preserved @@ -130,9 +131,9 @@ def test_none_values(): assert result[TransitionKey.ACTION].device.type == "cpu" # Test with None action - transition = create_transition(observation={"observation.state": torch.randn(10)}, action=None) + transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=None) result = processor(transition) - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu" assert result[TransitionKey.ACTION] is None @@ -271,9 +272,7 @@ def test_features(): processor = DeviceProcessorStep(device="cpu") features = { - PipelineFeatureType.OBSERVATION: { - "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)) - }, + PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, } @@ -376,7 +375,7 @@ def test_reward_done_truncated_types(): # Test with scalar values (not tensors) transition = create_transition( - observation={"observation.state": torch.randn(5)}, + observation={OBS_STATE: torch.randn(5)}, action=torch.randn(3), reward=1.0, # float done=False, # bool @@ -392,7 +391,7 @@ def test_reward_done_truncated_types(): # Test with tensor values transition = create_transition( - observation={"observation.state": torch.randn(5)}, + observation={OBS_STATE: torch.randn(5)}, action=torch.randn(3), reward=torch.tensor(1.0), done=torch.tensor(False), @@ -422,7 +421,7 @@ def test_complementary_data_preserved(): } transition = create_transition( - observation={"observation.state": torch.randn(5)}, complementary_data=complementary_data + observation={OBS_STATE: torch.randn(5)}, complementary_data=complementary_data ) result = processor(transition) @@ -491,13 +490,13 @@ def test_float_dtype_bfloat16(): """Test conversion to bfloat16.""" processor = DeviceProcessorStep(device="cpu", float_dtype="bfloat16") - observation = {"observation.state": torch.randn(5, dtype=torch.float32)} + observation = {OBS_STATE: torch.randn(5, dtype=torch.float32)} action = torch.randn(3, dtype=torch.float64) transition = create_transition(observation=observation, action=action) result = processor(transition) - assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.bfloat16 + assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 assert result[TransitionKey.ACTION].dtype == torch.bfloat16 @@ -505,13 +504,13 @@ def test_float_dtype_float64(): """Test conversion to float64.""" processor = DeviceProcessorStep(device="cpu", float_dtype="float64") - observation = {"observation.state": torch.randn(5, dtype=torch.float16)} + observation = {OBS_STATE: torch.randn(5, dtype=torch.float16)} action = torch.randn(3, dtype=torch.float32) transition = create_transition(observation=observation, action=action) result = processor(transition) - assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float64 + assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float64 assert result[TransitionKey.ACTION].dtype == torch.float64 @@ -541,8 +540,8 @@ def test_float_dtype_with_mixed_tensors(): processor = DeviceProcessorStep(device="cpu", float_dtype="float32") observation = { - "observation.image": torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert - "observation.state": torch.randn(10, dtype=torch.float64), # Should convert + OBS_IMAGE: torch.randint(0, 255, (3, 64, 64), dtype=torch.uint8), # Should not convert + OBS_STATE: torch.randn(10, dtype=torch.float64), # Should convert "observation.mask": torch.tensor([True, False, True], dtype=torch.bool), # Should not convert "observation.indices": torch.tensor([1, 2, 3], dtype=torch.long), # Should not convert } @@ -552,8 +551,8 @@ def test_float_dtype_with_mixed_tensors(): result = processor(transition) # Check conversions - assert result[TransitionKey.OBSERVATION]["observation.image"].dtype == torch.uint8 # Unchanged - assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.uint8 # Unchanged + assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float32 # Converted assert result[TransitionKey.OBSERVATION]["observation.mask"].dtype == torch.bool # Unchanged assert result[TransitionKey.OBSERVATION]["observation.indices"].dtype == torch.long # Unchanged assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted @@ -612,7 +611,7 @@ def test_complementary_data_index_fields(): "episode_id": 123, # Non-tensor field } transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, + observation={OBS_STATE: torch.randn(1, 7)}, action=torch.randn(1, 4), complementary_data=complementary_data, ) @@ -736,7 +735,7 @@ def test_complementary_data_full_pipeline_cuda(): processor = DeviceProcessorStep(device="cuda:0", float_dtype="float16") # Create full transition with mixed CPU tensors - observation = {"observation.state": torch.randn(1, 7, dtype=torch.float32)} + observation = {OBS_STATE: torch.randn(1, 7, dtype=torch.float32)} action = torch.randn(1, 4, dtype=torch.float32) reward = torch.tensor(1.5, dtype=torch.float32) done = torch.tensor(False) @@ -757,7 +756,7 @@ def test_complementary_data_full_pipeline_cuda(): result = processor(transition) # Check all components moved to CUDA - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda" assert result[TransitionKey.REWARD].device.type == "cuda" assert result[TransitionKey.DONE].device.type == "cuda" @@ -768,7 +767,7 @@ def test_complementary_data_full_pipeline_cuda(): assert processed_comp_data["task_index"].device.type == "cuda" # Check float conversion happened for float tensors - assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float16 + assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float16 assert result[TransitionKey.ACTION].dtype == torch.float16 assert result[TransitionKey.REWARD].dtype == torch.float16 @@ -782,7 +781,7 @@ def test_complementary_data_empty(): processor = DeviceProcessorStep(device="cpu") transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, + observation={OBS_STATE: torch.randn(1, 7)}, complementary_data={}, ) @@ -797,7 +796,7 @@ def test_complementary_data_none(): processor = DeviceProcessorStep(device="cpu") transition = create_transition( - observation={"observation.state": torch.randn(1, 7)}, + observation={OBS_STATE: torch.randn(1, 7)}, complementary_data=None, ) @@ -814,8 +813,8 @@ def test_preserves_gpu_placement(): # Create tensors already on GPU observation = { - "observation.state": torch.randn(10).cuda(), # Already on GPU - "observation.image": torch.randn(3, 224, 224).cuda(), # Already on GPU + OBS_STATE: torch.randn(10).cuda(), # Already on GPU + OBS_IMAGE: torch.randn(3, 224, 224).cuda(), # Already on GPU } action = torch.randn(5).cuda() # Already on GPU @@ -823,14 +822,12 @@ def test_preserves_gpu_placement(): result = processor(transition) # Check that tensors remain on their original GPU - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" - assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda" # Verify no unnecessary copies were made (same data pointer) - assert torch.equal( - result[TransitionKey.OBSERVATION]["observation.state"], observation["observation.state"] - ) + assert torch.equal(result[TransitionKey.OBSERVATION][OBS_STATE], observation[OBS_STATE]) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") @@ -842,8 +839,8 @@ def test_multi_gpu_preservation(): # Create tensors on cuda:1 (simulating Accelerate placement) cuda1_device = torch.device("cuda:1") observation = { - "observation.state": torch.randn(10).to(cuda1_device), - "observation.image": torch.randn(3, 224, 224).to(cuda1_device), + OBS_STATE: torch.randn(10).to(cuda1_device), + OBS_IMAGE: torch.randn(3, 224, 224).to(cuda1_device), } action = torch.randn(5).to(cuda1_device) @@ -851,20 +848,20 @@ def test_multi_gpu_preservation(): result = processor_gpu(transition) # Check that tensors remain on cuda:1 (not moved to cuda:0) - assert result[TransitionKey.OBSERVATION]["observation.state"].device == cuda1_device - assert result[TransitionKey.OBSERVATION]["observation.image"].device == cuda1_device + assert result[TransitionKey.OBSERVATION][OBS_STATE].device == cuda1_device + assert result[TransitionKey.OBSERVATION][OBS_IMAGE].device == cuda1_device assert result[TransitionKey.ACTION].device == cuda1_device # Test 2: GPU-to-CPU should move to CPU (not preserve GPU) processor_cpu = DeviceProcessorStep(device="cpu") transition_gpu = create_transition( - observation={"observation.state": torch.randn(10).cuda()}, action=torch.randn(5).cuda() + observation={OBS_STATE: torch.randn(10).cuda()}, action=torch.randn(5).cuda() ) result_cpu = processor_cpu(transition_gpu) # Check that tensors are moved to CPU - assert result_cpu[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu" + assert result_cpu[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu" assert result_cpu[TransitionKey.ACTION].device.type == "cpu" @@ -933,14 +930,14 @@ def test_simulated_accelerate_scenario(): # Simulate data already placed by Accelerate device = torch.device(f"cuda:{gpu_id}") - observation = {"observation.state": torch.randn(1, 10).to(device)} + observation = {OBS_STATE: torch.randn(1, 10).to(device)} action = torch.randn(1, 5).to(device) transition = create_transition(observation=observation, action=action) result = processor(transition) # Verify data stays on the GPU where Accelerate placed it - assert result[TransitionKey.OBSERVATION]["observation.state"].device == device + assert result[TransitionKey.OBSERVATION][OBS_STATE].device == device assert result[TransitionKey.ACTION].device == device @@ -1081,7 +1078,7 @@ def test_mps_float64_with_complementary_data(): } transition = create_transition( - observation={"observation.state": torch.randn(5, dtype=torch.float64)}, + observation={OBS_STATE: torch.randn(5, dtype=torch.float64)}, action=torch.randn(3, dtype=torch.float64), complementary_data=complementary_data, ) @@ -1089,7 +1086,7 @@ def test_mps_float64_with_complementary_data(): result = processor(transition) # Check that all tensors are on MPS device - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "mps" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "mps" assert result[TransitionKey.ACTION].device.type == "mps" processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA] @@ -1099,7 +1096,7 @@ def test_mps_float64_with_complementary_data(): assert processed_comp_data["float32_tensor"].device.type == "mps" # Check dtype conversions - assert result[TransitionKey.OBSERVATION]["observation.state"].dtype == torch.float32 # Converted + assert result[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.float32 # Converted assert result[TransitionKey.ACTION].dtype == torch.float32 # Converted assert processed_comp_data["float64_tensor"].dtype == torch.float32 # Converted assert processed_comp_data["float32_tensor"].dtype == torch.float32 # Unchanged diff --git a/tests/processor/test_migration_detection.py b/tests/processor/test_migration_detection.py index 6bed8289d..b46cc6bdd 100644 --- a/tests/processor/test_migration_detection.py +++ b/tests/processor/test_migration_detection.py @@ -25,6 +25,7 @@ from pathlib import Path import pytest from lerobot.processor.pipeline import DataProcessorPipeline, ProcessorMigrationError +from lerobot.utils.constants import OBS_STATE def test_is_processor_config_valid_configs(): @@ -111,7 +112,7 @@ def test_should_suggest_migration_with_model_config_only(): # Create a model config (like old LeRobot format) model_config = { "type": "act", - "input_features": {"observation.state": {"shape": [7]}}, + "input_features": {OBS_STATE: {"shape": [7]}}, "output_features": {"action": {"shape": [7]}}, "hidden_dim": 256, "n_obs_steps": 1, diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 5d7791919..616f33db9 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -29,22 +29,23 @@ from lerobot.processor import ( hotswap_stats, ) from lerobot.processor.converters import create_transition, identity_transition, to_tensor +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR from lerobot.utils.utils import auto_select_torch_device def test_numpy_conversion(): stats = { - "observation.image": { + OBS_IMAGE: { "mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2]), } } tensor_stats = to_tensor(stats) - assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor) - assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor) - assert torch.allclose(tensor_stats["observation.image"]["mean"], torch.tensor([0.5, 0.5, 0.5])) - assert torch.allclose(tensor_stats["observation.image"]["std"], torch.tensor([0.2, 0.2, 0.2])) + assert isinstance(tensor_stats[OBS_IMAGE]["mean"], torch.Tensor) + assert isinstance(tensor_stats[OBS_IMAGE]["std"], torch.Tensor) + assert torch.allclose(tensor_stats[OBS_IMAGE]["mean"], torch.tensor([0.5, 0.5, 0.5])) + assert torch.allclose(tensor_stats[OBS_IMAGE]["std"], torch.tensor([0.2, 0.2, 0.2])) def test_tensor_conversion(): @@ -75,15 +76,15 @@ def test_scalar_conversion(): def test_list_conversion(): stats = { - "observation.state": { + OBS_STATE: { "min": [0.0, -1.0, -2.0], "max": [1.0, 1.0, 2.0], } } tensor_stats = to_tensor(stats) - assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0])) - assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0])) + assert torch.allclose(tensor_stats[OBS_STATE]["min"], torch.tensor([0.0, -1.0, -2.0])) + assert torch.allclose(tensor_stats[OBS_STATE]["max"], torch.tensor([1.0, 1.0, 2.0])) def test_unsupported_type(): @@ -99,8 +100,8 @@ def test_unsupported_type(): # Helper functions to create feature maps and norm maps def _create_observation_features(): return { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), } @@ -115,11 +116,11 @@ def _create_observation_norm_map(): @pytest.fixture def observation_stats(): return { - "observation.image": { + OBS_IMAGE: { "mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2]), }, - "observation.state": { + OBS_STATE: { "min": np.array([0.0, -1.0]), "max": np.array([1.0, 1.0]), }, @@ -136,8 +137,8 @@ def observation_normalizer(observation_stats): def test_mean_std_normalization(observation_normalizer): observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } transition = create_transition(observation=observation) @@ -146,12 +147,12 @@ def test_mean_std_normalization(observation_normalizer): # Check mean/std normalization expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 - assert torch.allclose(normalized_obs["observation.image"], expected_image) + assert torch.allclose(normalized_obs[OBS_IMAGE], expected_image) def test_min_max_normalization(observation_normalizer): observation = { - "observation.state": torch.tensor([0.5, 0.0]), + OBS_STATE: torch.tensor([0.5, 0.0]), } transition = create_transition(observation=observation) @@ -162,7 +163,7 @@ def test_min_max_normalization(observation_normalizer): # For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0 # For state[1]: 2 * (0.0 - (-1.0)) / (1.0 - (-1.0)) - 1 = 0.0 expected_state = torch.tensor([0.0, 0.0]) - assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6) + assert torch.allclose(normalized_obs[OBS_STATE], expected_state, atol=1e-6) def test_selective_normalization(observation_stats): @@ -172,12 +173,12 @@ def test_selective_normalization(observation_stats): features=features, norm_map=norm_map, stats=observation_stats, - normalize_observation_keys={"observation.image"}, + normalize_observation_keys={OBS_IMAGE}, ) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } transition = create_transition(observation=observation) @@ -185,9 +186,9 @@ def test_selective_normalization(observation_stats): normalized_obs = normalized_transition[TransitionKey.OBSERVATION] # Only image should be normalized - assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2) + assert torch.allclose(normalized_obs[OBS_IMAGE], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2) # State should remain unchanged - assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"]) + assert torch.allclose(normalized_obs[OBS_STATE], observation[OBS_STATE]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -196,26 +197,26 @@ def test_device_compatibility(observation_stats): norm_map = _create_observation_norm_map() normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=observation_stats) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]).cuda(), } transition = create_transition(observation=observation) normalized_transition = normalizer(transition) normalized_obs = normalized_transition[TransitionKey.OBSERVATION] - assert normalized_obs["observation.image"].device.type == "cuda" + assert normalized_obs[OBS_IMAGE].device.type == "cuda" def test_from_lerobot_dataset(): # Mock dataset mock_dataset = Mock() mock_dataset.meta.stats = { - "observation.image": {"mean": [0.5], "std": [0.2]}, + OBS_IMAGE: {"mean": [0.5], "std": [0.2]}, "action": {"mean": [0.0], "std": [1.0]}, } features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), "action": PolicyFeature(FeatureType.ACTION, (1,)), } norm_map = { @@ -226,7 +227,7 @@ def test_from_lerobot_dataset(): normalizer = NormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map) # Both observation and action statistics should be present in tensor stats - assert "observation.image" in normalizer._tensor_stats + assert OBS_IMAGE in normalizer._tensor_stats assert "action" in normalizer._tensor_stats @@ -242,13 +243,13 @@ def test_state_dict_save_load(observation_normalizer): new_normalizer.load_state_dict(state_dict) # Test that it works the same - observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + observation = {OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3])} transition = create_transition(observation=observation) result1 = observation_normalizer(transition)[TransitionKey.OBSERVATION] result2 = new_normalizer(transition)[TransitionKey.OBSERVATION] - assert torch.allclose(result1["observation.image"], result2["observation.image"]) + assert torch.allclose(result1[OBS_IMAGE], result2[OBS_IMAGE]) # Fixtures for ActionUnnormalizer tests @@ -375,11 +376,11 @@ def test_action_from_lerobot_dataset(): @pytest.fixture def full_stats(): return { - "observation.image": { + OBS_IMAGE: { "mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2]), }, - "observation.state": { + OBS_STATE: { "min": np.array([0.0, -1.0]), "max": np.array([1.0, 1.0]), }, @@ -392,8 +393,8 @@ def full_stats(): def _create_full_features(): return { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } @@ -415,8 +416,8 @@ def normalizer_processor(full_stats): def test_combined_normalization(normalizer_processor): observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } action = torch.tensor([1.0, -0.5]) transition = create_transition( @@ -434,7 +435,7 @@ def test_combined_normalization(normalizer_processor): # Check normalized observations processed_obs = processed_transition[TransitionKey.OBSERVATION] expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 - assert torch.allclose(processed_obs["observation.image"], expected_image) + assert torch.allclose(processed_obs[OBS_IMAGE], expected_image) # Check normalized action processed_action = processed_transition[TransitionKey.ACTION] @@ -455,11 +456,11 @@ def test_processor_from_lerobot_dataset(full_stats): norm_map = _create_full_norm_map() processor = NormalizerProcessorStep.from_lerobot_dataset( - mock_dataset, features, norm_map, normalize_observation_keys={"observation.image"} + mock_dataset, features, norm_map, normalize_observation_keys={OBS_IMAGE} ) - assert processor.normalize_observation_keys == {"observation.image"} - assert "observation.image" in processor._tensor_stats + assert processor.normalize_observation_keys == {OBS_IMAGE} + assert OBS_IMAGE in processor._tensor_stats assert "action" in processor._tensor_stats @@ -470,17 +471,17 @@ def test_get_config(full_stats): features=features, norm_map=norm_map, stats=full_stats, - normalize_observation_keys={"observation.image"}, + normalize_observation_keys={OBS_IMAGE}, eps=1e-6, ) config = processor.get_config() expected_config = { - "normalize_observation_keys": ["observation.image"], + "normalize_observation_keys": [OBS_IMAGE], "eps": 1e-6, "features": { - "observation.image": {"type": "VISUAL", "shape": (3, 96, 96)}, - "observation.state": {"type": "STATE", "shape": (2,)}, + OBS_IMAGE: {"type": "VISUAL", "shape": (3, 96, 96)}, + OBS_STATE: {"type": "STATE", "shape": (2,)}, "action": {"type": "ACTION", "shape": (2,)}, }, "norm_map": { @@ -499,8 +500,8 @@ def test_integration_with_robot_processor(normalizer_processor): ) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } action = torch.tensor([1.0, -0.5]) transition = create_transition( @@ -522,8 +523,8 @@ def test_integration_with_robot_processor(normalizer_processor): # Edge case tests def test_empty_observation(): - stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} - features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + stats = {OBS_IMAGE: {"mean": [0.5], "std": [0.2]}} + features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) @@ -534,37 +535,35 @@ def test_empty_observation(): def test_empty_stats(): - features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={}) - observation = {"observation.image": torch.tensor([0.5])} + observation = {OBS_IMAGE: torch.tensor([0.5])} transition = create_transition(observation=observation) result = normalizer(transition) # Should return observation unchanged since no stats are available - assert torch.allclose( - result[TransitionKey.OBSERVATION]["observation.image"], observation["observation.image"] - ) + assert torch.allclose(result[TransitionKey.OBSERVATION][OBS_IMAGE], observation[OBS_IMAGE]) def test_partial_stats(): """If statistics are incomplete, the value should pass through unchanged.""" - stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max) - features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + stats = {OBS_IMAGE: {"mean": [0.5]}} # Missing std / (min,max) + features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) - observation = {"observation.image": torch.tensor([0.7])} + observation = {OBS_IMAGE: torch.tensor([0.7])} transition = create_transition(observation=observation) processed = normalizer(transition)[TransitionKey.OBSERVATION] - assert torch.allclose(processed["observation.image"], observation["observation.image"]) + assert torch.allclose(processed[OBS_IMAGE], observation[OBS_IMAGE]) def test_missing_action_stats_no_error(): mock_dataset = Mock() - mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + mock_dataset.meta.stats = {OBS_IMAGE: {"mean": [0.5], "std": [0.2]}} - features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} processor = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map) @@ -580,7 +579,7 @@ def test_serialization_roundtrip(full_stats): features=features, norm_map=norm_map, stats=full_stats, - normalize_observation_keys={"observation.image"}, + normalize_observation_keys={OBS_IMAGE}, eps=1e-6, ) @@ -598,8 +597,8 @@ def test_serialization_roundtrip(full_stats): # Test that both processors work the same way observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } action = torch.tensor([1.0, -0.5]) transition = create_transition( @@ -617,8 +616,8 @@ def test_serialization_roundtrip(full_stats): # Compare results assert torch.allclose( - result1[TransitionKey.OBSERVATION]["observation.image"], - result2[TransitionKey.OBSERVATION]["observation.image"], + result1[TransitionKey.OBSERVATION][OBS_IMAGE], + result2[TransitionKey.OBSERVATION][OBS_IMAGE], ) assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) @@ -644,23 +643,23 @@ def test_serialization_roundtrip(full_stats): def test_identity_normalization_observations(): """Test that IDENTITY mode skips normalization for observations.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode FeatureType.STATE: NormalizationMode.MEAN_STD, # Normal mode for comparison } stats = { - "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, - "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + OBS_STATE: {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([1.0, -0.5]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([1.0, -0.5]), } transition = create_transition(observation=observation) @@ -668,11 +667,11 @@ def test_identity_normalization_observations(): normalized_obs = normalized_transition[TransitionKey.OBSERVATION] # Image should remain unchanged (IDENTITY) - assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + assert torch.allclose(normalized_obs[OBS_IMAGE], observation[OBS_IMAGE]) # State should be normalized (MEAN_STD) expected_state = (torch.tensor([1.0, -0.5]) - torch.tensor([0.0, 0.0])) / torch.tensor([1.0, 1.0]) - assert torch.allclose(normalized_obs["observation.state"], expected_state) + assert torch.allclose(normalized_obs[OBS_STATE], expected_state) def test_identity_normalization_actions(): @@ -695,23 +694,23 @@ def test_identity_normalization_actions(): def test_identity_unnormalization_observations(): """Test that IDENTITY mode skips unnormalization for observations.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.IDENTITY, # IDENTITY mode FeatureType.STATE: NormalizationMode.MIN_MAX, # Normal mode for comparison } stats = { - "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, - "observation.state": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, + OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + OBS_STATE: {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, } unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.0, -1.0]), # Normalized values in [-1, 1] + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.0, -1.0]), # Normalized values in [-1, 1] } transition = create_transition(observation=observation) @@ -719,13 +718,13 @@ def test_identity_unnormalization_observations(): unnormalized_obs = unnormalized_transition[TransitionKey.OBSERVATION] # Image should remain unchanged (IDENTITY) - assert torch.allclose(unnormalized_obs["observation.image"], observation["observation.image"]) + assert torch.allclose(unnormalized_obs[OBS_IMAGE], observation[OBS_IMAGE]) # State should be unnormalized (MIN_MAX) # (0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = 0.0 # (-1.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0) = -1.0 expected_state = torch.tensor([0.0, -1.0]) - assert torch.allclose(unnormalized_obs["observation.state"], expected_state) + assert torch.allclose(unnormalized_obs[OBS_STATE], expected_state) def test_identity_unnormalization_actions(): @@ -748,7 +747,7 @@ def test_identity_unnormalization_actions(): def test_identity_with_missing_stats(): """Test that IDENTITY mode works even when stats are missing.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { @@ -760,7 +759,7 @@ def test_identity_with_missing_stats(): normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) - observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + observation = {OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3])} action = torch.tensor([1.0, -0.5]) transition = create_transition(observation=observation, action=action) @@ -769,13 +768,13 @@ def test_identity_with_missing_stats(): unnormalized_transition = unnormalizer(transition) assert torch.allclose( - normalized_transition[TransitionKey.OBSERVATION]["observation.image"], - observation["observation.image"], + normalized_transition[TransitionKey.OBSERVATION][OBS_IMAGE], + observation[OBS_IMAGE], ) assert torch.allclose(normalized_transition[TransitionKey.ACTION], action) assert torch.allclose( - unnormalized_transition[TransitionKey.OBSERVATION]["observation.image"], - observation["observation.image"], + unnormalized_transition[TransitionKey.OBSERVATION][OBS_IMAGE], + observation[OBS_IMAGE], ) assert torch.allclose(unnormalized_transition[TransitionKey.ACTION], action) @@ -783,8 +782,8 @@ def test_identity_with_missing_stats(): def test_identity_mixed_with_other_modes(): """Test IDENTITY mode mixed with other normalization modes.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { @@ -793,16 +792,16 @@ def test_identity_mixed_with_other_modes(): FeatureType.ACTION: NormalizationMode.MIN_MAX, } stats = { - "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, # Will be ignored - "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, # Will be ignored + OBS_STATE: {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, "action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([1.0, -0.5]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([1.0, -0.5]), } action = torch.tensor([0.5, 0.0]) transition = create_transition(observation=observation, action=action) @@ -812,11 +811,11 @@ def test_identity_mixed_with_other_modes(): normalized_action = normalized_transition[TransitionKey.ACTION] # Image should remain unchanged (IDENTITY) - assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + assert torch.allclose(normalized_obs[OBS_IMAGE], observation[OBS_IMAGE]) # State should be normalized (MEAN_STD) expected_state = torch.tensor([1.0, -0.5]) # (x - 0) / 1 = x - assert torch.allclose(normalized_obs["observation.state"], expected_state) + assert torch.allclose(normalized_obs[OBS_STATE], expected_state) # Action should be normalized (MIN_MAX) to [-1, 1] # 2 * (0.5 - (-1)) / (1 - (-1)) - 1 = 2 * 1.5 / 2 - 1 = 0.5 @@ -828,23 +827,23 @@ def test_identity_mixed_with_other_modes(): def test_identity_defaults_when_not_in_norm_map(): """Test that IDENTITY is used as default when feature type not in norm_map.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), } norm_map = { FeatureType.STATE: NormalizationMode.MEAN_STD, # VISUAL not specified, should default to IDENTITY } stats = { - "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, - "observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + OBS_STATE: {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([1.0, -0.5]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([1.0, -0.5]), } transition = create_transition(observation=observation) @@ -852,17 +851,17 @@ def test_identity_defaults_when_not_in_norm_map(): normalized_obs = normalized_transition[TransitionKey.OBSERVATION] # Image should remain unchanged (defaults to IDENTITY) - assert torch.allclose(normalized_obs["observation.image"], observation["observation.image"]) + assert torch.allclose(normalized_obs[OBS_IMAGE], observation[OBS_IMAGE]) # State should be normalized (explicitly MEAN_STD) expected_state = torch.tensor([1.0, -0.5]) - assert torch.allclose(normalized_obs["observation.state"], expected_state) + assert torch.allclose(normalized_obs[OBS_STATE], expected_state) def test_identity_roundtrip(): """Test that IDENTITY normalization and unnormalization are true inverses.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { @@ -870,14 +869,14 @@ def test_identity_roundtrip(): FeatureType.ACTION: NormalizationMode.IDENTITY, } stats = { - "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, "action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) - original_observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + original_observation = {OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3])} original_action = torch.tensor([0.5, -0.2]) original_transition = create_transition(observation=original_observation, action=original_action) @@ -886,16 +885,14 @@ def test_identity_roundtrip(): roundtrip = unnormalizer(normalized) # Should be identical to original - assert torch.allclose( - roundtrip[TransitionKey.OBSERVATION]["observation.image"], original_observation["observation.image"] - ) + assert torch.allclose(roundtrip[TransitionKey.OBSERVATION][OBS_IMAGE], original_observation[OBS_IMAGE]) assert torch.allclose(roundtrip[TransitionKey.ACTION], original_action) def test_identity_config_serialization(): """Test that IDENTITY mode is properly saved and loaded in config.""" features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { @@ -903,7 +900,7 @@ def test_identity_config_serialization(): FeatureType.ACTION: NormalizationMode.MEAN_STD, } stats = { - "observation.image": {"mean": [0.5], "std": [0.2]}, + OBS_IMAGE: {"mean": [0.5], "std": [0.2]}, "action": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, } @@ -925,7 +922,7 @@ def test_identity_config_serialization(): ) # Test that both work the same way - observation = {"observation.image": torch.tensor([0.7])} + observation = {OBS_IMAGE: torch.tensor([0.7])} action = torch.tensor([1.0, -0.5]) transition = create_transition(observation=observation, action=action) @@ -934,15 +931,15 @@ def test_identity_config_serialization(): # Results should be identical assert torch.allclose( - result1[TransitionKey.OBSERVATION]["observation.image"], - result2[TransitionKey.OBSERVATION]["observation.image"], + result1[TransitionKey.OBSERVATION][OBS_IMAGE], + result2[TransitionKey.OBSERVATION][OBS_IMAGE], ) assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) # def test_unsupported_normalization_mode_error(): # """Test that unsupported normalization modes raise appropriate errors.""" -# features = {"observation.state": PolicyFeature(FeatureType.STATE, (2,))} +# features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (2,))} # # Create an invalid norm_map (this would never happen in practice, but tests error handling) # from enum import Enum @@ -953,14 +950,14 @@ def test_identity_config_serialization(): # # We can't actually pass an invalid enum to the processor due to type checking, # # but we can test the error by manipulating the norm_map after creation # norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} -# stats = {"observation.state": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}} +# stats = {OBS_STATE: {"mean": [0.0, 0.0], "std": [1.0, 1.0]}} # normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) # # Manually inject an invalid mode to test error handling # normalizer.norm_map[FeatureType.STATE] = "INVALID_MODE" -# observation = {"observation.state": torch.tensor([1.0, -0.5])} +# observation = {OBS_STATE: torch.tensor([1.0, -0.5])} # transition = create_transition(observation=observation) # with pytest.raises(ValueError, match="Unsupported normalization mode"): @@ -971,19 +968,19 @@ def test_hotswap_stats_basic_functionality(): """Test that hotswap_stats correctly updates stats in normalizer/unnormalizer steps.""" # Create initial stats initial_stats = { - "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } # Create new stats for hotswapping new_stats = { - "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, } # Create features and norm_map features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1021,15 +1018,15 @@ def test_hotswap_stats_basic_functionality(): def test_hotswap_stats_deep_copy(): """Test that hotswap_stats creates a deep copy and doesn't modify the original processor.""" initial_stats = { - "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, } new_stats = { - "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} @@ -1060,15 +1057,15 @@ def test_hotswap_stats_deep_copy(): def test_hotswap_stats_only_affects_normalizer_steps(): """Test that hotswap_stats only modifies NormalizerProcessorStep and UnnormalizerProcessorStep steps.""" stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, } new_stats = { - "observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])}, + OBS_IMAGE: {"mean": np.array([0.3]), "std": np.array([0.1])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} @@ -1093,13 +1090,13 @@ def test_hotswap_stats_only_affects_normalizer_steps(): def test_hotswap_stats_empty_stats(): """Test hotswap_stats with empty stats dictionary.""" initial_stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, } empty_stats = {} features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} @@ -1117,7 +1114,7 @@ def test_hotswap_stats_empty_stats(): def test_hotswap_stats_no_normalizer_steps(): """Test hotswap_stats with a processor that has no normalizer/unnormalizer steps.""" stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, } # Create processor with only identity steps @@ -1139,18 +1136,18 @@ def test_hotswap_stats_no_normalizer_steps(): def test_hotswap_stats_preserves_other_attributes(): """Test that hotswap_stats preserves other processor attributes like features and norm_map.""" initial_stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, } new_stats = { - "observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])}, + OBS_IMAGE: {"mean": np.array([0.3]), "std": np.array([0.1])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - normalize_observation_keys = {"observation.image"} + normalize_observation_keys = {OBS_IMAGE} eps = 1e-6 normalizer = NormalizerProcessorStep( @@ -1179,17 +1176,17 @@ def test_hotswap_stats_preserves_other_attributes(): def test_hotswap_stats_multiple_normalizer_types(): """Test hotswap_stats with multiple normalizer and unnormalizer steps.""" initial_stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, "action": {"min": np.array([-1.0]), "max": np.array([1.0])}, } new_stats = { - "observation.image": {"mean": np.array([0.3]), "std": np.array([0.1])}, + OBS_IMAGE: {"mean": np.array([0.3]), "std": np.array([0.1])}, "action": {"min": np.array([-2.0]), "max": np.array([2.0])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), } norm_map = { @@ -1224,12 +1221,12 @@ def test_hotswap_stats_multiple_normalizer_types(): def test_hotswap_stats_with_different_data_types(): """Test hotswap_stats with various data types in stats.""" initial_stats = { - "observation.image": {"mean": np.array([0.5]), "std": np.array([0.2])}, + OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, } # New stats with different data types (int, float, list, tuple) new_stats = { - "observation.image": { + OBS_IMAGE: { "mean": [0.3, 0.4, 0.5], # list "std": (0.1, 0.2, 0.3), # tuple "min": 0, # int @@ -1242,7 +1239,7 @@ def test_hotswap_stats_with_different_data_types(): } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1261,43 +1258,43 @@ def test_hotswap_stats_with_different_data_types(): # Check that tensor conversion worked correctly tensor_stats = new_processor.steps[0]._tensor_stats - assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor) - assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor) - assert isinstance(tensor_stats["observation.image"]["min"], torch.Tensor) - assert isinstance(tensor_stats["observation.image"]["max"], torch.Tensor) + assert isinstance(tensor_stats[OBS_IMAGE]["mean"], torch.Tensor) + assert isinstance(tensor_stats[OBS_IMAGE]["std"], torch.Tensor) + assert isinstance(tensor_stats[OBS_IMAGE]["min"], torch.Tensor) + assert isinstance(tensor_stats[OBS_IMAGE]["max"], torch.Tensor) assert isinstance(tensor_stats["action"]["mean"], torch.Tensor) assert isinstance(tensor_stats["action"]["std"], torch.Tensor) # Check values - torch.testing.assert_close(tensor_stats["observation.image"]["mean"], torch.tensor([0.3, 0.4, 0.5])) - torch.testing.assert_close(tensor_stats["observation.image"]["std"], torch.tensor([0.1, 0.2, 0.3])) - torch.testing.assert_close(tensor_stats["observation.image"]["min"], torch.tensor(0.0)) - torch.testing.assert_close(tensor_stats["observation.image"]["max"], torch.tensor(1.0)) + torch.testing.assert_close(tensor_stats[OBS_IMAGE]["mean"], torch.tensor([0.3, 0.4, 0.5])) + torch.testing.assert_close(tensor_stats[OBS_IMAGE]["std"], torch.tensor([0.1, 0.2, 0.3])) + torch.testing.assert_close(tensor_stats[OBS_IMAGE]["min"], torch.tensor(0.0)) + torch.testing.assert_close(tensor_stats[OBS_IMAGE]["max"], torch.tensor(1.0)) def test_hotswap_stats_functional_test(): """Test that hotswapped processor actually works functionally.""" # Create test data observation = { - "observation.image": torch.tensor([[[0.6, 0.7], [0.8, 0.9]], [[0.5, 0.6], [0.7, 0.8]]]), + OBS_IMAGE: torch.tensor([[[0.6, 0.7], [0.8, 0.9]], [[0.5, 0.6], [0.7, 0.8]]]), } action = torch.tensor([0.5, -0.5]) transition = create_transition(observation=observation, action=action) # Initial stats initial_stats = { - "observation.image": {"mean": np.array([0.5, 0.4]), "std": np.array([0.2, 0.3])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.4]), "std": np.array([0.2, 0.3])}, "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } # New stats new_stats = { - "observation.image": {"mean": np.array([0.3, 0.2]), "std": np.array([0.1, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.3, 0.2]), "std": np.array([0.1, 0.2])}, "action": {"mean": np.array([0.1, -0.1]), "std": np.array([0.5, 0.5])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(2, 2, 2)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(2, 2, 2)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1322,8 +1319,8 @@ def test_hotswap_stats_functional_test(): # Results should be different since normalization changed assert not torch.allclose( - original_result["observation"]["observation.image"], - new_result["observation"]["observation.image"], + original_result[OBS_STR][OBS_IMAGE], + new_result[OBS_STR][OBS_IMAGE], rtol=1e-3, atol=1e-3, ) @@ -1331,60 +1328,54 @@ def test_hotswap_stats_functional_test(): # Verify that the new processor is actually using the new stats by checking internal state assert new_processor.steps[0].stats == new_stats - assert torch.allclose( - new_processor.steps[0]._tensor_stats["observation.image"]["mean"], torch.tensor([0.3, 0.2]) - ) - assert torch.allclose( - new_processor.steps[0]._tensor_stats["observation.image"]["std"], torch.tensor([0.1, 0.2]) - ) + assert torch.allclose(new_processor.steps[0]._tensor_stats[OBS_IMAGE]["mean"], torch.tensor([0.3, 0.2])) + assert torch.allclose(new_processor.steps[0]._tensor_stats[OBS_IMAGE]["std"], torch.tensor([0.1, 0.2])) assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["mean"], torch.tensor([0.1, -0.1])) assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["std"], torch.tensor([0.5, 0.5])) # Test that normalization actually happens (output should not equal input) - assert not torch.allclose( - new_result["observation"]["observation.image"], observation["observation.image"] - ) + assert not torch.allclose(new_result[OBS_STR][OBS_IMAGE], observation[OBS_IMAGE]) assert not torch.allclose(new_result["action"], action) def test_zero_std_uses_eps(): """When std == 0, (x-mean)/(std+eps) is well-defined; x==mean should map to 0.""" - features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (1,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} - stats = {"observation.state": {"mean": np.array([0.5]), "std": np.array([0.0])}} + stats = {OBS_STATE: {"mean": np.array([0.5]), "std": np.array([0.0])}} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats, eps=1e-6) - observation = {"observation.state": torch.tensor([0.5])} # equals mean + observation = {OBS_STATE: torch.tensor([0.5])} # equals mean out = normalizer(create_transition(observation=observation)) - assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([0.0])) + assert torch.allclose(out[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([0.0])) def test_min_equals_max_maps_to_minus_one(): """When min == max, MIN_MAX path maps to -1 after [-1,1] scaling for x==min.""" - features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (1,))} norm_map = {FeatureType.STATE: NormalizationMode.MIN_MAX} - stats = {"observation.state": {"min": np.array([2.0]), "max": np.array([2.0])}} + stats = {OBS_STATE: {"min": np.array([2.0]), "max": np.array([2.0])}} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats, eps=1e-6) - observation = {"observation.state": torch.tensor([2.0])} + observation = {OBS_STATE: torch.tensor([2.0])} out = normalizer(create_transition(observation=observation)) - assert torch.allclose(out[TransitionKey.OBSERVATION]["observation.state"], torch.tensor([-1.0])) + assert torch.allclose(out[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([-1.0])) def test_action_normalized_despite_normalize_observation_keys(): """Action normalization is independent of normalize_observation_keys filter for observations.""" features = { - "observation.state": PolicyFeature(FeatureType.STATE, (1,)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (1,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD} stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} normalizer = NormalizerProcessorStep( - features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={"observation.state"} + features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={OBS_STATE} ) transition = create_transition( - observation={"observation.state": torch.tensor([3.0])}, action=torch.tensor([3.0, 3.0]) + observation={OBS_STATE: torch.tensor([3.0])}, action=torch.tensor([3.0, 3.0]) ) out = normalizer(transition) # (3-1)/2 = 1.0 ; (3-(-1))/4 = 1.0 @@ -1421,12 +1412,12 @@ def test_unnormalize_observations_mean_std_and_min_max(): def test_unknown_observation_keys_ignored(): - features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (1,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} - stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}} + stats = {OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])}} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) - obs = {"observation.state": torch.tensor([1.0]), "observation.unknown": torch.tensor([5.0])} + obs = {OBS_STATE: torch.tensor([1.0]), "observation.unknown": torch.tensor([5.0])} tr = create_transition(observation=obs) out = normalizer(tr) @@ -1447,13 +1438,13 @@ def test_batched_action_normalization(): def test_complementary_data_preservation(): - features = {"observation.state": PolicyFeature(FeatureType.STATE, (1,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (1,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} - stats = {"observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}} + stats = {OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])}} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) comp = {"existing": 123} - tr = create_transition(observation={"observation.state": torch.tensor([1.0])}, complementary_data=comp) + tr = create_transition(observation={OBS_STATE: torch.tensor([1.0])}, complementary_data=comp) out = normalizer(tr) new_comp = out[TransitionKey.COMPLEMENTARY_DATA] assert new_comp["existing"] == 123 @@ -1461,36 +1452,34 @@ def test_complementary_data_preservation(): def test_roundtrip_normalize_unnormalize_non_identity(): features = { - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MIN_MAX} stats = { - "observation.state": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}, + OBS_STATE: {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}, "action": {"min": np.array([-2.0, 0.0]), "max": np.array([2.0, 4.0])}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) # Add a time dimension in action for broadcasting check (B,T,D) - obs = {"observation.state": torch.tensor([[3.0, 3.0], [1.0, -1.0]])} + obs = {OBS_STATE: torch.tensor([[3.0, 3.0], [1.0, -1.0]])} act = torch.tensor([[[0.0, -1.0], [1.0, 1.0]]]) # shape (1,2,2) already in [-1,1] tr = create_transition(observation=obs, action=act) out = unnormalizer(normalizer(tr)) - assert torch.allclose( - out[TransitionKey.OBSERVATION]["observation.state"], obs["observation.state"], atol=1e-5 - ) + assert torch.allclose(out[TransitionKey.OBSERVATION][OBS_STATE], obs[OBS_STATE], atol=1e-5) assert torch.allclose(out[TransitionKey.ACTION], act, atol=1e-5) def test_dtype_adaptation_bfloat16_input_float32_normalizer(): """Test automatic dtype adaptation: NormalizerProcessor(float32) adapts to bfloat16 input → bfloat16 output""" - features = {"observation.state": PolicyFeature(FeatureType.STATE, (5,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (5,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} stats = { - "observation.state": { + OBS_STATE: { "mean": np.array([0.0, 0.0, 0.0, 0.0, 0.0]), "std": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), } @@ -1503,11 +1492,11 @@ def test_dtype_adaptation_bfloat16_input_float32_normalizer(): # Verify initial configuration assert normalizer.dtype == torch.float32 - for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + for stat_tensor in normalizer._tensor_stats[OBS_STATE].values(): assert stat_tensor.dtype == torch.float32 # Create bfloat16 input tensor - observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16)} + observation = {OBS_STATE: torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16)} transition = create_transition(observation=observation) # Process the transition @@ -1516,11 +1505,11 @@ def test_dtype_adaptation_bfloat16_input_float32_normalizer(): # Verify that: # 1. Stats were automatically adapted to bfloat16 assert normalizer.dtype == torch.bfloat16 - for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + for stat_tensor in normalizer._tensor_stats[OBS_STATE].values(): assert stat_tensor.dtype == torch.bfloat16 # 2. Output is in bfloat16 - output_tensor = result[TransitionKey.OBSERVATION]["observation.state"] + output_tensor = result[TransitionKey.OBSERVATION][OBS_STATE] assert output_tensor.dtype == torch.bfloat16 # 3. Normalization was applied correctly (mean should be close to original - mean) / std @@ -1540,18 +1529,18 @@ def test_stats_override_preservation_in_load_state_dict(): """ # Create original stats original_stats = { - "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } # Create override stats (what user wants to use) override_stats = { - "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1611,12 +1600,12 @@ def test_stats_without_override_loads_normally(): load_state_dict works as before. """ original_stats = { - "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1651,12 +1640,12 @@ def test_stats_without_override_loads_normally(): def test_stats_explicit_provided_flag_detection(): """Test that the _stats_explicitly_provided flag is set correctly in different scenarios.""" features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), } norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} # Test 1: Explicitly provided stats (non-empty dict) - stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + stats = {OBS_IMAGE: {"mean": [0.5], "std": [0.2]}} normalizer1 = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) assert normalizer1._stats_explicitly_provided is True @@ -1684,7 +1673,7 @@ def test_pipeline_from_pretrained_with_stats_overrides(): # Create test data features = { - "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)), + OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)), "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { @@ -1693,12 +1682,12 @@ def test_pipeline_from_pretrained_with_stats_overrides(): } original_stats = { - "observation.image": {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, + OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } override_stats = { - "observation.image": {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, + OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, } @@ -1740,7 +1729,7 @@ def test_pipeline_from_pretrained_with_stats_overrides(): # Test that the override stats are actually used in processing observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), } action = torch.tensor([1.0, -0.5]) transition = create_transition(observation=observation, action=action) @@ -1770,9 +1759,9 @@ def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32(): """Test policy pipeline scenario: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → bfloat16 output""" from lerobot.processor import DeviceProcessorStep - features = {"observation.state": PolicyFeature(FeatureType.STATE, (3,))} + features = {OBS_STATE: PolicyFeature(FeatureType.STATE, (3,))} norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} - stats = {"observation.state": {"mean": np.array([0.0, 0.0, 0.0]), "std": np.array([1.0, 1.0, 1.0])}} + stats = {OBS_STATE: {"mean": np.array([0.0, 0.0, 0.0]), "std": np.array([1.0, 1.0, 1.0])}} # Create pipeline: DeviceProcessor(bfloat16) → NormalizerProcessor(float32) device_processor = DeviceProcessorStep(device=str(auto_select_torch_device()), float_dtype="bfloat16") @@ -1784,18 +1773,18 @@ def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32(): assert normalizer.dtype == torch.float32 # Create CPU input - observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)} + observation = {OBS_STATE: torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)} transition = create_transition(observation=observation) # Step 1: DeviceProcessor converts to bfloat16 + moves to CUDA processed_1 = device_processor(transition) - intermediate_tensor = processed_1[TransitionKey.OBSERVATION]["observation.state"] + intermediate_tensor = processed_1[TransitionKey.OBSERVATION][OBS_STATE] assert intermediate_tensor.dtype == torch.bfloat16 assert intermediate_tensor.device.type == str(auto_select_torch_device()) # Step 2: NormalizerProcessor receives bfloat16 input and adapts final_result = normalizer(processed_1) - final_tensor = final_result[TransitionKey.OBSERVATION]["observation.state"] + final_tensor = final_result[TransitionKey.OBSERVATION][OBS_STATE] # Verify final output is bfloat16 (automatic adaptation worked) assert final_tensor.dtype == torch.bfloat16 @@ -1803,7 +1792,7 @@ def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32(): # Verify normalizer adapted its internal state assert normalizer.dtype == torch.bfloat16 - for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + for stat_tensor in normalizer._tensor_stats[OBS_STATE].values(): assert stat_tensor.dtype == torch.bfloat16 assert stat_tensor.device.type == str(auto_select_torch_device()) @@ -1821,8 +1810,8 @@ def test_stats_reconstruction_after_load_state_dict(): # Create normalizer with stats features = { - "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), "action": PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { @@ -1831,11 +1820,11 @@ def test_stats_reconstruction_after_load_state_dict(): FeatureType.ACTION: NormalizationMode.MEAN_STD, } stats = { - "observation.image": { + OBS_IMAGE: { "mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2]), }, - "observation.state": { + OBS_STATE: { "min": np.array([0.0, -1.0]), "max": np.array([1.0, 1.0]), }, @@ -1861,15 +1850,15 @@ def test_stats_reconstruction_after_load_state_dict(): assert new_normalizer.stats != {} # Check that all expected keys are present - assert "observation.image" in new_normalizer.stats - assert "observation.state" in new_normalizer.stats + assert OBS_IMAGE in new_normalizer.stats + assert OBS_STATE in new_normalizer.stats assert "action" in new_normalizer.stats # Check that values are correct (converted back from tensors) - np.testing.assert_allclose(new_normalizer.stats["observation.image"]["mean"], [0.5, 0.5, 0.5]) - np.testing.assert_allclose(new_normalizer.stats["observation.image"]["std"], [0.2, 0.2, 0.2]) - np.testing.assert_allclose(new_normalizer.stats["observation.state"]["min"], [0.0, -1.0]) - np.testing.assert_allclose(new_normalizer.stats["observation.state"]["max"], [1.0, 1.0]) + np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"], [0.5, 0.5, 0.5]) + np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"], [0.2, 0.2, 0.2]) + np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["min"], [0.0, -1.0]) + np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["max"], [1.0, 1.0]) np.testing.assert_allclose(new_normalizer.stats["action"]["mean"], [0.0, 0.0]) np.testing.assert_allclose(new_normalizer.stats["action"]["std"], [1.0, 2.0]) @@ -1885,8 +1874,8 @@ def test_stats_reconstruction_after_load_state_dict(): # Test 2: hotswap_stats should work new_stats = { - "observation.image": {"mean": [0.3, 0.3, 0.3], "std": [0.1, 0.1, 0.1]}, - "observation.state": {"min": [-1.0, -2.0], "max": [2.0, 2.0]}, + OBS_IMAGE: {"mean": [0.3, 0.3, 0.3], "std": [0.1, 0.1, 0.1]}, + OBS_STATE: {"min": [-1.0, -2.0], "max": [2.0, 2.0]}, "action": {"mean": [0.1, 0.1], "std": [0.5, 0.5]}, } @@ -1900,8 +1889,8 @@ def test_stats_reconstruction_after_load_state_dict(): # Test 3: The normalizer should work functionally the same as the original observation = { - "observation.image": torch.tensor([0.7, 0.5, 0.3]), - "observation.state": torch.tensor([0.5, 0.0]), + OBS_IMAGE: torch.tensor([0.7, 0.5, 0.3]), + OBS_STATE: torch.tensor([0.5, 0.0]), } action = torch.tensor([1.0, -0.5]) transition = create_transition(observation=observation, action=action) @@ -1911,11 +1900,11 @@ def test_stats_reconstruction_after_load_state_dict(): # Results should be identical (within floating point precision) torch.testing.assert_close( - original_result[TransitionKey.OBSERVATION]["observation.image"], - new_result[TransitionKey.OBSERVATION]["observation.image"], + original_result[TransitionKey.OBSERVATION][OBS_IMAGE], + new_result[TransitionKey.OBSERVATION][OBS_IMAGE], ) torch.testing.assert_close( - original_result[TransitionKey.OBSERVATION]["observation.state"], - new_result[TransitionKey.OBSERVATION]["observation.state"], + original_result[TransitionKey.OBSERVATION][OBS_STATE], + new_result[TransitionKey.OBSERVATION][OBS_STATE], ) torch.testing.assert_close(original_result[TransitionKey.ACTION], new_result[TransitionKey.ACTION]) diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py index 6abc9edef..11b58a66c 100644 --- a/tests/processor/test_observation_processor.py +++ b/tests/processor/test_observation_processor.py @@ -39,8 +39,8 @@ def test_process_single_image(): processed_obs = result[TransitionKey.OBSERVATION] # Check that the image was processed correctly - assert "observation.image" in processed_obs - processed_img = processed_obs["observation.image"] + assert OBS_IMAGE in processed_obs + processed_img = processed_obs[OBS_IMAGE] # Check shape: should be (1, 3, 64, 64) - batch, channels, height, width assert processed_img.shape == (1, 3, 64, 64) @@ -66,12 +66,12 @@ def test_process_image_dict(): processed_obs = result[TransitionKey.OBSERVATION] # Check that both images were processed - assert "observation.images.camera1" in processed_obs - assert "observation.images.camera2" in processed_obs + assert f"{OBS_IMAGES}.camera1" in processed_obs + assert f"{OBS_IMAGES}.camera2" in processed_obs # Check shapes - assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32) - assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48) + assert processed_obs[f"{OBS_IMAGES}.camera1"].shape == (1, 3, 32, 32) + assert processed_obs[f"{OBS_IMAGES}.camera2"].shape == (1, 3, 48, 48) def test_process_batched_image(): @@ -88,7 +88,7 @@ def test_process_batched_image(): processed_obs = result[TransitionKey.OBSERVATION] # Check that batch dimension is preserved - assert processed_obs["observation.image"].shape == (2, 3, 64, 64) + assert processed_obs[OBS_IMAGE].shape == (2, 3, 64, 64) def test_invalid_image_format(): @@ -173,10 +173,10 @@ def test_process_environment_state(): processed_obs = result[TransitionKey.OBSERVATION] # Check that environment_state was renamed and processed - assert "observation.environment_state" in processed_obs + assert OBS_ENV_STATE in processed_obs assert "environment_state" not in processed_obs - processed_state = processed_obs["observation.environment_state"] + processed_state = processed_obs[OBS_ENV_STATE] assert processed_state.shape == (1, 3) # Batch dimension added assert processed_state.dtype == torch.float32 torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]])) @@ -194,10 +194,10 @@ def test_process_agent_pos(): processed_obs = result[TransitionKey.OBSERVATION] # Check that agent_pos was renamed and processed - assert "observation.state" in processed_obs + assert OBS_STATE in processed_obs assert "agent_pos" not in processed_obs - processed_state = processed_obs["observation.state"] + processed_state = processed_obs[OBS_STATE] assert processed_state.shape == (1, 3) # Batch dimension added assert processed_state.dtype == torch.float32 torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]])) @@ -217,8 +217,8 @@ def test_process_batched_states(): processed_obs = result[TransitionKey.OBSERVATION] # Check that batch dimensions are preserved - assert processed_obs["observation.environment_state"].shape == (2, 2) - assert processed_obs["observation.state"].shape == (2, 2) + assert processed_obs[OBS_ENV_STATE].shape == (2, 2) + assert processed_obs[OBS_STATE].shape == (2, 2) def test_process_both_states(): @@ -235,8 +235,8 @@ def test_process_both_states(): processed_obs = result[TransitionKey.OBSERVATION] # Check that both states were processed - assert "observation.environment_state" in processed_obs - assert "observation.state" in processed_obs + assert OBS_ENV_STATE in processed_obs + assert OBS_STATE in processed_obs # Check that original keys were removed assert "environment_state" not in processed_obs @@ -281,12 +281,12 @@ def test_complete_observation_processing(): processed_obs = result[TransitionKey.OBSERVATION] # Check that image was processed - assert "observation.image" in processed_obs - assert processed_obs["observation.image"].shape == (1, 3, 32, 32) + assert OBS_IMAGE in processed_obs + assert processed_obs[OBS_IMAGE].shape == (1, 3, 32, 32) # Check that states were processed - assert "observation.environment_state" in processed_obs - assert "observation.state" in processed_obs + assert OBS_ENV_STATE in processed_obs + assert OBS_STATE in processed_obs # Check that original keys were removed assert "pixels" not in processed_obs @@ -308,7 +308,7 @@ def test_image_only_processing(): result = processor(transition) processed_obs = result[TransitionKey.OBSERVATION] - assert "observation.image" in processed_obs + assert OBS_IMAGE in processed_obs assert len(processed_obs) == 1 @@ -323,7 +323,7 @@ def test_state_only_processing(): result = processor(transition) processed_obs = result[TransitionKey.OBSERVATION] - assert "observation.state" in processed_obs + assert OBS_STATE in processed_obs assert "agent_pos" not in processed_obs @@ -504,7 +504,7 @@ def test_state_processor_features_prefixed_inputs(policy_feature_factory): proc = VanillaObservationProcessorStep() features = { PipelineFeatureType.OBSERVATION: { - "observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), + OBS_ENV_STATE: policy_feature_factory(FeatureType.STATE, (2,)), "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), }, } @@ -513,7 +513,7 @@ def test_state_processor_features_prefixed_inputs(policy_feature_factory): assert ( OBS_ENV_STATE in out[PipelineFeatureType.OBSERVATION] and out[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE] - == features[PipelineFeatureType.OBSERVATION]["observation.environment_state"] + == features[PipelineFeatureType.OBSERVATION][OBS_ENV_STATE] ) assert ( OBS_STATE in out[PipelineFeatureType.OBSERVATION] diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 0d17fed00..6d056e4dc 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -35,6 +35,7 @@ from lerobot.processor import ( TransitionKey, ) from lerobot.processor.converters import create_transition, identity_transition +from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE from tests.conftest import assert_contract_is_typed @@ -255,7 +256,7 @@ def test_step_through_with_dict(): pipeline = DataProcessorPipeline([step1, step2]) batch = { - "observation.image": None, + OBS_IMAGE: None, "action": None, "next.reward": 0.0, "next.done": False, @@ -1840,7 +1841,7 @@ def test_save_load_with_custom_converter_functions(): # Verify it uses default converters by checking with standard batch format batch = { - "observation.image": torch.randn(1, 3, 32, 32), + OBS_IMAGE: torch.randn(1, 3, 32, 32), "action": torch.randn(1, 7), "next.reward": torch.tensor([1.0]), "next.done": torch.tensor([False]), @@ -1851,7 +1852,7 @@ def test_save_load_with_custom_converter_functions(): # Should work with standard format (wouldn't work with custom converter) result = loaded(batch) # With new behavior, default to_output is _default_transition_to_batch, so result is batch dict - assert "observation.image" in result + assert OBS_IMAGE in result class NonCompliantStep: @@ -2075,10 +2076,10 @@ class AddObservationStateFeatures(ProcessorStep): self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: # State features (mix EE and a joint state) - features[PipelineFeatureType.OBSERVATION]["observation.state.ee.x"] = float - features[PipelineFeatureType.OBSERVATION]["observation.state.j1.pos"] = float + features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.ee.x"] = float + features[PipelineFeatureType.OBSERVATION][f"{OBS_STATE}.j1.pos"] = float if self.add_front_image: - features[PipelineFeatureType.OBSERVATION]["observation.images.front"] = self.front_image_shape + features[PipelineFeatureType.OBSERVATION][f"{OBS_IMAGES}.front"] = self.front_image_shape return features @@ -2094,7 +2095,7 @@ def test_aggregate_joint_action_only(): ) # Expect only "action" with joint names - assert "action" in out and "observation.state" not in out + assert "action" in out and OBS_STATE not in out assert out["action"]["dtype"] == "float32" assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"} assert out["action"]["shape"] == (len(out["action"]["names"]),) @@ -2108,7 +2109,7 @@ def test_aggregate_ee_action_and_observation_with_videos(): pipeline=rp, initial_features={PipelineFeatureType.OBSERVATION: initial, PipelineFeatureType.ACTION: {}}, use_videos=True, - patterns=["action.ee", "observation.state"], + patterns=["action.ee", OBS_STATE], ) # Action should pack only EE names @@ -2117,13 +2118,13 @@ def test_aggregate_ee_action_and_observation_with_videos(): assert out["action"]["dtype"] == "float32" # Observation state should pack both ee.x and j1.pos as a vector - assert "observation.state" in out - assert set(out["observation.state"]["names"]) == {"ee.x", "j1.pos"} - assert out["observation.state"]["dtype"] == "float32" + assert OBS_STATE in out + assert set(out[OBS_STATE]["names"]) == {"ee.x", "j1.pos"} + assert out[OBS_STATE]["dtype"] == "float32" # Cameras from initial_features appear as videos for cam in ("front", "side"): - key = f"observation.images.{cam}" + key = f"{OBS_IMAGES}.{cam}" assert key in out assert out[key]["dtype"] == "video" assert out[key]["shape"] == initial[cam] @@ -2156,8 +2157,8 @@ def test_aggregate_images_when_use_videos_false(): patterns=None, ) - key = "observation.images.back" - key_front = "observation.images.front" + key = f"{OBS_IMAGES}.back" + key_front = f"{OBS_IMAGES}.front" assert key not in out assert key_front not in out @@ -2173,8 +2174,8 @@ def test_aggregate_images_when_use_videos_true(): patterns=None, ) - key = "observation.images.front" - key_back = "observation.images.back" + key = f"{OBS_IMAGES}.front" + key_back = f"{OBS_IMAGES}.back" assert key in out assert key_back in out assert out[key]["dtype"] == "video" @@ -2194,9 +2195,9 @@ def test_initial_camera_not_overridden_by_step_image(): pipeline=rp, initial_features={PipelineFeatureType.ACTION: {}, PipelineFeatureType.OBSERVATION: initial}, use_videos=True, - patterns=["observation.images.front"], + patterns=[f"{OBS_IMAGES}.front"], ) - key = "observation.images.front" + key = f"{OBS_IMAGES}.front" assert key in out assert out[key]["shape"] == (240, 320, 3) # from the step, not from initial diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py index 5f2b48576..c6aa303f1 100644 --- a/tests/processor/test_rename_processor.py +++ b/tests/processor/test_rename_processor.py @@ -28,6 +28,7 @@ from lerobot.processor import ( ) from lerobot.processor.converters import create_transition, identity_transition from lerobot.processor.rename_processor import rename_stats +from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE from tests.conftest import assert_contract_is_typed @@ -121,13 +122,13 @@ def test_overlapping_rename(): def test_partial_rename(): """Test renaming only some keys.""" rename_map = { - "observation.state": "observation.proprio_state", - "pixels": "observation.image", + OBS_STATE: "observation.proprio_state", + "pixels": OBS_IMAGE, } processor = RenameObservationsProcessorStep(rename_map=rename_map) observation = { - "observation.state": torch.randn(10), + OBS_STATE: torch.randn(10), "pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8), "reward": 1.0, "info": {"episode": 1}, @@ -139,8 +140,8 @@ def test_partial_rename(): # Check renamed keys assert "observation.proprio_state" in processed_obs - assert "observation.image" in processed_obs - assert "observation.state" not in processed_obs + assert OBS_IMAGE in processed_obs + assert OBS_STATE not in processed_obs assert "pixels" not in processed_obs # Check unchanged keys @@ -174,8 +175,8 @@ def test_state_dict(): def test_integration_with_robot_processor(): """Test integration with RobotProcessor pipeline.""" rename_map = { - "agent_pos": "observation.state", - "pixels": "observation.image", + "agent_pos": OBS_STATE, + "pixels": OBS_IMAGE, } rename_processor = RenameObservationsProcessorStep(rename_map=rename_map) @@ -196,8 +197,8 @@ def test_integration_with_robot_processor(): processed_obs = result[TransitionKey.OBSERVATION] # Check renaming worked through pipeline - assert "observation.state" in processed_obs - assert "observation.image" in processed_obs + assert OBS_STATE in processed_obs + assert OBS_IMAGE in processed_obs assert "agent_pos" not in processed_obs assert "pixels" not in processed_obs assert processed_obs["other_data"] == "preserve_me" @@ -210,8 +211,8 @@ def test_integration_with_robot_processor(): def test_save_and_load_pretrained(): """Test saving and loading processor with RobotProcessor.""" rename_map = { - "old_state": "observation.state", - "old_image": "observation.image", + "old_state": OBS_STATE, + "old_image": OBS_IMAGE, } processor = RenameObservationsProcessorStep(rename_map=rename_map) pipeline = DataProcessorPipeline([processor], name="TestRenameProcessorStep") @@ -253,10 +254,10 @@ def test_save_and_load_pretrained(): result = loaded_pipeline(transition) processed_obs = result[TransitionKey.OBSERVATION] - assert "observation.state" in processed_obs - assert "observation.image" in processed_obs - assert processed_obs["observation.state"] == [1, 2, 3] - assert processed_obs["observation.image"] == "image_data" + assert OBS_STATE in processed_obs + assert OBS_IMAGE in processed_obs + assert processed_obs[OBS_STATE] == [1, 2, 3] + assert processed_obs[OBS_IMAGE] == "image_data" def test_registry_functionality(): @@ -317,8 +318,8 @@ def test_chained_rename_processors(): # Second processor: rename to final format processor2 = RenameObservationsProcessorStep( rename_map={ - "agent_position": "observation.state", - "camera_image": "observation.image", + "agent_position": OBS_STATE, + "camera_image": OBS_IMAGE, } ) @@ -342,8 +343,8 @@ def test_chained_rename_processors(): # After second processor final_obs = results[2][TransitionKey.OBSERVATION] - assert "observation.state" in final_obs - assert "observation.image" in final_obs + assert OBS_STATE in final_obs + assert OBS_IMAGE in final_obs assert final_obs["extra"] == "keep_me" # Original keys should be gone @@ -356,15 +357,15 @@ def test_chained_rename_processors(): def test_nested_observation_rename(): """Test renaming with nested observation structures.""" rename_map = { - "observation.images.left": "observation.camera.left_view", - "observation.images.right": "observation.camera.right_view", + f"{OBS_IMAGES}.left": "observation.camera.left_view", + f"{OBS_IMAGES}.right": "observation.camera.right_view", "observation.proprio": "observation.proprioception", } processor = RenameObservationsProcessorStep(rename_map=rename_map) observation = { - "observation.images.left": torch.randn(3, 64, 64), - "observation.images.right": torch.randn(3, 64, 64), + f"{OBS_IMAGES}.left": torch.randn(3, 64, 64), + f"{OBS_IMAGES}.right": torch.randn(3, 64, 64), "observation.proprio": torch.randn(7), "observation.gripper": torch.tensor([0.0]), # Not renamed } @@ -382,8 +383,8 @@ def test_nested_observation_rename(): assert "observation.gripper" in processed_obs # Check old keys removed - assert "observation.images.left" not in processed_obs - assert "observation.images.right" not in processed_obs + assert f"{OBS_IMAGES}.left" not in processed_obs + assert f"{OBS_IMAGES}.right" not in processed_obs assert "observation.proprio" not in processed_obs @@ -464,7 +465,7 @@ def test_features_chained_processors(policy_feature_factory): # Chain two rename processors at the contract level processor1 = RenameObservationsProcessorStep(rename_map={"pos": "agent_position", "img": "camera_image"}) processor2 = RenameObservationsProcessorStep( - rename_map={"agent_position": "observation.state", "camera_image": "observation.image"} + rename_map={"agent_position": OBS_STATE, "camera_image": OBS_IMAGE} ) pipeline = DataProcessorPipeline([processor1, processor2]) @@ -477,27 +478,21 @@ def test_features_chained_processors(policy_feature_factory): } out = pipeline.transform_features(initial_features=spec) - assert set(out[PipelineFeatureType.OBSERVATION]) == {"observation.state", "observation.image", "extra"} - assert ( - out[PipelineFeatureType.OBSERVATION]["observation.state"] - == spec[PipelineFeatureType.OBSERVATION]["pos"] - ) - assert ( - out[PipelineFeatureType.OBSERVATION]["observation.image"] - == spec[PipelineFeatureType.OBSERVATION]["img"] - ) + assert set(out[PipelineFeatureType.OBSERVATION]) == {OBS_STATE, OBS_IMAGE, "extra"} + assert out[PipelineFeatureType.OBSERVATION][OBS_STATE] == spec[PipelineFeatureType.OBSERVATION]["pos"] + assert out[PipelineFeatureType.OBSERVATION][OBS_IMAGE] == spec[PipelineFeatureType.OBSERVATION]["img"] assert out[PipelineFeatureType.OBSERVATION]["extra"] == spec[PipelineFeatureType.OBSERVATION]["extra"] assert_contract_is_typed(out) def test_rename_stats_basic(): orig = { - "observation.state": {"mean": np.array([0.0]), "std": np.array([1.0])}, + OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])}, "action": {"mean": np.array([0.0])}, } - mapping = {"observation.state": "observation.robot_state"} + mapping = {OBS_STATE: "observation.robot_state"} renamed = rename_stats(orig, mapping) - assert "observation.robot_state" in renamed and "observation.state" not in renamed + assert "observation.robot_state" in renamed and OBS_STATE not in renamed # Ensure deep copy: mutate original and verify renamed unaffected - orig["observation.state"]["mean"][0] = 42.0 + orig[OBS_STATE]["mean"][0] = 42.0 assert renamed["observation.robot_state"]["mean"][0] != 42.0 diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index 9e6c8de2f..35bbcfd8a 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -11,7 +11,7 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey from lerobot.processor.converters import create_transition, identity_transition -from lerobot.utils.constants import OBS_LANGUAGE +from lerobot.utils.constants import OBS_IMAGE, OBS_LANGUAGE, OBS_STATE from tests.utils import require_package @@ -503,16 +503,14 @@ def test_features_basic(): processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=128) input_features = { - PipelineFeatureType.OBSERVATION: { - "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)) - }, + PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, } output_features = processor.transform_features(input_features) # Check that original features are preserved - assert "observation.state" in output_features[PipelineFeatureType.OBSERVATION] + assert OBS_STATE in output_features[PipelineFeatureType.OBSERVATION] assert "action" in output_features[PipelineFeatureType.ACTION] # Check that tokenized features are added @@ -797,7 +795,7 @@ def test_device_detection_cpu(): processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) # Create transition with CPU tensors - observation = {"observation.state": torch.randn(10)} # CPU tensor + observation = {OBS_STATE: torch.randn(10)} # CPU tensor action = torch.randn(5) # CPU tensor transition = create_transition( observation=observation, action=action, complementary_data={"task": "test task"} @@ -821,7 +819,7 @@ def test_device_detection_cuda(): processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) # Create transition with CUDA tensors - observation = {"observation.state": torch.randn(10).cuda()} # CUDA tensor + observation = {OBS_STATE: torch.randn(10).cuda()} # CUDA tensor action = torch.randn(5).cuda() # CUDA tensor transition = create_transition( observation=observation, action=action, complementary_data={"task": "test task"} @@ -847,7 +845,7 @@ def test_device_detection_multi_gpu(): # Test with tensors on cuda:1 device = torch.device("cuda:1") - observation = {"observation.state": torch.randn(10).to(device)} + observation = {OBS_STATE: torch.randn(10).to(device)} action = torch.randn(5).to(device) transition = create_transition( observation=observation, action=action, complementary_data={"task": "multi gpu test"} @@ -943,7 +941,7 @@ def test_device_detection_preserves_dtype(): processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) # Create transition with float tensor (to test dtype isn't affected) - observation = {"observation.state": torch.randn(10, dtype=torch.float16)} + observation = {OBS_STATE: torch.randn(10, dtype=torch.float16)} transition = create_transition(observation=observation, complementary_data={"task": "dtype test"}) result = processor(transition) @@ -977,7 +975,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer): # Start with CPU tensors transition = create_transition( - observation={"observation.state": torch.randn(10)}, # CPU + observation={OBS_STATE: torch.randn(10)}, # CPU action=torch.randn(5), # CPU complementary_data={"task": "pipeline test"}, ) @@ -985,7 +983,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer): result = robot_processor(transition) # All tensors should end up on CUDA (moved by DeviceProcessorStep) - assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda" + assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda" assert result[TransitionKey.ACTION].device.type == "cuda" # Tokenized tensors should also be on CUDA @@ -1005,8 +1003,8 @@ def test_simulated_accelerate_scenario(): # Simulate Accelerate scenario: batch already on GPU device = torch.device("cuda:0") observation = { - "observation.state": torch.randn(1, 10).to(device), # Batched, on GPU - "observation.image": torch.randn(1, 3, 224, 224).to(device), # Batched, on GPU + OBS_STATE: torch.randn(1, 10).to(device), # Batched, on GPU + OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), # Batched, on GPU } action = torch.randn(1, 5).to(device) # Batched, on GPU diff --git a/tests/rl/test_actor.py b/tests/rl/test_actor.py index aa9913bb2..ec67f1889 100644 --- a/tests/rl/test_actor.py +++ b/tests/rl/test_actor.py @@ -21,6 +21,7 @@ import pytest import torch from torch.multiprocessing import Event, Queue +from lerobot.utils.constants import OBS_STR from lerobot.utils.transition import Transition from tests.utils import require_package @@ -110,12 +111,12 @@ def test_push_transitions_to_transport_queue(): transitions = [] for i in range(3): transition = Transition( - state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)}, action=torch.randn(5), reward=torch.tensor(1.0 + i), done=torch.tensor(False), truncated=torch.tensor(False), - next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + next_state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)}, complementary_info={"step": torch.tensor(i)}, ) transitions.append(transition) diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py index 43a6b0957..5d95dee04 100644 --- a/tests/rl/test_actor_learner.py +++ b/tests/rl/test_actor_learner.py @@ -24,6 +24,7 @@ from torch.multiprocessing import Event, Queue from lerobot.configs.train import TrainRLServerPipelineConfig from lerobot.policies.sac.configuration_sac import SACConfig +from lerobot.utils.constants import OBS_STR from lerobot.utils.transition import Transition from tests.utils import require_package @@ -33,12 +34,12 @@ def create_test_transitions(count: int = 3) -> list[Transition]: transitions = [] for i in range(count): transition = Transition( - state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)}, action=torch.randn(5), reward=torch.tensor(1.0 + i), done=torch.tensor(i == count - 1), # Last transition is done truncated=torch.tensor(False), - next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + next_state={OBS_STR: torch.randn(3, 64, 64), "state": torch.randn(10)}, complementary_info={"step": torch.tensor(i), "episode_id": i // 2}, ) transitions.append(transition) diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index b5254f393..6820d321f 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -22,11 +22,12 @@ import torch from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized +from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR from tests.fixtures.constants import DUMMY_REPO_ID def state_dims() -> list[str]: - return ["observation.image", "observation.state"] + return [OBS_IMAGE, OBS_STATE] @pytest.fixture @@ -61,10 +62,10 @@ def create_random_image() -> torch.Tensor: def create_dummy_transition() -> dict: return { - "observation.image": create_random_image(), + OBS_IMAGE: create_random_image(), "action": torch.randn(4), "reward": torch.tensor(1.0), - "observation.state": torch.randn( + OBS_STATE: torch.randn( 10, ), "done": torch.tensor(False), @@ -98,8 +99,8 @@ def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayB def create_dummy_state() -> dict: return { - "observation.image": create_random_image(), - "observation.state": torch.randn( + OBS_IMAGE: create_random_image(), + OBS_STATE: torch.randn( 10, ), } @@ -180,7 +181,7 @@ def test_empty_buffer_sample_raises_error(replay_buffer): def test_zero_capacity_buffer_raises_error(): with pytest.raises(ValueError, match="Capacity must be greater than 0."): - ReplayBuffer(0, "cpu", ["observation", "next_observation"]) + ReplayBuffer(0, "cpu", [OBS_STR, "next_observation"]) def test_add_transition(replay_buffer, dummy_state, dummy_action): @@ -203,7 +204,7 @@ def test_add_transition(replay_buffer, dummy_state, dummy_action): def test_add_over_capacity(): - replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"]) + replay_buffer = ReplayBuffer(2, "cpu", [OBS_STR, "next_observation"]) dummy_state_1 = create_dummy_state() dummy_action_1 = create_dummy_action() @@ -373,7 +374,7 @@ def test_to_lerobot_dataset(tmp_path): assert ds.num_frames == 4 for j, value in enumerate(ds): - print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j])) + print(torch.equal(value[OBS_IMAGE], buffer.next_states[OBS_IMAGE][j])) for i in range(len(ds)): for feature, value in ds[i].items(): @@ -383,12 +384,12 @@ def test_to_lerobot_dataset(tmp_path): assert torch.equal(value, buffer.rewards[i]) elif feature == "next.done": assert torch.equal(value, buffer.dones[i]) - elif feature == "observation.image": + elif feature == OBS_IMAGE: # Tensor -> numpy is not precise, so we have some diff there # TODO: Check and fix it - torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003) - elif feature == "observation.state": - assert torch.equal(value, buffer.states["observation.state"][i]) + torch.testing.assert_close(value, buffer.states[OBS_IMAGE][i], rtol=0.3, atol=0.003) + elif feature == OBS_STATE: + assert torch.equal(value, buffer.states[OBS_STATE][i]) def test_from_lerobot_dataset(tmp_path): @@ -436,14 +437,14 @@ def test_from_lerobot_dataset(tmp_path): ) assert torch.equal( - replay_buffer.states["observation.state"][: len(replay_buffer)], - reconverted_buffer.states["observation.state"][: len(replay_buffer)], + replay_buffer.states[OBS_STATE][: len(replay_buffer)], + reconverted_buffer.states[OBS_STATE][: len(replay_buffer)], ), "State should be the same after converting to dataset and return back" for i in range(4): torch.testing.assert_close( - replay_buffer.states["observation.image"][i], - reconverted_buffer.states["observation.image"][i], + replay_buffer.states[OBS_IMAGE][i], + reconverted_buffer.states[OBS_IMAGE][i], rtol=0.4, atol=0.004, ) @@ -454,16 +455,16 @@ def test_from_lerobot_dataset(tmp_path): next_index = (i + 1) % 4 torch.testing.assert_close( - replay_buffer.states["observation.image"][next_index], - reconverted_buffer.next_states["observation.image"][i], + replay_buffer.states[OBS_IMAGE][next_index], + reconverted_buffer.next_states[OBS_IMAGE][i], rtol=0.4, atol=0.004, ) for i in range(2, 4): assert torch.equal( - replay_buffer.states["observation.state"][i], - reconverted_buffer.next_states["observation.state"][i], + replay_buffer.states[OBS_STATE][i], + reconverted_buffer.next_states[OBS_STATE][i], ) @@ -563,10 +564,8 @@ def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_functio replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False) sampled_transitions = replay_buffer.sample(1) - assert torch.all(sampled_transitions["state"]["observation.image"] == 10), ( - "Image augmentations should be applied" - ) - assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), ( + assert torch.all(sampled_transitions["state"][OBS_IMAGE] == 10), "Image augmentations should be applied" + assert torch.all(sampled_transitions["next_state"][OBS_IMAGE] == 10), ( "Image augmentations should be applied" ) @@ -580,8 +579,8 @@ def test_check_image_augmentations_with_drq_and_default_image_augmentation_funct # Let's check that it doesn't fail and shapes are correct sampled_transitions = replay_buffer.sample(1) - assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84) - assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84) + assert sampled_transitions["state"][OBS_IMAGE].shape == (1, 3, 84, 84) + assert sampled_transitions["next_state"][OBS_IMAGE].shape == (1, 3, 84, 84) def test_random_crop_vectorized_basic(): @@ -620,7 +619,7 @@ def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer: buffer = ReplayBuffer( capacity=capacity, device="cpu", - state_keys=["observation.image", "observation.state"], + state_keys=[OBS_IMAGE, OBS_STATE], storage_device="cpu", ) @@ -628,8 +627,8 @@ def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer: img = torch.ones(3, 128, 128) * i state_vec = torch.arange(11).float() + i state = { - "observation.image": img, - "observation.state": state_vec, + OBS_IMAGE: img, + OBS_STATE: state_vec, } buffer.add( state=state, @@ -648,14 +647,14 @@ def test_async_iterator_shapes_basic(): iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=1) batch = next(iterator) - images = batch["state"]["observation.image"] - states = batch["state"]["observation.state"] + images = batch["state"][OBS_IMAGE] + states = batch["state"][OBS_STATE] assert images.shape == (batch_size, 3, 128, 128) assert states.shape == (batch_size, 11) - next_images = batch["next_state"]["observation.image"] - next_states = batch["next_state"]["observation.state"] + next_images = batch["next_state"][OBS_IMAGE] + next_states = batch["next_state"][OBS_STATE] assert next_images.shape == (batch_size, 3, 128, 128) assert next_states.shape == (batch_size, 11) @@ -668,13 +667,13 @@ def test_async_iterator_multiple_iterations(): for _ in range(5): batch = next(iterator) - images = batch["state"]["observation.image"] - states = batch["state"]["observation.state"] + images = batch["state"][OBS_IMAGE] + states = batch["state"][OBS_STATE] assert images.shape == (batch_size, 3, 128, 128) assert states.shape == (batch_size, 11) - next_images = batch["next_state"]["observation.image"] - next_states = batch["next_state"]["observation.state"] + next_images = batch["next_state"][OBS_IMAGE] + next_states = batch["next_state"][OBS_STATE] assert next_images.shape == (batch_size, 3, 128, 128) assert next_states.shape == (batch_size, 11) diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py index 29b7bf70a..65a97c6a3 100644 --- a/tests/utils/test_visualization_utils.py +++ b/tests/utils/test_visualization_utils.py @@ -6,6 +6,7 @@ import numpy as np import pytest from lerobot.processor import TransitionKey +from lerobot.utils.constants import OBS_STATE @pytest.fixture @@ -72,7 +73,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): # Build EnvTransition dict obs = { - "observation.state.temperature": np.float32(25.0), + f"{OBS_STATE}.temperature": np.float32(25.0), # CHW image should be converted to HWC for rr.Image "observation.camera": np.zeros((3, 10, 20), dtype=np.uint8), } @@ -97,7 +98,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): # - action.throttle -> Scalar # - action.vector_0, action.vector_1 -> Scalars expected_keys = { - "observation.state.temperature", + f"{OBS_STATE}.temperature", "observation.camera", "action.throttle", "action.vector_0", @@ -106,7 +107,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): assert set(_keys(calls)) == expected_keys # Check scalar types and values - temp_obj = _obj_for(calls, "observation.state.temperature") + temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature") assert type(temp_obj).__name__ == "DummyScalar" assert temp_obj.value == pytest.approx(25.0) From 9627765ce20ac7404898394bcd18a48b077ec82c Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Fri, 26 Sep 2025 11:53:27 +0200 Subject: [PATCH 128/158] chore(mypy): add mypy configuration and module overrides for gradual type checking (#2052) --- pyproject.toml | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index d2f1e502a..44e29043b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -267,8 +267,83 @@ default.extend-ignore-identifiers-re = [ # color = true # paths = ["src/lerobot"] +# TODO: Enable mypy gradually module by module across multiple PRs +# Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations + # [tool.mypy] # python_version = "3.10" # warn_return_any = true # warn_unused_configs = true # ignore_missing_imports = false +# strict = true +# disallow_untyped_defs = true +# disallow_incomplete_defs = true +# check_untyped_defs = true + +# [[tool.mypy.overrides]] +# module = "lerobot.utils.*" +# # include = "src/lerobot/utils/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.configs.*" +# # include = "src/lerobot/configs/**/*.py" + +# # Data processing modules +# [[tool.mypy.overrides]] +# module = "lerobot.processor.*" +# # include = "src/lerobot/processor/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.datasets.*" +# # include = "src/lerobot/datasets/**/*.py" + +# # Core machine learning modules +# [[tool.mypy.overrides]] +# module = "lerobot.optim.*" +# # include = "src/lerobot/optim/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.model.*" +# # include = "src/lerobot/model/**/*.py" + +# # Hardware interfaces +# [[tool.mypy.overrides]] +# module = "lerobot.cameras.*" +# # include = "src/lerobot/cameras/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.motors.*" +# # include = "src/lerobot/motors/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.robots.*" +# # include = "src/lerobot/robots/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.teleoperators.*" +# # include = "src/lerobot/teleoperators/**/*.py" + +# # Complex modules (enable these last) +# [[tool.mypy.overrides]] +# module = "lerobot.policies.*" +# # include = "src/lerobot/policies/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.rl.*" +# # include = "src/lerobot/rl/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.envs.*" +# # include = "src/lerobot/envs/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.async_inference.*" +# # include = "src/lerobot/async_inference/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.transport.*" +# # include = "src/lerobot/transport/**/*.py" + +# [[tool.mypy.overrides]] +# module = "lerobot.scripts.*" +# # include = "src/lerobot/scripts/**/*.py" From d2782cf66b0b22d9bfae8912ee2bbc7f63c3615e Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 26 Sep 2025 13:33:18 +0200 Subject: [PATCH 129/158] chore: replace hard-coded action values with constants throughout all the source code (#2055) * chore: replace hard-coded 'action' values with constants throughout all the source code * chore(tests): replace hard-coded action values with constants throughout all the test code --- examples/backward_compatibility/replay.py | 7 +- examples/lekiwi/evaluate.py | 4 +- examples/lekiwi/record.py | 4 +- examples/lekiwi/replay.py | 5 +- examples/phone_to_so100/replay.py | 5 +- examples/so100_to_so100_EE/replay.py | 5 +- src/lerobot/datasets/factory.py | 4 +- src/lerobot/datasets/pipeline_features.py | 2 +- src/lerobot/datasets/utils.py | 6 +- src/lerobot/envs/configs.py | 16 +- src/lerobot/policies/act/modeling_act.py | 6 +- .../policies/diffusion/modeling_diffusion.py | 10 +- .../conversion_scripts/compare_with_jax.py | 6 +- src/lerobot/policies/sac/configuration_sac.py | 2 +- src/lerobot/policies/sac/modeling_sac.py | 6 +- src/lerobot/policies/tdmpc/modeling_tdmpc.py | 2 +- src/lerobot/processor/converters.py | 8 +- .../processor/migrate_policy_normalization.py | 7 +- src/lerobot/processor/normalize_processor.py | 3 +- src/lerobot/processor/policy_robot_bridge.py | 3 +- src/lerobot/rl/buffer.py | 18 +-- src/lerobot/rl/gym_manipulator.py | 10 +- src/lerobot/rl/learner.py | 11 +- src/lerobot/robots/lekiwi/lekiwi_client.py | 4 +- src/lerobot/scripts/lerobot_dataset_viz.py | 8 +- src/lerobot/scripts/lerobot_eval.py | 8 +- src/lerobot/scripts/lerobot_record.py | 6 +- src/lerobot/scripts/lerobot_replay.py | 7 +- src/lerobot/utils/transition.py | 4 +- tests/datasets/test_dataset_utils.py | 28 ++-- tests/datasets/test_datasets.py | 18 +-- tests/datasets/test_streaming.py | 5 +- tests/fixtures/constants.py | 4 +- tests/policies/test_policies.py | 6 +- tests/policies/test_sac_config.py | 10 +- tests/policies/test_sac_policy.py | 10 +- tests/processor/test_batch_conversion.py | 28 ++-- tests/processor/test_converters.py | 14 +- tests/processor/test_device_processor.py | 4 +- tests/processor/test_migration_detection.py | 4 +- tests/processor/test_normalize_processor.py | 140 +++++++++--------- tests/processor/test_pipeline.py | 28 ++-- tests/processor/test_policy_robot_bridge.py | 15 +- tests/processor/test_rename_processor.py | 4 +- tests/processor/test_tokenizer_processor.py | 6 +- tests/transport/test_transport_utils.py | 3 +- tests/utils/test_replay_buffer.py | 10 +- 47 files changed, 269 insertions(+), 255 deletions(-) diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py index 6c680f204..6bca0570f 100644 --- a/examples/backward_compatibility/replay.py +++ b/examples/backward_compatibility/replay.py @@ -44,6 +44,7 @@ from lerobot.robots import ( # noqa: F401 so100_follower, so101_follower, ) +from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import ( init_logging, @@ -78,16 +79,16 @@ def replay(cfg: ReplayConfig): robot = make_robot_from_config(cfg.robot) dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode]) - actions = dataset.hf_dataset.select_columns("action") + actions = dataset.hf_dataset.select_columns(ACTION) robot.connect() log_say("Replaying episode", cfg.play_sounds, blocking=True) for idx in range(dataset.num_frames): start_episode_t = time.perf_counter() - action_array = actions[idx]["action"] + action_array = actions[idx][ACTION] action = {} - for i, name in enumerate(dataset.features["action"]["names"]): + for i, name in enumerate(dataset.features[ACTION]["names"]): key = f"{name.removeprefix('main_')}.pos" action[key] = action_array[i].item() diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 174486eb8..8a62d92a9 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -21,7 +21,7 @@ from lerobot.policies.factory import make_pre_post_processors from lerobot.processor import make_default_processors from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.scripts.lerobot_record import record_loop -from lerobot.utils.constants import OBS_STR +from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun @@ -42,7 +42,7 @@ robot = LeKiwiClient(robot_config) policy = ACTPolicy.from_pretrained(HF_MODEL_ID) # Configure the dataset features -action_features = hw_to_dataset_features(robot.action_features, "action") +action_features = hw_to_dataset_features(robot.action_features, ACTION) obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR) dataset_features = {**action_features, **obs_features} diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 471cb3668..9070741bf 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -22,7 +22,7 @@ from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig -from lerobot.utils.constants import OBS_STR +from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun @@ -48,7 +48,7 @@ keyboard = KeyboardTeleop(keyboard_config) teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() # Configure the dataset features -action_features = hw_to_dataset_features(robot.action_features, "action") +action_features = hw_to_dataset_features(robot.action_features, ACTION) obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR) dataset_features = {**action_features, **obs_features} diff --git a/examples/lekiwi/replay.py b/examples/lekiwi/replay.py index 0f8eabdff..3ae915286 100644 --- a/examples/lekiwi/replay.py +++ b/examples/lekiwi/replay.py @@ -19,6 +19,7 @@ import time from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient +from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import log_say @@ -34,7 +35,7 @@ robot = LeKiwiClient(robot_config) dataset = LeRobotDataset("/", episodes=[EPISODE_IDX]) # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) -actions = episode_frames.select_columns("action") +actions = episode_frames.select_columns(ACTION) # Connect to the robot robot.connect() @@ -49,7 +50,7 @@ for idx in range(len(episode_frames)): # Get recorded action from dataset action = { - name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) + name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"]) } # Send action to robot diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py index 80c65a4c2..f1181143c 100644 --- a/examples/phone_to_so100/replay.py +++ b/examples/phone_to_so100/replay.py @@ -28,6 +28,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import log_say @@ -66,7 +67,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotOb dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX]) # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) -actions = episode_frames.select_columns("action") +actions = episode_frames.select_columns(ACTION) # Connect to the robot robot.connect() @@ -81,7 +82,7 @@ for idx in range(len(episode_frames)): # Get recorded action from dataset ee_action = { - name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) + name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"]) } # Get robot observation diff --git a/examples/so100_to_so100_EE/replay.py b/examples/so100_to_so100_EE/replay.py index 6987f4839..ea78d4e66 100644 --- a/examples/so100_to_so100_EE/replay.py +++ b/examples/so100_to_so100_EE/replay.py @@ -29,6 +29,7 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import log_say @@ -67,7 +68,7 @@ robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotOb dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX]) # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX) -actions = episode_frames.select_columns("action") +actions = episode_frames.select_columns(ACTION) # Connect to the robot robot.connect() @@ -82,7 +83,7 @@ for idx in range(len(episode_frames)): # Get recorded action from dataset ee_action = { - name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) + name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"]) } # Get robot observation diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index 2bac84aed..f74b6ac4f 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -27,7 +27,7 @@ from lerobot.datasets.lerobot_dataset import ( ) from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.transforms import ImageTransforms -from lerobot.utils.constants import OBS_PREFIX +from lerobot.utils.constants import ACTION, OBS_PREFIX IMAGENET_STATS = { "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) @@ -57,7 +57,7 @@ def resolve_delta_timestamps( for key in ds_meta.features: if key == "next.reward" and cfg.reward_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices] - if key == "action" and cfg.action_delta_indices is not None: + if key == ACTION and cfg.action_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices] diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index 13555dd31..4fad7bd20 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -132,7 +132,7 @@ def aggregate_pipeline_dataset_features( # Convert the processed features into the final dataset format. dataset_features = {} if processed_features[ACTION]: - dataset_features.update(hw_to_dataset_features(processed_features["action"], ACTION, use_videos)) + dataset_features.update(hw_to_dataset_features(processed_features[ACTION], ACTION, use_videos)) if processed_features[OBS_STR]: dataset_features.update(hw_to_dataset_features(processed_features[OBS_STR], OBS_STR, use_videos)) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 96ae2eca6..35313bde5 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -43,7 +43,7 @@ from lerobot.datasets.backward_compatibility import ( BackwardCompatibilityError, ForwardCompatibilityError, ) -from lerobot.utils.constants import OBS_ENV_STATE, OBS_STR +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR from lerobot.utils.utils import is_valid_numpy_dtype_string DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk @@ -646,7 +646,7 @@ def hw_to_dataset_features( } cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} - if joint_fts and prefix == "action": + if joint_fts and prefix == ACTION: features[prefix] = { "dtype": "float32", "shape": (len(joint_fts),), @@ -733,7 +733,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea type = FeatureType.ENV elif key.startswith(OBS_STR): type = FeatureType.STATE - elif key.startswith("action"): + elif key.startswith(ACTION): type = FeatureType.ACTION else: continue diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 4456c51a5..8cbc597dc 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -53,12 +53,12 @@ class AlohaEnv(EnvConfig): render_mode: str = "rgb_array" features: dict[str, PolicyFeature] = field( default_factory=lambda: { - "action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(14,)), } ) features_map: dict[str, str] = field( default_factory=lambda: { - "action": ACTION, + ACTION: ACTION, "agent_pos": OBS_STATE, "top": f"{OBS_IMAGE}.top", "pixels/top": f"{OBS_IMAGES}.top", @@ -93,13 +93,13 @@ class PushtEnv(EnvConfig): visualization_height: int = 384 features: dict[str, PolicyFeature] = field( default_factory=lambda: { - "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), "agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)), } ) features_map: dict[str, str] = field( default_factory=lambda: { - "action": ACTION, + ACTION: ACTION, "agent_pos": OBS_STATE, "environment_state": OBS_ENV_STATE, "pixels": OBS_IMAGE, @@ -135,13 +135,13 @@ class XarmEnv(EnvConfig): visualization_height: int = 384 features: dict[str, PolicyFeature] = field( default_factory=lambda: { - "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)), "pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)), } ) features_map: dict[str, str] = field( default_factory=lambda: { - "action": ACTION, + ACTION: ACTION, "agent_pos": OBS_STATE, "pixels": OBS_IMAGE, } @@ -259,12 +259,12 @@ class LiberoEnv(EnvConfig): camera_name_mapping: dict[str, str] | None = (None,) features: dict[str, PolicyFeature] = field( default_factory=lambda: { - "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)), } ) features_map: dict[str, str] = field( default_factory=lambda: { - "action": ACTION, + ACTION: ACTION, "agent_pos": OBS_STATE, "pixels/agentview_image": f"{OBS_IMAGES}.image", "pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2", diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index f8261bb7f..e987f9070 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -394,7 +394,7 @@ class ACT(nn.Module): latent dimension. """ if self.config.use_vae and self.training: - assert "action" in batch, ( + assert ACTION in batch, ( "actions must be provided when using the variational objective in training mode." ) @@ -404,7 +404,7 @@ class ACT(nn.Module): batch_size = batch[OBS_ENV_STATE].shape[0] # Prepare the latent for input to the transformer encoder. - if self.config.use_vae and "action" in batch and self.training: + if self.config.use_vae and ACTION in batch and self.training: # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. cls_embed = einops.repeat( self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size @@ -412,7 +412,7 @@ class ACT(nn.Module): if self.config.robot_state_feature: robot_state_embed = self.vae_encoder_robot_state_input_proj(batch[OBS_STATE]) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) - action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) + action_embed = self.vae_encoder_action_input_proj(batch[ACTION]) # (B, S, D) if self.config.robot_state_feature: vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index af1327ba2..ad808d7c7 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -82,7 +82,7 @@ class DiffusionPolicy(PreTrainedPolicy): """Clear observation and action queues. Should be called on `env.reset()`""" self._queues = { OBS_STATE: deque(maxlen=self.config.n_obs_steps), - "action": deque(maxlen=self.config.n_action_steps), + ACTION: deque(maxlen=self.config.n_action_steps), } if self.config.image_features: self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps) @@ -306,10 +306,10 @@ class DiffusionModel(nn.Module): } """ # Input validation. - assert set(batch).issuperset({OBS_STATE, "action", "action_is_pad"}) + assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"}) assert OBS_IMAGES in batch or OBS_ENV_STATE in batch n_obs_steps = batch[OBS_STATE].shape[1] - horizon = batch["action"].shape[1] + horizon = batch[ACTION].shape[1] assert horizon == self.config.horizon assert n_obs_steps == self.config.n_obs_steps @@ -317,7 +317,7 @@ class DiffusionModel(nn.Module): global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim) # Forward diffusion. - trajectory = batch["action"] + trajectory = batch[ACTION] # Sample noise to add to the trajectory. eps = torch.randn(trajectory.shape, device=trajectory.device) # Sample a random noising timestep for each item in the batch. @@ -338,7 +338,7 @@ class DiffusionModel(nn.Module): if self.config.prediction_type == "epsilon": target = eps elif self.config.prediction_type == "sample": - target = batch["action"] + target = batch[ACTION] else: raise ValueError(f"Unsupported prediction type {self.config.prediction_type}") diff --git a/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py b/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py index fe9865697..dad7d002e 100644 --- a/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py +++ b/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py @@ -21,7 +21,7 @@ import torch from lerobot.configs.policies import PreTrainedConfig from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.policies.factory import make_policy -from lerobot.utils.constants import OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE def display(tensor: torch.Tensor): @@ -73,7 +73,7 @@ def main(): for cam_key, uint_chw_array in example["images"].items(): batch[f"{OBS_IMAGES}.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0 batch[OBS_STATE] = torch.from_numpy(example["state"]) - batch["action"] = torch.from_numpy(outputs["actions"]) + batch[ACTION] = torch.from_numpy(outputs["actions"]) batch["task"] = example["prompt"] if model_name == "pi0_aloha_towel": @@ -117,7 +117,7 @@ def main(): actions.append(action) actions = torch.stack(actions, dim=1) - pi_actions = batch["action"] + pi_actions = batch[ACTION] print("actions") display(actions) print() diff --git a/src/lerobot/policies/sac/configuration_sac.py b/src/lerobot/policies/sac/configuration_sac.py index a42758b85..6b5ad5b59 100644 --- a/src/lerobot/policies/sac/configuration_sac.py +++ b/src/lerobot/policies/sac/configuration_sac.py @@ -225,7 +225,7 @@ class SACConfig(PreTrainedConfig): "You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features" ) - if "action" not in self.output_features: + if ACTION not in self.output_features: raise ValueError("You must provide 'action' in the output features") @property diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index a6ed79d4e..c66044406 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -31,7 +31,7 @@ from torch.distributions import MultivariateNormal, TanhTransform, Transform, Tr from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig, is_image_feature from lerobot.policies.utils import get_device_from_parameters -from lerobot.utils.constants import OBS_ENV_STATE, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STATE DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension @@ -51,7 +51,7 @@ class SACPolicy( self.config = config # Determine action dimension and initialize all components - continuous_action_dim = config.output_features["action"].shape[0] + continuous_action_dim = config.output_features[ACTION].shape[0] self._init_encoders() self._init_critics(continuous_action_dim) self._init_actor(continuous_action_dim) @@ -158,7 +158,7 @@ class SACPolicy( The computed loss tensor """ # Extract common components from batch - actions: Tensor = batch["action"] + actions: Tensor = batch[ACTION] observations: dict[str, Tensor] = batch["state"] observation_features: Tensor = batch.get("observation_feature") diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index 4b5e8b7bd..195cf6154 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -92,7 +92,7 @@ class TDMPCPolicy(PreTrainedPolicy): """ self._queues = { OBS_STATE: deque(maxlen=1), - "action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), + ACTION: deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), } if self.config.image_features: self._queues[OBS_IMAGE] = deque(maxlen=1) diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 2e80cf4bb..68f9dd6fa 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -23,7 +23,7 @@ from typing import Any import numpy as np import torch -from lerobot.utils.constants import OBS_PREFIX +from lerobot.utils.constants import ACTION, OBS_PREFIX from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey @@ -344,7 +344,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition: if not isinstance(batch, dict): raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}") - action = batch.get("action") + action = batch.get(ACTION) if action is not None and not isinstance(action, PolicyAction): raise ValueError(f"Action should be a PolicyAction type got {type(action)}") @@ -354,7 +354,7 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition: return create_transition( observation=observation_keys if observation_keys else None, - action=batch.get("action"), + action=batch.get(ACTION), reward=batch.get("next.reward", 0.0), done=batch.get("next.done", False), truncated=batch.get("next.truncated", False), @@ -379,7 +379,7 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]: raise ValueError(f"Transition should be a EnvTransition type (dict) got {type(transition)}") batch = { - "action": transition.get(TransitionKey.ACTION), + ACTION: transition.get(TransitionKey.ACTION), "next.reward": transition.get(TransitionKey.REWARD, 0.0), "next.done": transition.get(TransitionKey.DONE, False), "next.truncated": transition.get(TransitionKey.TRUNCATED, False), diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py index 131f799d6..319145d1a 100644 --- a/src/lerobot/processor/migrate_policy_normalization.py +++ b/src/lerobot/processor/migrate_policy_normalization.py @@ -59,6 +59,7 @@ from safetensors.torch import load_file as load_safetensors from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.policies.factory import get_policy_class, make_policy_config, make_pre_post_processors +from lerobot.utils.constants import ACTION def extract_normalization_stats(state_dict: dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]: @@ -196,7 +197,7 @@ def detect_features_and_norm_modes( feature_type = FeatureType.VISUAL elif "state" in key: feature_type = FeatureType.STATE - elif "action" in key: + elif ACTION in key: feature_type = FeatureType.ACTION else: feature_type = FeatureType.STATE # Default @@ -215,7 +216,7 @@ def detect_features_and_norm_modes( feature_type = FeatureType.VISUAL elif "state" in key or "joint" in key or "position" in key: feature_type = FeatureType.STATE - elif "action" in key: + elif ACTION in key: feature_type = FeatureType.ACTION else: feature_type = FeatureType.STATE @@ -321,7 +322,7 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[ feature_type = FeatureType.VISUAL elif "state" in key: feature_type = FeatureType.STATE - elif "action" in key: + elif ACTION in key: feature_type = FeatureType.ACTION else: feature_type = FeatureType.STATE diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index bece54f0b..c4ded722f 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -26,6 +26,7 @@ from torch import Tensor from lerobot.configs.types import FeatureType, NormalizationMode, PipelineFeatureType, PolicyFeature from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import ACTION from .converters import from_tensor_to_numpy, to_tensor from .core import EnvTransition, PolicyAction, TransitionKey @@ -272,7 +273,7 @@ class _NormalizationMixin: Returns: The transformed action tensor. """ - processed_action = self._apply_transform(action, "action", FeatureType.ACTION, inverse=inverse) + processed_action = self._apply_transform(action, ACTION, FeatureType.ACTION, inverse=inverse) return processed_action def _apply_transform( diff --git a/src/lerobot/processor/policy_robot_bridge.py b/src/lerobot/processor/policy_robot_bridge.py index 74c534998..845ee065a 100644 --- a/src/lerobot/processor/policy_robot_bridge.py +++ b/src/lerobot/processor/policy_robot_bridge.py @@ -5,6 +5,7 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import ActionProcessorStep, PolicyAction, ProcessorStepRegistry, RobotAction +from lerobot.utils.constants import ACTION @dataclass @@ -23,7 +24,7 @@ class RobotActionToPolicyActionProcessorStep(ActionProcessorStep): return asdict(self) def transform_features(self, features): - features[PipelineFeatureType.ACTION]["action"] = PolicyFeature( + features[PipelineFeatureType.ACTION][ACTION] = PolicyFeature( type=FeatureType.ACTION, shape=(len(self.motor_names),) ) return features diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index fbf36de36..b572bbce5 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -24,7 +24,7 @@ import torch.nn.functional as F # noqa: N812 from tqdm import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.utils.constants import OBS_IMAGE +from lerobot.utils.constants import ACTION, OBS_IMAGE from lerobot.utils.transition import Transition @@ -467,7 +467,7 @@ class ReplayBuffer: if list_transition: first_transition = list_transition[0] first_state = {k: v.to(device) for k, v in first_transition["state"].items()} - first_action = first_transition["action"].to(device) + first_action = first_transition[ACTION].to(device) # Get complementary info if available first_complementary_info = None @@ -492,7 +492,7 @@ class ReplayBuffer: elif isinstance(v, torch.Tensor): data[k] = v.to(storage_device) - action = data["action"] + action = data[ACTION] replay_buffer.add( state=data["state"], @@ -530,8 +530,8 @@ class ReplayBuffer: # Add "action" sample_action = self.actions[0] - act_info = guess_feature_info(t=sample_action, name="action") - features["action"] = act_info + act_info = guess_feature_info(t=sample_action, name=ACTION) + features[ACTION] = act_info # Add "reward" and "done" features["next.reward"] = {"dtype": "float32", "shape": (1,)} @@ -577,7 +577,7 @@ class ReplayBuffer: frame_dict[key] = self.states[key][actual_idx].cpu() # Fill action, reward, done - frame_dict["action"] = self.actions[actual_idx].cpu() + frame_dict[ACTION] = self.actions[actual_idx].cpu() frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() frame_dict["task"] = task_name @@ -668,7 +668,7 @@ class ReplayBuffer: current_state[key] = val.unsqueeze(0) # Add batch dimension # ----- 2) Action ----- - action = current_sample["action"].unsqueeze(0) # Add batch dimension + action = current_sample[ACTION].unsqueeze(0) # Add batch dimension # ----- 3) Reward and done ----- reward = float(current_sample["next.reward"].item()) # ensure float @@ -788,8 +788,8 @@ def concatenate_batch_transitions( } # Concatenate basic fields - left_batch_transitions["action"] = torch.cat( - [left_batch_transitions["action"], right_batch_transition["action"]], dim=0 + left_batch_transitions[ACTION] = torch.cat( + [left_batch_transitions[ACTION], right_batch_transition[ACTION]], dim=0 ) left_batch_transitions["reward"] = torch.cat( [left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0 diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index 393135708..fa9f4e3e1 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -73,7 +73,7 @@ from lerobot.teleoperators import ( ) from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents -from lerobot.utils.constants import OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import log_say @@ -601,7 +601,7 @@ def control_loop( if cfg.mode == "record": action_features = teleop_device.action_features features = { - "action": action_features, + ACTION: action_features, "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, "next.done": {"dtype": "bool", "shape": (1,), "names": None}, } @@ -672,7 +672,7 @@ def control_loop( ) frame = { **observations, - "action": action_to_record.cpu(), + ACTION: action_to_record.cpu(), "next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32), "next.done": np.array([terminated or truncated], dtype=bool), } @@ -733,7 +733,7 @@ def replay_trajectory( download_videos=False, ) episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode) - actions = episode_frames.select_columns("action") + actions = episode_frames.select_columns(ACTION) _, info = env.reset() @@ -741,7 +741,7 @@ def replay_trajectory( start_time = time.perf_counter() transition = create_transition( observation=env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {}, - action=action_data["action"], + action=action_data[ACTION], ) transition = action_processor(transition) env.step(transition[TransitionKey.ACTION]) diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index 0faa460ef..b7cfdb30c 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -80,6 +80,7 @@ from lerobot.transport.utils import ( state_to_bytes, ) from lerobot.utils.constants import ( + ACTION, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK, PRETRAINED_MODEL_DIR, @@ -402,7 +403,7 @@ def add_actor_information_and_train( left_batch_transitions=batch, right_batch_transition=batch_offline ) - actions = batch["action"] + actions = batch[ACTION] rewards = batch["reward"] observations = batch["state"] next_observations = batch["next_state"] @@ -415,7 +416,7 @@ def add_actor_information_and_train( # Create a batch dictionary with all required elements for the forward method forward_batch = { - "action": actions, + ACTION: actions, "reward": rewards, "state": observations, "next_state": next_observations, @@ -460,7 +461,7 @@ def add_actor_information_and_train( left_batch_transitions=batch, right_batch_transition=batch_offline ) - actions = batch["action"] + actions = batch[ACTION] rewards = batch["reward"] observations = batch["state"] next_observations = batch["next_state"] @@ -474,7 +475,7 @@ def add_actor_information_and_train( # Create a batch dictionary with all required elements for the forward method forward_batch = { - "action": actions, + ACTION: actions, "reward": rewards, "state": observations, "next_state": next_observations, @@ -1155,7 +1156,7 @@ def process_transitions( # Skip transitions with NaN values if check_nan_in_transition( observations=transition["state"], - actions=transition["action"], + actions=transition[ACTION], next_state=transition["next_state"], ): logging.warning("[LEARNER] NaN detected in transition, skipping") diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py index 392d6d575..19744e244 100644 --- a/src/lerobot/robots/lekiwi/lekiwi_client.py +++ b/src/lerobot/robots/lekiwi/lekiwi_client.py @@ -23,7 +23,7 @@ from typing import Any import cv2 import numpy as np -from lerobot.utils.constants import OBS_STATE +from lerobot.utils.constants import ACTION, OBS_STATE from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -330,7 +330,7 @@ class LeKiwiClient(Robot): actions = np.array([action.get(k, 0.0) for k in self._state_order], dtype=np.float32) action_sent = {key: actions[i] for i, key in enumerate(self._state_order)} - action_sent["action"] = actions + action_sent[ACTION] = actions return action_sent def disconnect(self): diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 5c0d31f73..adff5c085 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -75,7 +75,7 @@ import torch.utils.data import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.utils.constants import OBS_STATE +from lerobot.utils.constants import ACTION, OBS_STATE class EpisodeSampler(torch.utils.data.Sampler): @@ -157,9 +157,9 @@ def visualize_dataset( rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) # display each dimension of action space (e.g. actuators command) - if "action" in batch: - for dim_idx, val in enumerate(batch["action"][i]): - rr.log(f"action/{dim_idx}", rr.Scalar(val.item())) + if ACTION in batch: + for dim_idx, val in enumerate(batch[ACTION][i]): + rr.log(f"{ACTION}/{dim_idx}", rr.Scalar(val.item())) # display each dimension of observed state space (e.g. agent position in joint space) if OBS_STATE in batch: diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index 310f771a9..882aeacc3 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -81,7 +81,7 @@ from lerobot.envs.utils import ( from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.processor import PolicyAction, PolicyProcessorPipeline -from lerobot.utils.constants import OBS_STR +from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( @@ -213,7 +213,7 @@ def rollout( # Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors. ret = { - "action": torch.stack(all_actions, dim=1), + ACTION: torch.stack(all_actions, dim=1), "reward": torch.stack(all_rewards, dim=1), "success": torch.stack(all_successes, dim=1), "done": torch.stack(all_dones, dim=1), @@ -440,14 +440,14 @@ def _compile_episode_data( """ ep_dicts = [] total_frames = 0 - for ep_ix in range(rollout_data["action"].shape[0]): + for ep_ix in range(rollout_data[ACTION].shape[0]): # + 2 to include the first done frame and the last observation frame. num_frames = done_indices[ep_ix].item() + 2 total_frames += num_frames # Here we do `num_frames - 1` as we don't want to include the last observation frame just yet. ep_dict = { - "action": rollout_data["action"][ep_ix, : num_frames - 1], + ACTION: rollout_data[ACTION][ep_ix, : num_frames - 1], "episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)), "frame_index": torch.arange(0, num_frames - 1, 1), "timestamp": torch.arange(0, num_frames - 1, 1) / fps, diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index f1d026a39..d097a9d2f 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -109,7 +109,7 @@ from lerobot.teleoperators import ( # noqa: F401 so101_leader, ) from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop -from lerobot.utils.constants import OBS_STR +from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.control_utils import ( init_keyboard_listener, is_headless, @@ -319,7 +319,7 @@ def record_loop( robot_type=robot.robot_type, ) - action_names = dataset.features["action"]["names"] + action_names = dataset.features[ACTION]["names"] act_processed_policy: RobotAction = { f"{name}": float(action_values[i]) for i, name in enumerate(action_names) } @@ -361,7 +361,7 @@ def record_loop( # Write to dataset if dataset is not None: - action_frame = build_dataset_frame(dataset.features, action_values, prefix="action") + action_frame = build_dataset_frame(dataset.features, action_values, prefix=ACTION) frame = {**observation_frame, **action_frame, "task": single_task} dataset.add_frame(frame) diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 6761e3f4f..b899745b6 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -60,6 +60,7 @@ from lerobot.robots import ( # noqa: F401 so100_follower, so101_follower, ) +from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import ( init_logging, @@ -99,7 +100,7 @@ def replay(cfg: ReplayConfig): # Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0 episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.episode) - actions = episode_frames.select_columns("action") + actions = episode_frames.select_columns(ACTION) robot.connect() @@ -107,9 +108,9 @@ def replay(cfg: ReplayConfig): for idx in range(len(episode_frames)): start_episode_t = time.perf_counter() - action_array = actions[idx]["action"] + action_array = actions[idx][ACTION] action = {} - for i, name in enumerate(dataset.features["action"]["names"]): + for i, name in enumerate(dataset.features[ACTION]["names"]): action[name] = action_array[i] robot_obs = robot.get_observation() diff --git a/src/lerobot/utils/transition.py b/src/lerobot/utils/transition.py index db413c388..e874bd096 100644 --- a/src/lerobot/utils/transition.py +++ b/src/lerobot/utils/transition.py @@ -18,6 +18,8 @@ from typing import TypedDict import torch +from lerobot.utils.constants import ACTION + class Transition(TypedDict): state: dict[str, torch.Tensor] @@ -39,7 +41,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr } # Move action to device - transition["action"] = transition["action"].to(device, non_blocking=non_blocking) + transition[ACTION] = transition[ACTION].to(device, non_blocking=non_blocking) # Move reward and done if they are tensors if isinstance(transition["reward"], torch.Tensor): diff --git a/tests/datasets/test_dataset_utils.py b/tests/datasets/test_dataset_utils.py index c0b07ca65..99b832e55 100644 --- a/tests/datasets/test_dataset_utils.py +++ b/tests/datasets/test_dataset_utils.py @@ -21,7 +21,7 @@ from huggingface_hub import DatasetCard from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch -from lerobot.utils.constants import OBS_IMAGES +from lerobot.utils.constants import ACTION, OBS_IMAGES def test_default_parameters(): @@ -59,14 +59,14 @@ def test_calculate_episode_data_index(): def test_merge_simple_vectors(): g1 = { - "action": { + ACTION: { "dtype": "float32", "shape": (2,), "names": ["ee.x", "ee.y"], } } g2 = { - "action": { + ACTION: { "dtype": "float32", "shape": (2,), "names": ["ee.y", "ee.z"], @@ -75,23 +75,23 @@ def test_merge_simple_vectors(): out = combine_feature_dicts(g1, g2) - assert "action" in out - assert out["action"]["dtype"] == "float32" + assert ACTION in out + assert out[ACTION]["dtype"] == "float32" # Names merged with preserved order and de-dupuplication - assert out["action"]["names"] == ["ee.x", "ee.y", "ee.z"] + assert out[ACTION]["names"] == ["ee.x", "ee.y", "ee.z"] # Shape correctly recomputed from names length - assert out["action"]["shape"] == (3,) + assert out[ACTION]["shape"] == (3,) def test_merge_multiple_groups_order_and_dedup(): - g1 = {"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}} - g2 = {"action": {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}} - g3 = {"action": {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}} + g1 = {ACTION: {"dtype": "float32", "shape": (2,), "names": ["a", "b"]}} + g2 = {ACTION: {"dtype": "float32", "shape": (2,), "names": ["b", "c"]}} + g3 = {ACTION: {"dtype": "float32", "shape": (3,), "names": ["a", "c", "d"]}} out = combine_feature_dicts(g1, g2, g3) - assert out["action"]["names"] == ["a", "b", "c", "d"] - assert out["action"]["shape"] == (4,) + assert out[ACTION]["names"] == ["a", "b", "c", "d"] + assert out[ACTION]["shape"] == (4,) def test_non_vector_last_wins_for_images(): @@ -117,8 +117,8 @@ def test_non_vector_last_wins_for_images(): def test_dtype_mismatch_raises(): - g1 = {"action": {"dtype": "float32", "shape": (1,), "names": ["a"]}} - g2 = {"action": {"dtype": "float64", "shape": (1,), "names": ["b"]}} + g1 = {ACTION: {"dtype": "float32", "shape": (1,), "names": ["a"]}} + g2 = {ACTION: {"dtype": "float64", "shape": (1,), "names": ["b"]}} with pytest.raises(ValueError, match="dtype mismatch for 'action'"): _ = combine_feature_dicts(g1, g2) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 1d461c8ba..fcfef677b 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -46,7 +46,7 @@ from lerobot.datasets.utils import ( from lerobot.envs.factory import make_env_config from lerobot.policies.factory import make_policy_config from lerobot.robots import make_robot_from_config -from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.mocks.mock_robot import MockRobotConfig from tests.utils import require_x86_64_kernel @@ -75,7 +75,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): """ # Instantiate both ways robot = make_robot_from_config(MockRobotConfig()) - action_features = hw_to_dataset_features(robot.action_features, "action", True) + action_features = hw_to_dataset_features(robot.action_features, ACTION, True) obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR, True) dataset_features = {**action_features, **obs_features} root_create = tmp_path / "create" @@ -393,7 +393,7 @@ def test_factory(env_name, repo_id, policy_name): item = dataset[0] keys_ndim_required = [ - ("action", 1, True), + (ACTION, 1, True), ("episode_index", 0, True), ("frame_index", 0, True), ("timestamp", 0, True), @@ -668,7 +668,7 @@ def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory): "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"], }, - "action": { + ACTION: { "dtype": "float32", "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"], @@ -775,7 +775,7 @@ def test_update_chunk_settings_video_dataset(tmp_path): "shape": (480, 640, 3), "names": ["height", "width", "channels"], }, - "action": {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]}, + ACTION: {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]}, } # Create video dataset @@ -842,7 +842,7 @@ def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_fact """Test episode metadata consistency across multiple episodes.""" features = { "state": {"dtype": "float32", "shape": (3,), "names": ["x", "y", "z"]}, - "action": {"dtype": "float32", "shape": (2,), "names": ["v", "w"]}, + ACTION: {"dtype": "float32", "shape": (2,), "names": ["v", "w"]}, } dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) @@ -852,7 +852,7 @@ def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_fact for episode_idx in range(num_episodes): for _ in range(frames_per_episode[episode_idx]): - dataset.add_frame({"state": torch.randn(3), "action": torch.randn(2), "task": tasks[episode_idx]}) + dataset.add_frame({"state": torch.randn(3), ACTION: torch.randn(2), "task": tasks[episode_idx]}) dataset.save_episode() # Load and validate episode metadata @@ -927,7 +927,7 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory) """Test that statistics are properly computed and stored for all features.""" features = { "state": {"dtype": "float32", "shape": (2,), "names": ["pos", "vel"]}, - "action": {"dtype": "float32", "shape": (1,), "names": ["force"]}, + ACTION: {"dtype": "float32", "shape": (1,), "names": ["force"]}, } dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) @@ -941,7 +941,7 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory) for frame_idx in range(frames_per_episode[episode_idx]): state_data = torch.tensor([frame_idx * 0.1, frame_idx * 0.2], dtype=torch.float32) action_data = torch.tensor([frame_idx * 0.05], dtype=torch.float32) - dataset.add_frame({"state": state_data, "action": action_data, "task": "stats_test"}) + dataset.add_frame({"state": state_data, ACTION: action_data, "task": "stats_test"}) dataset.save_episode() loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) diff --git a/tests/datasets/test_streaming.py b/tests/datasets/test_streaming.py index 506be3ecf..1bd4c1787 100644 --- a/tests/datasets/test_streaming.py +++ b/tests/datasets/test_streaming.py @@ -19,6 +19,7 @@ import torch from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.utils import safe_shard +from lerobot.utils.constants import ACTION from tests.fixtures.constants import DUMMY_REPO_ID @@ -234,7 +235,7 @@ def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_ delta_timestamps = { camera_key: state_deltas, "state": state_deltas, - "action": action_deltas, + ACTION: action_deltas, } ds = lerobot_dataset_factory( @@ -319,7 +320,7 @@ def test_frames_with_delta_consistency_with_shards( delta_timestamps = { camera_key: state_deltas, "state": state_deltas, - "action": action_deltas, + ACTION: action_deltas, } ds = lerobot_dataset_factory( diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index 973c5b050..35d8776ce 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing" DUMMY_REPO_ID = "dummy/repo" DUMMY_ROBOT_TYPE = "dummy_robot" DUMMY_MOTOR_FEATURES = { - "action": { + ACTION: { "dtype": "float32", "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 7752ad63f..34fa89390 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -59,7 +59,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p }, } motor_features = { - "action": { + ACTION: { "dtype": "float32", "shape": (6,), "names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], @@ -287,7 +287,7 @@ def test_multikey_construction(multikey: bool): ), } output_features = { - "action": PolicyFeature( + ACTION: PolicyFeature( type=FeatureType.ACTION, shape=(5,), ), @@ -304,7 +304,7 @@ def test_multikey_construction(multikey: bool): output_features = {} output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,)) output_features["action.last_two_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(2,)) - output_features["action"] = PolicyFeature( + output_features[ACTION] = PolicyFeature( type=FeatureType.ACTION, shape=(5,), ) diff --git a/tests/policies/test_sac_config.py b/tests/policies/test_sac_config.py index 59ed4af65..be6a8d26e 100644 --- a/tests/policies/test_sac_config.py +++ b/tests/policies/test_sac_config.py @@ -25,7 +25,7 @@ from lerobot.policies.sac.configuration_sac import ( PolicyConfig, SACConfig, ) -from lerobot.utils.constants import OBS_IMAGE, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE def test_sac_config_default_initialization(): @@ -46,7 +46,7 @@ def test_sac_config_default_initialization(): "min": [0.0, 0.0], "max": [1.0, 1.0], }, - "action": { + ACTION: { "min": [0.0, 0.0, 0.0], "max": [1.0, 1.0, 1.0], }, @@ -99,7 +99,7 @@ def test_sac_config_default_initialization(): "min": [0.0, 0.0], "max": [1.0, 1.0], }, - "action": { + ACTION: { "min": [0.0, 0.0, 0.0], "max": [1.0, 1.0, 1.0], }, @@ -193,7 +193,7 @@ def test_sac_config_custom_initialization(): def test_validate_features(): config = SACConfig( input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, - output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, ) config.validate_features() @@ -201,7 +201,7 @@ def test_validate_features(): def test_validate_features_missing_observation(): config = SACConfig( input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, - output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, ) with pytest.raises( ValueError, match="You must provide either 'observation.state' or an image observation" diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py index 71e45e055..8576883bd 100644 --- a/tests/policies/test_sac_policy.py +++ b/tests/policies/test_sac_policy.py @@ -23,7 +23,7 @@ from torch import Tensor, nn from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.modeling_sac import MLP, SACPolicy -from lerobot.utils.constants import OBS_IMAGE, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE from lerobot.utils.random_utils import seeded_context, set_seed try: @@ -105,7 +105,7 @@ def create_default_train_batch( batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 ) -> dict[str, Tensor]: return { - "action": create_dummy_action(batch_size, action_dim), + ACTION: create_dummy_action(batch_size, action_dim), "reward": torch.randn(batch_size), "state": create_dummy_state(batch_size, state_dim), "next_state": create_dummy_state(batch_size, state_dim), @@ -117,7 +117,7 @@ def create_train_batch_with_visual_input( batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 ) -> dict[str, Tensor]: return { - "action": create_dummy_action(batch_size, action_dim), + ACTION: create_dummy_action(batch_size, action_dim), "reward": torch.randn(batch_size), "state": create_dummy_with_visual_input(batch_size, state_dim), "next_state": create_dummy_with_visual_input(batch_size, state_dim), @@ -182,13 +182,13 @@ def create_default_config( config = SACConfig( input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, - output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))}, + output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))}, dataset_stats={ OBS_STATE: { "min": [0.0] * state_dim, "max": [1.0] * state_dim, }, - "action": { + ACTION: { "min": [0.0] * continuous_action_dim, "max": [1.0] * continuous_action_dim, }, diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 8bf24db02..0f7018972 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -2,7 +2,7 @@ import torch from lerobot.processor import DataProcessorPipeline, TransitionKey from lerobot.processor.converters import batch_to_transition, transition_to_batch -from lerobot.utils.constants import OBS_IMAGE, OBS_PREFIX, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_PREFIX, OBS_STATE def _dummy_batch(): @@ -11,7 +11,7 @@ def _dummy_batch(): f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128), f"{OBS_IMAGE}.right": torch.randn(1, 3, 128, 128), OBS_STATE: torch.tensor([[0.1, 0.2, 0.3, 0.4]]), - "action": torch.tensor([[0.5]]), + ACTION: torch.tensor([[0.5]]), "next.reward": 1.0, "next.done": False, "next.truncated": False, @@ -37,7 +37,7 @@ def test_observation_grouping_roundtrip(): assert torch.allclose(batch_out[OBS_STATE], batch_in[OBS_STATE]) # Check other fields - assert torch.allclose(batch_out["action"], batch_in["action"]) + assert torch.allclose(batch_out[ACTION], batch_in[ACTION]) assert batch_out["next.reward"] == batch_in["next.reward"] assert batch_out["next.done"] == batch_in["next.done"] assert batch_out["next.truncated"] == batch_in["next.truncated"] @@ -50,7 +50,7 @@ def test_batch_to_transition_observation_grouping(): f"{OBS_IMAGE}.top": torch.randn(1, 3, 128, 128), f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128), OBS_STATE: [1, 2, 3, 4], - "action": torch.tensor([0.1, 0.2, 0.3, 0.4]), + ACTION: torch.tensor([0.1, 0.2, 0.3, 0.4]), "next.reward": 1.5, "next.done": True, "next.truncated": False, @@ -114,7 +114,7 @@ def test_transition_to_batch_observation_flattening(): assert batch[OBS_STATE] == [1, 2, 3, 4] # Check other fields are mapped to next.* format - assert batch["action"] == "action_data" + assert batch[ACTION] == "action_data" assert batch["next.reward"] == 1.5 assert batch["next.done"] assert not batch["next.truncated"] @@ -124,7 +124,7 @@ def test_transition_to_batch_observation_flattening(): def test_no_observation_keys(): """Test behavior when there are no observation.* keys.""" batch = { - "action": torch.tensor([1.0, 2.0]), + ACTION: torch.tensor([1.0, 2.0]), "next.reward": 2.0, "next.done": False, "next.truncated": True, @@ -145,7 +145,7 @@ def test_no_observation_keys(): # Round trip should work reconstructed_batch = transition_to_batch(transition) - assert torch.allclose(reconstructed_batch["action"], torch.tensor([1.0, 2.0])) + assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([1.0, 2.0])) assert reconstructed_batch["next.reward"] == 2.0 assert not reconstructed_batch["next.done"] assert reconstructed_batch["next.truncated"] @@ -154,7 +154,7 @@ def test_no_observation_keys(): def test_minimal_batch(): """Test with minimal batch containing only observation.* and action.""" - batch = {OBS_STATE: "minimal_state", "action": torch.tensor([0.5])} + batch = {OBS_STATE: "minimal_state", ACTION: torch.tensor([0.5])} transition = batch_to_transition(batch) @@ -172,7 +172,7 @@ def test_minimal_batch(): # Round trip reconstructed_batch = transition_to_batch(transition) assert reconstructed_batch[OBS_STATE] == "minimal_state" - assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5])) + assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([0.5])) assert reconstructed_batch["next.reward"] == 0.0 assert not reconstructed_batch["next.done"] assert not reconstructed_batch["next.truncated"] @@ -196,7 +196,7 @@ def test_empty_batch(): # Round trip reconstructed_batch = transition_to_batch(transition) - assert reconstructed_batch["action"] is None + assert reconstructed_batch[ACTION] is None assert reconstructed_batch["next.reward"] == 0.0 assert not reconstructed_batch["next.done"] assert not reconstructed_batch["next.truncated"] @@ -209,7 +209,7 @@ def test_complex_nested_observation(): f"{OBS_IMAGE}.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890}, f"{OBS_IMAGE}.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, OBS_STATE: torch.randn(7), - "action": torch.randn(8), + ACTION: torch.randn(8), "next.reward": 3.14, "next.done": False, "next.truncated": True, @@ -237,7 +237,7 @@ def test_complex_nested_observation(): ) # Check action tensor - assert torch.allclose(batch["action"], reconstructed_batch["action"]) + assert torch.allclose(batch[ACTION], reconstructed_batch[ACTION]) # Check other fields assert batch["next.reward"] == reconstructed_batch["next.reward"] @@ -266,7 +266,7 @@ def test_custom_converter(): batch = { OBS_STATE: torch.randn(1, 4), - "action": torch.randn(1, 2), + ACTION: torch.randn(1, 2), "next.reward": 1.0, "next.done": False, } @@ -276,4 +276,4 @@ def test_custom_converter(): # Check the reward was doubled by our custom converter assert result["next.reward"] == 2.0 assert torch.allclose(result[OBS_STATE], batch[OBS_STATE]) - assert torch.allclose(result["action"], batch["action"]) + assert torch.allclose(result[ACTION], batch[ACTION]) diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py index b03d49214..d347858dc 100644 --- a/tests/processor/test_converters.py +++ b/tests/processor/test_converters.py @@ -9,7 +9,7 @@ from lerobot.processor.converters import ( to_tensor, transition_to_batch, ) -from lerobot.utils.constants import OBS_STATE, OBS_STR +from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR # Tests for the unified to_tensor function @@ -118,16 +118,16 @@ def test_to_tensor_dictionaries(): # Nested dictionary nested = { - "action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]}, + ACTION: {"mean": [0.1, 0.2], "std": [1.0, 2.0]}, OBS_STR: {"mean": np.array([0.5, 0.6]), "count": 10}, } result = to_tensor(nested) assert isinstance(result, dict) - assert isinstance(result["action"], dict) + assert isinstance(result[ACTION], dict) assert isinstance(result[OBS_STR], dict) - assert isinstance(result["action"]["mean"], torch.Tensor) + assert isinstance(result[ACTION]["mean"], torch.Tensor) assert isinstance(result[OBS_STR]["mean"], torch.Tensor) - assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2])) + assert torch.allclose(result[ACTION]["mean"], torch.tensor([0.1, 0.2])) assert torch.allclose(result[OBS_STR]["mean"], torch.tensor([0.5, 0.6])) @@ -200,7 +200,7 @@ def test_batch_to_transition_with_index_fields(): # Create batch with index and task_index fields batch = { OBS_STATE: torch.randn(1, 7), - "action": torch.randn(1, 4), + ACTION: torch.randn(1, 4), "next.reward": 1.5, "next.done": False, "task": ["pick_cube"], @@ -262,7 +262,7 @@ def test_batch_to_transition_without_index_fields(): # Batch without index/task_index batch = { OBS_STATE: torch.randn(1, 7), - "action": torch.randn(1, 4), + ACTION: torch.randn(1, 4), "task": ["pick_cube"], } diff --git a/tests/processor/test_device_processor.py b/tests/processor/test_device_processor.py index 36081e021..bb7d467bf 100644 --- a/tests/processor/test_device_processor.py +++ b/tests/processor/test_device_processor.py @@ -21,7 +21,7 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import DataProcessorPipeline, DeviceProcessorStep, TransitionKey from lerobot.processor.converters import create_transition, identity_transition -from lerobot.utils.constants import OBS_IMAGE, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE def test_basic_functionality(): @@ -273,7 +273,7 @@ def test_features(): features = { PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, - PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, + PipelineFeatureType.ACTION: {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, } result = processor.transform_features(features) diff --git a/tests/processor/test_migration_detection.py b/tests/processor/test_migration_detection.py index b46cc6bdd..1ddc87d1e 100644 --- a/tests/processor/test_migration_detection.py +++ b/tests/processor/test_migration_detection.py @@ -25,7 +25,7 @@ from pathlib import Path import pytest from lerobot.processor.pipeline import DataProcessorPipeline, ProcessorMigrationError -from lerobot.utils.constants import OBS_STATE +from lerobot.utils.constants import ACTION, OBS_STATE def test_is_processor_config_valid_configs(): @@ -113,7 +113,7 @@ def test_should_suggest_migration_with_model_config_only(): model_config = { "type": "act", "input_features": {OBS_STATE: {"shape": [7]}}, - "output_features": {"action": {"shape": [7]}}, + "output_features": {ACTION: {"shape": [7]}}, "hidden_dim": 256, "n_obs_steps": 1, "n_action_steps": 1, diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 616f33db9..98c9e0b23 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -29,7 +29,7 @@ from lerobot.processor import ( hotswap_stats, ) from lerobot.processor.converters import create_transition, identity_transition, to_tensor -from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE, OBS_STR from lerobot.utils.utils import auto_select_torch_device @@ -50,15 +50,15 @@ def test_numpy_conversion(): def test_tensor_conversion(): stats = { - "action": { + ACTION: { "mean": torch.tensor([0.0, 0.0]), "std": torch.tensor([1.0, 1.0]), } } tensor_stats = to_tensor(stats) - assert tensor_stats["action"]["mean"].dtype == torch.float32 - assert tensor_stats["action"]["std"].dtype == torch.float32 + assert tensor_stats[ACTION]["mean"].dtype == torch.float32 + assert tensor_stats[ACTION]["std"].dtype == torch.float32 def test_scalar_conversion(): @@ -212,12 +212,12 @@ def test_from_lerobot_dataset(): mock_dataset = Mock() mock_dataset.meta.stats = { OBS_IMAGE: {"mean": [0.5], "std": [0.2]}, - "action": {"mean": [0.0], "std": [1.0]}, + ACTION: {"mean": [0.0], "std": [1.0]}, } features = { OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "action": PolicyFeature(FeatureType.ACTION, (1,)), + ACTION: PolicyFeature(FeatureType.ACTION, (1,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -228,7 +228,7 @@ def test_from_lerobot_dataset(): # Both observation and action statistics should be present in tensor stats assert OBS_IMAGE in normalizer._tensor_stats - assert "action" in normalizer._tensor_stats + assert ACTION in normalizer._tensor_stats def test_state_dict_save_load(observation_normalizer): @@ -271,7 +271,7 @@ def action_stats_min_max(): def _create_action_features(): return { - "action": PolicyFeature(FeatureType.ACTION, (3,)), + ACTION: PolicyFeature(FeatureType.ACTION, (3,)), } @@ -291,7 +291,7 @@ def test_mean_std_unnormalization(action_stats_mean_std): features = _create_action_features() norm_map = _create_action_norm_map_mean_std() unnormalizer = UnnormalizerProcessorStep( - features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + features=features, norm_map=norm_map, stats={ACTION: action_stats_mean_std} ) normalized_action = torch.tensor([1.0, -0.5, 2.0]) @@ -309,7 +309,7 @@ def test_min_max_unnormalization(action_stats_min_max): features = _create_action_features() norm_map = _create_action_norm_map_min_max() unnormalizer = UnnormalizerProcessorStep( - features=features, norm_map=norm_map, stats={"action": action_stats_min_max} + features=features, norm_map=norm_map, stats={ACTION: action_stats_min_max} ) # Actions in [-1, 1] @@ -335,7 +335,7 @@ def test_tensor_action_input(action_stats_mean_std): features = _create_action_features() norm_map = _create_action_norm_map_mean_std() unnormalizer = UnnormalizerProcessorStep( - features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + features=features, norm_map=norm_map, stats={ACTION: action_stats_mean_std} ) normalized_action = torch.tensor([1.0, -0.5, 2.0], dtype=torch.float32) @@ -353,7 +353,7 @@ def test_none_action(action_stats_mean_std): features = _create_action_features() norm_map = _create_action_norm_map_mean_std() unnormalizer = UnnormalizerProcessorStep( - features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + features=features, norm_map=norm_map, stats={ACTION: action_stats_mean_std} ) transition = create_transition() @@ -365,11 +365,11 @@ def test_none_action(action_stats_mean_std): def test_action_from_lerobot_dataset(): mock_dataset = Mock() - mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}} - features = {"action": PolicyFeature(FeatureType.ACTION, (1,))} + mock_dataset.meta.stats = {ACTION: {"mean": [0.0], "std": [1.0]}} + features = {ACTION: PolicyFeature(FeatureType.ACTION, (1,))} norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD} unnormalizer = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map) - assert "mean" in unnormalizer._tensor_stats["action"] + assert "mean" in unnormalizer._tensor_stats[ACTION] # Fixtures for NormalizerProcessorStep tests @@ -384,7 +384,7 @@ def full_stats(): "min": np.array([0.0, -1.0]), "max": np.array([1.0, 1.0]), }, - "action": { + ACTION: { "mean": np.array([0.0, 0.0]), "std": np.array([1.0, 2.0]), }, @@ -395,7 +395,7 @@ def _create_full_features(): return { OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), + ACTION: PolicyFeature(FeatureType.ACTION, (2,)), } @@ -461,7 +461,7 @@ def test_processor_from_lerobot_dataset(full_stats): assert processor.normalize_observation_keys == {OBS_IMAGE} assert OBS_IMAGE in processor._tensor_stats - assert "action" in processor._tensor_stats + assert ACTION in processor._tensor_stats def test_get_config(full_stats): @@ -482,7 +482,7 @@ def test_get_config(full_stats): "features": { OBS_IMAGE: {"type": "VISUAL", "shape": (3, 96, 96)}, OBS_STATE: {"type": "STATE", "shape": (2,)}, - "action": {"type": "ACTION", "shape": (2,)}, + ACTION: {"type": "ACTION", "shape": (2,)}, }, "norm_map": { "VISUAL": "MEAN_STD", @@ -568,7 +568,7 @@ def test_missing_action_stats_no_error(): processor = UnnormalizerProcessorStep.from_lerobot_dataset(mock_dataset, features, norm_map) # The tensor stats should not contain the 'action' key - assert "action" not in processor._tensor_stats + assert ACTION not in processor._tensor_stats def test_serialization_roundtrip(full_stats): @@ -676,9 +676,9 @@ def test_identity_normalization_observations(): def test_identity_normalization_actions(): """Test that IDENTITY mode skips normalization for actions.""" - features = {"action": PolicyFeature(FeatureType.ACTION, (2,))} + features = {ACTION: PolicyFeature(FeatureType.ACTION, (2,))} norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY} - stats = {"action": {"mean": [0.0, 0.0], "std": [1.0, 2.0]}} + stats = {ACTION: {"mean": [0.0, 0.0], "std": [1.0, 2.0]}} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) @@ -729,9 +729,9 @@ def test_identity_unnormalization_observations(): def test_identity_unnormalization_actions(): """Test that IDENTITY mode skips unnormalization for actions.""" - features = {"action": PolicyFeature(FeatureType.ACTION, (2,))} + features = {ACTION: PolicyFeature(FeatureType.ACTION, (2,))} norm_map = {FeatureType.ACTION: NormalizationMode.IDENTITY} - stats = {"action": {"min": [-1.0, -2.0], "max": [1.0, 2.0]}} + stats = {ACTION: {"min": [-1.0, -2.0], "max": [1.0, 2.0]}} unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) @@ -748,7 +748,7 @@ def test_identity_with_missing_stats(): """Test that IDENTITY mode works even when stats are missing.""" features = { OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), + ACTION: PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.IDENTITY, @@ -784,7 +784,7 @@ def test_identity_mixed_with_other_modes(): features = { OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), + ACTION: PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.IDENTITY, @@ -794,7 +794,7 @@ def test_identity_mixed_with_other_modes(): stats = { OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, # Will be ignored OBS_STATE: {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, - "action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, + ACTION: {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) @@ -862,7 +862,7 @@ def test_identity_roundtrip(): """Test that IDENTITY normalization and unnormalization are true inverses.""" features = { OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), + ACTION: PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.IDENTITY, @@ -870,7 +870,7 @@ def test_identity_roundtrip(): } stats = { OBS_IMAGE: {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, - "action": {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, + ACTION: {"min": [-1.0, -1.0], "max": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) @@ -893,7 +893,7 @@ def test_identity_config_serialization(): """Test that IDENTITY mode is properly saved and loaded in config.""" features = { OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), + ACTION: PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.IDENTITY, @@ -901,7 +901,7 @@ def test_identity_config_serialization(): } stats = { OBS_IMAGE: {"mean": [0.5], "std": [0.2]}, - "action": {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, + ACTION: {"mean": [0.0, 0.0], "std": [1.0, 1.0]}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) @@ -969,19 +969,19 @@ def test_hotswap_stats_basic_functionality(): # Create initial stats initial_stats = { OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, - "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + ACTION: {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } # Create new stats for hotswapping new_stats = { OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, - "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, + ACTION: {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, } # Create features and norm_map features = { OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), - "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -1177,17 +1177,17 @@ def test_hotswap_stats_multiple_normalizer_types(): """Test hotswap_stats with multiple normalizer and unnormalizer steps.""" initial_stats = { OBS_IMAGE: {"mean": np.array([0.5]), "std": np.array([0.2])}, - "action": {"min": np.array([-1.0]), "max": np.array([1.0])}, + ACTION: {"min": np.array([-1.0]), "max": np.array([1.0])}, } new_stats = { OBS_IMAGE: {"mean": np.array([0.3]), "std": np.array([0.1])}, - "action": {"min": np.array([-2.0]), "max": np.array([2.0])}, + ACTION: {"min": np.array([-2.0]), "max": np.array([2.0])}, } features = { OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), - "action": PolicyFeature(type=FeatureType.ACTION, shape=(1,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(1,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -1232,7 +1232,7 @@ def test_hotswap_stats_with_different_data_types(): "min": 0, # int "max": 1.0, # float }, - "action": { + ACTION: { "mean": np.array([0.1, 0.2]), # numpy array "std": torch.tensor([0.5, 0.6]), # torch tensor }, @@ -1240,7 +1240,7 @@ def test_hotswap_stats_with_different_data_types(): features = { OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), - "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -1262,8 +1262,8 @@ def test_hotswap_stats_with_different_data_types(): assert isinstance(tensor_stats[OBS_IMAGE]["std"], torch.Tensor) assert isinstance(tensor_stats[OBS_IMAGE]["min"], torch.Tensor) assert isinstance(tensor_stats[OBS_IMAGE]["max"], torch.Tensor) - assert isinstance(tensor_stats["action"]["mean"], torch.Tensor) - assert isinstance(tensor_stats["action"]["std"], torch.Tensor) + assert isinstance(tensor_stats[ACTION]["mean"], torch.Tensor) + assert isinstance(tensor_stats[ACTION]["std"], torch.Tensor) # Check values torch.testing.assert_close(tensor_stats[OBS_IMAGE]["mean"], torch.tensor([0.3, 0.4, 0.5])) @@ -1284,18 +1284,18 @@ def test_hotswap_stats_functional_test(): # Initial stats initial_stats = { OBS_IMAGE: {"mean": np.array([0.5, 0.4]), "std": np.array([0.2, 0.3])}, - "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + ACTION: {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } # New stats new_stats = { OBS_IMAGE: {"mean": np.array([0.3, 0.2]), "std": np.array([0.1, 0.2])}, - "action": {"mean": np.array([0.1, -0.1]), "std": np.array([0.5, 0.5])}, + ACTION: {"mean": np.array([0.1, -0.1]), "std": np.array([0.5, 0.5])}, } features = { OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(2, 2, 2)), - "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -1324,18 +1324,18 @@ def test_hotswap_stats_functional_test(): rtol=1e-3, atol=1e-3, ) - assert not torch.allclose(original_result["action"], new_result["action"], rtol=1e-3, atol=1e-3) + assert not torch.allclose(original_result[ACTION], new_result[ACTION], rtol=1e-3, atol=1e-3) # Verify that the new processor is actually using the new stats by checking internal state assert new_processor.steps[0].stats == new_stats assert torch.allclose(new_processor.steps[0]._tensor_stats[OBS_IMAGE]["mean"], torch.tensor([0.3, 0.2])) assert torch.allclose(new_processor.steps[0]._tensor_stats[OBS_IMAGE]["std"], torch.tensor([0.1, 0.2])) - assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["mean"], torch.tensor([0.1, -0.1])) - assert torch.allclose(new_processor.steps[0]._tensor_stats["action"]["std"], torch.tensor([0.5, 0.5])) + assert torch.allclose(new_processor.steps[0]._tensor_stats[ACTION]["mean"], torch.tensor([0.1, -0.1])) + assert torch.allclose(new_processor.steps[0]._tensor_stats[ACTION]["std"], torch.tensor([0.5, 0.5])) # Test that normalization actually happens (output should not equal input) assert not torch.allclose(new_result[OBS_STR][OBS_IMAGE], observation[OBS_IMAGE]) - assert not torch.allclose(new_result["action"], action) + assert not torch.allclose(new_result[ACTION], action) def test_zero_std_uses_eps(): @@ -1366,10 +1366,10 @@ def test_action_normalized_despite_normalize_observation_keys(): """Action normalization is independent of normalize_observation_keys filter for observations.""" features = { OBS_STATE: PolicyFeature(FeatureType.STATE, (1,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), + ACTION: PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = {FeatureType.STATE: NormalizationMode.IDENTITY, FeatureType.ACTION: NormalizationMode.MEAN_STD} - stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} + stats = {ACTION: {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} normalizer = NormalizerProcessorStep( features=features, norm_map=norm_map, stats=stats, normalize_observation_keys={OBS_STATE} ) @@ -1426,9 +1426,9 @@ def test_unknown_observation_keys_ignored(): def test_batched_action_normalization(): - features = {"action": PolicyFeature(FeatureType.ACTION, (2,))} + features = {ACTION: PolicyFeature(FeatureType.ACTION, (2,))} norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD} - stats = {"action": {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} + stats = {ACTION: {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}} normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) actions = torch.tensor([[1.0, -1.0], [3.0, 3.0]]) # first equals mean → zeros; second → [1, 1] @@ -1453,12 +1453,12 @@ def test_complementary_data_preservation(): def test_roundtrip_normalize_unnormalize_non_identity(): features = { OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), + ACTION: PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MIN_MAX} stats = { OBS_STATE: {"mean": np.array([1.0, -1.0]), "std": np.array([2.0, 4.0])}, - "action": {"min": np.array([-2.0, 0.0]), "max": np.array([2.0, 4.0])}, + ACTION: {"min": np.array([-2.0, 0.0]), "max": np.array([2.0, 4.0])}, } normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) @@ -1530,18 +1530,18 @@ def test_stats_override_preservation_in_load_state_dict(): # Create original stats original_stats = { OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, - "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + ACTION: {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } # Create override stats (what user wants to use) override_stats = { OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, - "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, + ACTION: {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, } features = { OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), - "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -1601,12 +1601,12 @@ def test_stats_without_override_loads_normally(): """ original_stats = { OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, - "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + ACTION: {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } features = { OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), - "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -1674,7 +1674,7 @@ def test_pipeline_from_pretrained_with_stats_overrides(): # Create test data features = { OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 32, 32)), - "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -1683,12 +1683,12 @@ def test_pipeline_from_pretrained_with_stats_overrides(): original_stats = { OBS_IMAGE: {"mean": np.array([0.5, 0.5, 0.5]), "std": np.array([0.2, 0.2, 0.2])}, - "action": {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, + ACTION: {"mean": np.array([0.0, 0.0]), "std": np.array([1.0, 1.0])}, } override_stats = { OBS_IMAGE: {"mean": np.array([0.3, 0.3, 0.3]), "std": np.array([0.1, 0.1, 0.1])}, - "action": {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, + ACTION: {"mean": np.array([0.1, 0.1]), "std": np.array([0.5, 0.5])}, } # Create and save a pipeline with the original stats @@ -1751,8 +1751,8 @@ def test_pipeline_from_pretrained_with_stats_overrides(): # The critical part was verified above: loaded_normalizer.stats == override_stats # This confirms that override stats are preserved during load_state_dict. # Let's just verify the pipeline processes data successfully. - assert "action" in override_result - assert isinstance(override_result["action"], torch.Tensor) + assert ACTION in override_result + assert isinstance(override_result[ACTION], torch.Tensor) def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32(): @@ -1812,7 +1812,7 @@ def test_stats_reconstruction_after_load_state_dict(): features = { OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), OBS_STATE: PolicyFeature(FeatureType.STATE, (2,)), - "action": PolicyFeature(FeatureType.ACTION, (2,)), + ACTION: PolicyFeature(FeatureType.ACTION, (2,)), } norm_map = { FeatureType.VISUAL: NormalizationMode.MEAN_STD, @@ -1828,7 +1828,7 @@ def test_stats_reconstruction_after_load_state_dict(): "min": np.array([0.0, -1.0]), "max": np.array([1.0, 1.0]), }, - "action": { + ACTION: { "mean": np.array([0.0, 0.0]), "std": np.array([1.0, 2.0]), }, @@ -1852,15 +1852,15 @@ def test_stats_reconstruction_after_load_state_dict(): # Check that all expected keys are present assert OBS_IMAGE in new_normalizer.stats assert OBS_STATE in new_normalizer.stats - assert "action" in new_normalizer.stats + assert ACTION in new_normalizer.stats # Check that values are correct (converted back from tensors) np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["mean"], [0.5, 0.5, 0.5]) np.testing.assert_allclose(new_normalizer.stats[OBS_IMAGE]["std"], [0.2, 0.2, 0.2]) np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["min"], [0.0, -1.0]) np.testing.assert_allclose(new_normalizer.stats[OBS_STATE]["max"], [1.0, 1.0]) - np.testing.assert_allclose(new_normalizer.stats["action"]["mean"], [0.0, 0.0]) - np.testing.assert_allclose(new_normalizer.stats["action"]["std"], [1.0, 2.0]) + np.testing.assert_allclose(new_normalizer.stats[ACTION]["mean"], [0.0, 0.0]) + np.testing.assert_allclose(new_normalizer.stats[ACTION]["std"], [1.0, 2.0]) # Test that methods that depend on self.stats work correctly after loading # This would fail before the bug fix because self.stats was empty @@ -1876,7 +1876,7 @@ def test_stats_reconstruction_after_load_state_dict(): new_stats = { OBS_IMAGE: {"mean": [0.3, 0.3, 0.3], "std": [0.1, 0.1, 0.1]}, OBS_STATE: {"min": [-1.0, -2.0], "max": [2.0, 2.0]}, - "action": {"mean": [0.1, 0.1], "std": [0.5, 0.5]}, + ACTION: {"mean": [0.1, 0.1], "std": [0.5, 0.5]}, } pipeline = DataProcessorPipeline([new_normalizer]) diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 6d056e4dc..6dbf37450 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -35,7 +35,7 @@ from lerobot.processor import ( TransitionKey, ) from lerobot.processor.converters import create_transition, identity_transition -from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_IMAGES, OBS_STATE from tests.conftest import assert_contract_is_typed @@ -257,7 +257,7 @@ def test_step_through_with_dict(): batch = { OBS_IMAGE: None, - "action": None, + ACTION: None, "next.reward": 0.0, "next.done": False, "next.truncated": False, @@ -1842,7 +1842,7 @@ def test_save_load_with_custom_converter_functions(): # Verify it uses default converters by checking with standard batch format batch = { OBS_IMAGE: torch.randn(1, 3, 32, 32), - "action": torch.randn(1, 7), + ACTION: torch.randn(1, 7), "next.reward": torch.tensor([1.0]), "next.done": torch.tensor([False]), "next.truncated": torch.tensor([False]), @@ -2094,11 +2094,11 @@ def test_aggregate_joint_action_only(): patterns=["action.j1.pos", "action.j2.pos"], ) - # Expect only "action" with joint names - assert "action" in out and OBS_STATE not in out - assert out["action"]["dtype"] == "float32" - assert set(out["action"]["names"]) == {"j1.pos", "j2.pos"} - assert out["action"]["shape"] == (len(out["action"]["names"]),) + # Expect only ACTION with joint names + assert ACTION in out and OBS_STATE not in out + assert out[ACTION]["dtype"] == "float32" + assert set(out[ACTION]["names"]) == {"j1.pos", "j2.pos"} + assert out[ACTION]["shape"] == (len(out[ACTION]["names"]),) def test_aggregate_ee_action_and_observation_with_videos(): @@ -2113,9 +2113,9 @@ def test_aggregate_ee_action_and_observation_with_videos(): ) # Action should pack only EE names - assert "action" in out - assert set(out["action"]["names"]) == {"ee.x", "ee.y"} - assert out["action"]["dtype"] == "float32" + assert ACTION in out + assert set(out[ACTION]["names"]) == {"ee.x", "ee.y"} + assert out[ACTION]["dtype"] == "float32" # Observation state should pack both ee.x and j1.pos as a vector assert OBS_STATE in out @@ -2140,10 +2140,10 @@ def test_aggregate_both_action_types(): patterns=["action.ee", "action.j1", "action.j2.pos"], ) - assert "action" in out + assert ACTION in out expected = {"ee.x", "ee.y", "j1.pos", "j2.pos"} - assert set(out["action"]["names"]) == expected - assert out["action"]["shape"] == (len(expected),) + assert set(out[ACTION]["names"]) == expected + assert out[ACTION]["shape"] == (len(expected),) def test_aggregate_images_when_use_videos_false(): diff --git a/tests/processor/test_policy_robot_bridge.py b/tests/processor/test_policy_robot_bridge.py index f3bbd9a74..6269c508f 100644 --- a/tests/processor/test_policy_robot_bridge.py +++ b/tests/processor/test_policy_robot_bridge.py @@ -28,6 +28,7 @@ from lerobot.processor import ( RobotActionToPolicyActionProcessorStep, ) from lerobot.processor.converters import identity_transition +from lerobot.utils.constants import ACTION from tests.conftest import assert_contract_is_typed @@ -134,8 +135,8 @@ def test_robot_to_policy_transform_features(): transformed = processor.transform_features(features) - assert "action" in transformed[PipelineFeatureType.ACTION] - action_feature = transformed[PipelineFeatureType.ACTION]["action"] + assert ACTION in transformed[PipelineFeatureType.ACTION] + action_feature = transformed[PipelineFeatureType.ACTION][ACTION] assert action_feature.type == FeatureType.ACTION assert action_feature.shape == (3,) @@ -251,7 +252,7 @@ def test_policy_to_robot_transform_features(): features = { PipelineFeatureType.ACTION: { - "action": {"type": FeatureType.ACTION, "shape": (2,)}, + ACTION: {"type": FeatureType.ACTION, "shape": (2,)}, "other_data": {"type": FeatureType.ENV, "shape": (1,)}, } } @@ -266,7 +267,7 @@ def test_policy_to_robot_transform_features(): assert motor_feature.type == FeatureType.ACTION assert motor_feature.shape == (1,) - assert "action" in transformed[PipelineFeatureType.ACTION] + assert ACTION in transformed[PipelineFeatureType.ACTION] assert "other_data" in transformed[PipelineFeatureType.ACTION] @@ -447,8 +448,8 @@ def test_robot_to_policy_features_contract(policy_feature_factory): assert_contract_is_typed(out) - assert "action" in out[PipelineFeatureType.ACTION] - action_feature = out[PipelineFeatureType.ACTION]["action"] + assert ACTION in out[PipelineFeatureType.ACTION] + action_feature = out[PipelineFeatureType.ACTION][ACTION] assert action_feature.type == FeatureType.ACTION assert action_feature.shape == (2,) @@ -458,7 +459,7 @@ def test_policy_to_robot_features_contract(policy_feature_factory): processor = PolicyActionToRobotActionProcessorStep(motor_names=["m1", "m2", "m3"]) features = { PipelineFeatureType.ACTION: { - "action": policy_feature_factory(FeatureType.ACTION, (3,)), + ACTION: policy_feature_factory(FeatureType.ACTION, (3,)), "other": policy_feature_factory(FeatureType.ENV, (1,)), } } diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py index c6aa303f1..efb9f9328 100644 --- a/tests/processor/test_rename_processor.py +++ b/tests/processor/test_rename_processor.py @@ -28,7 +28,7 @@ from lerobot.processor import ( ) from lerobot.processor.converters import create_transition, identity_transition from lerobot.processor.rename_processor import rename_stats -from lerobot.utils.constants import OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_IMAGES, OBS_STATE from tests.conftest import assert_contract_is_typed @@ -488,7 +488,7 @@ def test_features_chained_processors(policy_feature_factory): def test_rename_stats_basic(): orig = { OBS_STATE: {"mean": np.array([0.0]), "std": np.array([1.0])}, - "action": {"mean": np.array([0.0])}, + ACTION: {"mean": np.array([0.0])}, } mapping = {OBS_STATE: "observation.robot_state"} renamed = rename_stats(orig, mapping) diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index 35bbcfd8a..503f2e036 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -11,7 +11,7 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey from lerobot.processor.converters import create_transition, identity_transition -from lerobot.utils.constants import OBS_IMAGE, OBS_LANGUAGE, OBS_STATE +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_LANGUAGE, OBS_STATE from tests.utils import require_package @@ -504,14 +504,14 @@ def test_features_basic(): input_features = { PipelineFeatureType.OBSERVATION: {OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))}, - PipelineFeatureType.ACTION: {"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, + PipelineFeatureType.ACTION: {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(5,))}, } output_features = processor.transform_features(input_features) # Check that original features are preserved assert OBS_STATE in output_features[PipelineFeatureType.OBSERVATION] - assert "action" in output_features[PipelineFeatureType.ACTION] + assert ACTION in output_features[PipelineFeatureType.ACTION] # Check that tokenized features are added assert f"{OBS_LANGUAGE}.tokens" in output_features[PipelineFeatureType.OBSERVATION] diff --git a/tests/transport/test_transport_utils.py b/tests/transport/test_transport_utils.py index 79edad4e4..52825a24e 100644 --- a/tests/transport/test_transport_utils.py +++ b/tests/transport/test_transport_utils.py @@ -21,6 +21,7 @@ from pickle import UnpicklingError import pytest import torch +from lerobot.utils.constants import ACTION from lerobot.utils.transition import Transition from tests.utils import require_cuda, require_package @@ -512,7 +513,7 @@ def test_transitions_to_bytes_single_transition(): def assert_transitions_equal(t1: Transition, t2: Transition): """Helper to assert two transitions are equal.""" assert_observation_equal(t1["state"], t2["state"]) - assert torch.allclose(t1["action"], t2["action"]) + assert torch.allclose(t1[ACTION], t2[ACTION]) assert torch.allclose(t1["reward"], t2["reward"]) assert torch.equal(t1["done"], t2["done"]) assert_observation_equal(t1["next_state"], t2["next_state"]) diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index 6820d321f..1e6c0df95 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -22,7 +22,7 @@ import torch from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized -from lerobot.utils.constants import OBS_IMAGE, OBS_STATE, OBS_STR +from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE, OBS_STR from tests.fixtures.constants import DUMMY_REPO_ID @@ -63,7 +63,7 @@ def create_random_image() -> torch.Tensor: def create_dummy_transition() -> dict: return { OBS_IMAGE: create_random_image(), - "action": torch.randn(4), + ACTION: torch.randn(4), "reward": torch.tensor(1.0), OBS_STATE: torch.randn( 10, @@ -341,7 +341,7 @@ def test_sample_batch(replay_buffer): f"{k} should be equal to one of the dummy states." ) - for got_action_item in got_batch_transition["action"]: + for got_action_item in got_batch_transition[ACTION]: assert any(torch.equal(got_action_item, dummy_action) for dummy_action in dummy_actions), ( "Actions should be equal to the dummy actions." ) @@ -378,7 +378,7 @@ def test_to_lerobot_dataset(tmp_path): for i in range(len(ds)): for feature, value in ds[i].items(): - if feature == "action": + if feature == ACTION: assert torch.equal(value, buffer.actions[i]) elif feature == "next.reward": assert torch.equal(value, buffer.rewards[i]) @@ -495,7 +495,7 @@ def test_buffer_sample_alignment(): for i in range(50): state_sig = batch["state"]["state_value"][i].item() - action_val = batch["action"][i].item() + action_val = batch[ACTION][i].item() reward_val = batch["reward"][i].item() next_state_sig = batch["next_state"]["state_value"][i].item() is_done = batch["done"][i].item() > 0.5 From ec40ccde0d892d6ada167f371e93eef997b2e669 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Fri, 26 Sep 2025 14:28:58 +0200 Subject: [PATCH 130/158] Bug in conversion from v2.1 script (#2057) * False logic in setting the dataset to index in the meta data when converting from v2.1' * Improved logging --- .../datasets/v30/convert_dataset_v21_to_v30.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py index e5a6e3c9a..ac9d41cf7 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -34,6 +34,7 @@ python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \ """ import argparse +import logging import shutil from pathlib import Path from typing import Any @@ -71,6 +72,7 @@ from lerobot.datasets.utils import ( ) from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.utils import init_logging V21 = "v2.1" @@ -144,6 +146,7 @@ def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]: def convert_tasks(root, new_root): + logging.info(f"Converting tasks from {root} to {new_root}") tasks, _ = legacy_load_tasks(root) task_indices = tasks.keys() task_strings = tasks.values() @@ -185,7 +188,10 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int): num_frames = 0 paths_to_cat = [] episodes_metadata = [] - for ep_path in ep_paths: + + logging.info(f"Converting data files from {len(ep_paths)} episodes") + + for ep_path in tqdm.tqdm(ep_paths, desc="convert data files"): ep_size_in_mb = get_parquet_file_size_in_mb(ep_path) ep_num_frames = get_parquet_num_frames(ep_path) ep_metadata = { @@ -209,7 +215,6 @@ def convert_data(root: Path, new_root: Path, data_file_size_in_mb: int): # Reset for the next file size_in_mb = ep_size_in_mb - num_frames = ep_num_frames paths_to_cat = [ep_path] chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE) @@ -236,6 +241,8 @@ def get_image_keys(root): def convert_videos(root: Path, new_root: Path, video_file_size_in_mb: int): + logging.info(f"Converting videos from {root} to {new_root}") + video_keys = get_video_keys(root) if len(video_keys) == 0: return None @@ -254,7 +261,7 @@ def convert_videos(root: Path, new_root: Path, video_file_size_in_mb: int): episods_metadata = [] num_cameras = len(video_keys) num_episodes = num_eps_per_cam[0] - for ep_idx in range(num_episodes): + for ep_idx in tqdm.tqdm(range(num_episodes), desc="convert videos"): # Sanity check ep_ids = [eps_metadata_per_cam[cam_idx][ep_idx]["episode_index"] for cam_idx in range(num_cameras)] ep_ids += [ep_idx] @@ -281,6 +288,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f duration_in_s = 0.0 paths_to_cat = [] episodes_metadata = [] + for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"): ep_size_in_mb = get_video_size_in_mb(ep_path) ep_duration_in_s = get_video_duration_in_s(ep_path) @@ -374,6 +382,8 @@ def generate_episode_metadata_dict( def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None): + logging.info(f"Converting episodes metadata from {root} to {new_root}") + episodes_legacy_metadata = legacy_load_episodes(root) episodes_stats = legacy_load_episodes_stats(root) @@ -405,6 +415,7 @@ def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb): info["data_path"] = DEFAULT_DATA_PATH info["video_path"] = DEFAULT_VIDEO_PATH info["fps"] = int(info["fps"]) + logging.info(f"Converting info from {root} to {new_root}") for key in info["features"]: if info["features"][key]["dtype"] == "video": # already has fps in video_info @@ -469,6 +480,7 @@ def convert_dataset( if __name__ == "__main__": + init_logging() parser = argparse.ArgumentParser() parser.add_argument( "--repo-id", From c5b5955c5acf15e1b3f0ace6ae72612f98c1fe06 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 26 Sep 2025 14:30:07 +0200 Subject: [PATCH 131/158] chore: replace hard-coded next values with constants throughout all the source code (#2056) --- src/lerobot/datasets/factory.py | 4 +- src/lerobot/processor/converters.py | 14 ++-- src/lerobot/rl/buffer.py | 16 ++--- src/lerobot/rl/crop_dataset_roi.py | 3 +- src/lerobot/rl/gym_manipulator.py | 10 +-- src/lerobot/scripts/lerobot_dataset_viz.py | 10 +-- src/lerobot/scripts/lerobot_eval.py | 6 +- tests/datasets/test_datasets.py | 6 +- .../hilserl/test_modeling_classifier.py | 10 +-- tests/processor/test_batch_conversion.py | 68 +++++++++---------- tests/processor/test_converters.py | 6 +- tests/processor/test_pipeline.py | 14 ++-- tests/utils/test_replay_buffer.py | 6 +- 13 files changed, 87 insertions(+), 86 deletions(-) diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index f74b6ac4f..f3ceb2b0c 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -27,7 +27,7 @@ from lerobot.datasets.lerobot_dataset import ( ) from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.transforms import ImageTransforms -from lerobot.utils.constants import ACTION, OBS_PREFIX +from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD IMAGENET_STATS = { "mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1) @@ -55,7 +55,7 @@ def resolve_delta_timestamps( """ delta_timestamps = {} for key in ds_meta.features: - if key == "next.reward" and cfg.reward_delta_indices is not None: + if key == REWARD and cfg.reward_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices] if key == ACTION and cfg.action_delta_indices is not None: delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices] diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 68f9dd6fa..6b0b67598 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -23,7 +23,7 @@ from typing import Any import numpy as np import torch -from lerobot.utils.constants import ACTION, OBS_PREFIX +from lerobot.utils.constants import ACTION, DONE, OBS_PREFIX, REWARD, TRUNCATED from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey @@ -355,9 +355,9 @@ def batch_to_transition(batch: dict[str, Any]) -> EnvTransition: return create_transition( observation=observation_keys if observation_keys else None, action=batch.get(ACTION), - reward=batch.get("next.reward", 0.0), - done=batch.get("next.done", False), - truncated=batch.get("next.truncated", False), + reward=batch.get(REWARD, 0.0), + done=batch.get(DONE, False), + truncated=batch.get(TRUNCATED, False), info=batch.get("info", {}), complementary_data=complementary_data if complementary_data else None, ) @@ -380,9 +380,9 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]: batch = { ACTION: transition.get(TransitionKey.ACTION), - "next.reward": transition.get(TransitionKey.REWARD, 0.0), - "next.done": transition.get(TransitionKey.DONE, False), - "next.truncated": transition.get(TransitionKey.TRUNCATED, False), + REWARD: transition.get(TransitionKey.REWARD, 0.0), + DONE: transition.get(TransitionKey.DONE, False), + TRUNCATED: transition.get(TransitionKey.TRUNCATED, False), "info": transition.get(TransitionKey.INFO, {}), } diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index b572bbce5..d30b65082 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -24,7 +24,7 @@ import torch.nn.functional as F # noqa: N812 from tqdm import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.utils.constants import ACTION, OBS_IMAGE +from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, REWARD from lerobot.utils.transition import Transition @@ -534,8 +534,8 @@ class ReplayBuffer: features[ACTION] = act_info # Add "reward" and "done" - features["next.reward"] = {"dtype": "float32", "shape": (1,)} - features["next.done"] = {"dtype": "bool", "shape": (1,)} + features[REWARD] = {"dtype": "float32", "shape": (1,)} + features[DONE] = {"dtype": "bool", "shape": (1,)} # Add state keys for key in self.states: @@ -578,8 +578,8 @@ class ReplayBuffer: # Fill action, reward, done frame_dict[ACTION] = self.actions[actual_idx].cpu() - frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() - frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() + frame_dict[REWARD] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() + frame_dict[DONE] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() frame_dict["task"] = task_name # Add complementary_info if available @@ -648,7 +648,7 @@ class ReplayBuffer: # Check if the dataset has "next.done" key sample = dataset[0] - has_done_key = "next.done" in sample + has_done_key = DONE in sample # Check for complementary_info keys complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")] @@ -671,11 +671,11 @@ class ReplayBuffer: action = current_sample[ACTION].unsqueeze(0) # Add batch dimension # ----- 3) Reward and done ----- - reward = float(current_sample["next.reward"].item()) # ensure float + reward = float(current_sample[REWARD].item()) # ensure float # Determine done flag - use next.done if available, otherwise infer from episode boundaries if has_done_key: - done = bool(current_sample["next.done"].item()) # ensure bool + done = bool(current_sample[DONE].item()) # ensure bool else: # If this is the last frame or if next frame is in a different episode, mark as done done = False diff --git a/src/lerobot/rl/crop_dataset_roi.py b/src/lerobot/rl/crop_dataset_roi.py index c4318c415..281069e14 100644 --- a/src/lerobot/rl/crop_dataset_roi.py +++ b/src/lerobot/rl/crop_dataset_roi.py @@ -25,6 +25,7 @@ import torchvision.transforms.functional as F # type: ignore # noqa: N812 from tqdm import tqdm # type: ignore from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import DONE, REWARD def select_rect_roi(img): @@ -212,7 +213,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset( for key, value in frame.items(): if key in ("task_index", "timestamp", "episode_index", "frame_index", "index", "task"): continue - if key in ("next.done", "next.reward"): + if key in (DONE, REWARD): # if not isinstance(value, str) and len(value.shape) == 0: value = value.unsqueeze(0) diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index fa9f4e3e1..ad36f1b36 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -73,7 +73,7 @@ from lerobot.teleoperators import ( ) from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents -from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import log_say @@ -602,8 +602,8 @@ def control_loop( action_features = teleop_device.action_features features = { ACTION: action_features, - "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, - "next.done": {"dtype": "bool", "shape": (1,), "names": None}, + REWARD: {"dtype": "float32", "shape": (1,), "names": None}, + DONE: {"dtype": "bool", "shape": (1,), "names": None}, } if use_gripper: features["complementary_info.discrete_penalty"] = { @@ -673,8 +673,8 @@ def control_loop( frame = { **observations, ACTION: action_to_record.cpu(), - "next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32), - "next.done": np.array([terminated or truncated], dtype=bool), + REWARD: np.array([transition[TransitionKey.REWARD]], dtype=np.float32), + DONE: np.array([terminated or truncated], dtype=bool), } if use_gripper: discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0) diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index adff5c085..55708d9a9 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -75,7 +75,7 @@ import torch.utils.data import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.utils.constants import ACTION, OBS_STATE +from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD class EpisodeSampler(torch.utils.data.Sampler): @@ -166,11 +166,11 @@ def visualize_dataset( for dim_idx, val in enumerate(batch[OBS_STATE][i]): rr.log(f"state/{dim_idx}", rr.Scalar(val.item())) - if "next.done" in batch: - rr.log("next.done", rr.Scalar(batch["next.done"][i].item())) + if DONE in batch: + rr.log(DONE, rr.Scalar(batch[DONE][i].item())) - if "next.reward" in batch: - rr.log("next.reward", rr.Scalar(batch["next.reward"][i].item())) + if REWARD in batch: + rr.log(REWARD, rr.Scalar(batch[REWARD][i].item())) if "next.success" in batch: rr.log("next.success", rr.Scalar(batch["next.success"][i].item())) diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index 882aeacc3..d45be5c42 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -81,7 +81,7 @@ from lerobot.envs.utils import ( from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.processor import PolicyAction, PolicyProcessorPipeline -from lerobot.utils.constants import ACTION, OBS_STR +from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD from lerobot.utils.io_utils import write_video from lerobot.utils.random_utils import set_seed from lerobot.utils.utils import ( @@ -451,9 +451,9 @@ def _compile_episode_data( "episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)), "frame_index": torch.arange(0, num_frames - 1, 1), "timestamp": torch.arange(0, num_frames - 1, 1) / fps, - "next.done": rollout_data["done"][ep_ix, : num_frames - 1], + DONE: rollout_data["done"][ep_ix, : num_frames - 1], "next.success": rollout_data["success"][ep_ix, : num_frames - 1], - "next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32), + REWARD: rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32), } # For the last observation frame, all other keys will just be copy padded. diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index fcfef677b..b9e966fe6 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -46,7 +46,7 @@ from lerobot.datasets.utils import ( from lerobot.envs.factory import make_env_config from lerobot.policies.factory import make_policy_config from lerobot.robots import make_robot_from_config -from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR +from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.mocks.mock_robot import MockRobotConfig from tests.utils import require_x86_64_kernel @@ -399,8 +399,8 @@ def test_factory(env_name, repo_id, policy_name): ("timestamp", 0, True), # TODO(rcadene): should we rename it agent_pos? (OBS_STATE, 1, True), - ("next.reward", 0, False), - ("next.done", 0, False), + (REWARD, 0, False), + (DONE, 0, False), ] # test number of dimensions diff --git a/tests/policies/hilserl/test_modeling_classifier.py b/tests/policies/hilserl/test_modeling_classifier.py index 7a8782230..a572ea9e1 100644 --- a/tests/policies/hilserl/test_modeling_classifier.py +++ b/tests/policies/hilserl/test_modeling_classifier.py @@ -19,7 +19,7 @@ import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput -from lerobot.utils.constants import OBS_IMAGE +from lerobot.utils.constants import OBS_IMAGE, REWARD from tests.utils import require_package @@ -45,7 +45,7 @@ def test_binary_classifier_with_default_params(): OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), } config.output_features = { - "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)), + REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)), } config.normalization_mapping = { "VISUAL": NormalizationMode.IDENTITY, @@ -58,7 +58,7 @@ def test_binary_classifier_with_default_params(): input = { OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)), - "next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(), + REWARD: torch.randint(low=0, high=2, size=(batch_size,)).float(), } images, labels = classifier.extract_images_and_labels(input) @@ -87,7 +87,7 @@ def test_multiclass_classifier(): OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), } config.output_features = { - "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)), + REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)), } config.num_cameras = 1 config.num_classes = num_classes @@ -97,7 +97,7 @@ def test_multiclass_classifier(): input = { OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)), - "next.reward": torch.rand((batch_size, num_classes)), + REWARD: torch.rand((batch_size, num_classes)), } images, labels = classifier.extract_images_and_labels(input) diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 0f7018972..88b873128 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -2,7 +2,7 @@ import torch from lerobot.processor import DataProcessorPipeline, TransitionKey from lerobot.processor.converters import batch_to_transition, transition_to_batch -from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_PREFIX, OBS_STATE +from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_PREFIX, OBS_STATE, REWARD, TRUNCATED def _dummy_batch(): @@ -12,9 +12,9 @@ def _dummy_batch(): f"{OBS_IMAGE}.right": torch.randn(1, 3, 128, 128), OBS_STATE: torch.tensor([[0.1, 0.2, 0.3, 0.4]]), ACTION: torch.tensor([[0.5]]), - "next.reward": 1.0, - "next.done": False, - "next.truncated": False, + REWARD: 1.0, + DONE: False, + TRUNCATED: False, "info": {"key": "value"}, } @@ -38,9 +38,9 @@ def test_observation_grouping_roundtrip(): # Check other fields assert torch.allclose(batch_out[ACTION], batch_in[ACTION]) - assert batch_out["next.reward"] == batch_in["next.reward"] - assert batch_out["next.done"] == batch_in["next.done"] - assert batch_out["next.truncated"] == batch_in["next.truncated"] + assert batch_out[REWARD] == batch_in[REWARD] + assert batch_out[DONE] == batch_in[DONE] + assert batch_out[TRUNCATED] == batch_in[TRUNCATED] assert batch_out["info"] == batch_in["info"] @@ -51,9 +51,9 @@ def test_batch_to_transition_observation_grouping(): f"{OBS_IMAGE}.left": torch.randn(1, 3, 128, 128), OBS_STATE: [1, 2, 3, 4], ACTION: torch.tensor([0.1, 0.2, 0.3, 0.4]), - "next.reward": 1.5, - "next.done": True, - "next.truncated": False, + REWARD: 1.5, + DONE: True, + TRUNCATED: False, "info": {"episode": 42}, } @@ -115,9 +115,9 @@ def test_transition_to_batch_observation_flattening(): # Check other fields are mapped to next.* format assert batch[ACTION] == "action_data" - assert batch["next.reward"] == 1.5 - assert batch["next.done"] - assert not batch["next.truncated"] + assert batch[REWARD] == 1.5 + assert batch[DONE] + assert not batch[TRUNCATED] assert batch["info"] == {"episode": 42} @@ -125,9 +125,9 @@ def test_no_observation_keys(): """Test behavior when there are no observation.* keys.""" batch = { ACTION: torch.tensor([1.0, 2.0]), - "next.reward": 2.0, - "next.done": False, - "next.truncated": True, + REWARD: 2.0, + DONE: False, + TRUNCATED: True, "info": {"test": "no_obs"}, } @@ -146,9 +146,9 @@ def test_no_observation_keys(): # Round trip should work reconstructed_batch = transition_to_batch(transition) assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([1.0, 2.0])) - assert reconstructed_batch["next.reward"] == 2.0 - assert not reconstructed_batch["next.done"] - assert reconstructed_batch["next.truncated"] + assert reconstructed_batch[REWARD] == 2.0 + assert not reconstructed_batch[DONE] + assert reconstructed_batch[TRUNCATED] assert reconstructed_batch["info"] == {"test": "no_obs"} @@ -173,9 +173,9 @@ def test_minimal_batch(): reconstructed_batch = transition_to_batch(transition) assert reconstructed_batch[OBS_STATE] == "minimal_state" assert torch.allclose(reconstructed_batch[ACTION], torch.tensor([0.5])) - assert reconstructed_batch["next.reward"] == 0.0 - assert not reconstructed_batch["next.done"] - assert not reconstructed_batch["next.truncated"] + assert reconstructed_batch[REWARD] == 0.0 + assert not reconstructed_batch[DONE] + assert not reconstructed_batch[TRUNCATED] assert reconstructed_batch["info"] == {} @@ -197,9 +197,9 @@ def test_empty_batch(): # Round trip reconstructed_batch = transition_to_batch(transition) assert reconstructed_batch[ACTION] is None - assert reconstructed_batch["next.reward"] == 0.0 - assert not reconstructed_batch["next.done"] - assert not reconstructed_batch["next.truncated"] + assert reconstructed_batch[REWARD] == 0.0 + assert not reconstructed_batch[DONE] + assert not reconstructed_batch[TRUNCATED] assert reconstructed_batch["info"] == {} @@ -210,9 +210,9 @@ def test_complex_nested_observation(): f"{OBS_IMAGE}.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, OBS_STATE: torch.randn(7), ACTION: torch.randn(8), - "next.reward": 3.14, - "next.done": False, - "next.truncated": True, + REWARD: 3.14, + DONE: False, + TRUNCATED: True, "info": {"episode_length": 200, "success": True}, } @@ -240,9 +240,9 @@ def test_complex_nested_observation(): assert torch.allclose(batch[ACTION], reconstructed_batch[ACTION]) # Check other fields - assert batch["next.reward"] == reconstructed_batch["next.reward"] - assert batch["next.done"] == reconstructed_batch["next.done"] - assert batch["next.truncated"] == reconstructed_batch["next.truncated"] + assert batch[REWARD] == reconstructed_batch[REWARD] + assert batch[DONE] == reconstructed_batch[DONE] + assert batch[TRUNCATED] == reconstructed_batch[TRUNCATED] assert batch["info"] == reconstructed_batch["info"] @@ -267,13 +267,13 @@ def test_custom_converter(): batch = { OBS_STATE: torch.randn(1, 4), ACTION: torch.randn(1, 2), - "next.reward": 1.0, - "next.done": False, + REWARD: 1.0, + DONE: False, } result = processor(batch) # Check the reward was doubled by our custom converter - assert result["next.reward"] == 2.0 + assert result[REWARD] == 2.0 assert torch.allclose(result[OBS_STATE], batch[OBS_STATE]) assert torch.allclose(result[ACTION], batch[ACTION]) diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py index d347858dc..bc58f7a61 100644 --- a/tests/processor/test_converters.py +++ b/tests/processor/test_converters.py @@ -9,7 +9,7 @@ from lerobot.processor.converters import ( to_tensor, transition_to_batch, ) -from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR +from lerobot.utils.constants import ACTION, DONE, OBS_STATE, OBS_STR, REWARD # Tests for the unified to_tensor function @@ -201,8 +201,8 @@ def test_batch_to_transition_with_index_fields(): batch = { OBS_STATE: torch.randn(1, 7), ACTION: torch.randn(1, 4), - "next.reward": 1.5, - "next.done": False, + REWARD: 1.5, + DONE: False, "task": ["pick_cube"], "index": torch.tensor([42], dtype=torch.int64), "task_index": torch.tensor([3], dtype=torch.int64), diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 6dbf37450..904fd6fc1 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -35,7 +35,7 @@ from lerobot.processor import ( TransitionKey, ) from lerobot.processor.converters import create_transition, identity_transition -from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, REWARD, TRUNCATED from tests.conftest import assert_contract_is_typed @@ -258,9 +258,9 @@ def test_step_through_with_dict(): batch = { OBS_IMAGE: None, ACTION: None, - "next.reward": 0.0, - "next.done": False, - "next.truncated": False, + REWARD: 0.0, + DONE: False, + TRUNCATED: False, "info": {}, } @@ -1843,9 +1843,9 @@ def test_save_load_with_custom_converter_functions(): batch = { OBS_IMAGE: torch.randn(1, 3, 32, 32), ACTION: torch.randn(1, 7), - "next.reward": torch.tensor([1.0]), - "next.done": torch.tensor([False]), - "next.truncated": torch.tensor([False]), + REWARD: torch.tensor([1.0]), + DONE: torch.tensor([False]), + TRUNCATED: torch.tensor([False]), "info": {}, } diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index 1e6c0df95..ddf0771f1 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -22,7 +22,7 @@ import torch from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized -from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE, OBS_STR +from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD from tests.fixtures.constants import DUMMY_REPO_ID @@ -380,9 +380,9 @@ def test_to_lerobot_dataset(tmp_path): for feature, value in ds[i].items(): if feature == ACTION: assert torch.equal(value, buffer.actions[i]) - elif feature == "next.reward": + elif feature == REWARD: assert torch.equal(value, buffer.rewards[i]) - elif feature == "next.done": + elif feature == DONE: assert torch.equal(value, buffer.dones[i]) elif feature == OBS_IMAGE: # Tensor -> numpy is not precise, so we have some diff there From 49918efbc1946da822f743902ba36da8e1e398b4 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 26 Sep 2025 14:30:17 +0200 Subject: [PATCH 132/158] chore(utils): remove unused code (#2059) --- src/lerobot/utils/control_utils.py | 59 ------------------------------ src/lerobot/utils/import_utils.py | 4 -- 2 files changed, 63 deletions(-) diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py index 47beb5746..17371921c 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/utils/control_utils.py @@ -27,7 +27,6 @@ from typing import Any import numpy as np import torch from deepdiff import DeepDiff -from termcolor import colored from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import DEFAULT_FEATURES @@ -36,64 +35,6 @@ from lerobot.processor import PolicyAction, PolicyProcessorPipeline from lerobot.robots import Robot -def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): - """ - Logs performance metrics for a single step of the robot control loop. - - This function formats and prints a single line of log information, including episode/frame counters, - total loop time (dt), and detailed timings for various robot and camera operations. It can also - highlight performance drops in yellow if the actual FPS is lower than the target FPS. - - Args: - robot: The `Robot` instance, used to access its internal logs for detailed timings. - dt_s: The total duration of the control loop step in seconds. - episode_index: The index of the current episode. - frame_index: The index of the current frame within the episode. - fps: The target frames per second, used to check for performance degradation. - """ - log_items = [] - if episode_index is not None: - log_items.append(f"ep:{episode_index}") - if frame_index is not None: - log_items.append(f"frame:{frame_index}") - - def log_dt(shortname, dt_val_s): - nonlocal log_items, fps - info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)" - if fps is not None: - actual_fps = 1 / dt_val_s - if actual_fps < fps - 1: - info_str = colored(info_str, "yellow") - log_items.append(info_str) - - # total step time displayed in milliseconds and its frequency - log_dt("dt", dt_s) - - # TODO(aliberts): move robot-specific logs logic in robot.print_logs() - if not robot.robot_type.startswith("stretch"): - for name in robot.leader_arms: - key = f"read_leader_{name}_pos_dt_s" - if key in robot.logs: - log_dt("dtRlead", robot.logs[key]) - - for name in robot.follower_arms: - key = f"write_follower_{name}_goal_pos_dt_s" - if key in robot.logs: - log_dt("dtWfoll", robot.logs[key]) - - key = f"read_follower_{name}_pos_dt_s" - if key in robot.logs: - log_dt("dtRfoll", robot.logs[key]) - - for name in robot.cameras: - key = f"read_camera_{name}_dt_s" - if key in robot.logs: - log_dt(f"dtR{name}", robot.logs[key]) - - info_str = " ".join(log_items) - logging.info(info_str) - - @cache def is_headless(): """ diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index 09e649372..5f41ea3a3 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -57,8 +57,4 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b return package_exists -_torch_available, _torch_version = is_package_available("torch", return_version=True) _transformers_available = is_package_available("transformers") -_gym_xarm_available = is_package_available("gym_xarm") -_gym_aloha_available = is_package_available("gym_aloha") -_gym_pusht_available = is_package_available("gym_pusht") From ddfff054bc80438e70a64088d73564aba0aec67b Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Fri, 26 Sep 2025 14:32:29 +0200 Subject: [PATCH 133/158] feat(train): enhance processor overrides with normalizer and unnormalizer stats (#2038) --- src/lerobot/scripts/lerobot_train.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 5ef8c7263..86b2bbae5 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -185,7 +185,13 @@ def train(cfg: TrainPipelineConfig): processor_kwargs["dataset_stats"] = dataset.meta.stats if cfg.policy.pretrained_path is not None: - processor_kwargs["preprocessor_overrides"] = {"device_processor": {"device": device.type}} + processor_kwargs["preprocessor_overrides"] = { + "device_processor": {"device": device.type}, + "normalizer_processor": {"stats": dataset.meta.stats}, + } + processor_kwargs["postprocessor_overrides"] = { + "unnormalizer_processor": {"stats": dataset.meta.stats}, + } preprocessor, postprocessor = make_pre_post_processors( policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs From 5b647e3bcbedc2a734714a1f7f219a861ec843ab Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Fri, 26 Sep 2025 15:09:42 +0200 Subject: [PATCH 134/158] docs(fix): libero example command (#2060) Signed-off-by: Jade Choghari --- docs/source/libero.mdx | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/libero.mdx b/docs/source/libero.mdx index 17e12d45e..eafe3e78b 100644 --- a/docs/source/libero.mdx +++ b/docs/source/libero.mdx @@ -106,6 +106,7 @@ For reference, here is the **original dataset** published by Physical Intelligen lerobot-train \ --policy.type=smolvla \ --policy.repo_id=${HF_USER}/libero-test \ + --policy.load_vlm_weights=true \ --dataset.repo_id=HuggingFaceVLA/libero \ --env.type=libero \ --env.task=libero_10 \ From e3b572992e2d66d44e0e6e9f02fdf40d6f484866 Mon Sep 17 00:00:00 2001 From: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Date: Sat, 27 Sep 2025 16:07:53 +0200 Subject: [PATCH 135/158] Save Cropped Dataset to Hub (#2071) * fix: cast fps argument from dataset to int * fix: typo * fix: specify repo-id --- src/lerobot/rl/crop_dataset_roi.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/lerobot/rl/crop_dataset_roi.py b/src/lerobot/rl/crop_dataset_roi.py index 281069e14..4345fed3c 100644 --- a/src/lerobot/rl/crop_dataset_roi.py +++ b/src/lerobot/rl/crop_dataset_roi.py @@ -160,7 +160,7 @@ def get_image_from_lerobot_dataset(dataset: LeRobotDataset): return image_dict -def convert_lerobot_dataset_to_cropper_lerobot_dataset( +def convert_lerobot_dataset_to_cropped_lerobot_dataset( original_dataset: LeRobotDataset, crop_params_dict: dict[str, tuple[int, int, int, int]], new_repo_id: str, @@ -190,7 +190,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset( # 1. Create a new (empty) LeRobotDataset for writing. new_dataset = LeRobotDataset.create( repo_id=new_repo_id, - fps=original_dataset.fps, + fps=int(original_dataset.fps), root=new_dataset_root, robot_type=original_dataset.meta.robot_type, features=original_dataset.meta.info["features"], @@ -275,6 +275,12 @@ if __name__ == "__main__": default="", help="The natural language task to describe the dataset.", ) + parser.add_argument( + "--new-repo-id", + type=str, + default=None, + help="The repository id for the new cropped and resized dataset. If not provided, it defaults to `repo_id` + '_cropped_resized'.", + ) args = parser.parse_args() dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root) @@ -294,10 +300,16 @@ if __name__ == "__main__": for key, roi in rois.items(): print(f"{key}: {roi}") - new_repo_id = args.repo_id + "_cropped_resized" - new_dataset_root = Path(str(dataset.root) + "_cropped_resized") + new_repo_id = args.new_repo_id if args.new_repo_id else args.repo_id + "_cropped_resized" - cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset( + if args.new_repo_id: + new_dataset_name = args.new_repo_id.split("/")[-1] + # Parent 1: HF user, Parent 2: HF LeRobot Home + new_dataset_root = dataset.root.parent.parent / new_dataset_name + else: + new_dataset_root = Path(str(dataset.root) + "_cropped_resized") + + cropped_resized_dataset = convert_lerobot_dataset_to_cropped_lerobot_dataset( original_dataset=dataset, crop_params_dict=rois, new_repo_id=new_repo_id, From 62e9849ffd4a600ca48dbc0f14d3f860e633c568 Mon Sep 17 00:00:00 2001 From: Qizhi Chen Date: Sun, 28 Sep 2025 20:18:22 +0800 Subject: [PATCH 136/158] use abs path when concatenating (#2076) --- src/lerobot/datasets/video_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 9da89022b..b4c036e6c 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -428,7 +428,7 @@ def concatenate_video_files( with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file: tmp_concatenate_file.write("ffconcat version 1.0\n") for input_path in input_video_paths: - tmp_concatenate_file.write(f"file '{str(input_path)}'\n") + tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n") tmp_concatenate_file.flush() tmp_concatenate_path = tmp_concatenate_file.name From f59eb54f5c0e3f46286e15305402bed2830a7848 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 29 Sep 2025 10:49:36 +0200 Subject: [PATCH 137/158] chore: remove unused code (#2062) --- docs/source/hilserl.mdx | 3 - docs/source/phone_teleop.mdx | 3 +- docs/source/processors_robots_teleop.mdx | 2 +- examples/phone_to_so100/record.py | 1 - examples/phone_to_so100/teleoperate.py | 1 - examples/so100_to_so100_EE/record.py | 1 - examples/so100_to_so100_EE/teleoperate.py | 1 - src/lerobot/async_inference/helpers.py | 3 +- src/lerobot/cameras/utils.py | 4 -- src/lerobot/configs/default.py | 3 - src/lerobot/configs/types.py | 5 -- src/lerobot/datasets/lerobot_dataset.py | 10 ---- .../datasets/push_dataset_to_hub/utils.py | 57 ------------------- src/lerobot/datasets/streaming_dataset.py | 4 +- src/lerobot/datasets/utils.py | 49 ---------------- src/lerobot/datasets/video_utils.py | 13 ----- src/lerobot/envs/configs.py | 2 - src/lerobot/motors/motors_bus.py | 6 -- src/lerobot/policies/sac/configuration_sac.py | 2 - src/lerobot/policies/sac/modeling_sac.py | 12 ---- .../policies/vqbet/configuration_vqbet.py | 2 - src/lerobot/policies/vqbet/vqbet_utils.py | 10 ---- .../processor/delta_action_processor.py | 2 - src/lerobot/rl/actor.py | 2 - src/lerobot/rl/learner.py | 2 - src/lerobot/robots/hope_jr/hope_jr_arm.py | 2 +- .../robot_kinematic_processor.py | 6 -- .../robots/stretch3/configuration_stretch3.py | 2 - src/lerobot/robots/stretch3/robot_stretch3.py | 4 -- .../teleoperators/gamepad/gamepad_utils.py | 22 ------- .../keyboard/configuration_keyboard.py | 3 +- .../configuration_stretch3.py | 2 +- .../stretch3_gamepad/stretch3_gamepad.py | 4 -- src/lerobot/utils/constants.py | 1 - src/lerobot/utils/errors.py | 11 ---- src/lerobot/utils/utils.py | 4 -- tests/policies/test_sac_config.py | 1 - 37 files changed, 8 insertions(+), 254 deletions(-) diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx index bc38408e6..ad1c74f9a 100644 --- a/docs/source/hilserl.mdx +++ b/docs/source/hilserl.mdx @@ -95,7 +95,6 @@ class HILSerlProcessorConfig: class ObservationConfig: add_joint_velocity_to_observation: bool = False # Add joint velocities to state add_current_to_observation: bool = False # Add motor currents to state - add_ee_pose_to_observation: bool = False # Add end-effector pose to state display_cameras: bool = False # Display camera feeds during execution class ImagePreprocessingConfig: @@ -105,7 +104,6 @@ class ImagePreprocessingConfig: class GripperConfig: use_gripper: bool = True # Enable gripper control gripper_penalty: float = 0.0 # Penalty for inappropriate gripper usage - gripper_penalty_in_reward: bool = False # Include gripper penalty in reward class ResetConfig: fixed_reset_joint_positions: Any | None = None # Joint positions for reset @@ -288,7 +286,6 @@ You can enable multiple observation processing features simultaneously: "observation": { "add_joint_velocity_to_observation": true, "add_current_to_observation": true, - "add_ee_pose_to_observation": false, "display_cameras": false } } diff --git a/docs/source/phone_teleop.mdx b/docs/source/phone_teleop.mdx index bab0ac28e..22159193c 100644 --- a/docs/source/phone_teleop.mdx +++ b/docs/source/phone_teleop.mdx @@ -136,13 +136,12 @@ Additionally you can customize mapping or safety limits by editing the processor ), ``` -- The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` and `max_ee_twist_step_rad` are the step limits for the EE pose and can be modified to change the safety limits. +- The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` are the step limits for the EE pose and can be modified to change the safety limits. ```examples/phone_to_so100/teleoperate.py EEBoundsAndSafety( end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.10, - max_ee_twist_step_rad=0.50, ) ``` diff --git a/docs/source/processors_robots_teleop.mdx b/docs/source/processors_robots_teleop.mdx index c4fcbe03d..3d8dcb409 100644 --- a/docs/source/processors_robots_teleop.mdx +++ b/docs/source/processors_robots_teleop.mdx @@ -38,7 +38,7 @@ phone_to_robot_ee_pose_processor = RobotProcessorPipeline[RobotAction, RobotActi kinematics=kinematics_solver, end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, motor_names=list(robot.bus.motors.keys()), ), EEBoundsAndSafety( - end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.20, max_ee_twist_step_rad=0.50, + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.20, ), GripperVelocityToJoint(), ], diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index bb2e2f5f7..d3ef293a7 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -84,7 +84,6 @@ phone_to_robot_ee_pose_processor = RobotProcessorPipeline[tuple[RobotAction, Rob EEBoundsAndSafety( end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.20, - max_ee_twist_step_rad=0.50, ), GripperVelocityToJoint(speed_factor=20.0), ], diff --git a/examples/phone_to_so100/teleoperate.py b/examples/phone_to_so100/teleoperate.py index 6c49a8453..783dce242 100644 --- a/examples/phone_to_so100/teleoperate.py +++ b/examples/phone_to_so100/teleoperate.py @@ -67,7 +67,6 @@ phone_to_robot_joints_processor = RobotProcessorPipeline[tuple[RobotAction, Robo EEBoundsAndSafety( end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.10, - max_ee_twist_step_rad=0.50, ), GripperVelocityToJoint( speed_factor=20.0, diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index 6c38553e2..9ed6e51a9 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -101,7 +101,6 @@ ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservati EEBoundsAndSafety( end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.10, - max_ee_twist_step_rad=0.50, ), InverseKinematicsEEToJoints( kinematics=follower_kinematics_solver, diff --git a/examples/so100_to_so100_EE/teleoperate.py b/examples/so100_to_so100_EE/teleoperate.py index aa9755788..b1a8c8c27 100644 --- a/examples/so100_to_so100_EE/teleoperate.py +++ b/examples/so100_to_so100_EE/teleoperate.py @@ -78,7 +78,6 @@ ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservati EEBoundsAndSafety( end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.10, - max_ee_twist_step_rad=0.50, ), InverseKinematicsEEToJoints( kinematics=follower_kinematics_solver, diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 75d81a0f3..a336d3a63 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -31,7 +31,6 @@ from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR from lerobot.utils.utils import init_logging Action = torch.Tensor -ActionChunk = torch.Tensor # observation as received from the robot RawObservation = dict[str, torch.Tensor] @@ -46,7 +45,7 @@ Observation = dict[str, torch.Tensor] def visualize_action_queue_size(action_queue_size: list[int]) -> None: import matplotlib.pyplot as plt - fig, ax = plt.subplots() + _, ax = plt.subplots() ax.set_title("Action Queue Size Over Time") ax.set_xlabel("Environment steps") ax.set_ylabel("Action Queue Size") diff --git a/src/lerobot/cameras/utils.py b/src/lerobot/cameras/utils.py index dfac33e17..4a23843b2 100644 --- a/src/lerobot/cameras/utils.py +++ b/src/lerobot/cameras/utils.py @@ -15,14 +15,10 @@ # limitations under the License. import platform -from pathlib import Path -from typing import TypeAlias from .camera import Camera from .configs import CameraConfig, Cv2Rotation -IndexOrPath: TypeAlias = int | Path - def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[str, Camera]: cameras = {} diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index 1bc2b8d16..afd644e1c 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -16,9 +16,6 @@ from dataclasses import dataclass, field -from lerobot import ( - policies, # noqa: F401 -) from lerobot.datasets.transforms import ImageTransformsConfig from lerobot.datasets.video_utils import get_safe_default_codec diff --git a/src/lerobot/configs/types.py b/src/lerobot/configs/types.py index e02527840..754aca1ab 100644 --- a/src/lerobot/configs/types.py +++ b/src/lerobot/configs/types.py @@ -15,7 +15,6 @@ # https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json from dataclasses import dataclass from enum import Enum -from typing import Any, Protocol class FeatureType(str, Enum): @@ -38,10 +37,6 @@ class NormalizationMode(str, Enum): IDENTITY = "IDENTITY" -class DictLike(Protocol): - def __getitem__(self, key: Any) -> Any: ... - - @dataclass class PolicyFeature: type: FeatureType diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 9eebcea4b..b8aa880da 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -848,11 +848,6 @@ class LeRobotDataset(torch.utils.data.Dataset): return item - def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict: - for key, val in padding.items(): - item[key] = torch.BoolTensor(val) - return item - def __len__(self): return self.num_frames @@ -1396,11 +1391,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): """ return {repo_id: i for i, repo_id in enumerate(self.repo_ids)} - @property - def repo_index_to_id(self): - """Return the inverse mapping if repo_id_to_index.""" - return {v: k for k, v in self.repo_id_to_index} - @property def fps(self) -> int: """Frames per second used during data collection. diff --git a/src/lerobot/datasets/push_dataset_to_hub/utils.py b/src/lerobot/datasets/push_dataset_to_hub/utils.py index 5f6363a77..48214e1bf 100644 --- a/src/lerobot/datasets/push_dataset_to_hub/utils.py +++ b/src/lerobot/datasets/push_dataset_to_hub/utils.py @@ -13,67 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from concurrent.futures import ThreadPoolExecutor -from pathlib import Path import datasets -import numpy -import PIL import torch -from lerobot.datasets.video_utils import encode_video_frames - - -def concatenate_episodes(ep_dicts): - data_dict = {} - - keys = ep_dicts[0].keys() - for key in keys: - if torch.is_tensor(ep_dicts[0][key][0]): - data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts]) - else: - if key not in data_dict: - data_dict[key] = [] - for ep_dict in ep_dicts: - for x in ep_dict[key]: - data_dict[key].append(x) - - total_frames = data_dict["frame_index"].shape[0] - data_dict["index"] = torch.arange(0, total_frames, 1) - return data_dict - - -def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4): - out_dir = Path(out_dir) - out_dir.mkdir(parents=True, exist_ok=True) - - def save_image(img_array, i, out_dir): - img = PIL.Image.fromarray(img_array) - img.save(str(out_dir / f"frame_{i:06d}.png"), quality=100) - - num_images = len(imgs_array) - with ThreadPoolExecutor(max_workers=max_workers) as executor: - [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)] - - -def get_default_encoding() -> dict: - """Returns the default ffmpeg encoding parameters used by `encode_video_frames`.""" - signature = inspect.signature(encode_video_frames) - return { - k: v.default - for k, v in signature.parameters.items() - if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"] - } - - -def check_repo_id(repo_id: str) -> None: - if len(repo_id.split("/")) != 2: - raise ValueError( - f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset - (e.g. 'lerobot/pusht'), but contains '{repo_id}'.""" - ) - # TODO(aliberts): remove def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]: diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py index c3c48d90d..454389d46 100644 --- a/src/lerobot/datasets/streaming_dataset.py +++ b/src/lerobot/datasets/streaming_dataset.py @@ -298,9 +298,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): return padding_mask - def make_frame( - self, dataset_iterator: Backtrackable, previous_dataset_iterator: Backtrackable | None = None - ) -> Generator: + def make_frame(self, dataset_iterator: Backtrackable) -> Generator: """Makes a frame starting from a dataset iterator""" item = next(dataset_iterator) item = item_to_torch(item) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 35313bde5..81b361ab6 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -67,18 +67,6 @@ DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{fram LEGACY_EPISODES_PATH = "meta/episodes.jsonl" LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" LEGACY_TASKS_PATH = "meta/tasks.jsonl" -LEGACY_DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" -LEGACY_DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" - -DATASET_CARD_TEMPLATE = """ ---- -# Metadata will go there ---- -This dataset was created using [LeRobot](https://github.com/huggingface/lerobot). - -## {} - -""" DEFAULT_FEATURES = { "timestamp": {"dtype": "float32", "shape": (1,), "names": None}, @@ -383,12 +371,6 @@ def load_episodes(local_dir: Path) -> datasets.Dataset: return episodes -def backward_compatible_episodes_stats( - stats: dict[str, dict[str, np.ndarray]], episodes: list[int] -) -> dict[int, dict[str, dict[str, np.ndarray]]]: - return dict.fromkeys(episodes, stats) - - def load_image_as_numpy( fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True ) -> np.ndarray: @@ -1346,12 +1328,6 @@ class Backtrackable(Generic[T]): # When cursor<0, slice so the order remains chronological return list(self._back_buf)[: self._cursor or None] - def lookahead_buffer(self) -> list[T]: - """ - Return a copy of the current lookahead buffer. - """ - return list(self._ahead_buf) - def can_peek_back(self, steps: int = 1) -> bool: """ Check if we can go back `steps` items without raising an IndexError. @@ -1377,31 +1353,6 @@ class Backtrackable(Generic[T]): except StopIteration: return False - def reset_cursor(self) -> None: - """ - Reset cursor to the most recent position (equivalent to calling next() - until you're back to the latest item). - """ - self._cursor = 0 - - def clear_ahead_buffer(self) -> None: - """ - Clear the ahead buffer, discarding any pre-fetched items. - """ - self._ahead_buf.clear() - - def switch_source_iterable(self, new_source: Iterable[T]) -> None: - """ - Switch the source of the backtrackable to a new iterable, keeping the history. - - This is useful when iterating over a sequence of datasets. The history from the - previous source is kept, but the lookahead buffer is cleared. The cursor is reset - to the present. - """ - self._source = iter(new_source) - self.clear_ahead_buffer() - self.reset_cursor() - def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) -> datasets.Dataset: """ diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index b4c036e6c..5f8b207e0 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -585,19 +585,6 @@ def get_video_pixel_channels(pix_fmt: str) -> int: raise ValueError("Unknown format") -def get_image_pixel_channels(image: Image): - if image.mode == "L": - return 1 # Grayscale - elif image.mode == "LA": - return 2 # Grayscale + Alpha - elif image.mode == "RGB": - return 3 # RGB - elif image.mode == "RGBA": - return 4 # RGBA - else: - raise ValueError("Unknown format") - - def get_video_duration_in_s(video_path: Path | str) -> float: """ Get the duration of a video file in seconds using PyAV. diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 8cbc597dc..8c0c8b3ab 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -193,7 +193,6 @@ class ObservationConfig: add_joint_velocity_to_observation: bool = False add_current_to_observation: bool = False - add_ee_pose_to_observation: bool = False display_cameras: bool = False @@ -203,7 +202,6 @@ class GripperConfig: use_gripper: bool = True gripper_penalty: float = 0.0 - gripper_penalty_in_reward: bool = False @dataclass diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index dca7650e0..8603d81a9 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -99,12 +99,6 @@ class Motor: norm_mode: MotorNormMode -class JointOutOfRangeError(Exception): - def __init__(self, message="Joint is out of range"): - self.message = message - super().__init__(self.message) - - class PortHandler(Protocol): def __init__(self, port_name): self.is_open: bool diff --git a/src/lerobot/policies/sac/configuration_sac.py b/src/lerobot/policies/sac/configuration_sac.py index 6b5ad5b59..ada12330c 100644 --- a/src/lerobot/policies/sac/configuration_sac.py +++ b/src/lerobot/policies/sac/configuration_sac.py @@ -139,8 +139,6 @@ class SACConfig(PreTrainedConfig): # Training parameter # Number of steps for online training online_steps: int = 1000000 - # Seed for the online environment - online_env_seed: int = 10000 # Capacity of the online replay buffer online_buffer_capacity: int = 100000 # Capacity of the offline replay buffer diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index c66044406..c7c6798ed 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -1061,15 +1061,3 @@ class TanhMultivariateNormalDiag(TransformedDistribution): x = transform(x) return x - - -def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict: - converted_params = {} - for outer_key, inner_dict in normalization_params.items(): - converted_params[outer_key] = {} - for key, value in inner_dict.items(): - converted_params[outer_key][key] = torch.tensor(value) - if "image" in outer_key: - converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1) - - return converted_params diff --git a/src/lerobot/policies/vqbet/configuration_vqbet.py b/src/lerobot/policies/vqbet/configuration_vqbet.py index d7a79f189..44ada9f17 100644 --- a/src/lerobot/policies/vqbet/configuration_vqbet.py +++ b/src/lerobot/policies/vqbet/configuration_vqbet.py @@ -82,7 +82,6 @@ class VQBeTConfig(PreTrainedConfig): gpt_n_head: Number of headers of GPT gpt_hidden_dim: Size of hidden dimensions of GPT dropout: Dropout rate for GPT - mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT offset_loss_weight: A constant that is multiplied to the offset loss primary_code_loss_weight: A constant that is multiplied to the primary code prediction loss secondary_code_loss_weight: A constant that is multiplied to the secondary code prediction loss @@ -125,7 +124,6 @@ class VQBeTConfig(PreTrainedConfig): gpt_n_head: int = 8 gpt_hidden_dim: int = 512 dropout: float = 0.1 - mlp_hidden_dim: int = 1024 offset_loss_weight: float = 10000.0 primary_code_loss_weight: float = 5.0 secondary_code_loss_weight: float = 0.5 diff --git a/src/lerobot/policies/vqbet/vqbet_utils.py b/src/lerobot/policies/vqbet/vqbet_utils.py index e0afe5585..44b7d5f0b 100644 --- a/src/lerobot/policies/vqbet/vqbet_utils.py +++ b/src/lerobot/policies/vqbet/vqbet_utils.py @@ -231,16 +231,6 @@ class GPT(nn.Module): torch.nn.init.zeros_(module.bias) torch.nn.init.ones_(module.weight) - def crop_block_size(self, gpt_block_size): - # model surgery to decrease the block size if necessary - # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) - # but want to use a smaller block size for some smaller, simpler model - assert gpt_block_size <= self.config.gpt_block_size - self.config.gpt_block_size = gpt_block_size - self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size]) - for block in self.transformer.h: - block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size] - def configure_parameters(self): """ This long function is unfortunately doing something very simple and is being very defensive: diff --git a/src/lerobot/processor/delta_action_processor.py b/src/lerobot/processor/delta_action_processor.py index 949ae78d5..a8395637c 100644 --- a/src/lerobot/processor/delta_action_processor.py +++ b/src/lerobot/processor/delta_action_processor.py @@ -83,14 +83,12 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep): Attributes: position_scale: A factor to scale the delta position inputs. - rotation_scale: A factor to scale the delta rotation inputs (currently unused). noise_threshold: The magnitude below which delta inputs are considered noise and do not trigger an "enabled" state. """ # Scale factors for delta movements position_scale: float = 1.0 - rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise def action(self, action: RobotAction) -> RobotAction: diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 3c025a05d..54d0fba69 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -97,8 +97,6 @@ from .gym_manipulator import ( step_env_and_process_transition, ) -ACTOR_SHUTDOWN_TIMEOUT = 30 - # Main entry point diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index b7cfdb30c..d9758d3a3 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -102,8 +102,6 @@ from lerobot.utils.utils import ( from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService -LOG_PREFIX = "[LEARNER]" - @parser.wrap() def train_cli(cfg: TrainRLServerPipelineConfig): diff --git a/src/lerobot/robots/hope_jr/hope_jr_arm.py b/src/lerobot/robots/hope_jr/hope_jr_arm.py index baa36b560..220a29f8c 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_arm.py +++ b/src/lerobot/robots/hope_jr/hope_jr_arm.py @@ -105,7 +105,7 @@ class HopeJrArm(Robot): def is_calibrated(self) -> bool: return self.bus.is_calibrated - def calibrate(self, limb_name: str = None) -> None: + def calibrate(self) -> None: groups = { "all": list(self.bus.motors.keys()), "shoulder": ["shoulder_pitch", "shoulder_yaw", "shoulder_roll"], diff --git a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py index 56686d447..87e832db6 100644 --- a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py +++ b/src/lerobot/robots/so100_follower/robot_kinematic_processor.py @@ -193,16 +193,12 @@ class EEBoundsAndSafety(RobotActionProcessorStep): Attributes: end_effector_bounds: A dictionary with "min" and "max" keys for position clipping. max_ee_step_m: The maximum allowed change in position (in meters) between steps. - max_ee_twist_step_rad: The maximum allowed change in orientation (in radians) between steps. _last_pos: Internal state storing the last commanded position. - _last_twist: Internal state storing the last commanded orientation. """ end_effector_bounds: dict max_ee_step_m: float = 0.05 - max_ee_twist_step_rad: float = 0.20 _last_pos: np.ndarray | None = field(default=None, init=False, repr=False) - _last_twist: np.ndarray | None = field(default=None, init=False, repr=False) def action(self, action: RobotAction) -> RobotAction: x = action["ee.x"] @@ -233,7 +229,6 @@ class EEBoundsAndSafety(RobotActionProcessorStep): raise ValueError(f"EE jump {n:.3f}m > {self.max_ee_step_m}m") self._last_pos = pos - self._last_twist = twist action["ee.x"] = float(pos[0]) action["ee.y"] = float(pos[1]) @@ -246,7 +241,6 @@ class EEBoundsAndSafety(RobotActionProcessorStep): def reset(self): """Resets the last known position and orientation.""" self._last_pos = None - self._last_twist = None def transform_features( self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] diff --git a/src/lerobot/robots/stretch3/configuration_stretch3.py b/src/lerobot/robots/stretch3/configuration_stretch3.py index d4e217ca0..c1226bf90 100644 --- a/src/lerobot/robots/stretch3/configuration_stretch3.py +++ b/src/lerobot/robots/stretch3/configuration_stretch3.py @@ -49,5 +49,3 @@ class Stretch3RobotConfig(RobotConfig): ), } ) - - mock: bool = False diff --git a/src/lerobot/robots/stretch3/robot_stretch3.py b/src/lerobot/robots/stretch3/robot_stretch3.py index 8a0ff5c6a..73df360b2 100644 --- a/src/lerobot/robots/stretch3/robot_stretch3.py +++ b/src/lerobot/robots/stretch3/robot_stretch3.py @@ -164,10 +164,6 @@ class Stretch3Robot(Robot): # TODO(aliberts): return action_sent when motion is limited return action - def print_logs(self) -> None: - pass - # TODO(aliberts): move robot-specific logs logic here - def teleop_safety_stop(self) -> None: if self.teleop is not None: self.teleop._safety_stop(robot=self) diff --git a/src/lerobot/teleoperators/gamepad/gamepad_utils.py b/src/lerobot/teleoperators/gamepad/gamepad_utils.py index d994dadd1..9f94b6746 100644 --- a/src/lerobot/teleoperators/gamepad/gamepad_utils.py +++ b/src/lerobot/teleoperators/gamepad/gamepad_utils.py @@ -52,10 +52,6 @@ class InputController: """Get the current movement deltas (dx, dy, dz) in meters.""" return 0.0, 0.0, 0.0 - def should_quit(self): - """Return True if the user has requested to quit.""" - return not self.running - def update(self): """Update controller state - call this once per frame.""" pass @@ -198,14 +194,6 @@ class KeyboardController(InputController): return delta_x, delta_y, delta_z - def should_quit(self): - """Return True if ESC was pressed.""" - return self.key_states["quit"] - - def should_save(self): - """Return True if Enter was pressed (save episode).""" - return self.key_states["success"] or self.key_states["failure"] - class GamepadController(InputController): """Generate motion deltas from gamepad input.""" @@ -351,8 +339,6 @@ class GamepadControllerHID(InputController): # Button states self.buttons = {} - self.quit_requested = False - self.save_requested = False def find_device(self): """Look for the gamepad device by vendor and product ID.""" @@ -472,11 +458,3 @@ class GamepadControllerHID(InputController): delta_z = -self.right_y * self.z_step_size # Up/down return delta_x, delta_y, delta_z - - def should_quit(self): - """Return True if quit button was pressed.""" - return self.quit_requested - - def should_save(self): - """Return True if save button was pressed.""" - return self.save_requested diff --git a/src/lerobot/teleoperators/keyboard/configuration_keyboard.py b/src/lerobot/teleoperators/keyboard/configuration_keyboard.py index 5d5ef364f..6e070dedd 100644 --- a/src/lerobot/teleoperators/keyboard/configuration_keyboard.py +++ b/src/lerobot/teleoperators/keyboard/configuration_keyboard.py @@ -22,8 +22,9 @@ from ..config import TeleoperatorConfig @TeleoperatorConfig.register_subclass("keyboard") @dataclass class KeyboardTeleopConfig(TeleoperatorConfig): + """KeyboardTeleopConfig""" + # TODO(Steven): Consider setting in here the keys that we want to capture/listen - mock: bool = False @TeleoperatorConfig.register_subclass("keyboard_ee") diff --git a/src/lerobot/teleoperators/stretch3_gamepad/configuration_stretch3.py b/src/lerobot/teleoperators/stretch3_gamepad/configuration_stretch3.py index 507a21589..3af0b5be1 100644 --- a/src/lerobot/teleoperators/stretch3_gamepad/configuration_stretch3.py +++ b/src/lerobot/teleoperators/stretch3_gamepad/configuration_stretch3.py @@ -22,4 +22,4 @@ from ..config import TeleoperatorConfig @TeleoperatorConfig.register_subclass("stretch3") @dataclass class Stretch3GamePadConfig(TeleoperatorConfig): - mock: bool = False + """Stretch3GamePadConfig""" diff --git a/src/lerobot/teleoperators/stretch3_gamepad/stretch3_gamepad.py b/src/lerobot/teleoperators/stretch3_gamepad/stretch3_gamepad.py index 94e1ca7cc..09fdfadd7 100644 --- a/src/lerobot/teleoperators/stretch3_gamepad/stretch3_gamepad.py +++ b/src/lerobot/teleoperators/stretch3_gamepad/stretch3_gamepad.py @@ -112,10 +112,6 @@ class Stretch3GamePad(Teleoperator): def send_feedback(self, feedback: np.ndarray) -> None: pass - def print_logs(self) -> None: - pass - # TODO(aliberts): move robot-specific logs logic here - def disconnect(self) -> None: self.api.stop() self.is_connected = False diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index 337817908..824f74b30 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -33,7 +33,6 @@ TRUNCATED = "next.truncated" DONE = "next.done" ROBOTS = "robots" -ROBOT_TYPE = "robot_type" TELEOPERATORS = "teleoperators" # files & directories diff --git a/src/lerobot/utils/errors.py b/src/lerobot/utils/errors.py index c02d568d4..31b73eaca 100644 --- a/src/lerobot/utils/errors.py +++ b/src/lerobot/utils/errors.py @@ -30,14 +30,3 @@ class DeviceAlreadyConnectedError(ConnectionError): ): self.message = message super().__init__(self.message) - - -class InvalidActionError(ValueError): - """Exception raised when an action is already invalid.""" - - def __init__( - self, - message="The action is invalid. Check the value follows what it is expected from the action space.", - ): - self.message = message - super().__init__(self.message) diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 523a5e4d2..8777d5a9d 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -330,10 +330,6 @@ class TimerManager: def history(self) -> list[float]: return deepcopy(self._history) - @property - def fps_history(self) -> list[float]: - return [1.0 / t for t in self._history] - @property def fps_last(self) -> float: return 0.0 if self.last == 0 else 1.0 / self.last diff --git a/tests/policies/test_sac_config.py b/tests/policies/test_sac_config.py index be6a8d26e..724c331ff 100644 --- a/tests/policies/test_sac_config.py +++ b/tests/policies/test_sac_config.py @@ -69,7 +69,6 @@ def test_sac_config_default_initialization(): # Training parameters assert config.online_steps == 1000000 - assert config.online_env_seed == 10000 assert config.online_buffer_capacity == 100000 assert config.offline_buffer_capacity == 100000 assert config.async_prefetch is False From 90684a9690c1a16e71dc64c7ee89847e5b82df1b Mon Sep 17 00:00:00 2001 From: Qizhi Chen Date: Mon, 29 Sep 2025 17:18:54 +0800 Subject: [PATCH 138/158] Improve V3 aggregate implementation (#2077) * fix return type * improve apply with vertorize op * Update src/lerobot/datasets/aggregate.py Co-authored-by: Michel Aractingi --- src/lerobot/datasets/aggregate.py | 45 +++++++++++-------------- src/lerobot/datasets/lerobot_dataset.py | 6 ++-- 2 files changed, 22 insertions(+), 29 deletions(-) diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 43d4ee233..803645f29 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -93,14 +93,13 @@ def update_data_df(df, src_meta, dst_meta): pd.DataFrame: Updated DataFrame with adjusted indices. """ - def _update(row): - row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"] - row["index"] = row["index"] + dst_meta.info["total_frames"] - task = src_meta.tasks.iloc[row["task_index"]].name - row["task_index"] = dst_meta.tasks.loc[task].task_index.item() - return row + df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"] + df["index"] = df["index"] + dst_meta.info["total_frames"] - return df.apply(_update, axis=1) + src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy()) + df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy() + + return df def update_meta_data( @@ -126,27 +125,21 @@ def update_meta_data( pd.DataFrame: Updated DataFrame with adjusted indices and timestamps. """ - def _update(row): - row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"] - row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"] - row["data/chunk_index"] = row["data/chunk_index"] + data_idx["chunk"] - row["data/file_index"] = row["data/file_index"] + data_idx["file"] - for key, video_idx in videos_idx.items(): - row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + video_idx["chunk"] - row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + video_idx["file"] - row[f"videos/{key}/from_timestamp"] = ( - row[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"] - ) - row[f"videos/{key}/to_timestamp"] = ( - row[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"] - ) + df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"] + df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"] + df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] + df["data/file_index"] = df["data/file_index"] + data_idx["file"] + for key, video_idx in videos_idx.items(): + df[f"videos/{key}/chunk_index"] = df[f"videos/{key}/chunk_index"] + video_idx["chunk"] + df[f"videos/{key}/file_index"] = df[f"videos/{key}/file_index"] + video_idx["file"] + df[f"videos/{key}/from_timestamp"] = df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"] + df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"] - row["dataset_from_index"] = row["dataset_from_index"] + dst_meta.info["total_frames"] - row["dataset_to_index"] = row["dataset_to_index"] + dst_meta.info["total_frames"] - row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"] - return row + df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"] + df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"] + df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"] - return df.apply(_update, axis=1) + return df def aggregate_datasets( diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index b8aa880da..691d86af7 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -1027,7 +1027,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # Reset episode buffer and clean up temporary images (if not already deleted during video encoding) self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0) - def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None): + def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None: """ Batch save videos for multiple episodes. @@ -1153,7 +1153,7 @@ class LeRobotDataset(torch.utils.data.Dataset): } return metadata - def _save_episode_video(self, video_key: str, episode_index: int): + def _save_episode_video(self, video_key: str, episode_index: int) -> dict: # Encode episode frames into a temporary video ep_path = self._encode_temporary_episode_video(video_key, episode_index) ep_size_in_mb = get_video_size_in_mb(ep_path) @@ -1258,7 +1258,7 @@ class LeRobotDataset(torch.utils.data.Dataset): if self.image_writer is not None: self.image_writer.wait_until_done() - def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> dict: + def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path: """ Use ffmpeg to convert frames stored as png into mp4 videos. Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, From c378a325f05a652d0d3808f76d18e9e889de93c2 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 29 Sep 2025 13:28:53 +0200 Subject: [PATCH 139/158] chore: enable pyugrade ruff lint (#2084) --- pyproject.toml | 2 +- src/lerobot/datasets/lerobot_dataset.py | 2 +- src/lerobot/datasets/transforms.py | 2 +- src/lerobot/datasets/utils.py | 14 +++++++------- src/lerobot/envs/libero.py | 2 +- src/lerobot/envs/utils.py | 4 ++-- src/lerobot/motors/motors_bus.py | 8 ++++---- src/lerobot/policies/vqbet/vqbet_utils.py | 8 +++----- src/lerobot/rl/buffer.py | 4 ++-- src/lerobot/rl/wandb_utils.py | 2 +- src/lerobot/scripts/lerobot_record.py | 6 +----- .../teleoperators/homunculus/homunculus_arm.py | 3 +-- .../teleoperators/homunculus/homunculus_glove.py | 3 +-- src/lerobot/utils/transition.py | 2 +- src/lerobot/utils/visualization_utils.py | 2 +- .../policies/save_policy_to_safetensors.py | 6 ++---- tests/processor/test_pipeline.py | 4 ++-- tests/utils/test_replay_buffer.py | 2 +- 18 files changed, 33 insertions(+), 43 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 44e29043b..12bb552fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -201,7 +201,7 @@ exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"] # N: pep8-naming # TODO: Uncomment rules when ready to use select = [ - "E", "W", "F", "I", "B", "C4", "T20", "N" # "SIM", "A", "S", "D", "RUF", "UP" + "E", "W", "F", "I", "B", "C4", "T20", "N", "UP" # "SIM", "A", "S", "D", "RUF" ] ignore = [ "E501", # Line too long diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 691d86af7..b661b21b0 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -1421,7 +1421,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): """Keys to access image and video stream from cameras.""" keys = [] for key, feats in self.features.items(): - if isinstance(feats, (datasets.Image, VideoFrame)): + if isinstance(feats, (datasets.Image | VideoFrame)): keys.append(key) return keys diff --git a/src/lerobot/datasets/transforms.py b/src/lerobot/datasets/transforms.py index f992275b7..f7072c72f 100644 --- a/src/lerobot/datasets/transforms.py +++ b/src/lerobot/datasets/transforms.py @@ -120,7 +120,7 @@ class SharpnessJitter(Transform): self.sharpness = self._check_input(sharpness) def _check_input(self, sharpness): - if isinstance(sharpness, (int, float)): + if isinstance(sharpness, (int | float)): if sharpness < 0: raise ValueError("If sharpness is a single number, it must be non negative.") sharpness = [1.0 - sharpness, 1.0 + sharpness] diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 81b361ab6..a2f285014 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -21,7 +21,7 @@ from collections import deque from collections.abc import Iterable, Iterator from pathlib import Path from pprint import pformat -from typing import Any, Deque, Generic, TypeVar +from typing import Any, Generic, TypeVar import datasets import numpy as np @@ -207,13 +207,13 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: """ serialized_dict = {} for key, value in flatten_dict(stats).items(): - if isinstance(value, (torch.Tensor, np.ndarray)): + if isinstance(value, (torch.Tensor | np.ndarray)): serialized_dict[key] = value.tolist() - elif isinstance(value, list) and isinstance(value[0], (int, float, list)): + elif isinstance(value, list) and isinstance(value[0], (int | float | list)): serialized_dict[key] = value elif isinstance(value, np.generic): serialized_dict[key] = value.item() - elif isinstance(value, (int, float)): + elif isinstance(value, (int | float)): serialized_dict[key] = value else: raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.") @@ -1179,7 +1179,7 @@ def item_to_torch(item: dict) -> dict: dict: Dictionary with all tensor-like items converted to torch.Tensor. """ for key, val in item.items(): - if isinstance(val, (np.ndarray, list)) and key not in ["task"]: + if isinstance(val, (np.ndarray | list)) and key not in ["task"]: # Convert numpy arrays and lists to torch tensors item[key] = torch.tensor(val) return item @@ -1253,8 +1253,8 @@ class Backtrackable(Generic[T]): raise ValueError("lookahead must be > 0") self._source: Iterator[T] = iter(iterable) - self._back_buf: Deque[T] = deque(maxlen=history) - self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque() + self._back_buf: deque[T] = deque(maxlen=history) + self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque() self._cursor: int = 0 self._history = history self._lookahead = lookahead diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 466796975..99ec6712f 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -35,7 +35,7 @@ def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]: """Normalize camera_name into a non-empty list of strings.""" if isinstance(camera_name, str): cams = [c.strip() for c in camera_name.split(",") if c.strip()] - elif isinstance(camera_name, (list, tuple)): + elif isinstance(camera_name, (list | tuple)): cams = [str(c).strip() for c in camera_name if str(c).strip()] else: raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}") diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index 023ceea67..b5cfc7e26 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -183,10 +183,10 @@ def _(env: Mapping) -> None: @close_envs.register def _(envs: Sequence) -> None: - if isinstance(envs, (str, bytes)): + if isinstance(envs, (str | bytes)): return for v in envs: - if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str, bytes)): + if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str | bytes)): close_envs(v) elif hasattr(v, "close"): _close_single_env(v) diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index 8603d81a9..17eaa8063 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -342,7 +342,7 @@ class MotorsBus(abc.ABC): raise TypeError(motors) def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]: - if isinstance(values, (int, float)): + if isinstance(values, (int | float)): return dict.fromkeys(self.ids, values) elif isinstance(values, dict): return {self.motors[motor].id: val for motor, val in values.items()} @@ -669,7 +669,7 @@ class MotorsBus(abc.ABC): """ if motors is None: motors = list(self.motors) - elif isinstance(motors, (str, int)): + elif isinstance(motors, (str | int)): motors = [motors] elif not isinstance(motors, list): raise TypeError(motors) @@ -697,7 +697,7 @@ class MotorsBus(abc.ABC): """ if motors is None: motors = list(self.motors) - elif isinstance(motors, (str, int)): + elif isinstance(motors, (str | int)): motors = [motors] elif not isinstance(motors, list): raise TypeError(motors) @@ -733,7 +733,7 @@ class MotorsBus(abc.ABC): """ if motors is None: motors = list(self.motors) - elif isinstance(motors, (str, int)): + elif isinstance(motors, (str | int)): motors = [motors] elif not isinstance(motors, list): raise TypeError(motors) diff --git a/src/lerobot/policies/vqbet/vqbet_utils.py b/src/lerobot/policies/vqbet/vqbet_utils.py index 44b7d5f0b..7b13577f6 100644 --- a/src/lerobot/policies/vqbet/vqbet_utils.py +++ b/src/lerobot/policies/vqbet/vqbet_utils.py @@ -260,13 +260,11 @@ class GPT(nn.Module): param_dict = dict(self.named_parameters()) inter_params = decay & no_decay union_params = decay | no_decay - assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( - str(inter_params) + assert len(inter_params) == 0, ( + f"parameters {str(inter_params)} made it into both decay/no_decay sets!" ) assert len(param_dict.keys() - union_params) == 0, ( - "parameters {} were not separated into either decay/no_decay set!".format( - str(param_dict.keys() - union_params), - ) + f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" ) decay = [param_dict[pn] for pn in sorted(decay)] diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index d30b65082..917e4e2cc 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -176,7 +176,7 @@ class ReplayBuffer: self.complementary_info[key] = torch.empty( (self.capacity, *value_shape), device=self.storage_device ) - elif isinstance(value, (int, float)): + elif isinstance(value, (int | float)): # Handle scalar values similar to reward self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device) else: @@ -223,7 +223,7 @@ class ReplayBuffer: value = complementary_info[key] if isinstance(value, torch.Tensor): self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) - elif isinstance(value, (int, float)): + elif isinstance(value, (int | float)): self.complementary_info[key][self.position] = value self.position = (self.position + 1) % self.capacity diff --git a/src/lerobot/rl/wandb_utils.py b/src/lerobot/rl/wandb_utils.py index b13254421..01cef9487 100644 --- a/src/lerobot/rl/wandb_utils.py +++ b/src/lerobot/rl/wandb_utils.py @@ -137,7 +137,7 @@ class WandBLogger: self._wandb.define_metric(new_custom_key, hidden=True) for k, v in d.items(): - if not isinstance(v, (int, float, str)): + if not isinstance(v, (int | float | str)): logging.warning( f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.' ) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index d097a9d2f..ddb21e917 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -267,11 +267,7 @@ def record_loop( for t in teleop if isinstance( t, - ( - so100_leader.SO100Leader, - so101_leader.SO101Leader, - koch_leader.KochLeader, - ), + (so100_leader.SO100Leader | so101_leader.SO101Leader | koch_leader.KochLeader), ) ), None, diff --git a/src/lerobot/teleoperators/homunculus/homunculus_arm.py b/src/lerobot/teleoperators/homunculus/homunculus_arm.py index 4eca4b9e2..21d73de2e 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_arm.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_arm.py @@ -18,7 +18,6 @@ import logging import threading from collections import deque from pprint import pformat -from typing import Deque import serial @@ -60,7 +59,7 @@ class HomunculusArm(Teleoperator): self.n: int = n self.alpha: float = 2 / (n + 1) # one deque *per joint* so we can inspect raw history if needed - self._buffers: dict[str, Deque[int]] = { + self._buffers: dict[str, deque[int]] = { joint: deque(maxlen=n) for joint in ( "shoulder_pitch", diff --git a/src/lerobot/teleoperators/homunculus/homunculus_glove.py b/src/lerobot/teleoperators/homunculus/homunculus_glove.py index 52fd19def..251ecf56d 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_glove.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_glove.py @@ -18,7 +18,6 @@ import logging import threading from collections import deque from pprint import pformat -from typing import Deque import serial @@ -97,7 +96,7 @@ class HomunculusGlove(Teleoperator): self.n: int = n self.alpha: float = 2 / (n + 1) # one deque *per joint* so we can inspect raw history if needed - self._buffers: dict[str, Deque[int]] = {joint: deque(maxlen=n) for joint in self.joints} + self._buffers: dict[str, deque[int]] = {joint: deque(maxlen=n) for joint in self.joints} # running EMA value per joint – lazily initialised on first read self._ema: dict[str, float | None] = dict.fromkeys(self._buffers) diff --git a/src/lerobot/utils/transition.py b/src/lerobot/utils/transition.py index e874bd096..fe3620861 100644 --- a/src/lerobot/utils/transition.py +++ b/src/lerobot/utils/transition.py @@ -63,7 +63,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr for key, val in transition["complementary_info"].items(): if isinstance(val, torch.Tensor): transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) - elif isinstance(val, (int, float, bool)): + elif isinstance(val, (int | float | bool)): transition["complementary_info"][key] = torch.tensor(val, device=device) else: raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index ae070b7c4..d0201ecbf 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -35,7 +35,7 @@ def _is_scalar(x): return ( isinstance(x, float) or isinstance(x, numbers.Real) - or isinstance(x, (np.integer, np.floating)) + or isinstance(x, (np.integer | np.floating)) or (isinstance(x, np.ndarray) and x.ndim == 0) ) diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index e130ae144..64b125cc9 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -66,15 +66,13 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): for key, param in policy.named_parameters(): if param.requires_grad: grad_stats[f"{key}_mean"] = param.grad.mean() - grad_stats[f"{key}_std"] = ( - param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0)) - ) + grad_stats[f"{key}_std"] = param.grad.std() if param.grad.numel() > 1 else torch.tensor(0.0) optimizer.step() param_stats = {} for key, param in policy.named_parameters(): param_stats[f"{key}_mean"] = param.mean() - param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0)) + param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(0.0) optimizer.zero_grad() policy.reset() diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 904fd6fc1..76f2b1c26 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -770,7 +770,7 @@ class MockStepWithNonSerializableParam(ProcessorStep): # Add type validation for multiplier if isinstance(multiplier, str): raise ValueError(f"multiplier must be a number, got string '{multiplier}'") - if not isinstance(multiplier, (int, float)): + if not isinstance(multiplier, (int | float)): raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}") self.multiplier = float(multiplier) self.env = env # Non-serializable parameter (like gym.Env) @@ -1623,7 +1623,7 @@ def test_override_with_callables(): # Define a transform function def double_values(x): - if isinstance(x, (int, float)): + if isinstance(x, (int | float)): return x * 2 elif isinstance(x, torch.Tensor): return x * 2 diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index ddf0771f1..b9d3a1ac0 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -121,7 +121,7 @@ def get_tensors_memory_consumption(obj, visited_addresses): if isinstance(obj, torch.Tensor): return get_tensor_memory_consumption(obj) - elif isinstance(obj, (list, tuple)): + elif isinstance(obj, (list | tuple)): for item in obj: total_size += get_tensors_memory_consumption(item, visited_addresses) elif isinstance(obj, dict): From bbcf66bd829e7e53e85c623c555a44102cf3cc16 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 29 Sep 2025 15:06:56 +0200 Subject: [PATCH 140/158] chore: enable simplify in ruff lint (#2085) --- pyproject.toml | 2 +- src/lerobot/datasets/video_utils.py | 4 +++- src/lerobot/policies/act/modeling_act.py | 5 +---- src/lerobot/processor/hil_processor.py | 2 +- src/lerobot/processor/normalize_processor.py | 13 ++++++------- src/lerobot/processor/observation_processor.py | 2 +- src/lerobot/utils/visualization_utils.py | 7 ++----- tests/conftest.py | 4 ++-- tests/datasets/test_datasets.py | 2 +- tests/processor/test_pipeline.py | 11 ++++------- tests/processor/test_tokenizer_processor.py | 5 +---- tests/robots/test_reachy2.py | 10 +++++----- tests/teleoperators/test_reachy2_teleoperator.py | 8 ++++---- 13 files changed, 32 insertions(+), 43 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 12bb552fa..8bbd998ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -201,7 +201,7 @@ exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"] # N: pep8-naming # TODO: Uncomment rules when ready to use select = [ - "E", "W", "F", "I", "B", "C4", "T20", "N", "UP" # "SIM", "A", "S", "D", "RUF" + "E", "W", "F", "I", "B", "C4", "T20", "N", "UP", "SIM" #, "A", "S", "D", "RUF" ] ignore = [ "E501", # Line too long diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 5f8b207e0..2c0e116cb 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -437,7 +437,9 @@ def concatenate_video_files( tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"} ) # safe = 0 allows absolute paths as well as relative paths - tmp_output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file: + tmp_output_video_path = tmp_named_file.name + output_container = av.open( tmp_output_video_path, mode="w", options={"movflags": "faststart"} ) # faststart is to move the metadata to the beginning of the file to speed up loading diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index e987f9070..4d2890ba6 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -398,10 +398,7 @@ class ACT(nn.Module): "actions must be provided when using the variational objective in training mode." ) - if OBS_IMAGES in batch: - batch_size = batch[OBS_IMAGES][0].shape[0] - else: - batch_size = batch[OBS_ENV_STATE].shape[0] + batch_size = batch[OBS_IMAGES][0].shape[0] if OBS_IMAGES in batch else batch[OBS_ENV_STATE].shape[0] # Prepare the latent for input to the transformer encoder. if self.config.use_vae and ACTION in batch and self.training: diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index 47f69a973..f0dbac9c3 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -340,7 +340,7 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep): """ action = self.transition.get(TransitionKey.ACTION) - raw_joint_positions = complementary_data.get("raw_joint_positions", None) + raw_joint_positions = complementary_data.get("raw_joint_positions") if raw_joint_positions is None: return complementary_data diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index c4ded722f..ce69a103f 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -119,13 +119,12 @@ class _NormalizationMixin: ) self.features = reconstructed - if self.norm_map: - # if keys are strings (JSON), rebuild enum map - if all(isinstance(k, str) for k in self.norm_map.keys()): - reconstructed = {} - for ft_type_str, norm_mode_str in self.norm_map.items(): - reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) - self.norm_map = reconstructed + # if keys are strings (JSON), rebuild enum map + if self.norm_map and all(isinstance(k, str) for k in self.norm_map): + reconstructed = {} + for ft_type_str, norm_mode_str in self.norm_map.items(): + reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) + self.norm_map = reconstructed # Convert stats to tensors and move to the target device once during initialization. self.stats = self.stats or {} diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index 486218157..d22d8fb96 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -152,7 +152,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep): """ # Build a new features mapping keyed by the same FeatureType buckets # We assume callers already placed features in the correct FeatureType. - new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features.keys()} + new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features} exact_pairs = { "pixels": OBS_IMAGE, diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index d0201ecbf..95fdb178a 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -32,11 +32,8 @@ def init_rerun(session_name: str = "lerobot_control_loop") -> None: def _is_scalar(x): - return ( - isinstance(x, float) - or isinstance(x, numbers.Real) - or isinstance(x, (np.integer | np.floating)) - or (isinstance(x, np.ndarray) and x.ndim == 0) + return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or ( + isinstance(x, np.ndarray) and x.ndim == 0 ) diff --git a/tests/conftest.py b/tests/conftest.py index 245cde526..b14e9aed5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -85,7 +85,7 @@ def policy_feature_factory(): def assert_contract_is_typed(features: dict[PipelineFeatureType, dict[str, PolicyFeature]]) -> None: assert isinstance(features, dict) - assert all(isinstance(k, PipelineFeatureType) for k in features.keys()) + assert all(isinstance(k, PipelineFeatureType) for k in features) assert all(isinstance(v, dict) for v in features.values()) - assert all(all(isinstance(nk, str) for nk in v.keys()) for v in features.values()) + assert all(all(isinstance(nk, str) for nk in v) for v in features.values()) assert all(all(isinstance(nv, PolicyFeature) for nv in v.values()) for v in features.values()) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index b9e966fe6..2bc3bea43 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -949,7 +949,7 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory) # Check that statistics exist for all features assert loaded_dataset.meta.stats is not None, "No statistics found" - for feature_name in features.keys(): + for feature_name in features: assert feature_name in loaded_dataset.meta.stats, f"No statistics for feature '{feature_name}'" feature_stats = loaded_dataset.meta.stats[feature_name] diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 76f2b1c26..134228c05 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -246,7 +246,7 @@ def test_step_through(): # Ensure all results are dicts (same format as input) for result in results: assert isinstance(result, dict) - assert all(isinstance(k, TransitionKey) for k in result.keys()) + assert all(isinstance(k, TransitionKey) for k in result) def test_step_through_with_dict(): @@ -1623,9 +1623,7 @@ def test_override_with_callables(): # Define a transform function def double_values(x): - if isinstance(x, (int | float)): - return x * 2 - elif isinstance(x, torch.Tensor): + if isinstance(x, (int | float | torch.Tensor)): return x * 2 return x @@ -1797,10 +1795,9 @@ def test_from_pretrained_nonexistent_path(): ) # Test with a local directory that exists but has no config files - with tempfile.TemporaryDirectory() as tmp_dir: + with tempfile.TemporaryDirectory() as tmp_dir, pytest.raises(FileNotFoundError): # Since the directory exists but has no config, it will raise FileNotFoundError - with pytest.raises(FileNotFoundError): - DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="processor.json") + DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="processor.json") def test_save_load_with_custom_converter_functions(): diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index 503f2e036..b81710db1 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -32,10 +32,7 @@ class MockTokenizer: **kwargs, ) -> dict[str, torch.Tensor]: """Mock tokenization that returns deterministic tokens based on text.""" - if isinstance(text, str): - texts = [text] - else: - texts = text + texts = [text] if isinstance(text, str) else text batch_size = len(texts) diff --git a/tests/robots/test_reachy2.py b/tests/robots/test_reachy2.py index c93fbeced..94152ea38 100644 --- a/tests/robots/test_reachy2.py +++ b/tests/robots/test_reachy2.py @@ -245,14 +245,14 @@ def test_get_observation(reachy2): obs = reachy2.get_observation() expected_keys = set(reachy2.joints_dict) - expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base) + expected_keys.update(f"{v}" for v in REACHY2_VEL if reachy2.config.with_mobile_base) expected_keys.update(reachy2.cameras.keys()) assert set(obs.keys()) == expected_keys - for motor in reachy2.joints_dict.keys(): + for motor in reachy2.joints_dict: assert obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position if reachy2.config.with_mobile_base: - for vel in REACHY2_VEL.keys(): + for vel in REACHY2_VEL: assert obs[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]] if reachy2.config.with_left_teleop_camera: assert obs["teleop_left"].shape == ( @@ -282,7 +282,7 @@ def test_send_action(reachy2): action.update({k: i * 0.1 for i, k in enumerate(REACHY2_VEL.keys(), start=1)}) previous_present_position = { - k: reachy2.reachy.joints[REACHY2_JOINTS[k]].present_position for k in reachy2.joints_dict.keys() + k: reachy2.reachy.joints[REACHY2_JOINTS[k]].present_position for k in reachy2.joints_dict } returned = reachy2.send_action(action) @@ -290,7 +290,7 @@ def test_send_action(reachy2): assert returned == action assert reachy2.reachy._goal_position_set_total == len(reachy2.joints_dict) - for motor in reachy2.joints_dict.keys(): + for motor in reachy2.joints_dict: expected_pos = action[motor] real_pos = reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position if reachy2.config.max_relative_target is None: diff --git a/tests/teleoperators/test_reachy2_teleoperator.py b/tests/teleoperators/test_reachy2_teleoperator.py index 5130de87d..dd8c5904c 100644 --- a/tests/teleoperators/test_reachy2_teleoperator.py +++ b/tests/teleoperators/test_reachy2_teleoperator.py @@ -121,20 +121,20 @@ def test_get_action(reachy2): action = reachy2.get_action() expected_keys = set(reachy2.joints_dict) - expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base) + expected_keys.update(f"{v}" for v in REACHY2_VEL if reachy2.config.with_mobile_base) assert set(action.keys()) == expected_keys - for motor in reachy2.joints_dict.keys(): + for motor in reachy2.joints_dict: if reachy2.config.use_present_position: assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position else: assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position if reachy2.config.with_mobile_base: if reachy2.config.use_present_position: - for vel in REACHY2_VEL.keys(): + for vel in REACHY2_VEL: assert action[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]] else: - for vel in REACHY2_VEL.keys(): + for vel in REACHY2_VEL: assert action[vel] == reachy2.reachy.mobile_base.last_cmd_vel[REACHY2_VEL[vel]] From f173265354166547414ff922faa7b014de761481 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 29 Sep 2025 16:02:15 +0200 Subject: [PATCH 141/158] feat(normalization): add validation for empty features in NormalizerProcessorStep and UnnormalizerProcessorStep (#2087) * feat(normalization): add validation for empty features in NormalizerProcessorStep and UnnormalizerProcessorStep * refactor(normalization): streamline feature reconstruction logic in _NormalizationMixin * refactor(tests): remove unused preprocessor initialization in test_act_backbone_lr --------- Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- src/lerobot/processor/normalize_processor.py | 20 +++++++++++--------- tests/policies/test_policies.py | 1 - tests/processor/test_normalize_processor.py | 12 ++++++++++++ 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index ce69a103f..885911ff0 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -108,16 +108,18 @@ class _NormalizationMixin: """ # Track if stats were explicitly provided (not None and not empty) self._stats_explicitly_provided = self.stats is not None and bool(self.stats) + # Check if self.features is not empty + if not self.features: + raise ValueError("Normalization features cannot be empty") # Robust JSON deserialization handling (guard empty maps). - if self.features: - first_val = next(iter(self.features.values())) - if isinstance(first_val, dict): - reconstructed = {} - for key, ft_dict in self.features.items(): - reconstructed[key] = PolicyFeature( - type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) - ) - self.features = reconstructed + first_val = next(iter(self.features.values())) + if isinstance(first_val, dict): + reconstructed = {} + for key, ft_dict in self.features.items(): + reconstructed[key] = PolicyFeature( + type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) + ) + self.features = reconstructed # if keys are strings (JSON), rebuild enum map if self.norm_map and all(isinstance(k, str) for k in self.norm_map): diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 34fa89390..07e80d59f 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -234,7 +234,6 @@ def test_act_backbone_lr(): assert cfg.policy.optimizer_lr_backbone == 0.001 dataset = make_dataset(cfg) - preprocessor, _ = make_pre_post_processors(cfg.policy, None) policy = make_policy(cfg.policy, ds_meta=dataset.meta) optimizer, _ = make_optimizer_and_scheduler(cfg, policy) assert len(optimizer.param_groups) == 2 diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 98c9e0b23..80ac58dfc 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -534,6 +534,18 @@ def test_empty_observation(): assert result == transition +def test_empty_features_raises_error(): + """Test that empty features dict raises ValueError.""" + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + stats = {OBS_IMAGE: {"mean": [0.5], "std": [0.2]}} + + with pytest.raises(ValueError, match="Normalization features cannot be empty"): + NormalizerProcessorStep(features={}, norm_map=norm_map, stats=stats) + + with pytest.raises(ValueError, match="Normalization features cannot be empty"): + UnnormalizerProcessorStep(features={}, norm_map=norm_map, stats=stats) + + def test_empty_stats(): features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} From 2d3a605b3c9e70f833e3cca791152efce052a321 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 29 Sep 2025 16:55:52 +0200 Subject: [PATCH 142/158] Revert feat(normalization): add validation for empty features in NormalizerProcessorStep and UnnormalizerProcessorStep (#2087) Revert "feat(normalization): add validation for empty features in NormalizerProcessorStep and UnnormalizerProcessorStep (#2087)" This reverts commit f173265354166547414ff922faa7b014de761481. --- src/lerobot/processor/normalize_processor.py | 20 +++++++++----------- tests/policies/test_policies.py | 1 + tests/processor/test_normalize_processor.py | 12 ------------ 3 files changed, 10 insertions(+), 23 deletions(-) diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 885911ff0..ce69a103f 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -108,18 +108,16 @@ class _NormalizationMixin: """ # Track if stats were explicitly provided (not None and not empty) self._stats_explicitly_provided = self.stats is not None and bool(self.stats) - # Check if self.features is not empty - if not self.features: - raise ValueError("Normalization features cannot be empty") # Robust JSON deserialization handling (guard empty maps). - first_val = next(iter(self.features.values())) - if isinstance(first_val, dict): - reconstructed = {} - for key, ft_dict in self.features.items(): - reconstructed[key] = PolicyFeature( - type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) - ) - self.features = reconstructed + if self.features: + first_val = next(iter(self.features.values())) + if isinstance(first_val, dict): + reconstructed = {} + for key, ft_dict in self.features.items(): + reconstructed[key] = PolicyFeature( + type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) + ) + self.features = reconstructed # if keys are strings (JSON), rebuild enum map if self.norm_map and all(isinstance(k, str) for k in self.norm_map): diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 07e80d59f..34fa89390 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -234,6 +234,7 @@ def test_act_backbone_lr(): assert cfg.policy.optimizer_lr_backbone == 0.001 dataset = make_dataset(cfg) + preprocessor, _ = make_pre_post_processors(cfg.policy, None) policy = make_policy(cfg.policy, ds_meta=dataset.meta) optimizer, _ = make_optimizer_and_scheduler(cfg, policy) assert len(optimizer.param_groups) == 2 diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 80ac58dfc..98c9e0b23 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -534,18 +534,6 @@ def test_empty_observation(): assert result == transition -def test_empty_features_raises_error(): - """Test that empty features dict raises ValueError.""" - norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} - stats = {OBS_IMAGE: {"mean": [0.5], "std": [0.2]}} - - with pytest.raises(ValueError, match="Normalization features cannot be empty"): - NormalizerProcessorStep(features={}, norm_map=norm_map, stats=stats) - - with pytest.raises(ValueError, match="Normalization features cannot be empty"): - UnnormalizerProcessorStep(features={}, norm_map=norm_map, stats=stats) - - def test_empty_stats(): features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} From 1ad2da403d5526c5d8933d805c726acc98ec561b Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 29 Sep 2025 17:02:19 +0200 Subject: [PATCH 143/158] feat(policies): add noise parameter to action prediction methods (#2063) * feat(policies): add noise parameter to action prediction methods - Introduced `ActionSelectKwargs` TypedDict for better type hinting. - Updated `predict_action_chunk` and `select_action` methods in `PreTrainedPolicy` and its subclasses to accept a `noise` parameter. - Modified `generate_actions` and `conditional_sample` methods in `DiffusionModel` to utilize the new noise parameter for action generation. * refactor(policies): make ActionSelectKwargs TypedDict fields optional - Updated `ActionSelectKwargs` to inherit with `total=False`, allowing for optional fields. --- .../policies/diffusion/modeling_diffusion.py | 32 ++++++++++++------- src/lerobot/policies/pi0/modeling_pi0.py | 2 +- src/lerobot/policies/pretrained.py | 11 +++++-- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index ad808d7c7..3ab6719cb 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -90,16 +90,16 @@ class DiffusionPolicy(PreTrainedPolicy): self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps) @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """Predict a chunk of actions given environment observations.""" # stack n latest observations from the queue batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} - actions = self.diffusion.generate_actions(batch) + actions = self.diffusion.generate_actions(batch, noise=noise) return actions @torch.no_grad() - def select_action(self, batch: dict[str, Tensor]) -> Tensor: + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """Select a single action given environment observations. This method handles caching a history of observations and an action trajectory generated by the @@ -131,7 +131,7 @@ class DiffusionPolicy(PreTrainedPolicy): self._queues = populate_queues(self._queues, batch) if len(self._queues[ACTION]) == 0: - actions = self.predict_action_chunk(batch) + actions = self.predict_action_chunk(batch, noise=noise) self._queues[ACTION].extend(actions.transpose(0, 1)) action = self._queues[ACTION].popleft() @@ -199,17 +199,25 @@ class DiffusionModel(nn.Module): # ========= inference ============ def conditional_sample( - self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None + self, + batch_size: int, + global_cond: Tensor | None = None, + generator: torch.Generator | None = None, + noise: Tensor | None = None, ) -> Tensor: device = get_device_from_parameters(self) dtype = get_dtype_from_parameters(self) # Sample prior. - sample = torch.randn( - size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]), - dtype=dtype, - device=device, - generator=generator, + sample = ( + noise + if noise is not None + else torch.randn( + size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]), + dtype=dtype, + device=device, + generator=generator, + ) ) self.noise_scheduler.set_timesteps(self.num_inference_steps) @@ -264,7 +272,7 @@ class DiffusionModel(nn.Module): # Concatenate features then flatten to (B, global_cond_dim). return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1) - def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: + def generate_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """ This function expects `batch` to have: { @@ -282,7 +290,7 @@ class DiffusionModel(nn.Module): global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim) # run sampling - actions = self.conditional_sample(batch_size, global_cond=global_cond) + actions = self.conditional_sample(batch_size, global_cond=global_cond, noise=noise) # Extract `n_action_steps` steps worth of actions (from the current observation). start = n_obs_steps - 1 diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 4d3f4ffa1..8406f94fe 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -253,7 +253,7 @@ class PI0Policy(PreTrainedPolicy): return super().from_pretrained(*args, **kwargs) @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """Predict a chunk of actions given environment observations.""" raise NotImplementedError("Currently not implemented for PI0") diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index b770c980b..3f5d89ec5 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -18,7 +18,7 @@ import os from importlib.resources import files from pathlib import Path from tempfile import TemporaryDirectory -from typing import TypeVar +from typing import TypedDict, TypeVar import packaging import safetensors @@ -27,6 +27,7 @@ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.errors import HfHubHTTPError from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor from torch import Tensor, nn +from typing_extensions import Unpack from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.train import TrainPipelineConfig @@ -36,6 +37,10 @@ from lerobot.utils.hub import HubMixin T = TypeVar("T", bound="PreTrainedPolicy") +class ActionSelectKwargs(TypedDict, total=False): + noise: Tensor | None + + class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): """ Base class for policy models. @@ -181,7 +186,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): raise NotImplementedError @abc.abstractmethod - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor: """Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode. Child classes using action chunking should use this method within `select_action` to form the action chunk @@ -190,7 +195,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): raise NotImplementedError @abc.abstractmethod - def select_action(self, batch: dict[str, Tensor]) -> Tensor: + def select_action(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor: """Return one action to run in the environment (potentially in batch mode). When the model uses a history of observations, or outputs a sequence of actions, this method deals From a0d7627d81c57f9ca826029e581deca313e0d548 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 29 Sep 2025 17:37:26 +0200 Subject: [PATCH 144/158] feat(train): include input and output features in processor overrides for normalization (#2088) (#2090) Signed-off-by: AdilZouitine --- src/lerobot/scripts/lerobot_train.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 86b2bbae5..12a1f53c7 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -187,10 +187,16 @@ def train(cfg: TrainPipelineConfig): if cfg.policy.pretrained_path is not None: processor_kwargs["preprocessor_overrides"] = { "device_processor": {"device": device.type}, - "normalizer_processor": {"stats": dataset.meta.stats}, + "normalizer_processor": { + "stats": dataset.meta.stats, + "features": {**policy.config.input_features, **policy.config.output_features}, + }, } processor_kwargs["postprocessor_overrides"] = { - "unnormalizer_processor": {"stats": dataset.meta.stats}, + "unnormalizer_processor": { + "stats": dataset.meta.stats, + "features": {**policy.config.input_features, **policy.config.output_features}, + }, } preprocessor, postprocessor = make_pre_post_processors( From 50977a2c280b9115f1f1d3003b850c15952876f0 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Wed, 1 Oct 2025 11:03:52 +0200 Subject: [PATCH 145/158] fix(video_path): setting video_path to None during conversion for images datasets (#2095) --- src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py index ac9d41cf7..03d135d7c 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -413,7 +413,7 @@ def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb): info["data_files_size_in_mb"] = data_file_size_in_mb info["video_files_size_in_mb"] = video_file_size_in_mb info["data_path"] = DEFAULT_DATA_PATH - info["video_path"] = DEFAULT_VIDEO_PATH + info["video_path"] = DEFAULT_VIDEO_PATH if info["video_path"] is not None else None info["fps"] = int(info["fps"]) logging.info(f"Converting info from {root} to {new_root}") for key in info["features"]: From 5dfdec92887c90f2cb9213a7dada046a5291c33e Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Wed, 1 Oct 2025 13:19:51 +0200 Subject: [PATCH 146/158] feat(mypy): enable type checking for envs module and configure mypy settings in pyproject.toml (#2099) * feat(mypy): enable type checking for envs module and configure mypy settings in pyproject.toml * Add mypy configuration to check only the envs module. * Exclude examples, benchmarks, and tests from type checking. * Set ignore_missing_imports to true and follow_imports to skip. * chore: comment out mypy configuration in pyproject.toml and pre-commit-config.yaml * Comment out mypy settings to disable type checking for the envs module. * Update pre-commit configuration to reflect changes in mypy settings. --- .pre-commit-config.yaml | 3 ++- pyproject.toml | 45 +++++++++++++++++++++++++---------------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f09017991..d15ecb8c6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -90,7 +90,8 @@ repos: # rev: v1.16.0 # hooks: # - id: mypy - # args: [--python-version=3.10] + # args: [--config-file=pyproject.toml] + # exclude: ^(examples|benchmarks|tests)/ ##### Docstring Checks ##### # - repo: https://github.com/akaihola/darglint2 diff --git a/pyproject.toml b/pyproject.toml index 8bbd998ab..e96555607 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -272,9 +272,19 @@ default.extend-ignore-identifiers-re = [ # [tool.mypy] # python_version = "3.10" +# Exclude examples, benchmarks, and tests from type checking +# exclude = [ +# "examples/", +# "benchmarks/", +# "tests/", +# ] +# Ignore missing imports for third-party libraries without stubs +# ignore_missing_imports = true +# Don't follow imports - only check files explicitly passed to mypy +# follow_imports = "skip" +# Gradual typing - start lenient, enable these as code improves: # warn_return_any = true # warn_unused_configs = true -# ignore_missing_imports = false # strict = true # disallow_untyped_defs = true # disallow_incomplete_defs = true @@ -282,68 +292,69 @@ default.extend-ignore-identifiers-re = [ # [[tool.mypy.overrides]] # module = "lerobot.utils.*" -# # include = "src/lerobot/utils/**/*.py" +# follow_imports = "normal" # [[tool.mypy.overrides]] # module = "lerobot.configs.*" -# # include = "src/lerobot/configs/**/*.py" +# follow_imports = "normal" # # Data processing modules # [[tool.mypy.overrides]] # module = "lerobot.processor.*" -# # include = "src/lerobot/processor/**/*.py" +# follow_imports = "normal" # [[tool.mypy.overrides]] # module = "lerobot.datasets.*" -# # include = "src/lerobot/datasets/**/*.py" +# follow_imports = "normal" # # Core machine learning modules # [[tool.mypy.overrides]] # module = "lerobot.optim.*" -# # include = "src/lerobot/optim/**/*.py" +# follow_imports = "normal" # [[tool.mypy.overrides]] # module = "lerobot.model.*" -# # include = "src/lerobot/model/**/*.py" +# follow_imports = "normal" # # Hardware interfaces # [[tool.mypy.overrides]] # module = "lerobot.cameras.*" -# # include = "src/lerobot/cameras/**/*.py" +# follow_imports = "normal" # [[tool.mypy.overrides]] # module = "lerobot.motors.*" -# # include = "src/lerobot/motors/**/*.py" +# follow_imports = "normal" # [[tool.mypy.overrides]] # module = "lerobot.robots.*" -# # include = "src/lerobot/robots/**/*.py" +# follow_imports = "normal" # [[tool.mypy.overrides]] # module = "lerobot.teleoperators.*" -# # include = "src/lerobot/teleoperators/**/*.py" +# follow_imports = "normal" # # Complex modules (enable these last) # [[tool.mypy.overrides]] # module = "lerobot.policies.*" -# # include = "src/lerobot/policies/**/*.py" +# follow_imports = "normal" # [[tool.mypy.overrides]] # module = "lerobot.rl.*" -# # include = "src/lerobot/rl/**/*.py" +# follow_imports = "normal" +# Currently checking only the envs module # [[tool.mypy.overrides]] # module = "lerobot.envs.*" -# # include = "src/lerobot/envs/**/*.py" +# follow_imports = "normal" # [[tool.mypy.overrides]] # module = "lerobot.async_inference.*" -# # include = "src/lerobot/async_inference/**/*.py" +# follow_imports = "normal" # [[tool.mypy.overrides]] # module = "lerobot.transport.*" -# # include = "src/lerobot/transport/**/*.py" +# follow_imports = "normal" # [[tool.mypy.overrides]] # module = "lerobot.scripts.*" -# # include = "src/lerobot/scripts/**/*.py" +# follow_imports = "normal" From 6d331310ab43e67de96d05d301ba47eb5b8e0af0 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Wed, 1 Oct 2025 15:14:41 +0200 Subject: [PATCH 147/158] feat(mypy): configure mypy settings and add module overrides for gradual typing (#2101) --- pyproject.toml | 71 +++++++++++++++++++++++--------------------------- 1 file changed, 32 insertions(+), 39 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e96555607..bc28511ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -272,17 +272,8 @@ default.extend-ignore-identifiers-re = [ # [tool.mypy] # python_version = "3.10" -# Exclude examples, benchmarks, and tests from type checking -# exclude = [ -# "examples/", -# "benchmarks/", -# "tests/", -# ] -# Ignore missing imports for third-party libraries without stubs # ignore_missing_imports = true -# Don't follow imports - only check files explicitly passed to mypy # follow_imports = "skip" -# Gradual typing - start lenient, enable these as code improves: # warn_return_any = true # warn_unused_configs = true # strict = true @@ -290,71 +281,73 @@ default.extend-ignore-identifiers-re = [ # disallow_incomplete_defs = true # check_untyped_defs = true +# [[tool.mypy.overrides]] +# module = "lerobot.*" +# ignore_errors = true + +# [[tool.mypy.overrides]] +# module = "lerobot.envs.*" +# # Enable type checking only for the envs module +# ignore_errors = false + + # [[tool.mypy.overrides]] # module = "lerobot.utils.*" -# follow_imports = "normal" +# ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.configs.*" -# follow_imports = "normal" +# ignore_errors = false -# # Data processing modules -# [[tool.mypy.overrides]] -# module = "lerobot.processor.*" -# follow_imports = "normal" - -# [[tool.mypy.overrides]] -# module = "lerobot.datasets.*" -# follow_imports = "normal" - -# # Core machine learning modules +# PHASE 2: Core modules # [[tool.mypy.overrides]] # module = "lerobot.optim.*" -# follow_imports = "normal" +# ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.model.*" -# follow_imports = "normal" +# ignore_errors = false + +# [[tool.mypy.overrides]] +# module = "lerobot.processor.*" +# ignore_errors = false + +# [[tool.mypy.overrides]] +# module = "lerobot.datasets.*" +# ignore_errors = false -# # Hardware interfaces # [[tool.mypy.overrides]] # module = "lerobot.cameras.*" -# follow_imports = "normal" +# ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.motors.*" -# follow_imports = "normal" +# ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.robots.*" -# follow_imports = "normal" +# ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.teleoperators.*" -# follow_imports = "normal" +# ignore_errors = false -# # Complex modules (enable these last) # [[tool.mypy.overrides]] # module = "lerobot.policies.*" -# follow_imports = "normal" +# ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.rl.*" -# follow_imports = "normal" - -# Currently checking only the envs module -# [[tool.mypy.overrides]] -# module = "lerobot.envs.*" -# follow_imports = "normal" +# ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.async_inference.*" -# follow_imports = "normal" +# ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.transport.*" -# follow_imports = "normal" +# ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.scripts.*" -# follow_imports = "normal" +# ignore_errors = false From b6c528a4384d28cffe5fd70726515ca7f6cb7e3b Mon Sep 17 00:00:00 2001 From: Akhil Ivaturi <52360071+iakhil@users.noreply.github.com> Date: Wed, 1 Oct 2025 09:11:48 -0500 Subject: [PATCH 148/158] Making Envs module pass MyPy checks (#2048) * Fix configs.py None MyPy error * Use img_tensor instead of img in utils.py * Add type assertion in factory.py * Resolve merge conflict * Uncomment envs moodule for mypy checks in pyproject.toml --------- Signed-off-by: Adil Zouitine Co-authored-by: Adil Zouitine --- .pre-commit-config.yaml | 12 ++++++------ pyproject.toml | 24 ++++++++++++------------ src/lerobot/envs/configs.py | 2 +- src/lerobot/envs/factory.py | 3 +++ src/lerobot/envs/utils.py | 20 ++++++++++---------- 5 files changed, 32 insertions(+), 29 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d15ecb8c6..7f5beff80 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -86,12 +86,12 @@ repos: # TODO(Steven): Uncomment when ready to use ##### Static Analysis & Typing ##### - # - repo: https://github.com/pre-commit/mirrors-mypy - # rev: v1.16.0 - # hooks: - # - id: mypy - # args: [--config-file=pyproject.toml] - # exclude: ^(examples|benchmarks|tests)/ + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.16.0 + hooks: + - id: mypy + args: [--config-file=pyproject.toml] + exclude: ^(examples|benchmarks|tests)/ ##### Docstring Checks ##### # - repo: https://github.com/akaihola/darglint2 diff --git a/pyproject.toml b/pyproject.toml index bc28511ae..2dd4dac39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -270,10 +270,10 @@ default.extend-ignore-identifiers-re = [ # TODO: Enable mypy gradually module by module across multiple PRs # Uncomment [tool.mypy] first, then uncomment individual module overrides as they get proper type annotations -# [tool.mypy] -# python_version = "3.10" -# ignore_missing_imports = true -# follow_imports = "skip" +[tool.mypy] +python_version = "3.10" +ignore_missing_imports = true +follow_imports = "skip" # warn_return_any = true # warn_unused_configs = true # strict = true @@ -281,14 +281,14 @@ default.extend-ignore-identifiers-re = [ # disallow_incomplete_defs = true # check_untyped_defs = true -# [[tool.mypy.overrides]] -# module = "lerobot.*" -# ignore_errors = true +[[tool.mypy.overrides]] +module = "lerobot.*" +ignore_errors = true -# [[tool.mypy.overrides]] -# module = "lerobot.envs.*" -# # Enable type checking only for the envs module -# ignore_errors = false +[[tool.mypy.overrides]] +module = "lerobot.envs.*" +# Enable type checking only for the envs module +ignore_errors = false # [[tool.mypy.overrides]] @@ -299,7 +299,6 @@ default.extend-ignore-identifiers-re = [ # module = "lerobot.configs.*" # ignore_errors = false -# PHASE 2: Core modules # [[tool.mypy.overrides]] # module = "lerobot.optim.*" # ignore_errors = false @@ -340,6 +339,7 @@ default.extend-ignore-identifiers-re = [ # module = "lerobot.rl.*" # ignore_errors = false + # [[tool.mypy.overrides]] # module = "lerobot.async_inference.*" # ignore_errors = false diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 8c0c8b3ab..0daaaf9fd 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -254,7 +254,7 @@ class LiberoEnv(EnvConfig): render_mode: str = "rgb_array" camera_name: str = "agentview_image,robot0_eye_in_hand_image" init_states: bool = True - camera_name_mapping: dict[str, str] | None = (None,) + camera_name_mapping: dict[str, str] | None = None features: dict[str, PolicyFeature] = field( default_factory=lambda: { ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(7,)), diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index 9b172854c..c27f01b65 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -63,6 +63,9 @@ def make_env( if "libero" in cfg.type: from lerobot.envs.libero import create_libero_envs + if cfg.task is None: + raise ValueError("LiberoEnv requires a task to be specified") + return create_libero_envs( task=cfg.task, n_envs=n_envs, diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index b5cfc7e26..5584e0bff 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -48,25 +48,25 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten for imgkey, img in imgs.items(): # TODO(aliberts, rcadene): use transforms.ToTensor()? - img = torch.from_numpy(img) + img_tensor = torch.from_numpy(img) # When preprocessing observations in a non-vectorized environment, we need to add a batch dimension. # This is the case for human-in-the-loop RL where there is only one environment. - if img.ndim == 3: - img = img.unsqueeze(0) + if img_tensor.ndim == 3: + img_tensor = img_tensor.unsqueeze(0) # sanity check that images are channel last - _, h, w, c = img.shape - assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" + _, h, w, c = img_tensor.shape + assert c < h and c < w, f"expect channel last images, but instead got {img_tensor.shape=}" # sanity check that images are uint8 - assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" + assert img_tensor.dtype == torch.uint8, f"expect torch.uint8, but instead {img_tensor.dtype=}" # convert to channel first of type float32 in range [0,1] - img = einops.rearrange(img, "b h w c -> b c h w").contiguous() - img = img.type(torch.float32) - img /= 255 + img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous() + img_tensor = img_tensor.type(torch.float32) + img_tensor /= 255 - return_observations[imgkey] = img + return_observations[imgkey] = img_tensor if "environment_state" in observations: env_state = torch.from_numpy(observations["environment_state"]).float() From abde7be3b3b78757982bc3d8f80b0a92ecd13965 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Thu, 2 Oct 2025 13:14:45 +0200 Subject: [PATCH 149/158] Add OpenPi, Pi0 and Pi0.5 (#1910) * initial commit * change device in test * do detailed import * adhere to python 3.11 syntax * fix autodocstring * additionally * do same in other files * add model. prefix to all keys in state dict * use dummy stats * add pi05 * also shorten action_steps * fix test * all test pass! and fix tokenizer max length between 05 and 0 * remove test * fix transformer dependency * fix test * split pi0 and pi05 policy in seperate files * fix test * fix push to hub test * add some comments, license and readme * remove warning in config * add pi05 to factory * remove check * rename action_horizon to chunk_size * clean up padding of state and action (more in line with lerobot pi0) * add openpi image transforms for training and add more flexibility to _preprocess_images similar to lerobot pi0 * fix key match from pytorch state dict (similar keys to openpi implementation now) * also for pi05 * update to python 3.11 * revert to openpi transformer replace python 3.11 * fix(modeling pi0): nit warning message * use safeauto_docstring * fix: remove unused param * fix from pretrained * add preprocess tests * also compile forward method * Do not add model prefix to normalization * use same name for action and state dim as lerobot pi0 and remove fixed image keys * load from pretrained_path * temp: hardcode base model * fix override self.pretrained_path = None overwrite * rename to loss * remove additional image augmentations, lerobot dataset already does this * Add docs * put tests in test folder * Add test to instatiate all base models * go back to python 3.10 * update docs * adapt docs pi05 * change docs: finetune base model options * minor docs fixes and dependencies * remove todo * cast float64 to float32 for mps * skip if no transformers * fix tests * add new models to modelcard * add back init * fix circular input * feat: only run pi test on GPU * remove require_nightly_gpu * replace decorator test_pi0_openpi * rename action_dim, state_dim to max_action_dim, max_state_dim * fix doc and constants * cleanup tests * fix from pretrained * fix tests * add comment pi0 pi05 tests, add image features to pi0 pi05 hub tests * fix, state is included in language not in flow head * Move test to specific folder * and paligemma task with newline * remove add_special_tokens, not needed * feedback pr * Remove previous pi0 and rename pi0_openpi and pi05_openpi * Add Quantile stats to LeRobotDataset (#1985) * - Add RunningQuantileStats class for efficient histogram-based quantile computation - Integrate quantile parameters (compute_quantiles, quantiles) into LeRobotDataset - Support quantile computation during episode collection and aggregation - Add comprehensive function-based test suite (24 tests) for quantile functionality - Maintain full backward compatibility with existing stats computation - Enable configurable quantiles (default: [0.01, 0.99]) for robust normalization * style fixes, make quantiles computation by default to new datasets * fix tests * - Added DEFAULT_QUANTILES=[0.01, 0.10, 0.50, 0.90, 0.99] to be computed for each features instead of being chosen by the user - Fortified tests. * - add helper functions to reshape stats - add missing test for quantiles * - Add QUANTILE normalization mode to normalize the data with the 1st and 99th percentiles. - Add QUANTILE10 normalization mode to normalize the data with the 10th and 90th percentiles. * style fixes * Added missing lisence * Simplify compute_stats * - added script `augment_dataset_quantile_stats.py` so that we can add quantile stats to existing v3 datasets that dont have quatniles - modified quantile computation instead of using the edge for the value, interpolate the values in the bin * rename pi0/pi05 files * Remove open pi patch and use custom transformer branch for now * renaming * fix * Revert "fix" This reverts commit 1ea65730ac2cbca6e5869df734fbd4392561b3c6. * fix naming * feet(pi0/pi0.5): add pipeline (#2009) * feat(processor): convert openpi model with processor * TODO: Make test works * fix(modeling_pi0openpi): update attention mask value and time scaling; improve task handling in tests - Changed the attention mask value from `self.config.attention_mask_value` to a fixed value of `-2.3819763e38`. - Updated time scaling in the `sample_noise` method to use a constant factor of `0.999` and an offset of `0.001`. - Enhanced task handling in tests to ensure proper formatting and batch size consistency. - Cleaned up commented-out test code for clarity. * refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy - Updated imports and references throughout the codebase to reflect the new naming convention. - Introduced a new processor file for PI0 to handle pre-processing and post-processing steps. - Adjusted tests to utilize the renamed classes, ensuring consistency and functionality. - Enhanced clarity and maintainability by removing outdated naming conventions. * refactor(pi05): rename PI0OpenPIPolicy to PI0Policy and update configuration - Renamed `PI0OpenPIPolicy` to `PI0Policy` for consistency with naming conventions. - Updated the `PI05OpenPIConfig` to include a new `tokenizer_max_length` attribute and changed the normalization mode for state from `MEAN_STD` to `QUANTILES`. - Simplified model initialization in `PI05OpenPIPolicy` by removing unused `dataset_stats` parameter. - Added a new processor class for `Pi05PrepareStateTokenizerProcessorStep` with `@dataclass` for improved readability. - Introduced a test script to compare the integration of the PI0OpenPI policy with the original implementation, ensuring local testing compatibility. * feat(processor): convert openpi model with processor * TODO: Make test works * fix(modeling_pi0openpi): update attention mask value and time scaling; improve task handling in tests - Changed the attention mask value from `self.config.attention_mask_value` to a fixed value of `-2.3819763e38`. - Updated time scaling in the `sample_noise` method to use a constant factor of `0.999` and an offset of `0.001`. - Enhanced task handling in tests to ensure proper formatting and batch size consistency. - Cleaned up commented-out test code for clarity. * refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy - Updated imports and references throughout the codebase to reflect the new naming convention. - Introduced a new processor file for PI0 to handle pre-processing and post-processing steps. - Adjusted tests to utilize the renamed classes, ensuring consistency and functionality. - Enhanced clarity and maintainability by removing outdated naming conventions. * refactor(pi05): rename PI0OpenPIPolicy to PI0Policy and update configuration - Renamed `PI0OpenPIPolicy` to `PI0Policy` for consistency with naming conventions. - Updated the `PI05OpenPIConfig` to include a new `tokenizer_max_length` attribute and changed the normalization mode for state from `MEAN_STD` to `QUANTILES`. - Simplified model initialization in `PI05OpenPIPolicy` by removing unused `dataset_stats` parameter. - Added a new processor class for `Pi05PrepareStateTokenizerProcessorStep` with `@dataclass` for improved readability. - Introduced a test script to compare the integration of the PI0OpenPI policy with the original implementation, ensuring local testing compatibility. * refactor(pi05): update imports and rename configuration classes - Changed imports to reflect the new naming convention for PI05 configuration and policy classes. - Renamed `PI05OpenPIConfig` to `PI05Config` and `PI05OpenPIPolicy` to `PI05Policy` for consistency. - Introduced a new processor file for PI05, implementing pre-processing and post-processing steps. - Updated tests to utilize the renamed classes, ensuring functionality and consistency across the codebase. * update(pi05): increase tokenizer_max_length for improved processing - Changed the `tokenizer_max_length` from 48 to 200 to enhance the model's capability in handling longer sequences. - This adjustment aims to improve the overall performance and flexibility of the PI05 configuration. * add default for state (max_state_dim) * correct naming * fix import * cleanup code * remove unused test * us quantiles for action * move to device * remove discrete state assert * fix pi05 test * move pi05 to device * use base models in comparison tests * small renames for tests * change number of tokens pi05 test * fix openpi tokenization in test * fix hub test * fix test * assert lerobot vs openpi tests --------- Co-authored-by: Pepijn * add headers * add back previously removed imports * update if statement load processor with dataset stats * remove to avoid circular import * inject dataset stats for pretrained models * check normalization before applying * add link to quantile augument script * fix(policies): transformers import for ci in PI0 & PI05 (#2039) * fix(policies): transformers import for ci in PI0 * fix(policies): transformers import for ci in PI05 * test(processor): fix expected raise when normalization types are missing (#2040) * switch normalization order pipeline for pi05 * Fix/quantiles script (#2064) * refactor augment stats with quantiles script add parallelization for faster processing shift the quantile normalization between -1 1 * fix replay buffer tests * fix comment * overwrite the pipeline normalization features with the policy features * remove double normalization overwrite * cleanup from pretrained * remove typo * also set norm_map * fix(augment_quantiles) images incorrectly divided by 255 * clamp quantiles * link to lerobot base models * rename tests * encorperate PR feedback * update docstring for RunningQuantileStats * update doc links * Revert "clamp quantiles" This reverts commit 172207471c8f2cb62958e9a9e6a0535ba3ff67d4. * fix self.paligemma * fix tests related to quantiles that were scaled to [0,1], the new range is [-1, 1] * fix libero doc and use different transformer branch * use fix branch instead of feat * update results libero * add new line * fix formatting * precommit * update results libero * update libero doc * update title * final changes * add quantiles to test * run pre commit --------- Signed-off-by: Steven Palma Co-authored-by: Michel Aractingi Co-authored-by: Adil Zouitine Co-authored-by: Steven Palma Co-authored-by: Steven Palma --- docs/source/_toctree.yml | 7 +- docs/source/libero.mdx | 39 + docs/source/pi0.mdx | 79 + docs/source/pi05.mdx | 98 ++ docs/source/smolvla.mdx | 2 +- pyproject.toml | 6 +- src/lerobot/async_inference/constants.py | 2 +- src/lerobot/async_inference/helpers.py | 9 +- src/lerobot/configs/policies.py | 4 +- src/lerobot/configs/types.py | 2 + src/lerobot/datasets/compute_stats.py | 512 ++++++- .../v30/augment_dataset_quantile_stats.py | 225 +++ src/lerobot/policies/__init__.py | 3 +- src/lerobot/policies/factory.py | 33 +- src/lerobot/policies/pi0/README.md | 49 + src/lerobot/policies/pi0/__init__.py | 21 + src/lerobot/policies/pi0/configuration_pi0.py | 133 +- .../pi0/conversion_scripts/benchmark.py | 82 - .../conversion_scripts/compare_with_jax.py | 132 -- .../conversion_scripts/conversion_utils.py | 84 -- .../convert_pi0_to_hf_lerobot.py | 437 ------ src/lerobot/policies/pi0/flex_attention.py | 141 -- src/lerobot/policies/pi0/modeling_pi0.py | 1331 +++++++++++------ .../policies/pi0/paligemma_with_expert.py | 420 ------ src/lerobot/policies/pi05/README.md | 49 + src/lerobot/policies/pi05/__init__.py | 21 + .../policies/pi05/configuration_pi05.py | 153 ++ src/lerobot/policies/pi05/modeling_pi05.py | 1163 ++++++++++++++ src/lerobot/policies/pi05/processor_pi05.py | 171 +++ .../policies/pi0fast/configuration_pi0fast.py | 16 + .../processor/migrate_policy_normalization.py | 156 +- src/lerobot/processor/normalize_processor.py | 69 +- src/lerobot/scripts/lerobot_train.py | 14 +- .../templates/lerobot_modelcard_template.md | 22 +- src/lerobot/utils/constants.py | 3 + tests/datasets/test_compute_stats.py | 524 +++++++ .../test_quantiles_dataset_integration.py | 212 +++ tests/policies/pi0_pi05/test_pi0.py | 117 ++ tests/policies/pi0_pi05/test_pi05.py | 154 ++ .../pi0_pi05/test_pi05_original_vs_lerobot.py | 419 ++++++ .../pi0_pi05/test_pi0_original_vs_lerobot.py | 410 +++++ tests/processor/test_normalize_processor.py | 226 ++- tests/processor/test_pi0_processor.py | 424 ------ 43 files changed, 5886 insertions(+), 2288 deletions(-) create mode 100644 docs/source/pi0.mdx create mode 100644 docs/source/pi05.mdx create mode 100644 src/lerobot/datasets/v30/augment_dataset_quantile_stats.py create mode 100644 src/lerobot/policies/pi0/README.md create mode 100644 src/lerobot/policies/pi0/__init__.py delete mode 100644 src/lerobot/policies/pi0/conversion_scripts/benchmark.py delete mode 100644 src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py delete mode 100644 src/lerobot/policies/pi0/conversion_scripts/conversion_utils.py delete mode 100644 src/lerobot/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py delete mode 100644 src/lerobot/policies/pi0/flex_attention.py delete mode 100644 src/lerobot/policies/pi0/paligemma_with_expert.py create mode 100644 src/lerobot/policies/pi05/README.md create mode 100644 src/lerobot/policies/pi05/__init__.py create mode 100644 src/lerobot/policies/pi05/configuration_pi05.py create mode 100644 src/lerobot/policies/pi05/modeling_pi05.py create mode 100644 src/lerobot/policies/pi05/processor_pi05.py create mode 100644 tests/datasets/test_quantiles_dataset_integration.py create mode 100644 tests/policies/pi0_pi05/test_pi0.py create mode 100644 tests/policies/pi0_pi05/test_pi05.py create mode 100644 tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py create mode 100644 tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py delete mode 100644 tests/processor/test_pi0_processor.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 7f4c07944..36eaea165 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -28,11 +28,14 @@ title: "Datasets" - sections: - local: smolvla - title: Finetune SmolVLA + title: SmolVLA + - local: pi0 + title: π₀ (Pi0) + - local: pi05 + title: π₀.₅ (Pi05) - local: libero title: Using Libero title: "Policies" - - sections: - local: introduction_processors title: Introduction to Robot Processors diff --git a/docs/source/libero.mdx b/docs/source/libero.mdx index eafe3e78b..3f2b92406 100644 --- a/docs/source/libero.mdx +++ b/docs/source/libero.mdx @@ -125,3 +125,42 @@ lerobot-train \ LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation: - `export MUJOCO_GL=egl` → for headless servers (e.g. HPC, cloud) + +## Reproducing π₀.₅ results + +We reproduce the results of π₀.₅ on the LIBERO benchmark using the LeRobot implementation. We take the Physical Intelligence LIBERO base model (`pi05_libero`) and finetune for an additional 6k steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero). + +The finetuned model can be found here: + +- **π₀.₅ LIBERO**: [lerobot/pi05_libero_finetuned](https://huggingface.co/lerobot/pi05_libero_finetuned) + +We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command: + +```bash +python src/lerobot/scripts/eval.py \ + --output_dir=/logs/ \ + --env.type=libero \ + --env.task=libero_spatial,libero_object,libero_goal,libero_10 \ + --eval.batch_size=1 \ + --eval.n_episodes=10 \ + --policy.path=pi05_libero_finetuned \ + --policy.n_action_steps=10 \ + --output_dir=./eval_logs/ \ + --env.max_parallel_tasks=1 +``` + +**Note:** We set `n_action_steps=10`, similar to the original OpenPI implementation. + +### Results + +We obtain the following results on the LIBERO benchmark: + +| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average | +| -------- | -------------- | ------------- | ----------- | --------- | -------- | +| **π₀.₅** | 97.0 | 99.0 | 98.0 | 96.0 | **97.5** | + +These results are consistent with the original [results](https://github.com/Physical-Intelligence/openpi/tree/main/examples/libero#results) reported by Physical Intelligence: + +| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average | +| -------- | -------------- | ------------- | ----------- | --------- | --------- | +| **π₀.₅** | 98.8 | 98.2 | 98.0 | 92.4 | **96.85** | diff --git a/docs/source/pi0.mdx b/docs/source/pi0.mdx new file mode 100644 index 000000000..10260ee72 --- /dev/null +++ b/docs/source/pi0.mdx @@ -0,0 +1,79 @@ +# π₀ (Pi0) + +π₀ is a **Vision-Language-Action model for general robot control**, from Physical Intelligence. The LeRobot implementation is adapted from their open source [OpenPI](https://github.com/Physical-Intelligence/openpi) repository. + +## Model Overview + +π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi0). Unlike traditional robot programs that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks. + +### The Vision for Physical Intelligence + +As described by Physical Intelligence, while AI has achieved remarkable success in digital domains, from chess-playing to drug discovery, human intelligence still dramatically outpaces AI in the physical world. To paraphrase Moravec's paradox, winning a game of chess represents an "easy" problem for AI, but folding a shirt or cleaning up a table requires solving some of the most difficult engineering problems ever conceived. π₀ represents a first step toward developing artificial physical intelligence that enables users to simply ask robots to perform any task they want, just like they can with large language models. + +### Architecture and Approach + +π₀ combines several key innovations: + +- **Flow Matching**: Uses a novel method to augment pre-trained VLMs with continuous action outputs via flow matching (a variant of diffusion models) +- **Cross-Embodiment Training**: Trained on data from 8 distinct robot platforms including UR5e, Bimanual UR5e, Franka, Bimanual Trossen, Bimanual ARX, Mobile Trossen, and Mobile Fibocom +- **Internet-Scale Pre-training**: Inherits semantic knowledge from a pre-trained 3B parameter Vision-Language Model +- **High-Frequency Control**: Outputs motor commands at up to 50 Hz for real-time dexterous manipulation + +## Installation Requirements + +1. Install LeRobot by following our [Installation Guide](./installation). +2. Install Pi0 dependencies by running: + + ```bash + pip install -e ".[pi]" + ``` + +## Training Data and Capabilities + +π₀ is trained on the largest robot interaction dataset to date, combining three key data sources: + +1. **Internet-Scale Pre-training**: Vision-language data from the web for semantic understanding +2. **Open X-Embodiment Dataset**: Open-source robot manipulation datasets +3. **Physical Intelligence Dataset**: Large and diverse dataset of dexterous tasks across 8 distinct robots + +## Usage + +To use π₀ in LeRobot, specify the policy type as: + +```python +policy.type=pi0 +``` + +## Training + +For training π₀, you can use the standard LeRobot training script with the appropriate configuration: + +```bash +python src/lerobot/scripts/train.py \ + --dataset.repo_id=your_dataset \ + --policy.type=pi0 \ + --output_dir=./outputs/pi0_training \ + --job_name=pi0_training \ + --policy.pretrained_path=lerobot/pi0_base \ + --policy.repo_id=your_repo_id \ + --policy.compile_model=true \ + --policy.gradient_checkpointing=true \ + --policy.dtype=bfloat16 \ + --steps=3000 \ + --policy.device=cuda \ + --batch_size=32 +``` + +### Key Training Parameters + +- **`--policy.compile_model=true`**: Enables model compilation for faster training +- **`--policy.gradient_checkpointing=true`**: Reduces memory usage significantly during training +- **`--policy.dtype=bfloat16`**: Use mixed precision training for efficiency +- **`--batch_size=32`**: Batch size for training, adapt this based on your GPU memory +- **`--policy.pretrained_path=lerobot/pi0_base`**: The base π₀ model you want to finetune, options are: + - [lerobot/pi0_base](https://huggingface.co/lerobot/pi0_base) + - [lerobot/pi0_libero](https://huggingface.co/lerobot/pi0_libero) (specifically trained on the Libero dataset) + +## License + +This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi). diff --git a/docs/source/pi05.mdx b/docs/source/pi05.mdx new file mode 100644 index 000000000..b777fcd58 --- /dev/null +++ b/docs/source/pi05.mdx @@ -0,0 +1,98 @@ +# π₀.₅ (Pi05) Policy + +π₀.₅ is a **Vision-Language-Action model with open-world generalization**, from Physical Intelligence. The LeRobot implementation is adapted from their open source [OpenPI](https://github.com/Physical-Intelligence/openpi) repository. + +## Model Overview + +π₀.₅ represents a significant evolution from π₀, developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi05) to address a big challenge in robotics: **open-world generalization**. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training. + +### The Generalization Challenge + +As Physical Intelligence explains, the fundamental challenge isn't performing tasks of agility or dexterity, but generalization, the ability to correctly perform tasks in new settings with new objects. Consider a robot cleaning different homes: each home has different objects in different places. Generalization must occur at multiple levels: + +- **Physical Level**: Understanding how to pick up a spoon (by the handle) or plate (by the edge), even with unseen objects in cluttered environments +- **Semantic Level**: Understanding task semantics, where to put clothes and shoes (laundry hamper, not on the bed), and what tools are appropriate for cleaning spills +- **Environmental Level**: Adapting to "messy" real-world environments like homes, grocery stores, offices, and hospitals + +### Co-Training on Heterogeneous Data + +The breakthrough innovation in π₀.₅ is **co-training on heterogeneous data sources**. The model learns from: + +1. **Multimodal Web Data**: Image captioning, visual question answering, object detection +2. **Verbal Instructions**: Humans coaching robots through complex tasks step-by-step +3. **Subtask Commands**: High-level semantic behavior labels (e.g., "pick up the pillow" for an unmade bed) +4. **Cross-Embodiment Robot Data**: Data from various robot platforms with different capabilities +5. **Multi-Environment Data**: Static robots deployed across many different homes +6. **Mobile Manipulation Data**: ~400 hours of mobile robot demonstrations + +This diverse training mixture creates a "curriculum" that enables generalization across physical, visual, and semantic levels simultaneously. + +## Installation Requirements + +1. Install LeRobot by following our [Installation Guide](./installation). +2. Install Pi0.5 dependencies by running: + + ```bash + pip install -e ".[pi]" + ``` + +## Usage + +To use π₀.₅ in your LeRobot configuration, specify the policy type as: + +```python +policy.type=pi05 +``` + +## Training + +### Training Command Example + +Here's a complete training command for finetuning the base π₀.₅ model on your own dataset: + +```bash +python src/lerobot/scripts/train.py \ + --dataset.repo_id=your_dataset \ + --policy.type=pi05 \ + --output_dir=./outputs/pi0_training \ + --job_name=pi0_training \ + --policy.repo_id=lerobot/pi05_base \ + --policy.pretrained_path=your_repo_id \ + --policy.compile_model=true \ + --policy.gradient_checkpointing=true \ + --wandb.enable=true \ + --policy.dtype=bfloat16 \ + --steps=3000 \ + --policy.device=cuda \ + --batch_size=32 +``` + +### Key Training Parameters + +- **`--policy.compile_model=true`**: Enables model compilation for faster training +- **`--policy.gradient_checkpointing=true`**: Reduces memory usage significantly during training +- **`--policy.dtype=bfloat16`**: Use mixed precision training for efficiency +- **`--batch_size=32`**: Batch size for training, adapt this based on your GPU memory +- **`--policy.pretrained_path=lerobot/pi05_base`**: The base π₀.₅ model you want to finetune, options are: + - [lerobot/pi05_base](https://huggingface.co/lerobot/pi05_base) + - [lerobot/pi05_libero](https://huggingface.co/lerobot/pi05_libero) (specifically trained on the Libero dataset) + +## Performance Results + +### Libero Benchmark Results + +π₀.₅ has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the libero base model for an additional 6k steps on the Libero dataset and compared the results to the OpenPI reference results. + +| Benchmark | LeRobot Implementation | OpenPI Reference | +| ------------------ | ---------------------- | ---------------- | +| **Libero Spatial** | 97.0% | 98.8% | +| **Libero Object** | 99.0% | 98.2% | +| **Libero Goal** | 98.0% | 98.0% | +| **Libero 10** | 96.0% | 92.4% | +| **Average** | 97.5% | 96.85% | + +These results demonstrate π₀.₅'s strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section. + +## License + +This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi). diff --git a/docs/source/smolvla.mdx b/docs/source/smolvla.mdx index a28e7cb44..a56298b5e 100644 --- a/docs/source/smolvla.mdx +++ b/docs/source/smolvla.mdx @@ -1,4 +1,4 @@ -# Finetune SmolVLA +# SmolVLA SmolVLA is Hugging Face’s lightweight foundation model for robotics. Designed for easy fine-tuning on LeRobot datasets, it helps accelerate your development! diff --git a/pyproject.toml b/pyproject.toml index 2dd4dac39..f350fac0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,7 @@ dependencies = [ # Common pygame-dep = ["pygame>=2.5.1"] placo-dep = ["placo>=0.9.6"] -transformers-dep = ["transformers>=4.52.0"] +transformers-dep = ["transformers>=4.53.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # Motors @@ -119,7 +119,7 @@ phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"] # ] # TODO: Currently not supported # Policies -pi0 = ["lerobot[transformers-dep]"] +pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"] hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] @@ -147,7 +147,7 @@ all = [ "lerobot[reachy2]", "lerobot[kinematics]", "lerobot[intelrealsense]", - "lerobot[pi0]", + "lerobot[pi]", "lerobot[smolvla]", "lerobot[hilserl]", "lerobot[async]", diff --git a/src/lerobot/async_inference/constants.py b/src/lerobot/async_inference/constants.py index af983a800..5ebf3780c 100644 --- a/src/lerobot/async_inference/constants.py +++ b/src/lerobot/async_inference/constants.py @@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS DEFAULT_OBS_QUEUE_TIMEOUT = 2 # All action chunking policies -SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"] +SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"] # TODO: Add all other robots SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"] diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index a336d3a63..88fb00a3f 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -25,7 +25,14 @@ from lerobot.configs.types import PolicyFeature from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features # NOTE: Configs need to be loaded for the client to be able to instantiate the policy config -from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401 +from lerobot.policies import ( # noqa: F401 + ACTConfig, + DiffusionConfig, + PI0Config, + PI05Config, + SmolVLAConfig, + VQBeTConfig, +) from lerobot.robots.robot import Robot from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR from lerobot.utils.utils import init_logging diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 06c220cb8..98dd4df3f 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -71,9 +71,11 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): tags: list[str] | None = None # Add tags to your policy on the hub. license: str | None = None + # Either the repo ID of a model hosted on the Hub or a path to a directory containing weights + # saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch. + pretrained_path: str | None = None def __post_init__(self): - self.pretrained_path = None if not self.device or not is_torch_device_available(self.device): auto_device = auto_select_torch_device() logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.") diff --git a/src/lerobot/configs/types.py b/src/lerobot/configs/types.py index 754aca1ab..cb578060e 100644 --- a/src/lerobot/configs/types.py +++ b/src/lerobot/configs/types.py @@ -35,6 +35,8 @@ class NormalizationMode(str, Enum): MIN_MAX = "MIN_MAX" MEAN_STD = "MEAN_STD" IDENTITY = "IDENTITY" + QUANTILES = "QUANTILES" + QUANTILE10 = "QUANTILE10" @dataclass diff --git a/src/lerobot/datasets/compute_stats.py b/src/lerobot/datasets/compute_stats.py index bfe7b18b4..61e174d5c 100644 --- a/src/lerobot/datasets/compute_stats.py +++ b/src/lerobot/datasets/compute_stats.py @@ -17,6 +17,179 @@ import numpy as np from lerobot.datasets.utils import load_image_as_numpy +DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99] + + +class RunningQuantileStats: + """ + Maintains running statistics for batches of vectors, including mean, + standard deviation, min, max, and approximate quantiles. + + Statistics are computed per feature dimension and updated incrementally + as new batches are observed. Quantiles are estimated using histograms, + which adapt dynamically if the observed data range expands. + """ + + def __init__(self, quantile_list: list[float] | None = None, num_quantile_bins: int = 5000): + self._count = 0 + self._mean = None + self._mean_of_squares = None + self._min = None + self._max = None + self._histograms = None + self._bin_edges = None + self._num_quantile_bins = num_quantile_bins + + self._quantile_list = quantile_list + if self._quantile_list is None: + self._quantile_list = DEFAULT_QUANTILES + self._quantile_keys = [f"q{int(q * 100):02d}" for q in self._quantile_list] + + def update(self, batch: np.ndarray) -> None: + """Update the running statistics with a batch of vectors. + + Args: + batch: An array where all dimensions except the last are batch dimensions. + """ + batch = batch.reshape(-1, batch.shape[-1]) + num_elements, vector_length = batch.shape + + if self._count == 0: + self._mean = np.mean(batch, axis=0) + self._mean_of_squares = np.mean(batch**2, axis=0) + self._min = np.min(batch, axis=0) + self._max = np.max(batch, axis=0) + self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)] + self._bin_edges = [ + np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1) + for i in range(vector_length) + ] + else: + if vector_length != self._mean.size: + raise ValueError("The length of new vectors does not match the initialized vector length.") + + new_max = np.max(batch, axis=0) + new_min = np.min(batch, axis=0) + max_changed = np.any(new_max > self._max) + min_changed = np.any(new_min < self._min) + self._max = np.maximum(self._max, new_max) + self._min = np.minimum(self._min, new_min) + + if max_changed or min_changed: + self._adjust_histograms() + + self._count += num_elements + + batch_mean = np.mean(batch, axis=0) + batch_mean_of_squares = np.mean(batch**2, axis=0) + + # Update running mean and mean of squares + self._mean += (batch_mean - self._mean) * (num_elements / self._count) + self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * ( + num_elements / self._count + ) + + self._update_histograms(batch) + + def get_statistics(self) -> dict[str, np.ndarray]: + """Compute and return the statistics of the vectors processed so far. + + Args: + quantiles: List of quantiles to compute (e.g., [0.01, 0.10, 0.50, 0.90, 0.99]). If None, no quantiles computed. + + Returns: + Dictionary containing the computed statistics. + """ + if self._count < 2: + raise ValueError("Cannot compute statistics for less than 2 vectors.") + + variance = self._mean_of_squares - self._mean**2 + + stddev = np.sqrt(np.maximum(0, variance)) + + stats = { + "min": self._min.copy(), + "max": self._max.copy(), + "mean": self._mean.copy(), + "std": stddev, + "count": np.array([self._count]), + } + + quantile_results = self._compute_quantiles() + for i, q in enumerate(self._quantile_keys): + stats[q] = quantile_results[i] + + return stats + + def _adjust_histograms(self): + """Adjust histograms when min or max changes.""" + for i in range(len(self._histograms)): + old_edges = self._bin_edges[i] + old_hist = self._histograms[i] + + # Create new edges with small padding to ensure range coverage + padding = (self._max[i] - self._min[i]) * 1e-10 + new_edges = np.linspace( + self._min[i] - padding, self._max[i] + padding, self._num_quantile_bins + 1 + ) + + # Redistribute existing histogram counts to new bins + # We need to map each old bin center to the new bins + old_centers = (old_edges[:-1] + old_edges[1:]) / 2 + new_hist = np.zeros(self._num_quantile_bins) + + for old_center, count in zip(old_centers, old_hist, strict=False): + if count > 0: + # Find which new bin this old center belongs to + bin_idx = np.searchsorted(new_edges, old_center) - 1 + bin_idx = max(0, min(bin_idx, self._num_quantile_bins - 1)) + new_hist[bin_idx] += count + + self._histograms[i] = new_hist + self._bin_edges[i] = new_edges + + def _update_histograms(self, batch: np.ndarray) -> None: + """Update histograms with new vectors.""" + for i in range(batch.shape[1]): + hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i]) + self._histograms[i] += hist + + def _compute_quantiles(self) -> list[np.ndarray]: + """Compute quantiles based on histograms.""" + results = [] + for q in self._quantile_list: + target_count = q * self._count + q_values = [] + + for hist, edges in zip(self._histograms, self._bin_edges, strict=True): + q_value = self._compute_single_quantile(hist, edges, target_count) + q_values.append(q_value) + + results.append(np.array(q_values)) + return results + + def _compute_single_quantile(self, hist: np.ndarray, edges: np.ndarray, target_count: float) -> float: + """Compute a single quantile value from histogram and bin edges.""" + cumsum = np.cumsum(hist) + idx = np.searchsorted(cumsum, target_count) + + if idx == 0: + return edges[0] + if idx >= len(cumsum): + return edges[-1] + + # If not edge case, interpolate within the bin + count_before = cumsum[idx - 1] + count_in_bin = cumsum[idx] - count_before + + # If no samples in this bin, use the bin edge + if count_in_bin == 0: + return edges[idx] + + # Linear interpolation within the bin + fraction = (target_count - count_before) / count_in_bin + return edges[idx] + fraction * (edges[idx + 1] - edges[idx]) + def estimate_num_samples( dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75 @@ -72,33 +245,282 @@ def sample_images(image_paths: list[str]) -> np.ndarray: return images -def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]: - return { - "min": np.min(array, axis=axis, keepdims=keepdims), - "max": np.max(array, axis=axis, keepdims=keepdims), - "mean": np.mean(array, axis=axis, keepdims=keepdims), - "std": np.std(array, axis=axis, keepdims=keepdims), - "count": np.array([len(array)]), +def _reshape_stats_by_axis( + stats: dict[str, np.ndarray], + axis: int | tuple[int, ...] | None, + keepdims: bool, + original_shape: tuple[int, ...], +) -> dict[str, np.ndarray]: + """Reshape all statistics to match NumPy's output conventions. + + Applies consistent reshaping to all statistics (except 'count') based on the + axis and keepdims parameters. This ensures statistics have the correct shape + for broadcasting with the original data. + + Args: + stats: Dictionary of computed statistics + axis: Axis or axes along which statistics were computed + keepdims: Whether to keep reduced dimensions as size-1 dimensions + original_shape: Shape of the original array + + Returns: + Dictionary with reshaped statistics + + Note: + The 'count' statistic is never reshaped as it represents metadata + rather than per-feature statistics. + """ + if axis == (1,) and not keepdims: + return stats + + result = {} + for key, value in stats.items(): + if key == "count": + result[key] = value + else: + result[key] = _reshape_single_stat(value, axis, keepdims, original_shape) + + return result + + +def _reshape_for_image_stats(value: np.ndarray, keepdims: bool) -> np.ndarray: + """Reshape statistics for image data (axis=(0,2,3)).""" + if keepdims and value.ndim == 1: + return value.reshape(1, -1, 1, 1) + return value + + +def _reshape_for_vector_stats( + value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...] +) -> np.ndarray: + """Reshape statistics for vector data (axis=0 or axis=(0,)).""" + if not keepdims: + return value + + if len(original_shape) == 1 and value.ndim > 0: + return value.reshape(1) + elif len(original_shape) >= 2 and value.ndim == 1: + return value.reshape(1, -1) + return value + + +def _reshape_for_feature_stats(value: np.ndarray, keepdims: bool) -> np.ndarray: + """Reshape statistics for feature-wise computation (axis=(1,)).""" + if not keepdims: + return value + + if value.ndim == 0: + return value.reshape(1, 1) + elif value.ndim == 1: + return value.reshape(-1, 1) + return value + + +def _reshape_for_global_stats( + value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...] +) -> np.ndarray | float: + """Reshape statistics for global reduction (axis=None).""" + if keepdims: + target_shape = tuple(1 for _ in original_shape) + return value.reshape(target_shape) + # Keep at least 1-D arrays to satisfy validator + return np.atleast_1d(value) + + +def _reshape_single_stat( + value: np.ndarray, axis: int | tuple[int, ...] | None, keepdims: bool, original_shape: tuple[int, ...] +) -> np.ndarray | float: + """Apply appropriate reshaping to a single statistic array. + + This function transforms statistic arrays to match expected output shapes + based on the axis configuration and keepdims parameter. + + Args: + value: The statistic array to reshape + axis: Axis or axes that were reduced during computation + keepdims: Whether to maintain reduced dimensions as size-1 dimensions + original_shape: Shape of the original data before reduction + + Returns: + Reshaped array following NumPy broadcasting conventions + + """ + if axis == (0, 2, 3): + return _reshape_for_image_stats(value, keepdims) + + if axis in [0, (0,)]: + return _reshape_for_vector_stats(value, keepdims, original_shape) + + if axis == (1,): + return _reshape_for_feature_stats(value, keepdims) + + if axis is None: + return _reshape_for_global_stats(value, keepdims, original_shape) + + return value + + +def _prepare_array_for_stats(array: np.ndarray, axis: int | tuple[int, ...] | None) -> tuple[np.ndarray, int]: + """Prepare array for statistics computation by reshaping according to axis. + + Args: + array: Input data array + axis: Axis or axes along which to compute statistics + + Returns: + Tuple of (reshaped_array, sample_count) + """ + if axis == (0, 2, 3): # Image data + batch_size, channels, height, width = array.shape + reshaped = array.transpose(0, 2, 3, 1).reshape(-1, channels) + return reshaped, batch_size + + if axis == 0 or axis == (0,): # Vector data + reshaped = array + if array.ndim == 1: + reshaped = array.reshape(-1, 1) + return reshaped, array.shape[0] + + if axis == (1,): # Feature-wise statistics + return array.T, array.shape[1] + + if axis is None: # Global statistics + reshaped = array.reshape(-1, 1) + # For backward compatibility, count represents the first dimension size + return reshaped, array.shape[0] if array.ndim > 0 else 1 + + raise ValueError(f"Unsupported axis configuration: {axis}") + + +def _compute_basic_stats( + array: np.ndarray, sample_count: int, quantile_list: list[float] | None = None +) -> dict[str, np.ndarray]: + """Compute basic statistics for arrays with insufficient samples for quantiles. + + Args: + array: Reshaped array ready for statistics computation + sample_count: Number of samples represented in the data + + Returns: + Dictionary with basic statistics and quantiles set to mean values + """ + if quantile_list is None: + quantile_list = DEFAULT_QUANTILES + quantile_list_keys = [f"q{int(q * 100):02d}" for q in quantile_list] + + stats = { + "min": np.min(array, axis=0), + "max": np.max(array, axis=0), + "mean": np.mean(array, axis=0), + "std": np.std(array, axis=0), + "count": np.array([sample_count]), } + for q in quantile_list_keys: + stats[q] = stats["mean"].copy() + + return stats + + +def get_feature_stats( + array: np.ndarray, + axis: int | tuple[int, ...] | None, + keepdims: bool, + quantile_list: list[float] | None = None, +) -> dict[str, np.ndarray]: + """Compute comprehensive statistics for array features along specified axes. + + This function calculates min, max, mean, std, and quantiles (1%, 10%, 50%, 90%, 99%) + for the input array along the specified axes. It handles different data layouts: + - Image data: axis=(0,2,3) computes per-channel statistics + - Vector data: axis=0 computes per-feature statistics + - Feature-wise: axis=1 computes statistics across features + - Global: axis=None computes statistics over entire array + + Args: + array: Input data array with shape appropriate for the specified axis + axis: Axis or axes along which to compute statistics + - (0, 2, 3): For image data (batch, channels, height, width) + - 0 or (0,): For vector/tabular data (samples, features) + - (1,): For computing across features + - None: For global statistics over entire array + keepdims: If True, reduced axes are kept as dimensions with size 1 + + Returns: + Dictionary containing: + - 'min': Minimum values + - 'max': Maximum values + - 'mean': Mean values + - 'std': Standard deviation + - 'count': Number of samples (always shape (1,)) + - 'q01', 'q10', 'q50', 'q90', 'q99': Quantile values + + """ + if quantile_list is None: + quantile_list = DEFAULT_QUANTILES + + original_shape = array.shape + reshaped, sample_count = _prepare_array_for_stats(array, axis) + + if reshaped.shape[0] < 2: + stats = _compute_basic_stats(reshaped, sample_count, quantile_list) + else: + running_stats = RunningQuantileStats() + running_stats.update(reshaped) + stats = running_stats.get_statistics() + stats["count"] = np.array([sample_count]) + + stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape) + return stats + + +def compute_episode_stats( + episode_data: dict[str, list[str] | np.ndarray], + features: dict, + quantile_list: list[float] | None = None, +) -> dict: + """Compute comprehensive statistics for all features in an episode. + + Processes different data types appropriately: + - Images/videos: Samples from paths, computes per-channel stats, normalizes to [0,1] + - Numerical arrays: Computes per-feature statistics + - Strings: Skipped (no statistics computed) + + Args: + episode_data: Dictionary mapping feature names to data + - For images/videos: list of file paths + - For numerical data: numpy arrays + features: Dictionary describing each feature's dtype and shape + + Returns: + Dictionary mapping feature names to their statistics dictionaries. + Each statistics dictionary contains min, max, mean, std, count, and quantiles. + + Note: + Image statistics are normalized to [0,1] range and have shape (3,1,1) for + per-channel values when dtype is 'image' or 'video'. + """ + if quantile_list is None: + quantile_list = DEFAULT_QUANTILES -def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict: ep_stats = {} for key, data in episode_data.items(): if features[key]["dtype"] == "string": - continue # HACK: we should receive np.arrays of strings - elif features[key]["dtype"] in ["image", "video"]: - ep_ft_array = sample_images(data) # data is a list of image paths - axes_to_reduce = (0, 2, 3) # keep channel dim + continue + + if features[key]["dtype"] in ["image", "video"]: + ep_ft_array = sample_images(data) + axes_to_reduce = (0, 2, 3) keepdims = True else: - ep_ft_array = data # data is already a np.ndarray - axes_to_reduce = 0 # compute stats over the first axis - keepdims = data.ndim == 1 # keep as np.array + ep_ft_array = data + axes_to_reduce = 0 + keepdims = data.ndim == 1 - ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims) + ep_stats[key] = get_feature_stats( + ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list + ) - # finally, we normalize and remove batch dim for images if features[key]["dtype"] in ["image", "video"]: ep_stats[key] = { k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items() @@ -107,20 +529,37 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu return ep_stats +def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None: + """Validate a single statistic value.""" + if not isinstance(value, np.ndarray): + raise ValueError( + f"Stats must be composed of numpy array, but key '{key}' of feature '{feature_key}' " + f"is of type '{type(value)}' instead." + ) + + if value.ndim == 0: + raise ValueError("Number of dimensions must be at least 1, and is 0 instead.") + + if key == "count" and value.shape != (1,): + raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.") + + if "image" in feature_key and key != "count" and value.shape != (3, 1, 1): + raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.") + + def _assert_type_and_shape(stats_list: list[dict[str, dict]]): - for i in range(len(stats_list)): - for fkey in stats_list[i]: - for k, v in stats_list[i][fkey].items(): - if not isinstance(v, np.ndarray): - raise ValueError( - f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead." - ) - if v.ndim == 0: - raise ValueError("Number of dimensions must be at least 1, and is 0 instead.") - if k == "count" and v.shape != (1,): - raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.") - if "image" in fkey and k != "count" and v.shape != (3, 1, 1): - raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.") + """Validate that all statistics have correct types and shapes. + + Args: + stats_list: List of statistics dictionaries to validate + + Raises: + ValueError: If any statistic has incorrect type or shape + """ + for stats in stats_list: + for feature_key, feature_stats in stats.items(): + for stat_key, stat_value in feature_stats.items(): + _validate_stat_value(stat_value, stat_key, feature_key) def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]: @@ -143,7 +582,7 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d weighted_variances = (variances + delta_means**2) * counts total_variance = weighted_variances.sum(axis=0) / total_count - return { + aggregated = { "min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0), "max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0), "mean": total_mean, @@ -151,6 +590,17 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d "count": total_count, } + if stats_ft_list: + quantile_keys = [k for k in stats_ft_list[0] if k.startswith("q") and k[1:].isdigit()] + + for q_key in quantile_keys: + if all(q_key in s for s in stats_ft_list): + quantile_values = np.stack([s[q_key] for s in stats_ft_list]) + weighted_quantiles = quantile_values * counts + aggregated[q_key] = weighted_quantiles.sum(axis=0) / total_count + + return aggregated + def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]: """Aggregate stats from multiple compute_stats outputs into a single set of stats. diff --git a/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py b/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py new file mode 100644 index 000000000..ff4689efa --- /dev/null +++ b/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script augments existing LeRobot datasets with quantile statistics. + +Most datasets created before the quantile feature was added do not contain +quantile statistics (q01, q10, q50, q90, q99) in their metadata. This script: + +1. Loads an existing LeRobot dataset in v3.0 format +2. Checks if it already contains quantile statistics +3. If missing, computes quantile statistics for all features +4. Updates the dataset metadata with the new quantile statistics + +Usage: + +```bash +python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \ + --repo-id=lerobot/pusht \ +``` +""" + +import argparse +import concurrent.futures +import logging +from pathlib import Path + +import numpy as np +import torch +from tqdm import tqdm + +from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.utils import write_stats +from lerobot.utils.utils import init_logging + + +def has_quantile_stats(stats: dict[str, dict] | None, quantile_list_keys: list[str] | None = None) -> bool: + """Check if dataset statistics already contain quantile information. + + Args: + stats: Dataset statistics dictionary + + Returns: + True if quantile statistics are present, False otherwise + """ + if quantile_list_keys is None: + quantile_list_keys = [f"q{int(q * 100):02d}" for q in DEFAULT_QUANTILES] + + if stats is None: + return False + + for feature_stats in stats.values(): + if any(q_key in feature_stats for q_key in quantile_list_keys): + return True + + return False + + +def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict: + """Process a single episode and return its statistics. + + Args: + dataset: The LeRobot dataset + episode_idx: Index of the episode to process + + Returns: + Dictionary containing episode statistics + """ + logging.info(f"Computing stats for episode {episode_idx}") + + start_idx = dataset.meta.episodes[episode_idx]["dataset_from_index"] + end_idx = dataset.meta.episodes[episode_idx]["dataset_to_index"] + + ep_stats = {} + for key, data in dataset.hf_dataset[start_idx:end_idx].items(): + if dataset.features[key]["dtype"] == "string": + continue + + data = torch.stack(data).cpu().numpy() + if dataset.features[key]["dtype"] in ["image", "video"]: + axes_to_reduce = (0, 2, 3) + keepdims = True + else: + axes_to_reduce = 0 + keepdims = data.ndim == 1 + + ep_stats[key] = get_feature_stats( + data, axis=axes_to_reduce, keepdims=keepdims, quantile_list=DEFAULT_QUANTILES + ) + + if dataset.features[key]["dtype"] in ["image", "video"]: + for k, v in ep_stats[key].items(): + if dataset.features[key]["dtype"] == "video": + v = v / 255.0 + if k != "count": + v = np.squeeze(v, axis=0) + ep_stats[key][k] = v + + return ep_stats + + +def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dict]: + """Compute quantile statistics for all episodes in the dataset. + + Args: + dataset: The LeRobot dataset to compute statistics for + + Returns: + Dictionary containing aggregated statistics with quantiles + """ + logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes") + + episode_stats_list = [] + max_workers = min(dataset.num_episodes, 16) + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_episode = { + executor.submit(process_single_episode, dataset, episode_idx): episode_idx + for episode_idx in range(dataset.num_episodes) + } + + episode_results = {} + with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar: + for future in concurrent.futures.as_completed(future_to_episode): + episode_idx = future_to_episode[future] + ep_stats = future.result() + episode_results[episode_idx] = ep_stats + pbar.update(1) + + for episode_idx in range(dataset.num_episodes): + if episode_idx in episode_results: + episode_stats_list.append(episode_results[episode_idx]) + + if not episode_stats_list: + raise ValueError("No episode data found for computing statistics") + + logging.info(f"Aggregating statistics from {len(episode_stats_list)} episodes") + return aggregate_stats(episode_stats_list) + + +def augment_dataset_with_quantile_stats( + repo_id: str, + root: str | Path | None = None, + overwrite: bool = False, +) -> None: + """Augment a dataset with quantile statistics if they are missing. + + Args: + repo_id: Repository ID of the dataset + root: Local root directory for the dataset + overwrite: Overwrite existing quantile statistics if they already exist + """ + logging.info(f"Loading dataset: {repo_id}") + dataset = LeRobotDataset( + repo_id=repo_id, + root=root, + ) + + if not overwrite and has_quantile_stats(dataset.meta.stats): + logging.info("Dataset already contains quantile statistics. No action needed.") + return + + logging.info("Dataset does not contain quantile statistics. Computing them now...") + + new_stats = compute_quantile_stats_for_dataset(dataset) + + logging.info("Updating dataset metadata with new quantile statistics") + dataset.meta.stats = new_stats + + write_stats(new_stats, dataset.meta.root) + + logging.info("Successfully updated dataset with quantile statistics") + dataset.push_to_hub() + + +def main(): + """Main function to run the augmentation script.""" + parser = argparse.ArgumentParser(description="Augment LeRobot dataset with quantile statistics") + + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Repository ID of the dataset (e.g., 'lerobot/pusht')", + ) + + parser.add_argument( + "--root", + type=str, + help="Local root directory for the dataset", + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="Overwrite existing quantile statistics if they already exist", + ) + + args = parser.parse_args() + root = Path(args.root) if args.root else None + + init_logging() + + augment_dataset_with_quantile_stats( + repo_id=args.repo_id, + root=root, + overwrite=args.overwrite, + ) + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 9b9de9931..49f1e0f95 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -15,7 +15,7 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .pi0.configuration_pi0 import PI0Config as PI0Config -from .pi0.processor_pi0 import Pi0NewLineProcessor +from .pi05.configuration_pi05 import PI05Config as PI05Config from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig from .smolvla.processor_smolvla import SmolVLANewLineProcessor from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig @@ -25,6 +25,7 @@ __all__ = [ "ACTConfig", "DiffusionConfig", "PI0Config", + "PI05Config", "SmolVLAConfig", "TDMPCConfig", "VQBeTConfig", diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 60c05240e..ac76baf9f 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -32,6 +32,7 @@ from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig +from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig @@ -81,14 +82,18 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy return VQBeTPolicy - elif name == "pi0": - from lerobot.policies.pi0.modeling_pi0 import PI0Policy - - return PI0Policy elif name == "pi0fast": from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy return PI0FASTPolicy + elif name == "pi0": + from lerobot.policies.pi0.modeling_pi0 import PI0Policy + + return PI0Policy + elif name == "pi05": + from lerobot.policies.pi05.modeling_pi05 import PI05Policy + + return PI05Policy elif name == "sac": from lerobot.policies.sac.modeling_sac import SACPolicy @@ -132,10 +137,12 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return ACTConfig(**kwargs) elif policy_type == "vqbet": return VQBeTConfig(**kwargs) - elif policy_type == "pi0": - return PI0Config(**kwargs) elif policy_type == "pi0fast": return PI0FASTConfig(**kwargs) + elif policy_type == "pi0": + return PI0Config(**kwargs) + elif policy_type == "pi05": + return PI05Config(**kwargs) elif policy_type == "sac": return SACConfig(**kwargs) elif policy_type == "smolvla": @@ -253,6 +260,14 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, PI0FASTConfig): + from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors + + processors = make_pi0fast_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + elif isinstance(policy_cfg, PI0Config): from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors @@ -261,10 +276,10 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) - elif isinstance(policy_cfg, PI0FASTConfig): - from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors + elif isinstance(policy_cfg, PI05Config): + from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors - processors = make_pi0fast_pre_post_processors( + processors = make_pi05_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), ) diff --git a/src/lerobot/policies/pi0/README.md b/src/lerobot/policies/pi0/README.md new file mode 100644 index 000000000..65b331e51 --- /dev/null +++ b/src/lerobot/policies/pi0/README.md @@ -0,0 +1,49 @@ +# π₀ (pi0) + +This repository contains the Hugging Face port of **π₀**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence. +It is designed as a **Vision-Language-Action model for general robot control**. + +--- + +## Model Overview + +| Feature | π₀ | π₀.₅ | +| -------------------- | ------------------------------------------------------ | ----------------------------------------- | +| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning | +| AdaRMS | Not used | Used in action expert | +| Tokenizer Length | 48 tokens | 200 tokens | +| Discrete State Input | False (Uses `state_proj` layer) | True | +| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) | + +--- + +## Citation + +If you use this work, please cite both **OpenPI** and the π₀ paper: + +```bibtex +@misc{openpi2024, + author = {Physical Intelligence Lab}, + title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies}, + year = {2024}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/Physical-Intelligence/openpi}}, + license = {Apache-2.0} +} + +@misc{black2024pi0visionlanguageactionflowmodel, + title = {π₀: A Vision-Language-Action Flow Model for General Robot Control}, + author = {Kevin Black and Noah Brown and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Lucy Xiaoyang Shi and James Tanner and Quan Vuong and Anna Walling and Haohuan Wang and Ury Zhilinsky}, + year = {2024}, + eprint = {2410.24164}, + archivePrefix= {arXiv}, + primaryClass = {cs.LG}, + url = {https://arxiv.org/abs/2410.24164}, +} +``` + +--- + +## License + +This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi). diff --git a/src/lerobot/policies/pi0/__init__.py b/src/lerobot/policies/pi0/__init__.py new file mode 100644 index 000000000..ea3095b4e --- /dev/null +++ b/src/lerobot/policies/pi0/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_pi0 import PI0Config +from .modeling_pi0 import PI0Policy +from .processor_pi0 import make_pi0_pre_post_processors + +__all__ = ["PI0Config", "PI0Policy", "make_pi0_pre_post_processors"] diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py index bd5bbf7ee..cc1cda9d8 100644 --- a/src/lerobot/policies/pi0/configuration_pi0.py +++ b/src/lerobot/policies/pi0/configuration_pi0.py @@ -1,4 +1,6 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,20 +19,40 @@ from dataclasses import dataclass, field from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig -from lerobot.optim.schedulers import ( - CosineDecayWithWarmupSchedulerConfig, -) +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.utils.constants import OBS_IMAGES @PreTrainedConfig.register_subclass("pi0") @dataclass class PI0Config(PreTrainedConfig): - # Input / output structure. - n_obs_steps: int = 1 - chunk_size: int = 50 - n_action_steps: int = 50 + paligemma_variant: str = "gemma_2b" + action_expert_variant: str = "gemma_300m" + dtype: str = "float32" # Options: "bfloat16", "float32" + n_obs_steps: int = 1 + chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon" + n_action_steps: int = 50 # Number of action steps to execute + + # Shorter state and action vectors will be padded to these dimensions + max_state_dim: int = 32 + max_action_dim: int = 32 + + # Flow matching parameters: see openpi `PI0Pytorch` + num_inference_steps: int = 10 # Number of denoising steps during inference + time_sampling_beta_alpha: float = 1.5 + time_sampling_beta_beta: float = 1.0 + time_sampling_scale: float = 0.999 + time_sampling_offset: float = 0.001 + min_period: float = 4e-3 + max_period: float = 4.0 + + image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py` + + # Add empty images. Used to add empty cameras when no image features are present. + empty_cameras: int = 0 + + # Normalization normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { "VISUAL": NormalizationMode.IDENTITY, @@ -39,94 +61,75 @@ class PI0Config(PreTrainedConfig): } ) - # Shorter state and action vectors will be padded - max_state_dim: int = 32 - max_action_dim: int = 32 + # Training settings + gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization + compile_model: bool = False # Whether to use torch.compile for model optimization + compile_mode: str = "max-autotune" # Torch compile mode + device: str | None = None # Device to use for the model (None = auto-detect) - # Image preprocessing - resize_imgs_with_padding: tuple[int, int] = (224, 224) - - # Add empty images. Used by pi0_aloha_sim which adds the empty - # left and right wrist cameras in addition to the top camera. - empty_cameras: int = 0 - - # Converts the joint and gripper values from the standard Aloha space to - # the space used by the pi internal runtime which was used to train the base model. - adapt_to_pi_aloha: bool = False - - # Converts joint dimensions to deltas with respect to the current state before passing to the model. - # Gripper dimensions will remain in absolute values. - use_delta_joint_actions_aloha: bool = False - - # Tokenizer - tokenizer_max_length: int = 48 - - # Projector - proj_width: int = 1024 - - # Decoding - num_steps: int = 10 - - # Attention utils - use_cache: bool = True - attention_implementation: str = "eager" # or fa2, flex - - # Finetuning settings - freeze_vision_encoder: bool = True - train_expert_only: bool = False - train_state_proj: bool = True - - # Training presets - optimizer_lr: float = 2.5e-5 + # Optimizer settings: see openpi `AdamW`` + optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr` optimizer_betas: tuple[float, float] = (0.9, 0.95) optimizer_eps: float = 1e-8 - optimizer_weight_decay: float = 1e-10 + optimizer_weight_decay: float = 0.01 + optimizer_grad_clip_norm: float = 1.0 + # Scheduler settings: see openpi `CosineDecaySchedule` scheduler_warmup_steps: int = 1_000 scheduler_decay_steps: int = 30_000 scheduler_decay_lr: float = 2.5e-6 - # TODO: Add EMA + tokenizer_max_length: int = 48 # see openpi `__post_init__` def __post_init__(self): super().__post_init__() - # TODO(Steven): Validate device and amp? in all policy configs? - """Input validation (not exhaustive).""" + # Validate configuration if self.n_action_steps > self.chunk_size: raise ValueError( - f"The chunk size is the upper bound for the number of action steps per model invocation. Got " - f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." - ) - if self.n_obs_steps != 1: - raise ValueError( - f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" + f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})" ) - if self.use_delta_joint_actions_aloha: - raise NotImplementedError( - "`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot." - ) + if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]: + raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}") + + if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]: + raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}") + + if self.dtype not in ["bfloat16", "float32"]: + raise ValueError(f"Invalid dtype: {self.dtype}") def validate_features(self) -> None: - # TODO: implement value error - # if not self.image_features and not self.env_state_feature: - # raise ValueError("You must provide at least one image or the environment state among the inputs.") - + """Validate and set up input/output features.""" for i in range(self.empty_cameras): key = f"{OBS_IMAGES}.empty_camera_{i}" empty_camera = PolicyFeature( type=FeatureType.VISUAL, - shape=(3, 480, 640), + shape=(3, *self.image_resolution), # Use configured image resolution ) self.input_features[key] = empty_camera + if "observation.state" not in self.input_features: + state_feature = PolicyFeature( + type=FeatureType.STATE, + shape=(self.max_state_dim,), # Padded to max_state_dim + ) + self.input_features["observation.state"] = state_feature + + if "action" not in self.output_features: + action_feature = PolicyFeature( + type=FeatureType.ACTION, + shape=(self.max_action_dim,), # Padded to max_action_dim + ) + self.output_features["action"] = action_feature + def get_optimizer_preset(self) -> AdamWConfig: return AdamWConfig( lr=self.optimizer_lr, betas=self.optimizer_betas, eps=self.optimizer_eps, weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, ) def get_scheduler_preset(self): diff --git a/src/lerobot/policies/pi0/conversion_scripts/benchmark.py b/src/lerobot/policies/pi0/conversion_scripts/benchmark.py deleted file mode 100644 index c1a488244..000000000 --- a/src/lerobot/policies/pi0/conversion_scripts/benchmark.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -from lerobot.configs.policies import PreTrainedConfig -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.policies.factory import make_policy - -torch.backends.cudnn.benchmark = True - - -def main(): - device = "cuda" - dataset_repo_id = "danaaubakirova/koch_test" - # model_name = "pi0_base" - # ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch" - ckpt_torch_dir = "lerobot/pi0" - - dataset = LeRobotDataset(dataset_repo_id, episodes=[0]) - - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=0, - batch_size=1, - ) - - batch = next(iter(dataloader)) - - # To device - for k in batch: - if isinstance(batch[k], torch.Tensor): - batch[k] = batch[k].to(device=device, dtype=torch.float32) - - cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir) - cfg.pretrained_path = ckpt_torch_dir - policy = make_policy(cfg, ds_meta=dataset.meta) - - # policy = torch.compile(policy, mode="reduce-overhead") - - warmup_iters = 10 - benchmark_iters = 30 - - # Warmup - for _ in range(warmup_iters): - torch.cuda.synchronize() - policy.select_action(batch) - policy.reset() - torch.cuda.synchronize() - - # Benchmark - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for _ in range(benchmark_iters): - policy.select_action(batch) - policy.reset() - end_event.record() - - # Synchronize and measure time - torch.cuda.synchronize() - elapsed_time_ms = start_event.elapsed_time(end_event) - - avg_time_per_iter = elapsed_time_ms / benchmark_iters - print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms") - - -if __name__ == "__main__": - with torch.inference_mode(): - main() diff --git a/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py b/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py deleted file mode 100644 index dad7d002e..000000000 --- a/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import pickle -from pathlib import Path - -import torch - -from lerobot.configs.policies import PreTrainedConfig -from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata -from lerobot.policies.factory import make_policy -from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE - - -def display(tensor: torch.Tensor): - if tensor.dtype == torch.bool: - tensor = tensor.float() - print(f"Shape: {tensor.shape}") - print(f"Mean: {tensor.mean().item()}") - print(f"Std: {tensor.std().item()}") - print(f"Min: {tensor.min().item()}") - print(f"Max: {tensor.max().item()}") - - -def main(): - num_motors = 14 - device = "cuda" - # model_name = "pi0_aloha_towel" - model_name = "pi0_aloha_sim" - - if model_name == "pi0_aloha_towel": - dataset_repo_id = "lerobot/aloha_static_towel" - else: - dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human" - - ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch" - ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}" - save_dir = Path(f"../openpi/data/{model_name}/save") - - with open(save_dir / "example.pkl", "rb") as f: - example = pickle.load(f) - with open(save_dir / "outputs.pkl", "rb") as f: - outputs = pickle.load(f) - with open(save_dir / "noise.pkl", "rb") as f: - noise = pickle.load(f) - - with open(ckpt_jax_dir / "assets/norm_stats.json") as f: - norm_stats = json.load(f) - - # Override stats - dataset_meta = LeRobotDatasetMetadata(dataset_repo_id) - dataset_meta.stats[OBS_STATE]["mean"] = torch.tensor( - norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32 - ) - dataset_meta.stats[OBS_STATE]["std"] = torch.tensor( - norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32 - ) - - # Create LeRobot batch from Jax - batch = {} - for cam_key, uint_chw_array in example["images"].items(): - batch[f"{OBS_IMAGES}.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0 - batch[OBS_STATE] = torch.from_numpy(example["state"]) - batch[ACTION] = torch.from_numpy(outputs["actions"]) - batch["task"] = example["prompt"] - - if model_name == "pi0_aloha_towel": - del batch[f"{OBS_IMAGES}.cam_low"] - elif model_name == "pi0_aloha_sim": - batch[f"{OBS_IMAGES}.top"] = batch[f"{OBS_IMAGES}.cam_high"] - del batch[f"{OBS_IMAGES}.cam_high"] - - # Batchify - for key in batch: - if isinstance(batch[key], torch.Tensor): - batch[key] = batch[key].unsqueeze(0) - elif isinstance(batch[key], str): - batch[key] = [batch[key]] - else: - raise ValueError(f"{key}, {batch[key]}") - - # To device - for k in batch: - if isinstance(batch[k], torch.Tensor): - batch[k] = batch[k].to(device=device, dtype=torch.float32) - - noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32) - - from lerobot import policies # noqa - - cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir) - cfg.pretrained_path = ckpt_torch_dir - policy = make_policy(cfg, dataset_meta) - - # loss_dict = policy.forward(batch, noise=noise, time=time_beta) - # loss_dict["loss"].backward() - # print("losses") - # display(loss_dict["losses_after_forward"]) - # print("pi_losses") - # display(pi_losses) - - actions = [] - for _ in range(50): - action = policy.select_action(batch, noise=noise) - actions.append(action) - - actions = torch.stack(actions, dim=1) - pi_actions = batch[ACTION] - print("actions") - display(actions) - print() - print("pi_actions") - display(pi_actions) - print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2)) - print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2)) - print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2)) - - -if __name__ == "__main__": - main() diff --git a/src/lerobot/policies/pi0/conversion_scripts/conversion_utils.py b/src/lerobot/policies/pi0/conversion_scripts/conversion_utils.py deleted file mode 100644 index 8835da31e..000000000 --- a/src/lerobot/policies/pi0/conversion_scripts/conversion_utils.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from transformers import GemmaConfig, PaliGemmaConfig - - -def get_paligemma_config(precision: str): - config = { - "image_token_index": None, - "pad_token_id": 0, - "bos_token_id": 2, - "eos_token_id": 1, - } - - # image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896} - - image_size = 224 # image_sizes[variant] - patch_size = 14 - num_image_tokens = (image_size**2) // (patch_size**2) - - config["image_token_index"] = 257152 - text_config = { - "vocab_size": 257152, - "num_hidden_layers": 18, - "num_key_value_heads": 1, - "head_dim": 256, - "torch_dtype": precision, - "hidden_size": 2048, - "hidden_activation": "gelu_pytorch_tanh", - "num_attention_heads": 8, - "intermediate_size": 16384, - "is_encoder_decoder": False, - } - vision_config = { - "torch_dtype": precision, - "image_size": image_size, - "patch_size": patch_size, - "num_image_tokens": num_image_tokens, - "hidden_size": 1152, - "intermediate_size": 4304, - "num_hidden_layers": 27, - "num_attention_heads": 16, - "projector_hidden_act": "gelu_fast", - "vision_use_head": False, - } - final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config) - return final_config - - -def get_gemma_config(precision: str): - config = { - "image_token_index": None, - "pad_token_id": 0, - "bos_token_id": 2, - "eos_token_id": 1, - } - - config["image_token_index"] = 257152 - text_config = { - "vocab_size": 257152, - "num_hidden_layers": 18, - "num_key_value_heads": 1, - "head_dim": 256, - "torch_dtype": precision, - "hidden_size": 1024, - "hidden_activation": "gelu_pytorch_tanh", - "num_attention_heads": 8, - "intermediate_size": 4096, - "is_encoder_decoder": False, - } - final_config = GemmaConfig() - final_config.update(text_config) - return final_config diff --git a/src/lerobot/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py b/src/lerobot/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py deleted file mode 100644 index 742c9ab3f..000000000 --- a/src/lerobot/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py +++ /dev/null @@ -1,437 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Convert pi0 parameters from Jax to Pytorch - -Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment -and install the required libraries. - -```bash -cd ~/code/openpi -source .venv/bin/activate -``` - -Example downloading parameters: -```bash -python ->>> import openpi.shared.download as download ->>> path='s3://openpi-assets/checkpoints/pi0_base/params' ->>> download.maybe_download(path) -``` - -Converting pi0_base: -```python -python -m lerobot.policies.pi0.conversion_scripts.convert_pi0_to_hf_lerobot \ - --checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \ - --output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch -``` - -```python -python -m lerobot.policies.pi0.conversion_scripts.convert_pi0_to_hf_lerobot \ - --checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \ - --output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch -``` -""" - -import argparse -import pathlib - -import jax -import numpy as np -import orbax.checkpoint as ocp -import torch -from jax.sharding import SingleDeviceSharding - -from lerobot.policies.pi0.configuration_pi0 import PI0Config -from lerobot.policies.pi0.conversion_scripts.conversion_utils import ( - get_gemma_config, - get_paligemma_config, -) -from lerobot.policies.pi0.modeling_pi0 import PI0Policy - -PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16} - - -def slice_paligemma_state_dict(state_dict, config): - suffix = "/value" if "img/embedding/kernel/value" in state_dict else "" - - # fmt: off - # patch embeddings - state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose( - 3, 2, 0, 1 - ) - state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}") - # positional embeddings - state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape( - -1, config.vision_config.hidden_size - ) - - # extract vision layers to be sliced at index 0. There are 27 layers in the base model. - encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}") - encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}") - encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}") - encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}") - - encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}") - encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}") - encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}") - encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}") - - encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}") - encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}") - encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}") - encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}") - encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}") - encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}") - encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}") - encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}") - - for i in range(config.vision_config.num_hidden_layers): - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose() - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i] - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose() - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i] - - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose() - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i] - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose() - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i] - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) - - state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose() - state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}") - - # multimodal projector - - state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose() - state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}") - - # text decoder (gemma) - embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}") - state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector - - # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b. - - llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}") - llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}") - llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}") - - llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}") - llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}") - # TODO verify correctness of layer norm loading - - llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}") - llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}") - - for i in range(config.text_config.num_hidden_layers): - # llm_attention_q_einsum[i].shape = (8, 2048, 256) - q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) - - state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped - - # llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256) - k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() - state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped - # llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256) - v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() - state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped - - # output projection. - - # llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048) - o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) - - state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped - # mlp layers - gate_proj_weight = llm_mlp_gating_einsum[i, 0] - state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() - up_proj_weight = llm_mlp_gating_einsum[i, 1] - state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() - state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() - state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] - state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] - - state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}") - state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied. - - # fmt: on - expert_dict = {} - final_state_dict = {} - for key, value in state_dict.items(): - if key not in [ - f"llm/final_norm_1/scale{suffix}", - f"llm/layers/attn/attn_vec_einsum_1/w{suffix}", - f"llm/layers/attn/kv_einsum_1/w{suffix}", - f"llm/layers/attn/q_einsum_1/w{suffix}", - f"llm/layers/mlp_1/gating_einsum{suffix}", - f"llm/layers/mlp_1/linear{suffix}", - f"llm/layers/pre_attention_norm_1/scale{suffix}", - f"llm/layers/pre_ffw_norm_1/scale{suffix}", - ]: - final_state_dict[key] = torch.from_numpy(value) - else: - expert_dict[key] = value - - return final_state_dict, expert_dict - - -def slice_gemma_state_dict(state_dict, config, num_expert=1): - # fmt: off - # text decoder (gemma) - # no embedding vector, the expert just has the decoder layers - - embedding_vector = torch.zeros([config.vocab_size, config.hidden_size]) - state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector - - # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b. - - suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else "" - - llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}") - llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}") - llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}") - - llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}") - llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}") - # TODO verify correctness of layer norm loading - - llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}") - llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}") - - for i in range(config.num_hidden_layers): - q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size) - - state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped - - k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() - state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped - v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() - state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped - - # output projection. - - # llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024) - o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0) - - state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped - # mlp layers - gate_proj_weight = llm_mlp_gating_einsum[i, 0] - state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() - up_proj_weight = llm_mlp_gating_einsum[i, 1] - state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() - state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() - state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] - state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] - - state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}") - state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here) - - # fmt: on - final_state_dict = {} - for key, value in state_dict.items(): - if not isinstance(value, torch.Tensor): - final_state_dict[key] = torch.from_numpy(value) - else: - final_state_dict[key] = value - return final_state_dict - - -def flatten_for_memory(tree, parent_key=""): - out = {} - for k, v in tree.items(): - new_key = f"{parent_key}/{k}" if parent_key else k - if isinstance(v, dict): - out.update(flatten_for_memory(v, new_key)) - else: - out[new_key] = np.array(v) # Ensure conversion to np.array for consistency - return out - - -def flatten_for_npz(tree, parent_key=""): - out = {} - for k, v in tree.items(): - new_key = f"{parent_key}/{k}" if parent_key else k - if isinstance(v, dict): - out.update(flatten_for_npz(v, new_key)) - else: - # bf16/f32 here? - out[new_key] = np.array(v) - return out - - -def slice_initial_orbax_checkpoint(checkpoint_dir: str): - params_path = pathlib.Path(checkpoint_dir).resolve() - checkpointer = ocp.PyTreeCheckpointer() - - metadata = checkpointer.metadata(params_path) - print("Metadata keys:", list(metadata.keys())) - - params_name = "params" - - item = {params_name: metadata[params_name]} - device = jax.local_devices()[0] # Use the first local device - sharding = SingleDeviceSharding(device) - restored = checkpointer.restore( - params_path, - ocp.args.PyTreeRestore( - item=item, - restore_args=jax.tree_util.tree_map( - lambda _: ocp.ArrayRestoreArgs( - restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it - sharding=sharding, - ), - item, - ), - transforms={}, - ), - ) - params = restored[params_name] - - # get params for PaliGemma - pali_params = params["PaliGemma"] - del params["PaliGemma"] - pali_params_flat = flatten_for_npz(pali_params) - return {"paligemma_params": pali_params_flat, "projection_params": params} - - -def update_keys_with_prefix(d: dict, prefix: str) -> dict: - """Update dictionary keys by adding a prefix.""" - return {f"{prefix}{key}": value for key, value in d.items()} - - -def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str): - # Break down orbax ckpts - they are in OCDBT - initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir) - # process projection params - keys = [ - "state_proj", - "action_in_proj", - "action_out_proj", - "action_time_mlp_in", - "action_time_mlp_out", - ] - - projection_params = {} - for key in keys: - kernel_params = initial_params["projection_params"][key]["kernel"] - bias_params = initial_params["projection_params"][key]["bias"] - if isinstance(kernel_params, dict): - weight = kernel_params["value"] - bias = bias_params["value"] - else: - weight = kernel_params - bias = bias_params - projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T - projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias)) - - # Process PaliGemma weights - paligemma_config = get_paligemma_config(precision) - paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict( - initial_params["paligemma_params"], paligemma_config - ) - - # Process Gemma weights (at this stage they are unused) - gemma_config = get_gemma_config(precision) - gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config) - - # Instantiate model from configs - - if "pi0_aloha_sim" in checkpoint_dir: - pi0_config = PI0Config( - empty_cameras=2, - adapt_to_pi_aloha=True, - use_delta_joint_actions_aloha=False, - ) - elif "pi0_aloha_towel" in checkpoint_dir: - pi0_config = PI0Config( - adapt_to_pi_aloha=True, - use_delta_joint_actions_aloha=True, - ) - elif "pi0_base" in checkpoint_dir: - pi0_config = PI0Config( - empty_cameras=0, - adapt_to_pi_aloha=False, - use_delta_joint_actions_aloha=False, - ) - else: - raise ValueError() - - # gemma_config=gemma_config, paligemma_config=paligemma_config) - pi0_model = PI0Policy(pi0_config) - - paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.") - gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.") - projection_params = update_keys_with_prefix(projection_params, "model.") - - # load state dict - torch_dtype = PRECISIONS[precision] - pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params}) - pi0_model = pi0_model.to(torch_dtype) - # pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) - - pi0_model.save_pretrained(output_path, safe_serialization=True) - # pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype) - - # assert that model loads properly - del pi0_model - PI0Policy.from_pretrained(output_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--checkpoint_dir", - default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params", - type=str, - help="Path to the ocdbt checkpoint", - ) - - parser.add_argument( - "--precision", - choices=["float32", "bfloat16", "float16"], - default="float32", - type=str, - help="Precision identifier for model conversion - should match the base checkpoint precision.", - ) - # tokenizer is identical to paligemma, it appears - - parser.add_argument( - "--tokenizer_hub_id", - default="google/paligemma-3b-pt-224", - type=str, - help="Hub path to the tokenizer to save", - ) - - parser.add_argument( - "--output_path", - required=True, - type=str, - help="Path to save converted weights to", - ) - - args = parser.parse_args() - convert_pi0_checkpoint( - checkpoint_dir=args.checkpoint_dir, - precision=args.precision, - tokenizer_id=args.tokenizer_hub_id, - output_path=args.output_path, - ) diff --git a/src/lerobot/policies/pi0/flex_attention.py b/src/lerobot/policies/pi0/flex_attention.py deleted file mode 100644 index 35628cddb..000000000 --- a/src/lerobot/policies/pi0/flex_attention.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn.functional as F # noqa: N812 -from packaging.version import Version - -if Version(torch.__version__) > Version("2.5.0"): - # Ffex attention is only available from torch 2.5 onwards - from torch.nn.attention.flex_attention import ( - _mask_mod_signature, - _round_up_to_multiple, - create_block_mask, - create_mask, - flex_attention, - ) - - -# @torch.compile(dynamic=False) -def flex_attention_forward( - attention_mask: torch.Tensor, - batch_size: int, - head_dim: int, - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - scaling=None, -): - """ - This is defined out of classes to make compile happy. - """ - - original_dtype = query_states.dtype - num_att_heads = 8 - num_key_value_heads = 1 - num_key_value_groups = num_att_heads // num_key_value_heads - - key_states = key_states[:, :, :, None, :] - key_states = key_states.expand( - batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim - ) - key_states = key_states.reshape( - batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim - ) - - value_states = value_states[:, :, :, None, :] - value_states = value_states.expand( - batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim - ) - value_states = value_states.reshape( - batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim - ) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - query_states = query_states.to(torch.float32) - key_states = key_states.to(torch.float32) - value_states = value_states.to(torch.float32) - - causal_mask = attention_mask - if causal_mask is not None: - causal_mask = causal_mask[:, None, :, : key_states.shape[2]] - - if causal_mask.shape[1] == 1 and query_states.shape[1] > 1: - causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1) - - def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature: - def mask_mod(b, h, q_idx, kv_idx): - # Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs. - return precomputed_mask[b][h][q_idx][kv_idx] - - return mask_mod - - b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask - - block_size = 128 - q_len_rounded = _round_up_to_multiple(q_len, block_size) - kv_len_rounded = _round_up_to_multiple(kv_len, block_size) - - # *CRITICAL* we do need to expand here, else we get a CUDA index error - - pad_q = q_len_rounded - q_len - pad_k = kv_len_rounded - kv_len - - padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0) - mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask) - - mask_4d = create_mask( - mod_fn=mask_mod_fn_orig, - B=b_mask, - H=h_mask, - Q_LEN=q_len_rounded, - KV_LEN=kv_len_rounded, - device=causal_mask.device, - _compile=False, - ) - - mask_mod_fn_padded = precomputed_mask_factory(mask_4d) - block_mask = create_block_mask( - mask_mod=mask_mod_fn_padded, - B=b_mask, - H=h_mask, - Q_LEN=q_len_rounded, - KV_LEN=kv_len_rounded, - BLOCK_SIZE=block_size, - device=causal_mask.device, - _compile=False, - ) - - # mask is applied inside the kernel, ideally more efficiently than score_mod. - attn_output, attention_weights = flex_attention( - query_states, - key_states, - value_states, - block_mask=block_mask, - enable_gqa=True, # because we shaped query/key states for GQA - scale=head_dim**-0.5 if scaling is None else scaling, - return_lse=True, - ) - - attn_output = attn_output.to(dtype=original_dtype) - attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim] - attn_output = attn_output.reshape( - batch_size, - -1, - attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim] - ) - return attn_output diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 8406f94fe..a2dcdaea3 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -14,61 +14,58 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -π0: A Vision-Language-Action Flow Model for General Robot Control - -[Paper](https://www.physicalintelligence.company/download/pi0.pdf) -[Jax code](https://github.com/Physical-Intelligence/openpi) - -Designed by Physical Intelligence. Ported from Jax by Hugging Face. -Disclaimer: It is not expected to perform as well as the original implementation. - -Install pi0 extra dependencies: -```bash -pip install -e ".[pi0]" -``` - -Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`): -```bash -lerobot-train \ ---policy.path=lerobot/pi0 \ ---dataset.repo_id=danaaubakirova/koch_test -``` - -Example of finetuning the pi0 neural network with PaliGemma and expert Gemma -pretrained with VLM default parameters before pi0 finetuning: -```bash -lerobot-train \ ---policy.type=pi0 \ ---dataset.repo_id=danaaubakirova/koch_test -``` - -Example of using the pi0 pretrained model outside LeRobot training framework: -```python -policy = Pi0Policy.from_pretrained("lerobot/pi0") -``` - -""" - +import builtins +import logging import math from collections import deque +from pathlib import Path +from typing import TYPE_CHECKING, Literal import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn +from lerobot.utils.import_utils import _transformers_available + +# Conditional import for type checking and lazy loading +if TYPE_CHECKING or _transformers_available: + from transformers.models.auto import CONFIG_MAPPING + from transformers.models.gemma import modeling_gemma + from transformers.models.gemma.modeling_gemma import GemmaForCausalLM + from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration +else: + CONFIG_MAPPING = None + modeling_gemma = None + GemmaForCausalLM = None + PaliGemmaForConditionalGeneration = None + +from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config -from lerobot.policies.pi0.paligemma_with_expert import ( - PaliGemmaWithExpertConfig, - PaliGemmaWithExpertModel, +from lerobot.policies.pretrained import PreTrainedPolicy, T +from lerobot.utils.constants import ( + ACTION, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OBS_STATE, + OPENPI_ATTENTION_MASK_VALUE, ) -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE -from lerobot.utils.utils import get_safe_dtype -def create_sinusoidal_pos_embedding( - time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" +def get_safe_dtype(target_dtype, device_type): + """Get a safe dtype for the given device type.""" + if device_type == "mps" and target_dtype == torch.float64: + return torch.float32 + if device_type == "cpu": + # CPU doesn't support bfloat16, use float32 instead + if target_dtype == torch.bfloat16: + return torch.float32 + if target_dtype == torch.float64: + return torch.float64 + return target_dtype + + +def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy) + time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu" ) -> Tensor: """Computes sine-cosine positional embedding vectors for scalar positions.""" if dimension % 2 != 0: @@ -84,11 +81,17 @@ def create_sinusoidal_pos_embedding( # Compute the outer product scaling_factor = 1.0 / period * 2 * math.pi sin_input = scaling_factor[None, :] * time[:, None] - pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) - return pos_emb + return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) -def make_att_2d_masks(pad_masks, att_masks): +def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy) + alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) + beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) + dist = torch.distributions.Beta(alpha_t, beta_t) + return dist.sample((bsize,)) + + +def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy) """Copied from big_vision. Tokens can attend to valid inputs tokens which have a cumulative mask_ar @@ -117,413 +120,514 @@ def make_att_2d_masks(pad_masks, att_masks): cumsum = torch.cumsum(att_masks, dim=1) att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] - att_2d_masks = att_2d_masks & pad_2d_masks - return att_2d_masks - - -def resize_with_pad(img, width, height, pad_value=-1): - # assume no-op when width height fits already - if img.ndim != 4: - raise ValueError(f"(b,c,h,w) expected, but {img.shape}") - - cur_height, cur_width = img.shape[2:] - - ratio = max(cur_width / width, cur_height / height) - resized_height = int(cur_height / ratio) - resized_width = int(cur_width / ratio) - resized_img = F.interpolate( - img, size=(resized_height, resized_width), mode="bilinear", align_corners=False - ) - - pad_height = max(0, int(height - resized_height)) - pad_width = max(0, int(width - resized_width)) - - # pad on left and top of image - padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) - return padded_img + return att_2d_masks & pad_2d_masks def pad_vector(vector, new_dim): - """Can be (batch_size x sequence_length x features_dimension) + """Pad the last dimension of a vector to new_dim with zeros. + + Can be (batch_size x sequence_length x features_dimension) or (batch_size x features_dimension) """ - if vector.shape[-1] == new_dim: + if vector.shape[-1] >= new_dim: return vector - shape = list(vector.shape) - current_dim = shape[-1] - shape[-1] = new_dim - new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device) - new_vector[..., :current_dim] = vector - return new_vector + return F.pad(vector, (0, new_dim - vector.shape[-1])) -def normalize(x, min_val, max_val): - return (x - min_val) / (max_val - min_val) +def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) + images: torch.Tensor, + height: int, + width: int, + mode: str = "bilinear", +) -> torch.Tensor: + """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion + by padding with black. If the image is float32, it must be in the range [-1, 1]. + + Args: + images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] + height: Target height + width: Target width + mode: Interpolation mode ('bilinear', 'nearest', etc.) + + Returns: + Resized and padded tensor with same shape format as input + """ + # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] + if images.shape[-1] <= 4: # Assume channels-last format + channels_last = True + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] + else: + channels_last = False + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + + batch_size, channels, cur_height, cur_width = images.shape + + # Calculate resize ratio + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + # Resize + resized_images = F.interpolate( + images, + size=(resized_height, resized_width), + mode=mode, + align_corners=False if mode == "bilinear" else None, + ) + + # Handle dtype-specific clipping + if images.dtype == torch.uint8: + resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) + elif images.dtype == torch.float32: + resized_images = resized_images.clamp(-1.0, 1.0) + else: + raise ValueError(f"Unsupported image dtype: {images.dtype}") + + # Calculate padding + pad_h0, remainder_h = divmod(height - resized_height, 2) + pad_h1 = pad_h0 + remainder_h + pad_w0, remainder_w = divmod(width - resized_width, 2) + pad_w1 = pad_w0 + remainder_w + + # Pad + constant_value = 0 if images.dtype == torch.uint8 else -1.0 + padded_images = F.pad( + resized_images, + (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom + mode="constant", + value=constant_value, + ) + + # Convert back to original format if needed + if channels_last: + padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + return padded_images -def unnormalize(x, min_val, max_val): - return x * (max_val - min_val) + min_val +# Define the complete layer computation function for gradient checkpointing +def compute_layer_complete( + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert +): + models = [paligemma.language_model, gemma_expert.model] + query_states = [] + key_states = [] + value_states = [] + gates = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 + gates.append(gate) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + # Concatenate and process attention + query_states = torch.cat(query_states, dim=2) + key_states = torch.cat(key_states, dim=2) + value_states = torch.cat(value_states, dim=2) + dummy_tensor = torch.zeros( + query_states.shape[0], + query_states.shape[2], + query_states.shape[-1], + device=query_states.device, + dtype=query_states.dtype, + ) + cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + query_states, key_states = modeling_gemma.apply_rotary_pos_emb( + query_states, key_states, cos, sin, unsqueeze_dim=1 + ) + batch_size = query_states.shape[0] + scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling + # Attention computation + att_output, _ = modeling_gemma.eager_attention_forward( + paligemma.language_model.layers[layer_idx].self_attn, + query_states, + key_states, + value_states, + attention_mask, + scaling, + ) + # Get head_dim from the current layer, not from the model + head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim + att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) + # Process layer outputs + outputs_embeds = [] + start_pos = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + end_pos = start_pos + hidden_states.shape[1] + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) + # first residual + out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + after_first_residual = out_emb.clone() + out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + # Convert to bfloat16 if the next layer (mlp) uses bfloat16 + if layer.mlp.up_proj.weight.dtype == torch.bfloat16: + out_emb = out_emb.to(dtype=torch.bfloat16) + out_emb = layer.mlp(out_emb) + # second residual + out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 + outputs_embeds.append(out_emb) + start_pos = end_pos + return outputs_embeds -def safe_arcsin(value): - # This ensures that the input stays within - # [−1,1] to avoid invalid values for arcsin - return torch.arcsin(torch.clamp(value, -1.0, 1.0)) +class GemmaConfig: # see openpi `gemma.py: Config` + """Configuration for Gemma model variants.""" + + def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim): + self.width = width + self.depth = depth + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim -def aloha_gripper_to_angular(value): - # Aloha transforms the gripper positions into a linear space. The following code - # reverses this transformation to be consistent with pi0 which is pretrained in - # angular space. - # - # These values are coming from the Aloha code: - # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED - value = unnormalize(value, min_val=0.01844, max_val=0.05800) - - # This is the inverse of the angular to linear transformation inside the Interbotix code. - def linear_to_radian(linear_position, arm_length, horn_radius): - value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) - return safe_arcsin(value) - - # The constants are taken from the Interbotix code. - value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) - - # Normalize to [0, 1]. - # The values 0.4 and 1.5 were measured on an actual Trossen robot. - return normalize(value, min_val=0.4, max_val=1.5) +def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config` + """Returns config for specified gemma variant.""" + if variant == "gemma_300m": + return GemmaConfig( + width=1024, + depth=18, + mlp_dim=4096, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + elif variant == "gemma_2b": + return GemmaConfig( + width=2048, + depth=18, + mlp_dim=16_384, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + else: + raise ValueError(f"Unknown variant: {variant}") -def aloha_gripper_from_angular(value): - # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. - # Note that the units are still angular but the range is different. - - # The values 0.4 and 1.5 were measured on an actual Trossen robot. - value = unnormalize(value, min_val=0.4, max_val=1.5) - - # These values are coming from the Aloha code: - # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE - return normalize(value, min_val=-0.6213, max_val=1.4910) - - -def aloha_gripper_from_angular_inv(value): - # Directly inverts the gripper_from_angular function. - value = unnormalize(value, min_val=-0.6213, max_val=1.4910) - return normalize(value, min_val=0.4, max_val=1.5) - - -class PI0Policy(PreTrainedPolicy): - """Wrapper class around PI0FlowMatching model to train and run inference within LeRobot.""" - - config_class = PI0Config - name = "pi0" +class PaliGemmaWithExpertModel( + nn.Module +): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi + """PaliGemma model with action expert for PI0.""" def __init__( self, - config: PI0Config, + vlm_config, + action_expert_config, + use_adarms=None, + precision: Literal["bfloat16", "float32"] = "bfloat16", ): - """ - Args: - config: Policy configuration class instance or None, in which case the default instantiation of - the configuration class is used. - """ + if use_adarms is None: + use_adarms = [False, False] + super().__init__() - super().__init__(config) - config.validate_features() - self.config = config + vlm_config_hf = CONFIG_MAPPING["paligemma"]() + vlm_config_hf._vocab_size = 257152 # noqa: SLF001 + vlm_config_hf.image_token_index = 257152 + vlm_config_hf.text_config.hidden_size = vlm_config.width + vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim + vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads + vlm_config_hf.text_config.head_dim = vlm_config.head_dim + vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth + vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads + vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" + vlm_config_hf.text_config.torch_dtype = "float32" + vlm_config_hf.text_config.vocab_size = 257152 + vlm_config_hf.text_config.use_adarms = use_adarms[0] + vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None + vlm_config_hf.vision_config.intermediate_size = 4304 + vlm_config_hf.vision_config.projection_dim = 2048 + vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" + vlm_config_hf.vision_config.torch_dtype = "float32" - self.model = PI0FlowMatching(config) - - self.reset() - - def reset(self): - """This should be called whenever the environment is reset.""" - self._action_queue = deque([], maxlen=self.config.n_action_steps) - - def get_optim_params(self) -> dict: - return self.parameters() - - @classmethod - def from_pretrained(cls, *args, **kwargs): - """Override the from_pretrained method to display important disclaimer.""" - print( - "⚠️ DISCLAIMER: The PI0 model is ported from JAX by the Hugging Face team. \n" - " It is not expected to perform as well as the original implementation. \n" - " Original implementation: https://github.com/Physical-Intelligence/openpi" + action_expert_config_hf = CONFIG_MAPPING["gemma"]( + head_dim=action_expert_config.head_dim, + hidden_size=action_expert_config.width, + intermediate_size=action_expert_config.mlp_dim, + num_attention_heads=action_expert_config.num_heads, + num_hidden_layers=action_expert_config.depth, + num_key_value_heads=action_expert_config.num_kv_heads, + vocab_size=257152, + hidden_activation="gelu_pytorch_tanh", + torch_dtype="float32", + use_adarms=use_adarms[1], + adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, ) - return super().from_pretrained(*args, **kwargs) - @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: - """Predict a chunk of actions given environment observations.""" - raise NotImplementedError("Currently not implemented for PI0") + self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) + self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf) + self.gemma_expert.model.embed_tokens = None - @torch.no_grad() - def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: - """Select a single action given environment observations. + self.to_bfloat16_for_selected_params(precision) - This method wraps `select_actions` in order to return one action at a time for execution in the - environment. It works by managing the actions in a queue and only calling `select_actions` when the - queue is empty. - """ - self.eval() + def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): + if precision == "bfloat16": + self.to(dtype=torch.bfloat16) + elif precision == "float32": + self.to(dtype=torch.float32) + return + else: + raise ValueError(f"Invalid precision: {precision}") - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) + params_to_keep_float32 = [ + "vision_tower.vision_model.embeddings.patch_embedding.weight", + "vision_tower.vision_model.embeddings.patch_embedding.bias", + "vision_tower.vision_model.embeddings.position_embedding.weight", + "input_layernorm", + "post_attention_layernorm", + "model.norm", + ] - # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by - # querying the policy. - if len(self._action_queue) == 0: - images, img_masks = self.prepare_images(batch) - state = self.prepare_state(batch) - lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] - lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + for name, param in self.named_parameters(): + if any(selector in name for selector in params_to_keep_float32): + param.data = param.data.to(dtype=torch.float32) - actions = self.model.sample_actions( - images, img_masks, lang_tokens, lang_masks, state, noise=noise + def embed_image(self, image: torch.Tensor): + return self.paligemma.model.get_image_features(image) + + def embed_language_tokens(self, tokens: torch.Tensor): + return self.paligemma.language_model.embed_tokens(tokens) + + def forward( + self, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: list[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + adarms_cond: list[torch.Tensor] | None = None, + ): + if adarms_cond is None: + adarms_cond = [None, None] + if inputs_embeds[1] is None: + prefix_output = self.paligemma.language_model.forward( + inputs_embeds=inputs_embeds[0], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[0] if adarms_cond is not None else None, ) - - # Unpad actions - original_action_dim = self.config.action_feature.shape[0] - actions = actions[:, :, :original_action_dim] - - if self.config.adapt_to_pi_aloha: - actions = self._pi_aloha_encode_actions(actions) - - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(actions.transpose(0, 1)) - return self._action_queue.popleft() - - def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]: - """Do a full training forward pass to compute the loss""" - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) - - images, img_masks = self.prepare_images(batch) - state = self.prepare_state(batch) - lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] - lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] - actions = self.prepare_action(batch) - actions_is_pad = batch.get("action_is_pad") - - loss_dict = {} - losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) - loss_dict["losses_after_forward"] = losses.clone() - - if actions_is_pad is not None: - in_episode_bound = ~actions_is_pad - losses = losses * in_episode_bound.unsqueeze(-1) - loss_dict["losses_after_in_ep_bound"] = losses.clone() - - # Remove padding - losses = losses[:, :, : self.config.max_action_dim] - loss_dict["losses_after_rm_padding"] = losses.clone() - - # For backward pass - loss = losses.mean() - # For logging - loss_dict["l2_loss"] = loss.item() - - return loss, loss_dict - - def prepare_images(self, batch): - """Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and - convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP. - """ - images = [] - img_masks = [] - - present_img_keys = [key for key in self.config.image_features if key in batch] - missing_img_keys = [key for key in self.config.image_features if key not in batch] - - if len(present_img_keys) == 0: - raise ValueError( - f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" + prefix_past_key_values = prefix_output.past_key_values + prefix_output = prefix_output.last_hidden_state + suffix_output = None + elif inputs_embeds[0] is None: + suffix_output = self.gemma_expert.model.forward( + inputs_embeds=inputs_embeds[1], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[1] if adarms_cond is not None else None, ) + suffix_output = suffix_output.last_hidden_state + prefix_output = None + prefix_past_key_values = None + else: + models = [self.paligemma.language_model, self.gemma_expert.model] + num_layers = self.paligemma.config.text_config.num_hidden_layers - # Preprocess image features present in the batch - for key in present_img_keys: - img = batch[key] + # Check if gradient checkpointing is enabled for any of the models + use_gradient_checkpointing = ( + hasattr(self.gemma_expert.model, "gradient_checkpointing") + and self.gemma_expert.model.gradient_checkpointing + and self.training + ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) - if self.config.resize_imgs_with_padding is not None: - img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0) + # Process all layers with gradient checkpointing if enabled + for layer_idx in range(num_layers): + if use_gradient_checkpointing: + inputs_embeds = torch.utils.checkpoint.checkpoint( + compute_layer_complete, + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + paligemma=self.paligemma, + gemma_expert=self.gemma_expert, + ) + else: + inputs_embeds = compute_layer_complete( + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms_cond, + paligemma=self.paligemma, + gemma_expert=self.gemma_expert, + ) - # Normalize from range [0,1] to [-1,1] as expected by siglip - img = img * 2.0 - 1.0 + # final norm + def compute_final_norms(inputs_embeds, adarms_cond): + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + outputs_embeds.append(out_emb) + return outputs_embeds - bsize = img.shape[0] - device = img.device - mask = torch.ones(bsize, dtype=torch.bool, device=device) - images.append(img) - img_masks.append(mask) + # Apply gradient checkpointing to final norm if enabled + if use_gradient_checkpointing: + outputs_embeds = torch.utils.checkpoint.checkpoint( + compute_final_norms, + inputs_embeds, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond) - # Create image features not present in the batch - # as fully 0 padded images. - for num_empty_cameras in range(len(missing_img_keys)): - if num_empty_cameras >= self.config.empty_cameras: - break - img = torch.ones_like(img) * -1 - mask = torch.zeros_like(mask) - images.append(img) - img_masks.append(mask) + prefix_output = outputs_embeds[0] + suffix_output = outputs_embeds[1] + prefix_past_key_values = None - return images, img_masks - - def _pi_aloha_decode_state(self, state): - # Flip the joints. - for motor_idx in [1, 2, 8, 9]: - state[:, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) - return state - - def _pi_aloha_encode_actions(self, actions): - # Flip the joints. - for motor_idx in [1, 2, 8, 9]: - actions[:, :, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) - return actions - - def _pi_aloha_encode_actions_inv(self, actions): - # Flip the joints again. - for motor_idx in [1, 2, 8, 9]: - actions[:, :, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) - return actions - - def prepare_state(self, batch): - """Pad state""" - state = pad_vector(batch[OBS_STATE], self.config.max_state_dim) - return state - - def prepare_action(self, batch): - """Pad action""" - actions = pad_vector(batch[ACTION], self.config.max_action_dim) - return actions + return [prefix_output, suffix_output], prefix_past_key_values -class PI0FlowMatching(nn.Module): - """ - π0: A Vision-Language-Action Flow Model for General Robot Control - - [Paper](https://www.physicalintelligence.company/download/pi0.pdf) - [Jax code](https://github.com/Physical-Intelligence/openpi) - - Designed by Physical Intelligence. Ported from Jax by Hugging Face. - ┌──────────────────────────────┐ - │ actions │ - │ ▲ │ - │ ┌┴─────┐ │ - │ kv cache │Gemma │ │ - │ ┌──────────►│Expert│ │ - │ │ │ │ │ - │ ┌┴────────┐ │x 10 │ │ - │ │ │ └▲──▲──┘ │ - │ │PaliGemma│ │ │ │ - │ │ │ │ robot state │ - │ │ │ noise │ - │ └▲──▲─────┘ │ - │ │ │ │ - │ │ image(s) │ - │ language tokens │ - └──────────────────────────────┘ - """ +class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` + """Core PI0 PyTorch model.""" def __init__(self, config: PI0Config): super().__init__() self.config = config - paligemma_with_export_config = PaliGemmaWithExpertConfig( - freeze_vision_encoder=self.config.freeze_vision_encoder, - train_expert_only=self.config.train_expert_only, - attention_implementation=self.config.attention_implementation, + paligemma_config = get_gemma_config(config.paligemma_variant) + action_expert_config = get_gemma_config(config.action_expert_variant) + + self.paligemma_with_expert = PaliGemmaWithExpertModel( + paligemma_config, + action_expert_config, + use_adarms=[False, False], + precision=config.dtype, ) - self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config) - # Projections are float32 - self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width) - self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width) - self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim) + self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) + self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim) - self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width) - self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width) + self.state_proj = nn.Linear(config.max_state_dim, action_expert_config.width) + self.action_time_mlp_in = nn.Linear(2 * action_expert_config.width, action_expert_config.width) + self.action_time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) - self.set_requires_grad() + # Initialize gradient checkpointing flag + self.gradient_checkpointing_enabled = False - def set_requires_grad(self): - for params in self.state_proj.parameters(): - params.requires_grad = self.config.train_state_proj + # Compile model if requested + if config.compile_model: + torch.set_float32_matmul_precision("high") + self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) + # Also compile the main forward pass used during training + self.forward = torch.compile(self.forward, mode=config.compile_mode) + + msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" + + try: + from transformers.models.siglip import check + + if not check.check_whether_transformers_replace_is_installed_correctly(): + raise ValueError(msg) + except ImportError: + raise ValueError(msg) from None + + def gradient_checkpointing_enable(self): + """Enable gradient checkpointing for memory optimization.""" + self.gradient_checkpointing_enabled = True + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True + logging.info("Enabled gradient checkpointing for PI0Pytorch model") + + def gradient_checkpointing_disable(self): + """Disable gradient checkpointing.""" + self.gradient_checkpointing_enabled = False + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False + logging.info("Disabled gradient checkpointing for PI0Pytorch model") + + def _apply_checkpoint(self, func, *args, **kwargs): + """Helper method to apply gradient checkpointing if enabled.""" + if self.gradient_checkpointing_enabled and self.training: + return torch.utils.checkpoint.checkpoint( + func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs + ) + return func(*args, **kwargs) + + def _prepare_attention_masks_4d(self, att_2d_masks): + """Helper method to prepare 4D attention masks for transformer.""" + att_2d_masks_4d = att_2d_masks[:, None, :, :] + return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) def sample_noise(self, shape, device): - noise = torch.normal( + return torch.normal( mean=0.0, std=1.0, size=shape, dtype=torch.float32, device=device, ) - return noise def sample_time(self, bsize, device): - beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0) - time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32) - time = time_beta * 0.999 + 0.001 - return time + time_beta = sample_beta( + self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device + ) + time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset + return time.to(dtype=torch.float32, device=device) def embed_prefix( self, images, img_masks, lang_tokens, lang_masks ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Embed images with SigLIP and language tokens with embedding layer to prepare - for PaliGemma transformer processing. - """ - # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty + """Embed images with SigLIP and language tokens with embedding layer.""" embs = [] pad_masks = [] att_masks = [] - # TODO: remove for loop - for ( - img, - img_mask, - ) in zip(images, img_masks, strict=False): - img_emb = self.paligemma_with_expert.embed_image(img) - img_emb = img_emb.to(dtype=torch.bfloat16) + # Process images + for img, img_mask in zip(images, img_masks, strict=True): - # Normalize image embeddings - img_emb_dim = img_emb.shape[-1] - img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device) + def image_embed_func(img): + return self.paligemma_with_expert.embed_image(img) + img_emb = self._apply_checkpoint(image_embed_func, img) bsize, num_img_embs = img_emb.shape[:2] - img_mask = img_mask[:, None].expand(bsize, num_img_embs) embs.append(img_emb) - pad_masks.append(img_mask) - - # Create attention masks so that image tokens attend to each other + pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) att_masks += [0] * num_img_embs - lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) - - # Normalize language embeddings - lang_emb_dim = lang_emb.shape[-1] - lang_emb = lang_emb * math.sqrt(lang_emb_dim) + # Process language tokens + def lang_embed_func(lang_tokens): + lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + lang_emb_dim = lang_emb.shape[-1] + return lang_emb * math.sqrt(lang_emb_dim) + lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) embs.append(lang_emb) pad_masks.append(lang_masks) - # full attention between image and language inputs num_lang_embs = lang_emb.shape[1] att_masks += [0] * num_lang_embs embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + + bsize = pad_masks.shape[0] att_masks = att_masks[None, :].expand(bsize, len(att_masks)) return embs, pad_masks, att_masks @@ -534,57 +638,67 @@ class PI0FlowMatching(nn.Module): pad_masks = [] att_masks = [] - # Embed state - state_emb = self.state_proj(state) - state_emb = state_emb.to(dtype=torch.bfloat16) + if self.state_proj.weight.dtype == torch.float32: + state = state.to(torch.float32) + + def state_proj_func(state): + return self.state_proj(state) + + state_emb = self._apply_checkpoint(state_proj_func, state) embs.append(state_emb[:, None, :]) bsize = state_emb.shape[0] - dtype = state_emb.dtype device = state_emb.device state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device) pad_masks.append(state_mask) - - # Set attention masks so that image and language inputs do not attend to state or actions att_masks += [1] - # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] + # Embed timestep using sine-cosine positional encoding time_emb = create_sinusoidal_pos_embedding( - timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device + timestep, + self.action_in_proj.out_features, + min_period=self.config.min_period, + max_period=self.config.max_period, + device=timestep.device, ) - time_emb = time_emb.type(dtype=dtype) + time_emb = time_emb.type(dtype=timestep.dtype) # Fuse timestep + action information using an MLP - action_emb = self.action_in_proj(noisy_actions) + def action_proj_func(noisy_actions): + return self.action_in_proj(noisy_actions) + + action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) time_emb = time_emb[:, None, :].expand_as(action_emb) action_time_emb = torch.cat([action_emb, time_emb], dim=2) - action_time_emb = self.action_time_mlp_in(action_time_emb) - action_time_emb = F.silu(action_time_emb) # swish == silu - action_time_emb = self.action_time_mlp_out(action_time_emb) + def mlp_func(action_time_emb): + x = self.action_time_mlp_in(action_time_emb) + x = F.silu(x) + return self.action_time_mlp_out(x) + + action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb) + adarms_cond = None - # Add to input tokens embs.append(action_time_emb) - bsize, action_time_dim = action_time_emb.shape[:2] - action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device) + action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) pad_masks.append(action_time_mask) # Set attention masks so that image, language and state inputs do not attend to action tokens - att_masks += [1] + ([0] * (self.config.n_action_steps - 1)) + att_masks += [1] + ([0] * (self.config.chunk_size - 1)) embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) att_masks = att_masks[None, :].expand(bsize, len(att_masks)) - return embs, pad_masks, att_masks + return embs, pad_masks, att_masks, adarms_cond def forward( self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None ) -> Tensor: - """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" + """Do a full training forward pass and compute the loss.""" if noise is None: noise = self.sample_noise(actions.shape, actions.device) @@ -598,7 +712,14 @@ class PI0FlowMatching(nn.Module): prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( images, img_masks, lang_tokens, lang_masks ) - suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time) + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, time) + + if ( + self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + suffix_embs = suffix_embs.to(dtype=torch.bfloat16) + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) @@ -606,29 +727,51 @@ class PI0FlowMatching(nn.Module): att_2d_masks = make_att_2d_masks(pad_masks, att_masks) position_ids = torch.cumsum(pad_masks, dim=1) - 1 - (_, suffix_out), _ = self.paligemma_with_expert.forward( - attention_mask=att_2d_masks, - position_ids=position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, suffix_embs], - use_cache=False, - fill_kv_cache=False, + att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) + + def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + return suffix_out + + suffix_out = self._apply_checkpoint( + forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond ) - suffix_out = suffix_out[:, -self.config.n_action_steps :] - # Original openpi code, upcast attention output + + suffix_out = suffix_out[:, -self.config.chunk_size :] suffix_out = suffix_out.to(dtype=torch.float32) - v_t = self.action_out_proj(suffix_out) - losses = F.mse_loss(u_t, v_t, reduction="none") - return losses + def action_out_proj_func(suffix_out): + return self.action_out_proj(suffix_out) + + v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) + + return F.mse_loss(u_t, v_t, reduction="none") + + @torch.no_grad() # see openpi `sample_actions` (slightly adapted) + def sample_actions( + self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None + ) -> Tensor: + """Do a full inference forward and compute the action.""" + if num_steps is None: + num_steps = self.config.num_inference_steps - def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor: - """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" bsize = state.shape[0] device = state.device if noise is None: - actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim) + # Sample noise with padded dimension as expected by action_in_proj + actions_shape = ( + bsize, + self.config.chunk_size, + self.config.max_action_dim, + ) # Use config max_action_dim for internal processing noise = self.sample_noise(actions_shape, device) prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( @@ -637,17 +780,18 @@ class PI0FlowMatching(nn.Module): prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 - # Compute image and language key value cache + prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) + self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 + _, past_key_values = self.paligemma_with_expert.forward( - attention_mask=prefix_att_2d_masks, + attention_mask=prefix_att_2d_masks_4d, position_ids=prefix_position_ids, past_key_values=None, inputs_embeds=[prefix_embs, None], - use_cache=self.config.use_cache, - fill_kv_cache=True, + use_cache=True, ) - dt = -1.0 / self.config.num_steps + dt = -1.0 / num_steps dt = torch.tensor(dt, dtype=torch.float32, device=device) x_t = noise @@ -661,10 +805,9 @@ class PI0FlowMatching(nn.Module): x_t, expanded_time, ) - - # Euler step - x_t += dt * v_t + x_t = x_t + dt * v_t time += dt + return x_t def denoise_step( @@ -676,30 +819,374 @@ class PI0FlowMatching(nn.Module): timestep, ): """Apply one denoising step of the noise `x_t` at a given timestep.""" - suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep) + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(state, x_t, timestep) suffix_len = suffix_pad_masks.shape[1] batch_size = prefix_pad_masks.shape[0] prefix_len = prefix_pad_masks.shape[1] + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) - suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) - full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) + self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 + outputs_embeds, _ = self.paligemma_with_expert.forward( - attention_mask=full_att_2d_masks, + attention_mask=full_att_2d_masks_4d, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=[None, suffix_embs], - use_cache=self.config.use_cache, - fill_kv_cache=False, + use_cache=False, + adarms_cond=[None, adarms_cond], ) + suffix_out = outputs_embeds[1] - suffix_out = suffix_out[:, -self.config.n_action_steps :] + suffix_out = suffix_out[:, -self.config.chunk_size :] suffix_out = suffix_out.to(dtype=torch.float32) - v_t = self.action_out_proj(suffix_out) - return v_t + return self.action_out_proj(suffix_out) + + +class PI0Policy(PreTrainedPolicy): + """PI0 OpenPI Policy for LeRobot.""" + + config_class = PI0Config + name = "pi0" + + def __init__( + self, + config: PI0Config, + ): + """ + Args: + config: Policy configuration class instance. + """ + super().__init__(config) + config.validate_features() + self.config = config + + # Initialize the core PI0 model + self.model = PI0Pytorch(config) + + # Enable gradient checkpointing if requested + if config.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + self.model.to(config.device) + + self.reset() + + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: PreTrainedConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = True, + **kwargs, + ) -> T: + """Override the from_pretrained method to handle key remapping and display important disclaimer.""" + print( + "The PI05 model is a direct port of the OpenPI implementation. \n" + "This implementation follows the original OpenPI structure for compatibility. \n" + "Original implementation: https://github.com/Physical-Intelligence/openpi" + ) + if pretrained_name_or_path is None: + raise ValueError("pretrained_name_or_path is required") + + # Use provided config if available, otherwise create default config + if config is None: + config = PreTrainedConfig.from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + **kwargs, + ) + + # Initialize model without loading weights + # Check if dataset_stats were provided in kwargs + model = cls(config, **kwargs) + + # Now manually load and remap the state dict + try: + # Try to load the pytorch_model.bin or model.safetensors file + print(f"Loading model from: {pretrained_name_or_path}") + try: + from transformers.utils import cached_file + + # Try safetensors first + resolved_file = cached_file( + pretrained_name_or_path, + "model.safetensors", + cache_dir=kwargs.get("cache_dir"), + force_download=kwargs.get("force_download", False), + resume_download=kwargs.get("resume_download"), + proxies=kwargs.get("proxies"), + use_auth_token=kwargs.get("use_auth_token"), + revision=kwargs.get("revision"), + local_files_only=kwargs.get("local_files_only", False), + ) + from safetensors.torch import load_file + + original_state_dict = load_file(resolved_file) + print("✓ Loaded state dict from model.safetensors") + except Exception as e: + print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") + return model + + # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) + + # Then add "model." prefix for all keys that don't already have it + remapped_state_dict = {} + remap_count = 0 + + for key, value in fixed_state_dict.items(): + if not key.startswith("model."): + new_key = f"model.{key}" + remapped_state_dict[new_key] = value + remap_count += 1 + else: + remapped_state_dict[key] = value + + if remap_count > 0: + print(f"Remapped {remap_count} state dict keys") + + # Load the remapped state dict into the model + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) + + if missing_keys: + print(f"Missing keys when loading state dict: {len(missing_keys)} keys") + if len(missing_keys) <= 5: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:5]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 5} more") + + if unexpected_keys: + print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") + if len(unexpected_keys) <= 5: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:5]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 5} more") + + if not missing_keys and not unexpected_keys: + print("All keys loaded successfully!") + + except Exception as e: + print(f"Warning: Could not remap state dict keys: {e}") + + return model + + def _fix_pytorch_state_dict_keys( + self, state_dict, model_config + ): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys` + """Fix state dict keys to match current model architecture.""" + import re + + fixed_state_dict = {} + + for key, value in state_dict.items(): + new_key = key + + # Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias + # For gemma expert layers + if re.match( + r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight", + key, + ): + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}") + continue + + if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key): + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping norm key (adaRMS mismatch): {key}") + continue + + # Handle MLP naming changes for pi0 + # non-pi05 model expects action_time_mlp_*, but checkpoint might have time_mlp_* + if key.startswith("time_mlp_in."): + new_key = key.replace("time_mlp_in.", "action_time_mlp_in.") + elif key.startswith("time_mlp_out."): + new_key = key.replace("time_mlp_out.", "action_time_mlp_out.") + + # Handle vision tower embedding layer potential differences + if "patch_embedding" in key: + # Some checkpoints might have this, but current model expects different structure + logging.warning(f"Vision embedding key might need handling: {key}") + + fixed_state_dict[new_key] = value + + return fixed_state_dict + + def get_optim_params(self) -> dict: + return self.parameters() + + def reset(self): + """Reset internal state - called when environment resets.""" + self._action_queue = deque(maxlen=self.config.n_action_steps) + self._queues = { + ACTION: deque(maxlen=self.config.n_action_steps), + } + + def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: + """Preprocess images for the model. + + Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1]. + PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1]. + """ + images = [] + img_masks = [] + + # Get device from model parameters + device = next(self.parameters()).device + + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] + + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. " + f"(batch: {batch.keys()}) (image_features: {self.config.image_features})" + ) + + for key in present_img_keys: + img = batch[key] + + # Ensure tensor is on the same device as the model + if img.device != device: + img = img.to(device) + + # Ensure float32 dtype for consistency + if img.dtype != torch.float32: + img = img.to(torch.float32) + + # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1 + + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + img = img.permute(0, 2, 3, 1) + + # from openpi preprocess_observation_pytorch: Resize with padding if needed + if img.shape[1:3] != self.config.image_resolution: + img = resize_with_pad_torch(img, *self.config.image_resolution) + + # Normalize from [0,1] to [-1,1] as expected by siglip + img = img * 2.0 - 1.0 + + # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + images.append(img) + # Create mask (all ones for real images) + bsize = img.shape[0] + mask = torch.ones(bsize, dtype=torch.bool, device=device) + img_masks.append(mask) + + # Create image features not present in the batch as fully 0 padded images + for _num_empty_cameras in range(len(missing_img_keys)): + img = torch.ones_like(img) * -1 # padded with -1 for SigLIP + mask = torch.zeros_like(mask) # mask is zero for empty cameras + images.append(img) + img_masks.append(mask) + + return images, img_masks + + def prepare_state(self, batch): + """Pad state""" + state = pad_vector(batch[OBS_STATE], self.config.max_state_dim) + return state + + def prepare_action(self, batch): + """Pad action""" + actions = pad_vector(batch[ACTION], self.config.max_action_dim) + return actions + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations.""" + self.eval() + + # Action queue logic for n_action_steps > 1 + if len(self._action_queue) == 0: + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] + # Transpose to get shape (n_action_steps, batch_size, action_dim) + self._action_queue.extend(actions.transpose(0, 1)) + + return self._action_queue.popleft() + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + self.eval() + + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + state = self.prepare_state(batch) + + # Sample actions using the model + actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state) + + # Unpad actions to actual action dimension + original_action_dim = self.config.output_features[ACTION].shape[0] + actions = actions[:, :, :original_action_dim] + + return actions + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss for training.""" + + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + state = self.prepare_state(batch) + actions = self.prepare_action(batch) + + # Compute loss + losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions) + + # Truncate losses to actual action dimensions + original_action_dim = self.config.output_features[ACTION].shape[0] + losses = losses[:, :, :original_action_dim] + + loss = losses.mean() + + loss_dict = { + "loss": loss.item(), + "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), + } + + return loss, loss_dict diff --git a/src/lerobot/policies/pi0/paligemma_with_expert.py b/src/lerobot/policies/pi0/paligemma_with_expert.py deleted file mode 100644 index edc34b7c5..000000000 --- a/src/lerobot/policies/pi0/paligemma_with_expert.py +++ /dev/null @@ -1,420 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.version -from pytest import Cache -from torch import nn -from transformers import ( - AutoConfig, - GemmaForCausalLM, - PaliGemmaForConditionalGeneration, - PretrainedConfig, - PreTrainedModel, -) -from transformers.models.auto import CONFIG_MAPPING - -from lerobot.policies.pi0.flex_attention import flex_attention_forward - - -def apply_rope(x, positions, max_wavelength=10_000): - """ - Applies RoPE positions [B, L] to x [B, L, H, D]. - """ - d_half = x.shape[-1] // 2 - device = x.device - dtype = x.dtype - x = x.to(torch.float32) - - freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) - timescale = max_wavelength**freq_exponents - radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) - - radians = radians[..., None, :] - - sin = torch.sin(radians) # .to(dtype=dtype) - cos = torch.cos(radians) # .to(dtype=dtype) - - x1, x2 = x.split(d_half, dim=-1) - res = torch.empty_like(x) - res[..., :d_half] = x1 * cos - x2 * sin - res[..., d_half:] = x2 * cos + x1 * sin - - return res.to(dtype) - - -class PaliGemmaWithExpertConfig(PretrainedConfig): - model_type = "PaliGemmaWithExpertModel" - sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig} - - def __init__( - self, - paligemma_config: dict | None = None, - gemma_expert_config: dict | None = None, - freeze_vision_encoder: bool = True, - train_expert_only: bool = True, - attention_implementation: str = "eager", - **kwargs, - ): - self.freeze_vision_encoder = freeze_vision_encoder - self.train_expert_only = train_expert_only - self.attention_implementation = attention_implementation - - if paligemma_config is None: - # Default config from Pi0 - self.paligemma_config = CONFIG_MAPPING["paligemma"]( - transformers_version="4.48.1", - _vocab_size=257152, - bos_token_id=2, - eos_token_id=1, - hidden_size=2048, - image_token_index=257152, - model_type="paligemma", - pad_token_id=0, - projection_dim=2048, - text_config={ - "hidden_activation": "gelu_pytorch_tanh", - "hidden_size": 2048, - "intermediate_size": 16384, - "model_type": "gemma", - "num_attention_heads": 8, - "num_hidden_layers": 18, - "num_image_tokens": 256, - "num_key_value_heads": 1, - "torch_dtype": "float32", - "vocab_size": 257152, - }, - vision_config={ - "hidden_size": 1152, - "intermediate_size": 4304, - "model_type": "siglip_vision_model", - "num_attention_heads": 16, - "num_hidden_layers": 27, - "num_image_tokens": 256, - "patch_size": 14, - "projection_dim": 2048, - "projector_hidden_act": "gelu_fast", - "torch_dtype": "float32", - "vision_use_head": False, - }, - ) - elif isinstance(self.paligemma_config, dict): - # Override Pi0 default config for PaliGemma - if "model_type" not in gemma_expert_config: - paligemma_config["model_type"] = "paligemma" - - cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]] - self.paligemma_config = cfg_cls(**paligemma_config) - - if gemma_expert_config is None: - # Default config from Pi0 - self.gemma_expert_config = CONFIG_MAPPING["gemma"]( - attention_bias=False, - attention_dropout=0.0, - bos_token_id=2, - eos_token_id=1, - head_dim=256, - hidden_act="gelu_pytorch_tanh", - hidden_activation="gelu_pytorch_tanh", - hidden_size=1024, - initializer_range=0.02, - intermediate_size=4096, - max_position_embeddings=8192, - model_type="gemma", - num_attention_heads=8, - num_hidden_layers=18, - num_key_value_heads=1, - pad_token_id=0, - rms_norm_eps=1e-06, - rope_theta=10000.0, - torch_dtype="float32", - transformers_version="4.48.1", - use_cache=True, - vocab_size=257152, - ) - elif isinstance(self.gemma_expert_config, dict): - # Override Pi0 default config for Gemma Expert - if "model_type" not in gemma_expert_config: - gemma_expert_config["model_type"] = "gemma" - - cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]] - self.gemma_expert_config = cfg_cls(**gemma_expert_config) - - super().__init__(**kwargs) - - def __post_init__(self): - super().__post_init__() - if self.train_expert_only and not self.freeze_vision_encoder: - raise ValueError( - "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible." - ) - - if self.attention_implementation not in ["eager", "fa2", "flex"]: - raise ValueError( - f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'." - ) - - -class PaliGemmaWithExpertModel(PreTrainedModel): - config_class = PaliGemmaWithExpertConfig - - def __init__(self, config: PaliGemmaWithExpertConfig): - super().__init__(config=config) - self.config = config - self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config) - self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config) - # Remove unused embed_tokens - self.gemma_expert.model.embed_tokens = None - - self.to_bfloat16_like_physical_intelligence() - self.set_requires_grad() - - def set_requires_grad(self): - if self.config.freeze_vision_encoder: - self.paligemma.vision_tower.eval() - for params in self.paligemma.vision_tower.parameters(): - params.requires_grad = False - - if self.config.train_expert_only: - self.paligemma.eval() - for params in self.paligemma.parameters(): - params.requires_grad = False - - def train(self, mode: bool = True): - super().train(mode) - - if self.config.freeze_vision_encoder: - self.paligemma.vision_tower.eval() - - if self.config.train_expert_only: - self.paligemma.eval() - - def to_bfloat16_like_physical_intelligence(self): - self.paligemma = self.paligemma.to(dtype=torch.bfloat16) - - params_to_change_dtype = [ - "language_model.model.layers", - "gemma_expert.model.layers", - "vision_tower", - "multi_modal", - ] - for name, param in self.named_parameters(): - if any(selector in name for selector in params_to_change_dtype): - param.data = param.data.to(dtype=torch.bfloat16) - - def embed_image(self, image: torch.Tensor): - # Handle different transformers versions - if hasattr(self.paligemma, "get_image_features"): - return self.paligemma.get_image_features(image) - else: - return self.paligemma.model.get_image_features(image) - - def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.language_model.embed_tokens(tokens) - - # TODO: break down this huge forward into modules or functions - def forward( - self, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | Cache | None = None, - inputs_embeds: list[torch.FloatTensor] = None, - use_cache: bool | None = None, - fill_kv_cache: bool | None = None, - ): - models = [self.paligemma.language_model, self.gemma_expert.model] - - for hidden_states in inputs_embeds: - # TODO this is very inefficient - # dtype is always the same, batch size too (if > 1 len) - # device could be trickier in multi gpu edge cases but that's it - if hidden_states is None: - continue - batch_size = hidden_states.shape[0] - - # RMSNorm - num_layers = self.paligemma.config.text_config.num_hidden_layers - head_dim = self.paligemma.config.text_config.head_dim - for layer_idx in range(num_layers): - query_states = [] - key_states = [] - value_states = [] - for i, hidden_states in enumerate(inputs_embeds): - if hidden_states is None: - continue - layer = models[i].layers[layer_idx] - # normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype) - # hidden_states = hidden_states * normalizer - hidden_states = layer.input_layernorm(hidden_states) - - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) - - hidden_states = hidden_states.to(dtype=torch.bfloat16) - query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) - key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) - value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) - - query_states.append(query_state) - key_states.append(key_state) - value_states.append(value_state) - - # B,L,H,D with L sequence length, H number of heads, D head dim - # concatenate on the number of embeddings/tokens - query_states = torch.cat(query_states, dim=1) - key_states = torch.cat(key_states, dim=1) - value_states = torch.cat(value_states, dim=1) - - query_states = apply_rope(query_states, position_ids) - key_states = apply_rope(key_states, position_ids) - - if use_cache and past_key_values is None: - past_key_values = {} - - if use_cache: - if fill_kv_cache: - past_key_values[layer_idx] = { - "key_states": key_states, - "value_states": value_states, - } - else: - # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before. - # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach - # the max len, then we (for instance) double the cache size. This implementation already exists - # in `transformers`. (molbap) - key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) - value_states = torch.cat( - [past_key_values[layer_idx]["value_states"], value_states], dim=1 - ) - - attention_interface = self.get_attention_interface() - att_output = attention_interface( - attention_mask, batch_size, head_dim, query_states, key_states, value_states - ) - att_output = att_output.to(dtype=torch.bfloat16) - - # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len]) - outputs_embeds = [] - start = 0 - for i, hidden_states in enumerate(inputs_embeds): - layer = models[i].layers[layer_idx] - - if hidden_states is not None: - end = start + hidden_states.shape[1] - - if att_output.dtype != layer.self_attn.o_proj.weight.dtype: - att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) - out_emb = layer.self_attn.o_proj(att_output[:, start:end]) - - # TODO: first dropout (by default 0.0) - - # first residual - out_emb += hidden_states - after_first_residual = out_emb.clone() - - out_emb = layer.post_attention_layernorm(out_emb) - out_emb = layer.mlp(out_emb) - - # TODO: second dropout (by default 0.0) - - # second residual - out_emb += after_first_residual - - outputs_embeds.append(out_emb) - - start = end - else: - outputs_embeds.append(None) - - inputs_embeds = outputs_embeds - - # final norm - outputs_embeds = [] - for i, hidden_states in enumerate(inputs_embeds): - if hidden_states is not None: - out_emb = models[i].norm(hidden_states) - outputs_embeds.append(out_emb) - else: - outputs_embeds.append(None) - - return outputs_embeds, past_key_values - - def get_attention_interface(self): - if self.config.attention_implementation == "fa2": - attention_interface = self.flash_attention_forward - elif self.config.attention_implementation == "flex": - attention_interface = flex_attention_forward - else: - attention_interface = self.eager_attention_forward - return attention_interface - - def flash_attention_forward( - self, attention_mask, batch_size, head_dim, query_states, key_states, value_states - ): - raise NotImplementedError("FA2 is not implemented (yet)") - - def eager_attention_forward( - self, attention_mask, batch_size, head_dim, query_states, key_states, value_states - ): - num_att_heads = self.config.paligemma_config.text_config.num_attention_heads - num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads - num_key_value_groups = num_att_heads // num_key_value_heads - - # query_states: batch_size, sequence_length, num_att_head, head_dim - # key_states: batch_size, sequence_length, num_key_value_head, head_dim - # value_states: batch_size, sequence_length, num_key_value_head, head_dim - sequence_length = key_states.shape[1] - - key_states = key_states[:, :, :, None, :].expand( - batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim - ) - key_states = key_states.reshape( - batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim - ) - - value_states = value_states[:, :, :, None, :].expand( - batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim - ) - value_states = value_states.reshape( - batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim - ) - - # Attention here is upcasted to float32 to match the original eager implementation. - - query_states = query_states.to(dtype=torch.float32) - key_states = key_states.to(dtype=torch.float32) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - - att_weights = torch.matmul(query_states, key_states.transpose(2, 3)) - att_weights *= head_dim**-0.5 - big_neg = -2.3819763e38 # See gemma/modules.py - - masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg) - - probs = nn.functional.softmax(masked_att_weights, dim=-1) - probs = probs.to(dtype=value_states.dtype) - - # probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length - # value_states: batch_size, sequence_length, num_att_heads, head_dim - - att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3)) - - att_output = att_output.permute(0, 2, 1, 3) - # we use -1 because sequence length can change - att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim) - - return att_output diff --git a/src/lerobot/policies/pi05/README.md b/src/lerobot/policies/pi05/README.md new file mode 100644 index 000000000..2ae69d978 --- /dev/null +++ b/src/lerobot/policies/pi05/README.md @@ -0,0 +1,49 @@ +# π₀.₅ (pi05) + +This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence. +It is designed as a **Vision-Language-Action model with open-world generalization**. + +--- + +## Model Overview + +| Feature | π₀ | π₀.₅ | +| -------------------- | ------------------------------------------------------ | ----------------------------------------- | +| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning | +| AdaRMS | Not used | Used in action expert | +| Tokenizer Length | 48 tokens | 200 tokens | +| Discrete State Input | False (Uses `state_proj` layer) | True | +| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) | + +--- + +## Citation + +If you use this work, please cite both **OpenPI** and the π₀.₅ paper: + +```bibtex +@misc{openpi2024, + author = {Physical Intelligence Lab}, + title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies}, + year = {2024}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/Physical-Intelligence/openpi}}, + license = {Apache-2.0} +} + +@misc{intelligence2025pi05visionlanguageactionmodelopenworld, + title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization}, + author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky}, + year = {2025}, + eprint = {2504.16054}, + archivePrefix= {arXiv}, + primaryClass = {cs.LG}, + url = {https://arxiv.org/abs/2504.16054}, +} +``` + +--- + +## License + +This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi). diff --git a/src/lerobot/policies/pi05/__init__.py b/src/lerobot/policies/pi05/__init__.py new file mode 100644 index 000000000..4f9a9de4a --- /dev/null +++ b/src/lerobot/policies/pi05/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_pi05 import PI05Config +from .modeling_pi05 import PI05Policy +from .processor_pi05 import make_pi05_pre_post_processors + +__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"] diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py new file mode 100644 index 000000000..7c1e950b0 --- /dev/null +++ b/src/lerobot/policies/pi05/configuration_pi05.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig + + +@PreTrainedConfig.register_subclass("pi05") +@dataclass +class PI05Config(PreTrainedConfig): + paligemma_variant: str = "gemma_2b" + action_expert_variant: str = "gemma_300m" + dtype: str = "float32" # Options: "bfloat16", "float32" + + n_obs_steps: int = 1 + chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon" + n_action_steps: int = 50 # Number of action steps to execute + + # Shorter state and action vectors will be padded to these dimensions + max_state_dim: int = 32 + max_action_dim: int = 32 + + # Flow matching parameters: see openpi `PI0Pytorch` + num_inference_steps: int = 10 + time_sampling_beta_alpha: float = 1.5 + time_sampling_beta_beta: float = 1.0 + time_sampling_scale: float = 0.999 + time_sampling_offset: float = 0.001 + min_period: float = 4e-3 + max_period: float = 4.0 + + image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py` + + # Add empty images. Used to add empty cameras when no image features are present. + empty_cameras: int = 0 + + tokenizer_max_length: int = 200 # see openpi `__post_init__` + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state + "ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action + } + ) + + # Training settings + gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization + compile_model: bool = False # Whether to use torch.compile for model optimization + compile_mode: str = "max-autotune" # Torch compile mode + device: str | None = None # Device to use for the model (None = auto-detect) + + # Optimizer settings: see openpi `AdamW` + optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr` + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 0.01 + optimizer_grad_clip_norm: float = 1.0 + + # Scheduler settings: see openpi `CosineDecaySchedule` + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + tokenizer_max_length: int = 200 # see openpi `__post_init__` + + def __post_init__(self): + super().__post_init__() + + # Validate configuration + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})" + ) + + if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]: + raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}") + + if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]: + raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}") + + if self.dtype not in ["bfloat16", "float32"]: + raise ValueError(f"Invalid dtype: {self.dtype}") + + def validate_features(self) -> None: + """Validate and set up input/output features.""" + for i in range(self.empty_cameras): + key = f"observation.images.empty_camera_{i}" + empty_camera = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, *self.image_resolution), # Use configured image resolution + ) + self.input_features[key] = empty_camera + + if "observation.state" not in self.input_features: + state_feature = PolicyFeature( + type=FeatureType.STATE, + shape=(self.max_state_dim,), # Padded to max_state_dim + ) + self.input_features["observation.state"] = state_feature + + if "action" not in self.output_features: + action_feature = PolicyFeature( + type=FeatureType.ACTION, + shape=(self.max_action_dim,), # Padded to max_action_dim + ) + self.output_features["action"] = action_feature + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py new file mode 100644 index 000000000..93ca5fa82 --- /dev/null +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -0,0 +1,1163 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import builtins +import logging +import math +from collections import deque +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn + +from lerobot.utils.import_utils import _transformers_available + +# Conditional import for type checking and lazy loading +if TYPE_CHECKING or _transformers_available: + from transformers.models.auto import CONFIG_MAPPING + from transformers.models.gemma import modeling_gemma + from transformers.models.gemma.modeling_gemma import GemmaForCausalLM + from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration +else: + CONFIG_MAPPING = None + modeling_gemma = None + GemmaForCausalLM = None + PaliGemmaForConditionalGeneration = None + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.policies.pi05.configuration_pi05 import PI05Config +from lerobot.policies.pretrained import PreTrainedPolicy, T +from lerobot.utils.constants import ( + ACTION, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OPENPI_ATTENTION_MASK_VALUE, +) + + +def get_safe_dtype(target_dtype, device_type): + """Get a safe dtype for the given device type.""" + if device_type == "mps" and target_dtype == torch.float64: + return torch.float32 + if device_type == "cpu": + # CPU doesn't support bfloat16, use float32 instead + if target_dtype == torch.bfloat16: + return torch.float32 + if target_dtype == torch.float64: + return torch.float64 + return target_dtype + + +def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedding` (exact copy) + time: torch.Tensor, dimension: int, min_period: float, max_period: float, device="cpu" +) -> Tensor: + """Computes sine-cosine positional embedding vectors for scalar positions.""" + if dimension % 2 != 0: + raise ValueError(f"dimension ({dimension}) must be divisible by 2") + + if time.ndim != 1: + raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + + dtype = get_safe_dtype(torch.float64, device.type) + fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) + period = min_period * (max_period / min_period) ** fraction + + # Compute the outer product + scaling_factor = 1.0 / period * 2 * math.pi + sin_input = scaling_factor[None, :] * time[:, None] + return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + + +def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy) + alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) + beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) + dist = torch.distributions.Beta(alpha_t, beta_t) + return dist.sample((bsize,)) + + +def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy) + """Copied from big_vision. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + return att_2d_masks & pad_2d_masks + + +def pad_vector(vector, new_dim): + """Pad the last dimension of a vector to new_dim with zeros. + + Can be (batch_size x sequence_length x features_dimension) + or (batch_size x features_dimension) + """ + if vector.shape[-1] >= new_dim: + return vector + return F.pad(vector, (0, new_dim - vector.shape[-1])) + + +def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) + images: torch.Tensor, + height: int, + width: int, + mode: str = "bilinear", +) -> torch.Tensor: + """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion + by padding with black. If the image is float32, it must be in the range [-1, 1]. + + Args: + images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] + height: Target height + width: Target width + mode: Interpolation mode ('bilinear', 'nearest', etc.) + + Returns: + Resized and padded tensor with same shape format as input + """ + # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] + if images.shape[-1] <= 4: # Assume channels-last format + channels_last = True + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] + else: + channels_last = False + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + + batch_size, channels, cur_height, cur_width = images.shape + + # Calculate resize ratio + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + # Resize + resized_images = F.interpolate( + images, + size=(resized_height, resized_width), + mode=mode, + align_corners=False if mode == "bilinear" else None, + ) + + # Handle dtype-specific clipping + if images.dtype == torch.uint8: + resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) + elif images.dtype == torch.float32: + resized_images = resized_images.clamp(-1.0, 1.0) + else: + raise ValueError(f"Unsupported image dtype: {images.dtype}") + + # Calculate padding + pad_h0, remainder_h = divmod(height - resized_height, 2) + pad_h1 = pad_h0 + remainder_h + pad_w0, remainder_w = divmod(width - resized_width, 2) + pad_w1 = pad_w0 + remainder_w + + # Pad + constant_value = 0 if images.dtype == torch.uint8 else -1.0 + padded_images = F.pad( + resized_images, + (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom + mode="constant", + value=constant_value, + ) + + # Convert back to original format if needed + if channels_last: + padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + return padded_images + + +# Define the complete layer computation function for gradient checkpointing +def compute_layer_complete( + layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond, paligemma, gemma_expert +): + models = [paligemma.language_model, gemma_expert.model] + query_states = [] + key_states = [] + value_states = [] + gates = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i]) # noqa: PLW2901 + gates.append(gate) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + # Concatenate and process attention + query_states = torch.cat(query_states, dim=2) + key_states = torch.cat(key_states, dim=2) + value_states = torch.cat(value_states, dim=2) + dummy_tensor = torch.zeros( + query_states.shape[0], + query_states.shape[2], + query_states.shape[-1], + device=query_states.device, + dtype=query_states.dtype, + ) + cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + query_states, key_states = modeling_gemma.apply_rotary_pos_emb( + query_states, key_states, cos, sin, unsqueeze_dim=1 + ) + batch_size = query_states.shape[0] + scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling + # Attention computation + att_output, _ = modeling_gemma.eager_attention_forward( + paligemma.language_model.layers[layer_idx].self_attn, + query_states, + key_states, + value_states, + attention_mask, + scaling, + ) + # Get head_dim from the current layer, not from the model + head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim + att_output = att_output.reshape(batch_size, -1, 1 * 8 * head_dim) + # Process layer outputs + outputs_embeds = [] + start_pos = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + end_pos = start_pos + hidden_states.shape[1] + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) + # first residual + out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001 + after_first_residual = out_emb.clone() + out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + # Convert to bfloat16 if the next layer (mlp) uses bfloat16 + if layer.mlp.up_proj.weight.dtype == torch.bfloat16: + out_emb = out_emb.to(dtype=torch.bfloat16) + out_emb = layer.mlp(out_emb) + # second residual + out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) # noqa: SLF001 + outputs_embeds.append(out_emb) + start_pos = end_pos + return outputs_embeds + + +class GemmaConfig: # see openpi `gemma.py: Config` + """Configuration for Gemma model variants.""" + + def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim): + self.width = width + self.depth = depth + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + + +def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config` + """Returns config for specified gemma variant.""" + if variant == "gemma_300m": + return GemmaConfig( + width=1024, + depth=18, + mlp_dim=4096, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + elif variant == "gemma_2b": + return GemmaConfig( + width=2048, + depth=18, + mlp_dim=16_384, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + else: + raise ValueError(f"Unknown variant: {variant}") + + +class PaliGemmaWithExpertModel( + nn.Module +): # see openpi `gemma_pytorch.py: PaliGemmaWithExpertModel` this class is almost a exact copy of PaliGemmaWithExpertModel in openpi + """PaliGemma model with action expert for PI05.""" + + def __init__( + self, + vlm_config, + action_expert_config, + use_adarms=None, + precision: Literal["bfloat16", "float32"] = "bfloat16", + ): + if use_adarms is None: + use_adarms = [False, False] + super().__init__() + + vlm_config_hf = CONFIG_MAPPING["paligemma"]() + vlm_config_hf._vocab_size = 257152 # noqa: SLF001 + vlm_config_hf.image_token_index = 257152 + vlm_config_hf.text_config.hidden_size = vlm_config.width + vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim + vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads + vlm_config_hf.text_config.head_dim = vlm_config.head_dim + vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth + vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads + vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" + vlm_config_hf.text_config.torch_dtype = "float32" + vlm_config_hf.text_config.vocab_size = 257152 + vlm_config_hf.text_config.use_adarms = use_adarms[0] + vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None + vlm_config_hf.vision_config.intermediate_size = 4304 + vlm_config_hf.vision_config.projection_dim = 2048 + vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" + vlm_config_hf.vision_config.torch_dtype = "float32" + + action_expert_config_hf = CONFIG_MAPPING["gemma"]( + head_dim=action_expert_config.head_dim, + hidden_size=action_expert_config.width, + intermediate_size=action_expert_config.mlp_dim, + num_attention_heads=action_expert_config.num_heads, + num_hidden_layers=action_expert_config.depth, + num_key_value_heads=action_expert_config.num_kv_heads, + vocab_size=257152, + hidden_activation="gelu_pytorch_tanh", + torch_dtype="float32", + use_adarms=use_adarms[1], + adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, + ) + + self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) + self.gemma_expert = GemmaForCausalLM(config=action_expert_config_hf) + self.gemma_expert.model.embed_tokens = None + + self.to_bfloat16_for_selected_params(precision) + + def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): + if precision == "bfloat16": + self.to(dtype=torch.bfloat16) + elif precision == "float32": + self.to(dtype=torch.float32) + return + else: + raise ValueError(f"Invalid precision: {precision}") + + params_to_keep_float32 = [ + "vision_tower.vision_model.embeddings.patch_embedding.weight", + "vision_tower.vision_model.embeddings.patch_embedding.bias", + "vision_tower.vision_model.embeddings.position_embedding.weight", + "input_layernorm", + "post_attention_layernorm", + "model.norm", + ] + + for name, param in self.named_parameters(): + if any(selector in name for selector in params_to_keep_float32): + param.data = param.data.to(dtype=torch.float32) + + def embed_image(self, image: torch.Tensor): + return self.paligemma.model.get_image_features(image) + + def embed_language_tokens(self, tokens: torch.Tensor): + return self.paligemma.language_model.embed_tokens(tokens) + + def forward( + self, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: list[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + adarms_cond: list[torch.Tensor] | None = None, + ): + if adarms_cond is None: + adarms_cond = [None, None] + if inputs_embeds[1] is None: + prefix_output = self.paligemma.language_model.forward( + inputs_embeds=inputs_embeds[0], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[0] if adarms_cond is not None else None, + ) + prefix_past_key_values = prefix_output.past_key_values + prefix_output = prefix_output.last_hidden_state + suffix_output = None + elif inputs_embeds[0] is None: + suffix_output = self.gemma_expert.model.forward( + inputs_embeds=inputs_embeds[1], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[1] if adarms_cond is not None else None, + ) + suffix_output = suffix_output.last_hidden_state + prefix_output = None + prefix_past_key_values = None + else: + models = [self.paligemma.language_model, self.gemma_expert.model] + num_layers = self.paligemma.config.text_config.num_hidden_layers + + # Check if gradient checkpointing is enabled for any of the models + use_gradient_checkpointing = ( + hasattr(self.gemma_expert.model, "gradient_checkpointing") + and self.gemma_expert.model.gradient_checkpointing + and self.training + ) or (hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training) + + # Process all layers with gradient checkpointing if enabled + for layer_idx in range(num_layers): + if use_gradient_checkpointing: + inputs_embeds = torch.utils.checkpoint.checkpoint( + compute_layer_complete, + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + paligemma=self.paligemma, + gemma_expert=self.gemma_expert, + ) + else: + inputs_embeds = compute_layer_complete( + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms_cond, + paligemma=self.paligemma, + gemma_expert=self.gemma_expert, + ) + + # final norm + def compute_final_norms(inputs_embeds, adarms_cond): + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + outputs_embeds.append(out_emb) + return outputs_embeds + + # Apply gradient checkpointing to final norm if enabled + if use_gradient_checkpointing: + outputs_embeds = torch.utils.checkpoint.checkpoint( + compute_final_norms, + inputs_embeds, + adarms_cond, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + outputs_embeds = compute_final_norms(inputs_embeds, adarms_cond) + + prefix_output = outputs_embeds[0] + suffix_output = outputs_embeds[1] + prefix_past_key_values = None + + return [prefix_output, suffix_output], prefix_past_key_values + + +class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` + """Core PI05 PyTorch model.""" + + def __init__(self, config: PI05Config): + super().__init__() + self.config = config + + paligemma_config = get_gemma_config(config.paligemma_variant) + action_expert_config = get_gemma_config(config.action_expert_variant) + + self.paligemma_with_expert = PaliGemmaWithExpertModel( + paligemma_config, + action_expert_config, + use_adarms=[False, True], + precision=config.dtype, + ) + + self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) + self.action_out_proj = nn.Linear(action_expert_config.width, config.max_action_dim) + + self.time_mlp_in = nn.Linear(action_expert_config.width, action_expert_config.width) + self.time_mlp_out = nn.Linear(action_expert_config.width, action_expert_config.width) + + # Initialize gradient checkpointing flag + self.gradient_checkpointing_enabled = False + + # Compile model if requested + if config.compile_model: + torch.set_float32_matmul_precision("high") + self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode) + + msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" + + try: + from transformers.models.siglip import check + + if not check.check_whether_transformers_replace_is_installed_correctly(): + raise ValueError(msg) + except ImportError: + raise ValueError(msg) from None + + def gradient_checkpointing_enable(self): + """Enable gradient checkpointing for memory optimization.""" + self.gradient_checkpointing_enabled = True + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = True + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True + logging.info("Enabled gradient checkpointing for PI05Pytorch model") + + def gradient_checkpointing_disable(self): + """Disable gradient checkpointing.""" + self.gradient_checkpointing_enabled = False + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = False + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False + logging.info("Disabled gradient checkpointing for PI05Pytorch model") + + def _apply_checkpoint(self, func, *args, **kwargs): + """Helper method to apply gradient checkpointing if enabled.""" + if self.gradient_checkpointing_enabled and self.training: + return torch.utils.checkpoint.checkpoint( + func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs + ) + return func(*args, **kwargs) + + def _prepare_attention_masks_4d(self, att_2d_masks): + """Helper method to prepare 4D attention masks for transformer.""" + att_2d_masks_4d = att_2d_masks[:, None, :, :] + return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) + + def sample_noise(self, shape, device): + return torch.normal( + mean=0.0, + std=1.0, + size=shape, + dtype=torch.float32, + device=device, + ) + + def sample_time(self, bsize, device): + time_beta = sample_beta( + self.config.time_sampling_beta_alpha, self.config.time_sampling_beta_beta, bsize, device + ) + time = time_beta * self.config.time_sampling_scale + self.config.time_sampling_offset + return time.to(dtype=torch.float32, device=device) + + def embed_prefix( + self, images, img_masks, tokens, masks + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Embed images with SigLIP and language tokens with embedding layer.""" + embs = [] + pad_masks = [] + att_masks = [] + + # Process images + for img, img_mask in zip(images, img_masks, strict=True): + + def image_embed_func(img): + return self.paligemma_with_expert.embed_image(img) + + img_emb = self._apply_checkpoint(image_embed_func, img) + bsize, num_img_embs = img_emb.shape[:2] + + embs.append(img_emb) + pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) + att_masks += [0] * num_img_embs + + # Process language tokens + def lang_embed_func(tokens): + lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) + lang_emb_dim = lang_emb.shape[-1] + return lang_emb * math.sqrt(lang_emb_dim) + + lang_emb = self._apply_checkpoint(lang_embed_func, tokens) + embs.append(lang_emb) + pad_masks.append(masks) + + num_lang_embs = lang_emb.shape[1] + att_masks += [0] * num_lang_embs + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + + bsize = pad_masks.shape[0] + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks + + def embed_suffix(self, noisy_actions, timestep): + """Embed noisy_actions, timestep to prepare for Expert Gemma processing.""" + embs = [] + pad_masks = [] + att_masks = [] + + # Embed timestep using sine-cosine positional encoding + time_emb = create_sinusoidal_pos_embedding( + timestep, + self.action_in_proj.out_features, + min_period=self.config.min_period, + max_period=self.config.max_period, + device=timestep.device, + ) + time_emb = time_emb.type(dtype=timestep.dtype) + + # Fuse timestep + action information using an MLP + def action_proj_func(noisy_actions): + return self.action_in_proj(noisy_actions) + + action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) + + def time_mlp_func(time_emb): + x = self.time_mlp_in(time_emb) + x = F.silu(x) + x = self.time_mlp_out(x) + return F.silu(x) + + time_emb = self._apply_checkpoint(time_mlp_func, time_emb) + action_time_emb = action_emb + adarms_cond = time_emb + + embs.append(action_time_emb) + bsize, action_time_dim = action_time_emb.shape[:2] + action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=timestep.device) + pad_masks.append(action_time_mask) + + # Set attention masks so that image, language and state inputs do not attend to action tokens + att_masks += [1] + ([0] * (self.config.chunk_size - 1)) + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks, adarms_cond + + def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor: + """Do a full training forward pass and compute the loss.""" + if noise is None: + noise = self.sample_noise(actions.shape, actions.device) + + if time is None: + time = self.sample_time(actions.shape[0], actions.device) + + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, time) + + if ( + self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + suffix_embs = suffix_embs.to(dtype=torch.bfloat16) + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) + att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) + + att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + + att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) + + def forward_func(prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond): + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + return suffix_out + + suffix_out = self._apply_checkpoint( + forward_func, prefix_embs, suffix_embs, att_2d_masks_4d, position_ids, adarms_cond + ) + + suffix_out = suffix_out[:, -self.config.chunk_size :] + suffix_out = suffix_out.to(dtype=torch.float32) + + def action_out_proj_func(suffix_out): + return self.action_out_proj(suffix_out) + + v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) + + return F.mse_loss(u_t, v_t, reduction="none") + + @torch.no_grad() # see openpi `sample_actions` (slightly adapted) + def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor: + """Do a full inference forward and compute the action.""" + if num_steps is None: + num_steps = self.config.num_inference_steps + + bsize = tokens.shape[0] + device = tokens.device + + if noise is None: + # Sample noise with padded dimension as expected by action_in_proj + actions_shape = ( + bsize, + self.config.chunk_size, + self.config.max_action_dim, + ) # Use config max_action_dim for internal processing + noise = self.sample_noise(actions_shape, device) + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(images, img_masks, tokens, masks) + prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) + self.paligemma_with_expert.paligemma.language_model.config._attn_implementation = "eager" # noqa: SLF001 + + _, past_key_values = self.paligemma_with_expert.forward( + attention_mask=prefix_att_2d_masks_4d, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=True, + ) + + dt = -1.0 / num_steps + dt = torch.tensor(dt, dtype=torch.float32, device=device) + + x_t = noise + time = torch.tensor(1.0, dtype=torch.float32, device=device) + while time >= -dt / 2: + expanded_time = time.expand(bsize) + v_t = self.denoise_step( + prefix_pad_masks, + past_key_values, + x_t, + expanded_time, + ) + x_t = x_t + dt * v_t + time += dt + + return x_t + + def denoise_step( + self, + prefix_pad_masks, + past_key_values, + x_t, + timestep, + ): + """Apply one denoising step of the noise `x_t` at a given timestep.""" + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = self.embed_suffix(x_t, timestep) + + suffix_len = suffix_pad_masks.shape[1] + batch_size = prefix_pad_masks.shape[0] + prefix_len = prefix_pad_masks.shape[1] + + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) + suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) + + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] + position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) + self.paligemma_with_expert.gemma_expert.model.config._attn_implementation = "eager" # noqa: SLF001 + + outputs_embeds, _ = self.paligemma_with_expert.forward( + attention_mask=full_att_2d_masks_4d, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=[None, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + + suffix_out = outputs_embeds[1] + suffix_out = suffix_out[:, -self.config.chunk_size :] + suffix_out = suffix_out.to(dtype=torch.float32) + return self.action_out_proj(suffix_out) + + +class PI05Policy(PreTrainedPolicy): + """PI05 Policy for LeRobot.""" + + config_class = PI05Config + name = "pi05" + + def __init__( + self, + config: PI05Config, + ): + """ + Args: + config: Policy configuration class instance. + """ + super().__init__(config) + config.validate_features() + self.config = config + + # Initialize the core PI05 model + self.model = PI05Pytorch(config) + + # Enable gradient checkpointing if requested + if config.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + self.model.to(config.device) + + self.reset() + + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: PreTrainedConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = True, + **kwargs, + ) -> T: + """Override the from_pretrained method to handle key remapping and display important disclaimer.""" + print( + "The PI05 model is a direct port of the OpenPI implementation. \n" + "This implementation follows the original OpenPI structure for compatibility. \n" + "Original implementation: https://github.com/Physical-Intelligence/openpi" + ) + if pretrained_name_or_path is None: + raise ValueError("pretrained_name_or_path is required") + + # Use provided config if available, otherwise create default config + if config is None: + config = PreTrainedConfig.from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + **kwargs, + ) + + # Initialize model without loading weights + # Check if dataset_stats were provided in kwargs + model = cls(config, **kwargs) + + # Now manually load and remap the state dict + try: + # Try to load the pytorch_model.bin or model.safetensors file + print(f"Loading model from: {pretrained_name_or_path}") + try: + from transformers.utils import cached_file + + # Try safetensors first + resolved_file = cached_file( + pretrained_name_or_path, + "model.safetensors", + cache_dir=kwargs.get("cache_dir"), + force_download=kwargs.get("force_download", False), + resume_download=kwargs.get("resume_download"), + proxies=kwargs.get("proxies"), + use_auth_token=kwargs.get("use_auth_token"), + revision=kwargs.get("revision"), + local_files_only=kwargs.get("local_files_only", False), + ) + from safetensors.torch import load_file + + original_state_dict = load_file(resolved_file) + print("✓ Loaded state dict from model.safetensors") + except Exception as e: + print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") + return model + + # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) + + # Then add "model." prefix for all keys that don't already have it + remapped_state_dict = {} + remap_count = 0 + + for key, value in fixed_state_dict.items(): + if not key.startswith("model."): + new_key = f"model.{key}" + remapped_state_dict[new_key] = value + remap_count += 1 + if remap_count <= 10: # Only print first 10 to avoid spam + print(f"Remapped: {key} -> {new_key}") + else: + remapped_state_dict[key] = value + + if remap_count > 0: + print(f"Remapped {remap_count} state dict keys") + + # Load the remapped state dict into the model + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) + + if missing_keys: + print(f"Missing keys when loading state dict: {len(missing_keys)} keys") + if len(missing_keys) <= 5: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:5]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 5} more") + + if unexpected_keys: + print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") + if len(unexpected_keys) <= 5: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:5]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 5} more") + + if not missing_keys and not unexpected_keys: + print("All keys loaded successfully!") + + except Exception as e: + print(f"Warning: Could not remap state dict keys: {e}") + + return model + + def _fix_pytorch_state_dict_keys( + self, state_dict, model_config + ): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys` + """Fix state dict keys to match current model architecture.""" + import re + + fixed_state_dict = {} + + for key, value in state_dict.items(): + new_key = key + + # Handle layer norm structure changes: .weight -> .dense.weight + .dense.bias + # For gemma expert layers + if re.match( + r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\.(input_layernorm|post_attention_layernorm)\.weight", + key, + ): + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping layer norm key (adaRMS mismatch): {key}") + continue + + if re.match(r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key): + # Check if the model actually has adaRMS enabled for the expert + expert_uses_adarms = getattr( + self.model.paligemma_with_expert.gemma_expert.config, "use_adarms", False + ) + if expert_uses_adarms: + logging.warning(f"Skipping norm key (adaRMS mismatch): {key}") + continue + + # Handle MLP naming changes for pi05 + # pi05 model expects time_mlp_*, but checkpoint might have action_time_mlp_* + if key.startswith("action_time_mlp_in."): + new_key = key.replace("action_time_mlp_in.", "time_mlp_in.") + elif key.startswith("action_time_mlp_out."): + new_key = key.replace("action_time_mlp_out.", "time_mlp_out.") + # Also handle state_proj which shouldn't exist in pi05 + if key.startswith("state_proj."): + logging.warning(f"Skipping state_proj key in pi05 mode: {key}") + continue + + # Handle vision tower embedding layer potential differences + if "patch_embedding" in key: + # Some checkpoints might have this, but current model expects different structure + logging.warning(f"Vision embedding key might need handling: {key}") + + fixed_state_dict[new_key] = value + + return fixed_state_dict + + def get_optim_params(self) -> dict: + return self.parameters() + + def reset(self): + """Reset internal state - called when environment resets.""" + self._action_queue = deque(maxlen=self.config.n_action_steps) + self._queues = { + ACTION: deque(maxlen=self.config.n_action_steps), + } + + def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: + """Preprocess images for the model. + + Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1]. + PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1]. + """ + images = [] + img_masks = [] + + # Get device from model parameters + device = next(self.parameters()).device + + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] + + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. " + f"(batch: {batch.keys()}) (image_features: {self.config.image_features})" + ) + + # Preprocess image features present in the batch + for key in present_img_keys: + img = batch[key] + + # Ensure tensor is on the same device as the model + if img.device != device: + img = img.to(device) + + # Ensure float32 dtype for consistency + if img.dtype != torch.float32: + img = img.to(torch.float32) + + # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1 + + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + img = img.permute(0, 2, 3, 1) + + # from openpi preprocess_observation_pytorch: Resize with padding if needed + if img.shape[1:3] != self.config.image_resolution: + img = resize_with_pad_torch(img, *self.config.image_resolution) + + # Normalize from [0,1] to [-1,1] as expected by siglip + img = img * 2.0 - 1.0 + + # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + images.append(img) + # Create mask (all ones for real images) + bsize = img.shape[0] + mask = torch.ones(bsize, dtype=torch.bool, device=device) + img_masks.append(mask) + + # Create image features not present in the batch as fully 0 padded images + for _num_empty_cameras in range(len(missing_img_keys)): + img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP + mask = torch.zeros_like(mask) # Mask is zero for empty cameras + images.append(img) + img_masks.append(mask) + + return images, img_masks + + def prepare_action(self, batch): + """Pad action""" + actions = pad_vector(batch[ACTION], self.config.max_action_dim) + return actions + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations.""" + self.eval() + + # Action queue logic for n_action_steps > 1 + if len(self._action_queue) == 0: + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] + # Transpose to get shape (n_action_steps, batch_size, action_dim) + self._action_queue.extend(actions.transpose(0, 1)) + + return self._action_queue.popleft() + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + self.eval() + + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + + # Sample actions using the model (no separate state needed for PI05) + actions = self.model.sample_actions(images, img_masks, tokens, masks) + + # Unpad actions to actual action dimension + original_action_dim = self.config.output_features[ACTION].shape[0] + actions = actions[:, :, :original_action_dim] + + return actions + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss for training.""" + + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + + actions = self.prepare_action(batch) + + # Compute loss (no separate state needed for PI05) + losses = self.model.forward(images, img_masks, tokens, masks, actions) + + # Truncate losses to actual action dimensions + original_action_dim = self.config.output_features[ACTION].shape[0] + losses = losses[:, :, :original_action_dim] + + loss = losses.mean() + + loss_dict = { + "loss": loss.item(), + "loss_per_dim": losses.mean(dim=[0, 1]).detach().cpu().numpy().tolist(), + } + + return loss, loss_dict diff --git a/src/lerobot/policies/pi05/processor_pi05.py b/src/lerobot/policies/pi05/processor_pi05.py new file mode 100644 index 000000000..e29bc4c23 --- /dev/null +++ b/src/lerobot/policies/pi05/processor_pi05.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.policies.pi05.configuration_pi05 import PI05Config +from lerobot.policies.pi05.modeling_pi05 import pad_vector +from lerobot.processor import ( + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + TokenizerProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.utils.constants import ( + OBS_STATE, + POLICY_POSTPROCESSOR_DEFAULT_NAME, + POLICY_PREPROCESSOR_DEFAULT_NAME, +) + + +@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step") +@dataclass +class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep): + """ + Processor step to prepare the state and tokenize the language input. + """ + + max_state_dim: int = 32 + task_key: str = "task" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + transition = transition.copy() + + state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE) + if state is None: + raise ValueError("State is required for PI05") + tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key) + if tasks is None: + raise ValueError("No task found in complementary data") + + # TODO: check if this necessary + state = deepcopy(state) + + # Prepare state (pad to max_state_dim) + state = pad_vector(state, self.max_state_dim) + + # State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step + # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + state_np = state.cpu().numpy() + discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + + full_prompts = [] + for i, task in enumerate(tasks): + cleaned_text = task.strip().replace("_", " ").replace("\n", " ") + state_str = " ".join(map(str, discretized_states[i])) + full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " + full_prompts.append(full_prompt) + + transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts + # Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!) + # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + This step does not alter the feature definitions. + """ + return features + + +def make_pi05_pre_post_processors( + config: PI05Config, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the PI0 policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Appending a newline character to the task description for tokenizer compatibility. + 5. Tokenizing the text prompt using the PaliGemma tokenizer. + 6. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the PI0 policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + + # Add remaining processors + input_steps: list[ProcessorStep] = [ + RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one + AddBatchDimensionProcessorStep(), + # NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep + # because the tokenizer step expects normalized state in [-1, 1] range for discretization + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim), + TokenizerProcessorStep( + tokenizer_name="google/paligemma-3b-pt-224", + max_length=config.tokenizer_max_length, + padding_side="right", + padding="max_length", + ), + DeviceProcessorStep(device=config.device), + ] + + output_steps: list[ProcessorStep] = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/pi0fast/configuration_pi0fast.py b/src/lerobot/policies/pi0fast/configuration_pi0fast.py index 705b61ea8..cefd4e688 100644 --- a/src/lerobot/policies/pi0fast/configuration_pi0fast.py +++ b/src/lerobot/policies/pi0fast/configuration_pi0fast.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass, field from lerobot.configs.policies import PreTrainedConfig diff --git a/src/lerobot/processor/migrate_policy_normalization.py b/src/lerobot/processor/migrate_policy_normalization.py index 319145d1a..525b7431c 100644 --- a/src/lerobot/processor/migrate_policy_normalization.py +++ b/src/lerobot/processor/migrate_policy_normalization.py @@ -303,6 +303,65 @@ def clean_state_dict( return new_state_dict +def load_state_dict_with_missing_key_handling( + policy: torch.nn.Module, + state_dict: dict[str, torch.Tensor], + policy_type: str, + known_missing_keys_whitelist: dict[str, list[str]], +) -> list[str]: + """ + Load state dict into policy with graceful handling of missing keys. + + This function loads the state dict with strict=False, filters out whitelisted + missing keys, and provides detailed reporting about any issues found. + + Args: + policy: The policy model to load the state dict into. + state_dict: The cleaned state dictionary to load. + policy_type: The type of policy (used for whitelist lookup). + known_missing_keys_whitelist: Dictionary mapping policy types to lists of + known acceptable missing keys. + + Returns: + List of problematic missing keys that weren't in the whitelist. + """ + # Load the cleaned state dict with strict=False to capture missing/unexpected keys + load_result = policy.load_state_dict(state_dict, strict=False) + + # Check for missing keys + missing_keys = load_result.missing_keys + unexpected_keys = load_result.unexpected_keys + + # Filter out whitelisted missing keys + policy_type_lower = policy_type.lower() + whitelisted_keys = known_missing_keys_whitelist.get(policy_type_lower, []) + problematic_missing_keys = [key for key in missing_keys if key not in whitelisted_keys] + + if missing_keys: + if problematic_missing_keys: + print(f"WARNING: Found {len(problematic_missing_keys)} unexpected missing keys:") + for key in problematic_missing_keys: + print(f" - {key}") + + if len(missing_keys) > len(problematic_missing_keys): + whitelisted_missing = [key for key in missing_keys if key in whitelisted_keys] + print(f"INFO: Found {len(whitelisted_missing)} expected missing keys (whitelisted):") + for key in whitelisted_missing: + print(f" - {key}") + + if unexpected_keys: + print(f"WARNING: Found {len(unexpected_keys)} unexpected keys:") + for key in unexpected_keys: + print(f" - {key}") + + if not missing_keys and not unexpected_keys: + print("Successfully loaded cleaned state dict into policy model (all keys matched)") + else: + print("State dict loaded with some missing/unexpected keys (see details above)") + + return problematic_missing_keys + + def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]: """ Converts a feature dictionary from the old config format to the new `PolicyFeature` format. @@ -336,9 +395,45 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[ return converted_features +def display_migration_summary_with_warnings(problematic_missing_keys: list[str]) -> None: + """ + Display final migration summary with warnings about problematic missing keys. + + Args: + problematic_missing_keys: List of missing keys that weren't in the whitelist. + """ + if not problematic_missing_keys: + return + + print("\n" + "=" * 60) + print("IMPORTANT: MIGRATION COMPLETED WITH WARNINGS") + print("=" * 60) + print( + f"The migration was successful, but {len(problematic_missing_keys)} unexpected missing keys were found:" + ) + print() + for key in problematic_missing_keys: + print(f" - {key}") + print() + print("These missing keys may indicate:") + print(" • The model architecture has changed") + print(" • Some components were not properly saved in the original model") + print(" • The migration script needs to be updated for this policy type") + print() + print("What to do next:") + print(" 1. Test your migrated model carefully to ensure it works as expected") + print(" 2. If you encounter issues, please open an issue at:") + print(" https://github.com/huggingface/lerobot/issues") + print(" 3. Include this migration log and the missing keys listed above") + print() + print("If the model works correctly despite these warnings, the missing keys") + print("might be expected for your policy type and can be added to the whitelist.") + print("=" * 60) + + def load_model_from_hub( repo_id: str, revision: str | None = None -) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]: +) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any] | None]: """ Downloads and loads a model's state_dict and configs from the Hugging Face Hub. @@ -348,13 +443,12 @@ def load_model_from_hub( Returns: A tuple containing the model's state dictionary, the policy configuration, - and the training configuration. + and the training configuration (None if train_config.json is not found). """ # Download files. safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision) config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision) - train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision) # Load state_dict state_dict = load_safetensors(safetensors_path) @@ -363,8 +457,14 @@ def load_model_from_hub( with open(config_path) as f: config = json.load(f) - with open(train_config_path) as f: - train_config = json.load(f) + # Try to load train_config (optional) + train_config = None + try: + train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision) + with open(train_config_path) as f: + train_config = json.load(f) + except FileNotFoundError: + print("train_config.json not found - continuing without training configuration") return state_dict, config, train_config @@ -410,8 +510,15 @@ def main(): state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors")) with open(os.path.join(args.pretrained_path, "config.json")) as f: config = json.load(f) - with open(os.path.join(args.pretrained_path, "train_config.json")) as f: - train_config = json.load(f) + + # Try to load train_config (optional) + train_config = None + train_config_path = os.path.join(args.pretrained_path, "train_config.json") + if os.path.exists(train_config_path): + with open(train_config_path) as f: + train_config = json.load(f) + else: + print("train_config.json not found - continuing without training configuration") else: # Hub repository state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision) @@ -488,10 +595,20 @@ def main(): policy_class = get_policy_class(policy_type) policy = policy_class(policy_config) - # Load the cleaned state dict - policy.load_state_dict(new_state_dict, strict=True) - print("Successfully loaded cleaned state dict into policy model") + # Define whitelist of known missing keys that are acceptable (for example weight tie) for certain policy types + known_missing_keys_whitelist = { + "pi0": ["model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"], + # Add other policy types and their known missing keys here as needed + } + # Load state dict with graceful missing key handling + problematic_missing_keys = load_state_dict_with_missing_key_handling( + policy=policy, + state_dict=new_state_dict, + policy_type=policy_type, + known_missing_keys_whitelist=known_missing_keys_whitelist, + ) + policy.to(torch.float32) # Create preprocessor and postprocessor using the factory print("Creating preprocessor and postprocessor using make_pre_post_processors...") preprocessor, postprocessor = make_pre_post_processors(policy_cfg=policy_config, dataset_stats=stats) @@ -521,7 +638,9 @@ def main(): # Generate and save model card print("Generating model card...") # Get metadata from original config - dataset_repo_id = train_config.get("repo_id", "unknown") + dataset_repo_id = "unknown" + if train_config is not None: + dataset_repo_id = train_config.get("repo_id", "unknown") license = config.get("license", "apache-2.0") tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type] @@ -552,25 +671,25 @@ def main(): if create_pr: # Separate commit description for PR body - commit_description = """🤖 **Automated Policy Migration to PolicyProcessorPipeline** + commit_description = """**Automated Policy Migration to PolicyProcessorPipeline** This PR migrates your model to the new LeRobot policy format using the modern PolicyProcessorPipeline architecture. ## What Changed -### ✨ **New Architecture - PolicyProcessorPipeline** +### **New Architecture - PolicyProcessorPipeline** Your model now uses external PolicyProcessorPipeline components for data processing instead of built-in normalization layers. This provides: - **Modularity**: Separate preprocessing and postprocessing pipelines - **Flexibility**: Easy to swap, configure, and debug processing steps - **Compatibility**: Works with the latest LeRobot ecosystem -### 🔧 **Normalization Extraction** +### **Normalization Extraction** We've extracted normalization statistics from your model's state_dict and removed the built-in normalization layers: - **Extracted patterns**: `normalize_inputs.*`, `unnormalize_outputs.*`, `normalize.*`, `unnormalize.*`, `input_normalizer.*`, `output_normalizer.*` - **Statistics preserved**: Mean, std, min, max values for all features - **Clean model**: State dict now contains only core model weights -### 📦 **Files Added** +### **Files Added** - **preprocessor_config.json**: Configuration for input preprocessing pipeline - **postprocessor_config.json**: Configuration for output postprocessing pipeline - **model.safetensors**: Clean model weights without normalization layers @@ -578,13 +697,13 @@ We've extracted normalization statistics from your model's state_dict and remove - **train_config.json**: Training configuration - **README.md**: Updated model card with migration information -### 🚀 **Benefits** +### **Benefits** - **Backward Compatible**: Your model behavior remains identical - **Future Ready**: Compatible with latest LeRobot features and updates - **Debuggable**: Easy to inspect and modify processing steps - **Portable**: Processors can be shared and reused across models -### 💻 **Usage** +### **Usage** ```python # Load your migrated model from lerobot.policies import get_policy_class @@ -642,6 +761,9 @@ final_action = postprocessor(action) else: print(f"\nView the changes at: https://huggingface.co/{hub_repo_id}") + # Display final summary about any problematic missing keys + display_migration_summary_with_warnings(problematic_missing_keys) + if __name__ == "__main__": main() diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index ce69a103f..368c9b270 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -281,8 +281,14 @@ class _NormalizationMixin: """ Core logic to apply a normalization or unnormalization transformation to a tensor. - This method selects the appropriate normalization mode (e.g., mean/std, min/max) - based on the feature type and applies the corresponding mathematical operation. + This method selects the appropriate normalization mode based on the feature type + and applies the corresponding mathematical operation. + + Normalization Modes: + - MEAN_STD: Centers data around zero with unit variance. + - MIN_MAX: Scales data to [-1, 1] range using actual min/max values. + - QUANTILES: Scales data to [-1, 1] range using 1st and 99th percentiles (q01/q99). + - QUANTILE10: Scales data to [-1, 1] range using 10th and 90th percentiles (q10/q90). Args: tensor: The input tensor to transform. @@ -300,7 +306,12 @@ class _NormalizationMixin: if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats: return tensor - if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX): + if norm_mode not in ( + NormalizationMode.MEAN_STD, + NormalizationMode.MIN_MAX, + NormalizationMode.QUANTILES, + NormalizationMode.QUANTILE10, + ): raise ValueError(f"Unsupported normalization mode: {norm_mode}") # For Accelerate compatibility: Ensure stats are on the same device and dtype as the input tensor @@ -311,7 +322,14 @@ class _NormalizationMixin: stats = self._tensor_stats[key] - if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats: + if norm_mode == NormalizationMode.MEAN_STD: + mean = stats.get("mean", None) + std = stats.get("std", None) + if mean is None or std is None: + raise ValueError( + "MEAN_STD normalization mode requires mean and std stats, please update the dataset with the correct stats" + ) + mean, std = stats["mean"], stats["std"] # Avoid division by zero by adding a small epsilon. denom = std + self.eps @@ -319,7 +337,14 @@ class _NormalizationMixin: return tensor * std + mean return (tensor - mean) / denom - if norm_mode == NormalizationMode.MIN_MAX and "min" in stats and "max" in stats: + if norm_mode == NormalizationMode.MIN_MAX: + min_val = stats.get("min", None) + max_val = stats.get("max", None) + if min_val is None or max_val is None: + raise ValueError( + "MIN_MAX normalization mode requires min and max stats, please update the dataset with the correct stats" + ) + min_val, max_val = stats["min"], stats["max"] denom = max_val - min_val # When min_val == max_val, substitute the denominator with a small epsilon @@ -334,6 +359,40 @@ class _NormalizationMixin: # Map from [min, max] to [-1, 1] return 2 * (tensor - min_val) / denom - 1 + if norm_mode == NormalizationMode.QUANTILES: + q01 = stats.get("q01", None) + q99 = stats.get("q99", None) + if q01 is None or q99 is None: + raise ValueError( + "QUANTILES normalization mode requires q01 and q99 stats, please update the dataset with the correct stats using the `augment_dataset_quantile_stats.py` script" + ) + + denom = q99 - q01 + # Avoid division by zero by adding epsilon when quantiles are identical + denom = torch.where( + denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom + ) + if inverse: + return (tensor + 1.0) * denom / 2.0 + q01 + return 2.0 * (tensor - q01) / denom - 1.0 + + if norm_mode == NormalizationMode.QUANTILE10: + q10 = stats.get("q10", None) + q90 = stats.get("q90", None) + if q10 is None or q90 is None: + raise ValueError( + "QUANTILE10 normalization mode requires q10 and q90 stats, please update the dataset with the correct stats using the `augment_dataset_quantile_stats.py` script" + ) + + denom = q90 - q10 + # Avoid division by zero by adding epsilon when quantiles are identical + denom = torch.where( + denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom + ) + if inverse: + return (tensor + 1.0) * denom / 2.0 + q10 + return 2.0 * (tensor - q10) / denom - 1.0 + # If necessary stats are missing, return input unchanged. return tensor diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 12a1f53c7..bc66618ca 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -180,7 +180,8 @@ def train(cfg: TrainPipelineConfig): # Create processors - only provide dataset_stats if not resuming from saved processors processor_kwargs = {} - if not (cfg.resume and cfg.policy.pretrained_path): + postprocessor_kwargs = {} + if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path: # Only provide dataset_stats when not resuming from saved processor state processor_kwargs["dataset_stats"] = dataset.meta.stats @@ -190,17 +191,22 @@ def train(cfg: TrainPipelineConfig): "normalizer_processor": { "stats": dataset.meta.stats, "features": {**policy.config.input_features, **policy.config.output_features}, + "norm_map": policy.config.normalization_mapping, }, } - processor_kwargs["postprocessor_overrides"] = { + postprocessor_kwargs["postprocessor_overrides"] = { "unnormalizer_processor": { "stats": dataset.meta.stats, - "features": {**policy.config.input_features, **policy.config.output_features}, + "features": policy.config.output_features, + "norm_map": policy.config.normalization_mapping, }, } preprocessor, postprocessor = make_pre_post_processors( - policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + **processor_kwargs, + **postprocessor_kwargs, ) logging.info("Creating optimizer and scheduler") diff --git a/src/lerobot/templates/lerobot_modelcard_template.md b/src/lerobot/templates/lerobot_modelcard_template.md index 9293d6ba7..34af282b0 100644 --- a/src/lerobot/templates/lerobot_modelcard_template.md +++ b/src/lerobot/templates/lerobot_modelcard_template.md @@ -19,10 +19,28 @@ [Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation. {% elif model_name == "vqbet" %} [VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills. -{% elif model_name == "pi0" %} -[Pi0](https://huggingface.co/papers/2410.24164) is a generalist vision-language-action transformer that converts multimodal observations and text instructions into robot actions for zero-shot task transfer. {% elif model_name == "pi0fast" %} [Pi0-Fast](https://huggingface.co/papers/2501.09747) is a variant of Pi0 that uses a new tokenization method called FAST, which enables training of an autoregressive vision-language-action policy for high-frequency robotic tasks with improved performance and reduced training time. +{% elif model_name == "pi0" %} +**π₀ (Pi0)** + +π₀ is a Vision-Language-Action model for general robot control, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository. + +**Model Overview** + +π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by Physical Intelligence. Unlike traditional robots that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks. + +For more details, see the [Physical Intelligence π₀ blog post](https://www.physicalintelligence.company/blog/pi0). +{% elif model_name == "pi05" %} +**π₀.₅ (Pi05) Policy** + +π₀.₅ is a Vision-Language-Action model with open-world generalization, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository. + +**Model Overview** + +π₀.₅ represents a significant evolution from π₀, developed by Physical Intelligence to address a big challenge in robotics: open-world generalization. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training. + +For more details, see the [Physical Intelligence π₀.₅ blog post](https://www.physicalintelligence.company/blog/pi05). {% elif model_name == "sac" %} [Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) is an entropy-regularised actor-critic algorithm offering stable, sample-efficient learning in continuous-control environments. {% elif model_name == "reward_classifier" %} diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index 824f74b30..6847666eb 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -67,3 +67,6 @@ HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibr # streaming datasets LOOKBACK_BACKTRACKTABLE = 100 LOOKAHEAD_BACKTRACKTABLE = 100 + +# openpi +OPENPI_ATTENTION_MASK_VALUE = -2.3819763e38 # TODO(pepijn): Modify this when extending support to fp8 models diff --git a/tests/datasets/test_compute_stats.py b/tests/datasets/test_compute_stats.py index 982f35c3f..973c80bd8 100644 --- a/tests/datasets/test_compute_stats.py +++ b/tests/datasets/test_compute_stats.py @@ -19,6 +19,7 @@ import numpy as np import pytest from lerobot.datasets.compute_stats import ( + RunningQuantileStats, _assert_type_and_shape, aggregate_feature_stats, aggregate_stats, @@ -102,6 +103,9 @@ def test_get_feature_stats_axis_1(sample_array): "count": np.array([3]), } result = get_feature_stats(sample_array, axis=(1,), keepdims=False) + + # Check that basic stats are correct (quantiles are also included now) + assert set(expected.keys()).issubset(set(result.keys())) for key in expected: np.testing.assert_allclose(result[key], expected[key]) @@ -115,6 +119,9 @@ def test_get_feature_stats_no_axis(sample_array): "count": np.array([3]), } result = get_feature_stats(sample_array, axis=None, keepdims=False) + + # Check that basic stats are correct (quantiles are also included now) + assert set(expected.keys()).issubset(set(result.keys())) for key in expected: np.testing.assert_allclose(result[key], expected[key]) @@ -308,3 +315,520 @@ def test_aggregate_stats(): results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04 ) np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"]) + + +def test_running_quantile_stats_initialization(): + """Test proper initialization of RunningQuantileStats.""" + running_stats = RunningQuantileStats() + assert running_stats._count == 0 + assert running_stats._mean is None + assert running_stats._num_quantile_bins == 5000 + + # Test custom bin size + running_stats_custom = RunningQuantileStats(num_quantile_bins=1000) + assert running_stats_custom._num_quantile_bins == 1000 + + +def test_running_quantile_stats_single_batch_update(): + """Test updating with a single batch.""" + np.random.seed(42) + data = np.random.normal(0, 1, (100, 3)) + + running_stats = RunningQuantileStats() + running_stats.update(data) + + assert running_stats._count == 100 + assert running_stats._mean.shape == (3,) + assert len(running_stats._histograms) == 3 + assert len(running_stats._bin_edges) == 3 + + # Verify basic statistics are reasonable + np.testing.assert_allclose(running_stats._mean, np.mean(data, axis=0), atol=1e-10) + + +def test_running_quantile_stats_multiple_batch_updates(): + """Test updating with multiple batches.""" + np.random.seed(42) + data1 = np.random.normal(0, 1, (100, 2)) + data2 = np.random.normal(1, 1, (150, 2)) + + running_stats = RunningQuantileStats() + running_stats.update(data1) + running_stats.update(data2) + + assert running_stats._count == 250 + + # Verify running mean is correct + combined_data = np.vstack([data1, data2]) + expected_mean = np.mean(combined_data, axis=0) + np.testing.assert_allclose(running_stats._mean, expected_mean, atol=1e-10) + + +def test_running_quantile_stats_get_statistics_basic(): + """Test getting basic statistics without quantiles.""" + np.random.seed(42) + data = np.random.normal(0, 1, (100, 2)) + + running_stats = RunningQuantileStats() + running_stats.update(data) + + stats = running_stats.get_statistics() + + # Should have basic stats + expected_keys = {"min", "max", "mean", "std", "count"} + assert expected_keys.issubset(set(stats.keys())) + + # Verify values + np.testing.assert_allclose(stats["mean"], np.mean(data, axis=0), atol=1e-10) + np.testing.assert_allclose(stats["std"], np.std(data, axis=0), atol=1e-6) + np.testing.assert_equal(stats["count"], np.array([100])) + + +def test_running_quantile_stats_get_statistics_with_quantiles(): + """Test getting statistics with quantiles.""" + np.random.seed(42) + data = np.random.normal(0, 1, (1000, 2)) + + running_stats = RunningQuantileStats() + running_stats.update(data) + + stats = running_stats.get_statistics() + + # Should have basic stats plus quantiles + expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} + assert expected_keys.issubset(set(stats.keys())) + + # Verify quantile values are reasonable + from lerobot.datasets.compute_stats import DEFAULT_QUANTILES + + for i, q in enumerate(DEFAULT_QUANTILES): + q_key = f"q{int(q * 100):02d}" + assert q_key in stats + assert stats[q_key].shape == (2,) + + # Check that quantiles are in reasonable order + if i > 0: + prev_q_key = f"q{int(DEFAULT_QUANTILES[i - 1] * 100):02d}" + assert np.all(stats[prev_q_key] <= stats[q_key]) + + +def test_running_quantile_stats_histogram_adjustment(): + """Test that histograms adjust when min/max change.""" + running_stats = RunningQuantileStats() + + # Initial data with small range + data1 = np.array([[0.0, 1.0], [0.1, 1.1], [0.2, 1.2]]) + running_stats.update(data1) + + initial_edges_0 = running_stats._bin_edges[0].copy() + initial_edges_1 = running_stats._bin_edges[1].copy() + + # Add data with much larger range + data2 = np.array([[10.0, -10.0], [11.0, -11.0]]) + running_stats.update(data2) + + # Bin edges should have changed + assert not np.array_equal(initial_edges_0, running_stats._bin_edges[0]) + assert not np.array_equal(initial_edges_1, running_stats._bin_edges[1]) + + # New edges should cover the expanded range + # First dimension: min should still be ~0.0, max should be ~11.0 + assert running_stats._bin_edges[0][0] <= 0.0 + assert running_stats._bin_edges[0][-1] >= 11.0 + + # Second dimension: min should be ~-11.0, max should be ~1.2 + assert running_stats._bin_edges[1][0] <= -11.0 + assert running_stats._bin_edges[1][-1] >= 1.2 + + +def test_running_quantile_stats_insufficient_data_error(): + """Test error when trying to get stats with insufficient data.""" + running_stats = RunningQuantileStats() + + with pytest.raises(ValueError, match="Cannot compute statistics for less than 2 vectors"): + running_stats.get_statistics() + + # Single vector should also fail + running_stats.update(np.array([[1.0]])) + with pytest.raises(ValueError, match="Cannot compute statistics for less than 2 vectors"): + running_stats.get_statistics() + + +def test_running_quantile_stats_vector_length_consistency(): + """Test error when vector lengths don't match.""" + running_stats = RunningQuantileStats() + running_stats.update(np.array([[1.0, 2.0], [3.0, 4.0]])) + + with pytest.raises(ValueError, match="The length of new vectors does not match"): + running_stats.update(np.array([[1.0, 2.0, 3.0]])) # Different length + + +def test_running_quantile_stats_reshape_handling(): + """Test that various input shapes are handled correctly.""" + running_stats = RunningQuantileStats() + + # Test 3D input (e.g., images) + data_3d = np.random.normal(0, 1, (10, 32, 32)) + running_stats.update(data_3d) + + assert running_stats._count == 10 * 32 + assert running_stats._mean.shape == (32,) + + # Test 1D input + running_stats_1d = RunningQuantileStats() + data_1d = np.array([1, 2, 3, 4, 5]).reshape(-1, 1) + running_stats_1d.update(data_1d) + + assert running_stats_1d._count == 5 + assert running_stats_1d._mean.shape == (1,) + + +def test_get_feature_stats_quantiles_enabled_by_default(): + """Test that quantiles are computed by default.""" + data = np.random.normal(0, 1, (100, 5)) + stats = get_feature_stats(data, axis=0, keepdims=False) + + expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} + assert set(stats.keys()) == expected_keys + + +def test_get_feature_stats_quantiles_with_vector_data(): + """Test quantile computation with vector data.""" + np.random.seed(42) + data = np.random.normal(0, 1, (100, 5)) + + stats = get_feature_stats(data, axis=0, keepdims=False) + + expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} + assert set(stats.keys()) == expected_keys + + # Verify shapes + assert stats["q01"].shape == (5,) + assert stats["q99"].shape == (5,) + + # Verify quantiles are reasonable + assert np.all(stats["q01"] < stats["q99"]) + + +def test_get_feature_stats_quantiles_with_image_data(): + """Test quantile computation with image data.""" + np.random.seed(42) + data = np.random.normal(0, 1, (50, 3, 32, 32)) # batch, channels, height, width + + stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True) + + expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} + assert set(stats.keys()) == expected_keys + + # Verify shapes for images (should be (1, channels, 1, 1)) + assert stats["q01"].shape == (1, 3, 1, 1) + assert stats["q50"].shape == (1, 3, 1, 1) + assert stats["q99"].shape == (1, 3, 1, 1) + + +def test_get_feature_stats_fixed_quantiles(): + """Test that fixed quantiles are always computed.""" + data = np.random.normal(0, 1, (200, 3)) + + stats = get_feature_stats(data, axis=0, keepdims=False) + + expected_quantile_keys = {"q01", "q10", "q50", "q90", "q99"} + assert expected_quantile_keys.issubset(set(stats.keys())) + + +def test_get_feature_stats_unsupported_axis_error(): + """Test error for unsupported axis configuration.""" + data = np.random.normal(0, 1, (10, 5)) + + with pytest.raises(ValueError, match="Unsupported axis configuration"): + get_feature_stats( + data, + axis=(1, 2), # Unsupported axis + keepdims=False, + ) + + +def test_compute_episode_stats_backward_compatibility(): + """Test that existing functionality is preserved.""" + episode_data = { + "action": np.random.normal(0, 1, (100, 7)), + "observation.state": np.random.normal(0, 1, (100, 10)), + } + features = { + "action": {"dtype": "float32", "shape": (7,)}, + "observation.state": {"dtype": "float32", "shape": (10,)}, + } + + stats = compute_episode_stats(episode_data, features) + + for key in ["action", "observation.state"]: + expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} + assert set(stats[key].keys()) == expected_keys + + +def test_compute_episode_stats_with_custom_quantiles(): + """Test quantile computation with custom quantile values.""" + np.random.seed(42) + episode_data = { + "action": np.random.normal(0, 1, (100, 7)), + "observation.state": np.random.normal(2, 1, (100, 10)), + } + features = { + "action": {"dtype": "float32", "shape": (7,)}, + "observation.state": {"dtype": "float32", "shape": (10,)}, + } + + stats = compute_episode_stats(episode_data, features) + + # Should have quantiles + for key in ["action", "observation.state"]: + expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} + assert set(stats[key].keys()) == expected_keys + + # Verify shapes + assert stats[key]["q01"].shape == (features[key]["shape"][0],) + assert stats[key]["q99"].shape == (features[key]["shape"][0],) + + +def test_compute_episode_stats_with_image_data(): + """Test quantile computation with image features.""" + image_paths = [f"image_{i}.jpg" for i in range(50)] + episode_data = { + "observation.image": image_paths, + "action": np.random.normal(0, 1, (50, 5)), + } + features = { + "observation.image": {"dtype": "image"}, + "action": {"dtype": "float32", "shape": (5,)}, + } + + with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy): + stats = compute_episode_stats(episode_data, features) + + # Image quantiles should be normalized and have correct shape + assert "q01" in stats["observation.image"] + assert "q50" in stats["observation.image"] + assert "q99" in stats["observation.image"] + assert stats["observation.image"]["q01"].shape == (3, 1, 1) + assert stats["observation.image"]["q50"].shape == (3, 1, 1) + assert stats["observation.image"]["q99"].shape == (3, 1, 1) + + # Action quantiles should have correct shape + assert stats["action"]["q01"].shape == (5,) + assert stats["action"]["q50"].shape == (5,) + assert stats["action"]["q99"].shape == (5,) + + +def test_compute_episode_stats_string_features_skipped(): + """Test that string features are properly skipped.""" + episode_data = { + "task": ["pick_apple"] * 100, # String feature + "action": np.random.normal(0, 1, (100, 5)), + } + features = { + "task": {"dtype": "string"}, + "action": {"dtype": "float32", "shape": (5,)}, + } + + stats = compute_episode_stats( + episode_data, + features, + ) + + # String features should be skipped + assert "task" not in stats + assert "action" in stats + assert "q01" in stats["action"] + + +def test_aggregate_feature_stats_with_quantiles(): + """Test aggregating feature stats that include quantiles.""" + stats_ft_list = [ + { + "min": np.array([1.0]), + "max": np.array([10.0]), + "mean": np.array([5.0]), + "std": np.array([2.0]), + "count": np.array([100]), + "q01": np.array([1.5]), + "q99": np.array([9.5]), + }, + { + "min": np.array([2.0]), + "max": np.array([12.0]), + "mean": np.array([6.0]), + "std": np.array([2.5]), + "count": np.array([150]), + "q01": np.array([2.5]), + "q99": np.array([11.5]), + }, + ] + + result = aggregate_feature_stats(stats_ft_list) + + # Should preserve quantiles + assert "q01" in result + assert "q99" in result + + # Verify quantile aggregation (weighted average) + expected_q01 = (1.5 * 100 + 2.5 * 150) / 250 # ≈ 2.1 + expected_q99 = (9.5 * 100 + 11.5 * 150) / 250 # ≈ 10.7 + + np.testing.assert_allclose(result["q01"], np.array([expected_q01]), atol=1e-6) + np.testing.assert_allclose(result["q99"], np.array([expected_q99]), atol=1e-6) + + +def test_aggregate_stats_mixed_quantiles(): + """Test aggregating stats where some have quantiles and some don't.""" + stats_with_quantiles = { + "feature1": { + "min": np.array([1.0]), + "max": np.array([10.0]), + "mean": np.array([5.0]), + "std": np.array([2.0]), + "count": np.array([100]), + "q01": np.array([1.5]), + "q99": np.array([9.5]), + } + } + + stats_without_quantiles = { + "feature2": { + "min": np.array([0.0]), + "max": np.array([5.0]), + "mean": np.array([2.5]), + "std": np.array([1.5]), + "count": np.array([50]), + } + } + + all_stats = [stats_with_quantiles, stats_without_quantiles] + result = aggregate_stats(all_stats) + + # Feature1 should keep its quantiles + assert "q01" in result["feature1"] + assert "q99" in result["feature1"] + + # Feature2 should not have quantiles + assert "q01" not in result["feature2"] + assert "q99" not in result["feature2"] + + +def test_assert_type_and_shape_with_quantiles(): + """Test validation works correctly with quantile keys.""" + # Valid stats with quantiles + valid_stats = [ + { + "observation.image": { + "min": np.array([0.0, 0.0, 0.0]).reshape(3, 1, 1), + "max": np.array([1.0, 1.0, 1.0]).reshape(3, 1, 1), + "mean": np.array([0.5, 0.5, 0.5]).reshape(3, 1, 1), + "std": np.array([0.2, 0.2, 0.2]).reshape(3, 1, 1), + "count": np.array([100]), + "q01": np.array([0.1, 0.1, 0.1]).reshape(3, 1, 1), + "q99": np.array([0.9, 0.9, 0.9]).reshape(3, 1, 1), + } + } + ] + + # Should not raise error + _assert_type_and_shape(valid_stats) + + # Invalid shape for quantile + invalid_stats = [ + { + "observation.image": { + "count": np.array([100]), + "q01": np.array([0.1, 0.2]), # Wrong shape for image quantile + } + } + ] + + with pytest.raises(ValueError, match="Shape of quantile 'q01' must be \\(3,1,1\\)"): + _assert_type_and_shape(invalid_stats) + + +def test_quantile_integration_single_value_quantiles(): + """Test quantile computation with single repeated value.""" + data = np.ones((100, 3)) # All ones + + running_stats = RunningQuantileStats() + running_stats.update(data) + + stats = running_stats.get_statistics() + + # All quantiles should be approximately 1.0 + np.testing.assert_allclose(stats["q01"], np.array([1.0, 1.0, 1.0]), atol=1e-6) + np.testing.assert_allclose(stats["q50"], np.array([1.0, 1.0, 1.0]), atol=1e-6) + np.testing.assert_allclose(stats["q99"], np.array([1.0, 1.0, 1.0]), atol=1e-6) + + +def test_quantile_integration_fixed_quantiles(): + """Test that fixed quantiles are computed.""" + np.random.seed(42) + data = np.random.normal(0, 1, (1000, 2)) + + stats = get_feature_stats(data, axis=0, keepdims=False) + + # Check all fixed quantiles are present + assert "q01" in stats + assert "q10" in stats + assert "q50" in stats + assert "q90" in stats + assert "q99" in stats + + +def test_quantile_integration_large_dataset_quantiles(): + """Test quantile computation efficiency with large datasets.""" + np.random.seed(42) + large_data = np.random.normal(0, 1, (10000, 5)) + + running_stats = RunningQuantileStats(num_quantile_bins=1000) # Reduced bins for speed + running_stats.update(large_data) + + stats = running_stats.get_statistics() + + # Should complete without issues and produce reasonable results + assert stats["count"][0] == 10000 + assert len(stats["q01"]) == 5 + + +def test_fixed_quantiles_always_computed(): + """Test that the fixed quantiles [0.01, 0.10, 0.50, 0.90, 0.99] are always computed.""" + np.random.seed(42) + # Test with vector data + vector_data = np.random.normal(0, 1, (100, 5)) + vector_stats = get_feature_stats(vector_data, axis=0, keepdims=False) + + # Check all fixed quantiles are present + expected_quantiles = ["q01", "q10", "q50", "q90", "q99"] + for q_key in expected_quantiles: + assert q_key in vector_stats + assert vector_stats[q_key].shape == (5,) + + # Test with image data + image_data = np.random.randint(0, 256, (50, 3, 32, 32), dtype=np.uint8) + image_stats = get_feature_stats(image_data, axis=(0, 2, 3), keepdims=True) + + # Check all fixed quantiles are present for images + for q_key in expected_quantiles: + assert q_key in image_stats + assert image_stats[q_key].shape == (1, 3, 1, 1) + + # Test with episode data + episode_data = { + "action": np.random.normal(0, 1, (100, 7)), + "observation.state": np.random.normal(0, 1, (100, 10)), + } + features = { + "action": {"dtype": "float32", "shape": (7,)}, + "observation.state": {"dtype": "float32", "shape": (10,)}, + } + + episode_stats = compute_episode_stats(episode_data, features) + + # Check all fixed quantiles are present in episode stats + for key in ["action", "observation.state"]: + for q_key in expected_quantiles: + assert q_key in episode_stats[key] + assert episode_stats[key][q_key].shape == (features[key]["shape"][0],) diff --git a/tests/datasets/test_quantiles_dataset_integration.py b/tests/datasets/test_quantiles_dataset_integration.py new file mode 100644 index 000000000..4df7fab06 --- /dev/null +++ b/tests/datasets/test_quantiles_dataset_integration.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for quantile functionality in LeRobotDataset.""" + +import numpy as np +import pytest + +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +def mock_load_image_as_numpy(path, dtype, channel_first): + """Mock image loading for consistent test results.""" + return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype) + + +@pytest.fixture +def simple_features(): + """Simple feature configuration for testing.""" + return { + "action": { + "dtype": "float32", + "shape": (4,), + "names": ["arm_x", "arm_y", "arm_z", "gripper"], + }, + "observation.state": { + "dtype": "float32", + "shape": (10,), + "names": [f"joint_{i}" for i in range(10)], + }, + } + + +def test_create_dataset_with_fixed_quantiles(tmp_path, simple_features): + """Test creating dataset with fixed quantiles.""" + dataset = LeRobotDataset.create( + repo_id="test_dataset_fixed_quantiles", + fps=30, + features=simple_features, + root=tmp_path / "create_fixed_quantiles", + ) + + # Dataset should be created successfully + assert dataset is not None + + +def test_save_episode_computes_all_quantiles(tmp_path, simple_features): + """Test that all fixed quantiles are computed when saving an episode.""" + dataset = LeRobotDataset.create( + repo_id="test_dataset_save_episode", + fps=30, + features=simple_features, + root=tmp_path / "save_episode_quantiles", + ) + + # Add some frames + for _ in range(10): + dataset.add_frame( + { + "action": np.random.randn(4).astype(np.float32), # Correct shape for action + "observation.state": np.random.randn(10).astype(np.float32), + "task": "test_task", + } + ) + + dataset.save_episode() + + # Check that all fixed quantiles were computed + stats = dataset.meta.stats + for key in ["action", "observation.state"]: + assert "q01" in stats[key] + assert "q10" in stats[key] + assert "q50" in stats[key] + assert "q90" in stats[key] + assert "q99" in stats[key] + + +def test_quantile_values_ordering(tmp_path, simple_features): + """Test that quantile values are properly ordered.""" + dataset = LeRobotDataset.create( + repo_id="test_dataset_quantile_ordering", + fps=30, + features=simple_features, + root=tmp_path / "quantile_ordering", + ) + + # Add data with known distribution + np.random.seed(42) + for _ in range(100): + dataset.add_frame( + { + "action": np.random.randn(4).astype(np.float32), # Correct shape for action + "observation.state": np.random.randn(10).astype(np.float32), + "task": "test_task", + } + ) + + dataset.save_episode() + stats = dataset.meta.stats + + # Verify quantile ordering + for key in ["action", "observation.state"]: + assert np.all(stats[key]["q01"] <= stats[key]["q10"]) + assert np.all(stats[key]["q10"] <= stats[key]["q50"]) + assert np.all(stats[key]["q50"] <= stats[key]["q90"]) + assert np.all(stats[key]["q90"] <= stats[key]["q99"]) + + +def test_save_episode_with_fixed_quantiles(tmp_path, simple_features): + """Test saving episode always computes fixed quantiles.""" + dataset = LeRobotDataset.create( + repo_id="test_dataset_save_fixed", + fps=30, + features=simple_features, + root=tmp_path / "save_fixed_quantiles", + ) + + # Add frames to episode + np.random.seed(42) + for _ in range(50): + frame = { + "action": np.random.normal(0, 1, (4,)).astype(np.float32), + "observation.state": np.random.normal(0, 1, (10,)).astype(np.float32), + "task": "test_task", + } + dataset.add_frame(frame) + + dataset.save_episode() + + # Check that all fixed quantiles are included + stats = dataset.meta.stats + for key in ["action", "observation.state"]: + feature_stats = stats[key] + expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} + assert set(feature_stats.keys()) == expected_keys + + +def test_quantile_aggregation_across_episodes(tmp_path, simple_features): + """Test quantile aggregation across multiple episodes.""" + dataset = LeRobotDataset.create( + repo_id="test_dataset_aggregation", + fps=30, + features=simple_features, + root=tmp_path / "quantile_aggregation", + ) + + # Add frames to episode + np.random.seed(42) + for _ in range(100): + frame = { + "action": np.random.normal(0, 1, (4,)).astype(np.float32), + "observation.state": np.random.normal(2, 1, (10,)).astype(np.float32), + "task": "test_task", + } + dataset.add_frame(frame) + + dataset.save_episode() + + # Check stats include all fixed quantiles + stats = dataset.meta.stats + for key in ["action", "observation.state"]: + feature_stats = stats[key] + expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"} + assert set(feature_stats.keys()) == expected_keys + assert feature_stats["q01"].shape == (simple_features[key]["shape"][0],) + assert feature_stats["q50"].shape == (simple_features[key]["shape"][0],) + assert feature_stats["q99"].shape == (simple_features[key]["shape"][0],) + assert np.all(feature_stats["q01"] <= feature_stats["q50"]) + assert np.all(feature_stats["q50"] <= feature_stats["q99"]) + + +def test_save_multiple_episodes_with_quantiles(tmp_path, simple_features): + """Test quantile aggregation across multiple episodes.""" + dataset = LeRobotDataset.create( + repo_id="test_dataset_multiple_episodes", + fps=30, + features=simple_features, + root=tmp_path / "multiple_episodes", + ) + + # Save multiple episodes + np.random.seed(42) + for episode_idx in range(3): + for _ in range(50): + frame = { + "action": np.random.normal(episode_idx * 2.0, 1, (4,)).astype(np.float32), + "observation.state": np.random.normal(-episode_idx * 1.5, 1, (10,)).astype(np.float32), + "task": f"task_{episode_idx}", + } + dataset.add_frame(frame) + + dataset.save_episode() + + # Verify final stats include properly aggregated quantiles + stats = dataset.meta.stats + for key in ["action", "observation.state"]: + feature_stats = stats[key] + assert "q01" in feature_stats and "q99" in feature_stats + assert feature_stats["count"][0] == 150 # 3 episodes * 50 frames diff --git a/tests/policies/pi0_pi05/test_pi0.py b/tests/policies/pi0_pi05/test_pi0.py new file mode 100644 index 000000000..65f64e6bc --- /dev/null +++ b/tests/policies/pi0_pi05/test_pi0.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python + +"""Test script to verify PI0 policy integration with LeRobot, only meant to be run locally!""" + +import os + +import pytest +import torch + +# Skip this entire module in CI +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="This test requires local OpenPI installation and is not meant for CI", +) + +from lerobot.policies.factory import make_policy_config # noqa: E402 +from lerobot.policies.pi0 import ( # noqa: E402 + PI0Config, + PI0Policy, + make_pi0_pre_post_processors, # noqa: E402 +) +from lerobot.utils.random_utils import set_seed # noqa: E402 +from tests.utils import require_cuda # noqa: E402 + + +@require_cuda +def test_policy_instantiation(): + # Create config + set_seed(42) + config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32") + + # Set up input_features and output_features in the config + from lerobot.configs.types import FeatureType, PolicyFeature + + config.input_features = { + "observation.state": PolicyFeature( + type=FeatureType.STATE, + shape=(14,), + ), + "observation.images.base_0_rgb": PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 224, 224), + ), + } + + config.output_features = { + "action": PolicyFeature( + type=FeatureType.ACTION, + shape=(7,), + ), + } + + # Create dummy dataset stats + dataset_stats = { + "observation.state": { + "mean": torch.zeros(14), + "std": torch.ones(14), + }, + "action": { + "mean": torch.zeros(7), + "std": torch.ones(7), + }, + "observation.images.base_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + }, + } + + # Instantiate policy + policy = PI0Policy(config) + preprocessor, postprocessor = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats) + # Test forward pass with dummy data + batch_size = 1 + device = config.device + batch = { + "observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device), + "action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), # Use rand for [0,1] range + "task": ["Pick up the object"] * batch_size, + } + batch = preprocessor(batch) + try: + loss, loss_dict = policy.forward(batch) + print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}") + except Exception as e: + print(f"Forward pass failed: {e}") + raise + + try: + with torch.no_grad(): + action = policy.select_action(batch) + action = postprocessor(action) + print(f"Action: {action}") + print(f"Action prediction successful. Action shape: {action.shape}") + except Exception as e: + print(f"Action prediction failed: {e}") + raise + + +@require_cuda +def test_config_creation(): + """Test policy config creation through factory.""" + try: + config = make_policy_config( + policy_type="pi0", + max_action_dim=7, + max_state_dim=14, + ) + print("Config created successfully through factory") + print(f" Config type: {type(config).__name__}") + print(f" PaliGemma variant: {config.paligemma_variant}") + print(f" Action expert variant: {config.action_expert_variant}") + except Exception as e: + print(f"Config creation failed: {e}") + raise diff --git a/tests/policies/pi0_pi05/test_pi05.py b/tests/policies/pi0_pi05/test_pi05.py new file mode 100644 index 000000000..72828a02f --- /dev/null +++ b/tests/policies/pi0_pi05/test_pi05.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python + +"""Test script to verify PI0.5 (pi05) support in PI0 policy, only meant to be run locally!""" + +import os + +import pytest +import torch + +from lerobot.utils.random_utils import set_seed + +# Skip this entire module in CI +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="This test requires local OpenPI installation and is not meant for CI", +) + +from lerobot.policies.factory import make_policy_config # noqa: E402 +from lerobot.policies.pi05 import ( # noqa: E402 + PI05Config, + PI05Policy, + make_pi05_pre_post_processors, # noqa: E402 +) +from tests.utils import require_cuda # noqa: E402 + + +@require_cuda +def test_policy_instantiation(): + # Create config + set_seed(42) + config = PI05Config(max_action_dim=7, max_state_dim=14, dtype="float32") + + # Set up input_features and output_features in the config + from lerobot.configs.types import FeatureType, PolicyFeature + + config.input_features = { + "observation.state": PolicyFeature( + type=FeatureType.STATE, + shape=(14,), + ), + "observation.images.base_0_rgb": PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 224, 224), + ), + } + + config.output_features = { + "action": PolicyFeature( + type=FeatureType.ACTION, + shape=(7,), + ), + } + + assert config.tokenizer_max_length == 200, ( + f"Expected tokenizer_max_length=200 for pi05, got {config.tokenizer_max_length}" + ) + + # Create dummy dataset stats + dataset_stats = { + "observation.state": { + "mean": torch.zeros(14), + "std": torch.ones(14), + "min": torch.zeros(14), + "max": torch.ones(14), + "q01": torch.zeros(14), + "q99": torch.ones(14), + }, + "action": { + "mean": torch.zeros(7), + "std": torch.ones(7), + "min": torch.zeros(7), + "max": torch.ones(7), + "q01": torch.zeros(7), + "q99": torch.ones(7), + }, + "observation.images.base_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + "q01": torch.zeros(3, 224, 224), + "q99": torch.ones(3, 224, 224), + }, + } + + # Instantiate policy + policy = PI05Policy(config) + # Test forward pass with dummy data + batch_size = 1 + preprocessor, postprocessor = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats) + device = config.device + batch = { + "observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device), + "action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device), + "observation.images.base_0_rgb": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), # Use rand for [0,1] range + "task": ["Pick up the object"] * batch_size, + } + batch = preprocessor(batch) + try: + loss, loss_dict = policy.forward(batch) + print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}") + except Exception as e: + print(f"Forward pass failed: {e}") + raise + try: + with torch.no_grad(): + action = policy.select_action(batch) + action = postprocessor(action) + print(f"Action: {action}") + print(f"Action prediction successful. Action shape: {action.shape}") + except Exception as e: + print(f"Action prediction failed: {e}") + raise + + # Verify pi05 model components exist + # Check that time_mlp layers exist (for AdaRMS conditioning) + assert hasattr(policy.model, "time_mlp_in"), "Missing time_mlp_in layer for pi05" + assert hasattr(policy.model, "time_mlp_out"), "Missing time_mlp_out layer for pi05" + + # Check that action_time_mlp layers don't exist (pi0 only) + assert not hasattr(policy.model, "action_time_mlp_in"), "action_time_mlp_in should not exist in pi05 mode" + assert not hasattr(policy.model, "action_time_mlp_out"), ( + "action_time_mlp_out should not exist in pi05 mode" + ) + + # Check that state_proj doesn't exist in pi05 mode + assert not hasattr(policy.model, "state_proj"), "state_proj should not exist in pi05 mode" + + # Check AdaRMS configuration in the underlying model + adarms_config = policy.model.paligemma_with_expert.paligemma.config.text_config.use_adarms + assert adarms_config == False, f"PaliGemma should not use AdaRMS, got {adarms_config}" # noqa: E712 + + adarms_expert_config = policy.model.paligemma_with_expert.gemma_expert.config.use_adarms + assert adarms_expert_config == True, ( # noqa: E712 + f"Action expert should use AdaRMS in pi05, got {adarms_expert_config}" + ) + + +@require_cuda +def test_config_creation(): + """Test policy config creation through factory.""" + try: + config = make_policy_config( + policy_type="pi0", + max_action_dim=7, + max_state_dim=14, + ) + print("Config created successfully through factory") + print(f" Config type: {type(config).__name__}") + print(f" PaliGemma variant: {config.paligemma_variant}") + print(f" Action expert variant: {config.action_expert_variant}") + except Exception as e: + print(f"Config creation failed: {e}") + raise diff --git a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py new file mode 100644 index 000000000..7bea89486 --- /dev/null +++ b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py @@ -0,0 +1,419 @@ +"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation, only meant to be run locally!""" + +import os +from copy import deepcopy +from typing import Any + +import numpy as np +import pytest +import torch + +# Skip if openpi or transformers is not available +pytest.importorskip("openpi") +pytest.importorskip("transformers") + +# Skip this entire module in CI +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="This test requires local OpenPI installation and is not meant for CI", +) + +from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402 + +# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions. +from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + +from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402 +from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402 +from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 + +# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG +DUMMY_ACTION_DIM = 32 +DUMMY_STATE_DIM = 32 +DUMMY_ACTION_HORIZON = 50 +DUMMY_MAX_TOKEN_LEN = 200 +DEVICE = "cpu" # Use CPU to avoid memory issues for testing + +DUMMY_DATASET_STATS = { + "observation.state": { + "mean": torch.zeros(DUMMY_STATE_DIM), + "std": torch.ones(DUMMY_STATE_DIM), + "q01": torch.zeros(DUMMY_STATE_DIM), + "q99": torch.ones(DUMMY_STATE_DIM), + }, + "action": { + "mean": torch.zeros(DUMMY_ACTION_DIM), + "std": torch.ones(DUMMY_ACTION_DIM), + "q01": torch.zeros(DUMMY_ACTION_DIM), + "q99": torch.ones(DUMMY_ACTION_DIM), + }, + "images": { + "base_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + "q01": torch.zeros(3, 224, 224), + "q99": torch.ones(3, 224, 224), + }, + "left_wrist_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + "q01": torch.zeros(3, 224, 224), + "q99": torch.ones(3, 224, 224), + }, + "right_wrist_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + "q01": torch.zeros(3, 224, 224), + "q99": torch.ones(3, 224, 224), + }, + }, +} + + +class PI05BaseOriginalConfig: + action_dim: int = DUMMY_ACTION_DIM + action_horizon: int = DUMMY_ACTION_HORIZON + paligemma_variant: str = "gemma_2b" + action_expert_variant: str = "gemma_300m" + precision: str = "float32" + pi05: bool = True + dtype: str = "float32" + + +def instantiate_lerobot_pi05( + from_pretrained: bool = False, +) -> tuple[ + PI05Policy, + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + if from_pretrained: + # Load the policy first + policy = PI05Policy.from_pretrained(pretrained_name_or_path="lerobot/pi05_base", strict=True) + else: + config = PI05Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32") + policy = PI05Policy(config) + + policy.to(DEVICE) + policy.config.device = DEVICE + preprocessor, postprocessor = make_pi05_pre_post_processors( + config=policy.config, dataset_stats=DUMMY_DATASET_STATS + ) + return (policy, preprocessor, postprocessor) + + +def instantiate_original_pi05(from_pretrained: bool = False, model_path: str | None = None): + config = PI05BaseOriginalConfig() + policy = PI0Pytorch(config) + + if from_pretrained: + try: + print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi05_base)...") + + # Download the model from HuggingFace Hub + import safetensors.torch + from huggingface_hub import snapshot_download + + # Download the entire repository + if model_path and os.path.exists(model_path): + cache_dir = model_path + print(f"Using cached model from: {cache_dir}") + else: + cache_dir = snapshot_download(repo_id="lerobot/pi05_base", repo_type="model") + print(f"Downloaded model to: {cache_dir}") + + # Try to load safetensors format first + model_file = os.path.join(cache_dir, "model.safetensors") + if os.path.exists(model_file): + state_dict = safetensors.torch.load_file(model_file) + print(f"Loaded {len(state_dict)} parameters from safetensors") + else: + raise FileNotFoundError(f"No safetensors file found in {cache_dir}") + + # Load the state dict into the model + missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False) + + if missing_keys: + print(f"Missing keys: {len(missing_keys)}") + if len(missing_keys) <= 5: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:5]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 5} more") + + if unexpected_keys: + print(f"Unexpected keys: {len(unexpected_keys)}") + if len(unexpected_keys) <= 5: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:5]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 5} more") + + if not missing_keys and not unexpected_keys: + print("All pretrained weights loaded successfully!") + else: + print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)") + + except Exception as e: + print(f"Failed to load pretrained weights: {e}") + print(" Using randomly initialized weights...") + import traceback + + traceback.print_exc() + + policy.to(DEVICE) + return policy + + +def create_dummy_data(): + batch_size = 2 # Reduce batch size for testing + device = DEVICE + + # Use the exact same prompt for both implementations + prompt = "Pick up the red block and place it in the bin" + + batch = { + "observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device), + "action": torch.randn( + batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device + ), + # Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally) + "observation.images.base_0_rgb": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), + "observation.images.left_wrist_0_rgb": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), + "observation.images.right_wrist_0_rgb": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), + # Add the task prompt for LeRobot - provide as list with single element to trigger expansion + "task": [prompt for _ in range(batch_size)], + } + return batch + + +def extract_lerobot_processed_inputs(lerobot_pi0, batch): + """Extract the exact same processed inputs that LeRobot uses internally.""" + # Get the tokenized language from LeRobot's internal method + lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch) + + # Get the preprocessed images from LeRobot's internal method + images, img_masks = lerobot_pi0._preprocess_images(batch, train=False) + + # Create dummy token_ar_mask and token_loss_mask for original implementation + token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32) + token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool) + + return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask + + +class PI05Observation: + """Observation class that matches the original OpenPI format.""" + + def __init__( + self, + state, + images, + image_masks, + tokenized_prompt, + tokenized_prompt_mask, + token_ar_mask, + token_loss_mask, + ): + self.state = state + self.images = images + self.image_masks = image_masks + self.tokenized_prompt = tokenized_prompt + self.tokenized_prompt_mask = tokenized_prompt_mask + self.token_ar_mask = token_ar_mask + self.token_loss_mask = token_loss_mask + + +def create_original_observation_with_openpi_preprocessing(batch): + """Create observation object for OpenPI using OpenPI's own preprocessing with pi05 state tokenizer.""" + batch_size = batch["observation.state"].shape[0] + device = batch["observation.state"].device + + # Create tokenizer for OpenPI (same as LeRobot uses) + tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + + # Get task description (pi05 processor handles all text formatting) + tasks = batch.get("task", ["Pick up the object"] * batch_size) + if isinstance(tasks, str): + tasks = [tasks] * batch_size + elif len(tasks) == 1: + tasks = tasks * batch_size + + # Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep) + state = batch["observation.state"] + state = deepcopy(state) + + # Prepare state (pad to max_state_dim) + from lerobot.policies.pi05.modeling_pi05 import pad_vector + + state = pad_vector(state, DUMMY_STATE_DIM) + + # Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs) + # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + state_np = state.cpu().numpy() + discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + + # Create pi05-formatted prompts that include state information + full_prompts = [] + for i, task in enumerate(tasks): + cleaned_text = task.strip().replace("_", " ").replace("\n", " ") + state_str = " ".join(map(str, discretized_states[i])) + full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: " + full_prompts.append(full_prompt) + + # Tokenize with max_length padding to match OpenPI's expected format + tokenized = tokenizer( + full_prompts, + padding="max_length", + padding_side="right", + truncation=True, + max_length=DUMMY_MAX_TOKEN_LEN, + return_tensors="pt", + ) + + lang_tokens = tokenized["input_ids"].to(device) + lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool) + + # Create dummy token_ar_mask and token_loss_mask for OpenPI + token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32) + token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool) + + # Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range) + image_dict = { + "base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0, + "left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0, + "right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0, + } + + # Create image masks (all ones for real images) + image_masks_dict = {} + for key in image_dict: + image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device) + + # Create raw observation object (before preprocessing) + raw_observation = PI05Observation( + state=batch["observation.state"], + images=image_dict, + image_masks=image_masks_dict, + tokenized_prompt=lang_tokens, + tokenized_prompt_mask=lang_masks, + token_ar_mask=token_ar_mask, + token_loss_mask=token_loss_mask, + ) + + # Now use OpenPI's preprocessing + processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False) + + return processed_obs + + +def create_original_observation_from_lerobot(lerobot_pi0, batch): + """Create observation object compatible with original OpenPI using the exact same inputs as LeRobot.""" + _batch_size = batch["observation.state"].shape[0] + _device = batch["observation.state"].device + + # Extract the exact same processed inputs that LeRobot uses + images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = ( + extract_lerobot_processed_inputs(lerobot_pi0, batch) + ) + + # Convert images list to dict with original OpenPI keys + image_dict = { + "base_0_rgb": images[0], + "left_wrist_0_rgb": images[1], + "right_wrist_0_rgb": images[2], + } + + # Convert image masks list to dict with original OpenPI keys + image_masks_dict = { + "base_0_rgb": img_masks[0], + "left_wrist_0_rgb": img_masks[1], + "right_wrist_0_rgb": img_masks[2], + } + + return PI05Observation( + state=batch["observation.state"], + images=image_dict, + image_masks=image_masks_dict, + tokenized_prompt=lang_tokens, + tokenized_prompt_mask=lang_masks, + token_ar_mask=token_ar_mask, + token_loss_mask=token_loss_mask, + ) + + +def test_pi05_original_vs_lerobot(): + """Test PI05 original implementation vs LeRobot implementation.""" + print("Initializing models...") + lerobot_pi05, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi05( + from_pretrained=True + ) # Load pretrained LeRobot model + original_pi0 = instantiate_original_pi05( + from_pretrained=True + ) # Load pretrained OpenPI model from HuggingFace Hub + + print("Creating dummy data...") + batch = create_dummy_data() + batch_lerobot = deepcopy(batch) + + # Test each model with its own preprocessing (more realistic end-to-end test) + print("\nTest each model with its own preprocessing") + print("Creating observation for OpenPI using OpenPI's own preprocessing...") + pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch) + + print(f"Task prompt: '{batch['task'][0]}'") + print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}") + print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}") + print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}") + + print("Testing OpenPI with own preprocessing...") + original_pi0.eval() + torch.manual_seed(42) # Set seed for reproducibility + batch_size = batch["observation.state"].shape[0] + noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM) + fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE) + + with torch.no_grad(): + openpi_actions = original_pi0.sample_actions( + device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10 + ) + openpi_actions_unit = openpi_actions[:, 0, :] + print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}") + print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}") + print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}") + print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}") + + print("Testing LeRobot with own preprocessing...") + lerobot_pi05.eval() + torch.manual_seed(42) # Set the same seed + + batch_lerobot_processed = lerobot_preprocessor(batch_lerobot) + with torch.no_grad(): + lerobot_actions_own = lerobot_pi05.predict_action_chunk( + batch_lerobot_processed + ) # batch_size, n_action_steps, action_dim + lerobot_actions_unit = lerobot_actions_own[:, 0, :] + print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}") + print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}") + print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}") + print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}") + + print("\nComparing end-to-end implementations:") + print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}") + print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}") + print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}") + + assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4) + assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2) + assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4 diff --git a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py new file mode 100644 index 000000000..d91f716f1 --- /dev/null +++ b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py @@ -0,0 +1,410 @@ +"""Test script to verify PI0 policy integration with LeRobot vs the original implementation, only meant to be run locally!""" + +import os +from copy import deepcopy +from typing import Any + +import pytest +import torch + +# Skip if openpi or transformers is not available +pytest.importorskip("openpi") +pytest.importorskip("transformers") + +# Skip this entire module in CI +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="This test requires local OpenPI installation and is not meant for CI", +) + +from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402 + +# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions. +from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + +from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402 +from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402 +from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 + +# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG +DUMMY_ACTION_DIM = 32 +DUMMY_STATE_DIM = 32 +DUMMY_ACTION_HORIZON = 50 +DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05) +DEVICE = "cpu" # Use CPU to avoid memory issues for testing + +DUMMY_DATASET_STATS = { + "observation.state": { + "mean": torch.zeros(DUMMY_STATE_DIM), + "std": torch.ones(DUMMY_STATE_DIM), + "q01": torch.zeros(DUMMY_STATE_DIM), + "q99": torch.ones(DUMMY_STATE_DIM), + }, + "action": { + "mean": torch.zeros(DUMMY_ACTION_DIM), + "std": torch.ones(DUMMY_ACTION_DIM), + "q01": torch.zeros(DUMMY_ACTION_DIM), + "q99": torch.ones(DUMMY_ACTION_DIM), + }, + "images": { + "base_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + "q01": torch.zeros(3, 224, 224), + "q99": torch.ones(3, 224, 224), + }, + "left_wrist_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + "q01": torch.zeros(3, 224, 224), + "q99": torch.ones(3, 224, 224), + }, + "right_wrist_0_rgb": { + "mean": torch.zeros(3, 224, 224), + "std": torch.ones(3, 224, 224), + "q01": torch.zeros(3, 224, 224), + "q99": torch.ones(3, 224, 224), + }, + }, +} + + +class PI0BaseOriginalConfig: + action_dim: int = DUMMY_ACTION_DIM + action_horizon: int = DUMMY_ACTION_HORIZON + paligemma_variant: str = "gemma_2b" + action_expert_variant: str = "gemma_300m" + precision: str = "float32" + pi05: bool = False + dtype: str = "float32" + + +def instantiate_lerobot_pi0( + from_pretrained: bool = False, +) -> tuple[ + PI0Policy, + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + if from_pretrained: + # Load the policy first + policy = PI0Policy.from_pretrained(pretrained_name_or_path="lerobot/pi0_base", strict=True) + else: + config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32") + policy = PI0Policy(config) + + policy.to(DEVICE) + policy.config.device = DEVICE + preprocessor, postprocessor = make_pi0_pre_post_processors( + config=policy.config, dataset_stats=DUMMY_DATASET_STATS + ) + return (policy, preprocessor, postprocessor) + + +def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None): + config = PI0BaseOriginalConfig() + policy = PI0Pytorch(config) + + if from_pretrained: + try: + print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi0_base)...") + + # Download the model from HuggingFace Hub + import safetensors.torch + from huggingface_hub import snapshot_download + + # Download the entire repository + if model_path and os.path.exists(model_path): + cache_dir = model_path + print(f"Using cached model from: {cache_dir}") + else: + cache_dir = snapshot_download(repo_id="lerobot/pi0_base", repo_type="model") + print(f"Downloaded model to: {cache_dir}") + + # Try to load safetensors format first + model_file = os.path.join(cache_dir, "model.safetensors") + if os.path.exists(model_file): + state_dict = safetensors.torch.load_file(model_file) + print(f"Loaded {len(state_dict)} parameters from safetensors") + else: + raise FileNotFoundError(f"No safetensors file found in {cache_dir}") + + # Load the state dict into the model + missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False) + + if missing_keys: + print(f"Missing keys: {len(missing_keys)}") + if len(missing_keys) <= 5: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:5]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 5} more") + + if unexpected_keys: + print(f"Unexpected keys: {len(unexpected_keys)}") + if len(unexpected_keys) <= 5: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:5]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 5} more") + + if not missing_keys and not unexpected_keys: + print("All pretrained weights loaded successfully!") + else: + print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)") + + except Exception as e: + print(f"Failed to load pretrained weights: {e}") + print(" Using randomly initialized weights...") + import traceback + + traceback.print_exc() + + policy.to(DEVICE) + return policy + + +def create_dummy_data(): + batch_size = 2 # Reduce batch size for testing + device = DEVICE + + # Use the exact same prompt for both implementations + prompt = "Pick up the red block and place it in the bin" + + batch = { + "observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device), + "action": torch.randn( + batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device + ), + # Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally) + "observation.images.base_0_rgb": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), + "observation.images.left_wrist_0_rgb": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), + "observation.images.right_wrist_0_rgb": torch.rand( + batch_size, 3, 224, 224, dtype=torch.float32, device=device + ), + # Add the task prompt for LeRobot - provide as list with single element to trigger expansion + "task": [prompt for _ in range(batch_size)], + } + return batch + + +def extract_lerobot_processed_inputs(lerobot_pi0, batch): + """Extract the exact same processed inputs that LeRobot uses internally.""" + # Get the tokenized language from LeRobot's internal method + lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch) + + # Get the preprocessed images from LeRobot's internal method + images, img_masks = lerobot_pi0._preprocess_images(batch, train=False) + + # Create dummy token_ar_mask and token_loss_mask for original implementation + token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32) + token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool) + + return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask + + +class PI0Observation: + """Observation class that matches the original OpenPI format.""" + + def __init__( + self, + state, + images, + image_masks, + tokenized_prompt, + tokenized_prompt_mask, + token_ar_mask, + token_loss_mask, + ): + self.state = state + self.images = images + self.image_masks = image_masks + self.tokenized_prompt = tokenized_prompt + self.tokenized_prompt_mask = tokenized_prompt_mask + self.token_ar_mask = token_ar_mask + self.token_loss_mask = token_loss_mask + + +def create_original_observation_with_openpi_preprocessing(batch): + """Create observation object for OpenPI using OpenPI's own preprocessing.""" + batch_size = batch["observation.state"].shape[0] + device = batch["observation.state"].device + + # Create tokenizer for OpenPI (same as LeRobot uses) + tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + + # Get task description + if "task" in batch: + tasks = batch["task"] + if isinstance(tasks, str): + # Single string: add newline if not present, then convert to list + if not tasks.endswith("\n"): + tasks = f"{tasks}\n" + tasks = [tasks] + elif isinstance(tasks, list) and all(isinstance(t, str) for t in tasks): + # List of strings: add newline to each if not present + tasks = [t if t.endswith("\n") else f"{t}\n" for t in tasks] + if len(tasks) == 1: + # Expand to batch size + tasks = tasks * batch_size + if len(tasks) != batch_size: + raise ValueError(f"Expected batch size {batch_size}, got {len(tasks)}") + # If task is neither string nor list of strings, leave unchanged + else: + # Default task if not provided + tasks = ["Pick up the object\n"] * batch_size + + # Tokenize with max_length padding to match OpenPI's expected format + tokenized = tokenizer( + tasks, + padding="max_length", + padding_side="right", + truncation=True, + max_length=DUMMY_MAX_TOKEN_LEN, + return_tensors="pt", + ) + + lang_tokens = tokenized["input_ids"].to(device) + lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool) + + # Create dummy token_ar_mask and token_loss_mask for OpenPI + token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32) + token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool) + + # Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range) + image_dict = { + "base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0, + "left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0, + "right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0, + } + + # Create image masks (all ones for real images) + image_masks_dict = {} + for key in image_dict: + image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device) + + # Create raw observation object (before preprocessing) + raw_observation = PI0Observation( + state=batch["observation.state"], + images=image_dict, + image_masks=image_masks_dict, + tokenized_prompt=lang_tokens, + tokenized_prompt_mask=lang_masks, + token_ar_mask=token_ar_mask, + token_loss_mask=token_loss_mask, + ) + + # Now use OpenPI's preprocessing + processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False) + + return processed_obs + + +def create_original_observation_from_lerobot(lerobot_pi0, batch): + """Create observation object compatible with original OpenPI using the exact same inputs as LeRobot.""" + _batch_size = batch["observation.state"].shape[0] + _device = batch["observation.state"].device + + # Extract the exact same processed inputs that LeRobot uses + images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = ( + extract_lerobot_processed_inputs(lerobot_pi0, batch) + ) + + # Convert images list to dict with original OpenPI keys + image_dict = { + "base_0_rgb": images[0], + "left_wrist_0_rgb": images[1], + "right_wrist_0_rgb": images[2], + } + + # Convert image masks list to dict with original OpenPI keys + image_masks_dict = { + "base_0_rgb": img_masks[0], + "left_wrist_0_rgb": img_masks[1], + "right_wrist_0_rgb": img_masks[2], + } + + return PI0Observation( + state=batch["observation.state"], + images=image_dict, + image_masks=image_masks_dict, + tokenized_prompt=lang_tokens, + tokenized_prompt_mask=lang_masks, + token_ar_mask=token_ar_mask, + token_loss_mask=token_loss_mask, + ) + + +def test_pi0_original_vs_lerobot(): + """Test PI0 original implementation vs LeRobot implementation.""" + print("Initializing models...") + lerobot_pi0, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi0( + from_pretrained=True + ) # Load pretrained LeRobot model + original_pi0 = instantiate_original_pi0( + from_pretrained=True + ) # Load pretrained OpenPI model from HuggingFace Hub + + print("Creating dummy data...") + batch = create_dummy_data() + batch_lerobot = deepcopy(batch) + + # Test each model with its own preprocessing (more realistic end-to-end test) + print("\nTest each model with its own preprocessing") + print("Creating observation for OpenPI using OpenPI's own preprocessing...") + pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch) + + print(f"Task prompt: '{batch['task'][0]}'") + print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}") + print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}") + print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}") + + print("Testing OpenPI with own preprocessing...") + original_pi0.eval() + torch.manual_seed(42) # Set seed for reproducibility + batch_size = batch["observation.state"].shape[0] + noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM) + fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE) + + with torch.no_grad(): + openpi_actions = original_pi0.sample_actions( + device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10 + ) + openpi_actions_unit = openpi_actions[:, 0, :] + print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}") + print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}") + print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}") + print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}") + + print("Testing LeRobot with own preprocessing...") + lerobot_pi0.eval() + torch.manual_seed(42) # Set the same seed + + batch_lerobot_processed = lerobot_preprocessor(batch_lerobot) + with torch.no_grad(): + lerobot_actions_own = lerobot_pi0.predict_action_chunk( + batch_lerobot_processed + ) # batch_size, n_action_steps, action_dim + lerobot_actions_unit = lerobot_actions_own[:, 0, :] + print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}") + print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}") + print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}") + print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}") + + print("\nComparing end-to-end implementations:") + print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}") + print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}") + print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}") + + assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4) + assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2) + assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4 diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index 98c9e0b23..208a6b5c5 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -166,6 +166,226 @@ def test_min_max_normalization(observation_normalizer): assert torch.allclose(normalized_obs[OBS_STATE], expected_state, atol=1e-6) +def test_quantile_normalization(): + """Test QUANTILES mode using 1st-99th percentiles.""" + features = { + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.STATE: NormalizationMode.QUANTILES, + } + stats = { + "observation.state": { + "q01": np.array([0.1, -0.8]), # 1st percentile + "q99": np.array([0.9, 0.8]), # 99th percentile + }, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Check quantile normalization to [-1, 1] + # For state[0]: 2 * (0.5 - 0.1) / (0.9 - 0.1) - 1 = 2 * 0.4 / 0.8 - 1 = 0.0 + # For state[1]: 2 * (0.0 - (-0.8)) / (0.8 - (-0.8)) - 1 = 2 * 0.8 / 1.6 - 1 = 0.0 + expected_state = torch.tensor([0.0, 0.0]) + assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6) + + +def test_quantile10_normalization(): + """Test QUANTILE10 mode using 10th-90th percentiles.""" + features = { + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.STATE: NormalizationMode.QUANTILE10, + } + stats = { + "observation.state": { + "q10": np.array([0.2, -0.6]), # 10th percentile + "q90": np.array([0.8, 0.6]), # 90th percentile + }, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Check quantile normalization to [-1, 1] + # For state[0]: 2 * (0.5 - 0.2) / (0.8 - 0.2) - 1 = 2 * 0.3 / 0.6 - 1 = 0.0 + # For state[1]: 2 * (0.0 - (-0.6)) / (0.6 - (-0.6)) - 1 = 2 * 0.6 / 1.2 - 1 = 0.0 + expected_state = torch.tensor([0.0, 0.0]) + assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6) + + +def test_quantile_unnormalization(): + """Test that quantile normalization can be reversed properly.""" + features = { + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.ACTION: NormalizationMode.QUANTILES, + } + stats = { + "action": { + "q01": np.array([0.1, -0.8]), + "q99": np.array([0.9, 0.8]), + }, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + # Test round-trip normalization + original_action = torch.tensor([0.5, 0.0]) + transition = create_transition(action=original_action) + + # Normalize then unnormalize + normalized = normalizer(transition) + unnormalized = unnormalizer(normalized) + + # Should recover original values + recovered_action = unnormalized[TransitionKey.ACTION] + assert torch.allclose(recovered_action, original_action, atol=1e-6) + + +def test_quantile_division_by_zero(): + """Test quantile normalization handles edge case where q01 == q99.""" + features = { + "observation.state": PolicyFeature(FeatureType.STATE, (1,)), + } + norm_map = { + FeatureType.STATE: NormalizationMode.QUANTILES, + } + stats = { + "observation.state": { + "q01": np.array([0.5]), # Same value + "q99": np.array([0.5]), # Same value -> division by zero case + }, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.state": torch.tensor([0.5]), + } + transition = create_transition(observation=observation) + + # Should not crash and should handle gracefully + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # When quantiles are identical, should normalize to 0 (due to epsilon handling) + assert torch.isfinite(normalized_obs["observation.state"]).all() + + +def test_quantile_partial_stats(): + """Test that quantile normalization handles missing quantile stats by raising.""" + features = { + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.STATE: NormalizationMode.QUANTILES, + } + + # Missing q99 - should pass through unchanged + stats_partial = { + "observation.state": { + "q01": np.array([0.1, -0.8]), # Only q01, missing q99 + }, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats_partial) + + observation = { + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + with pytest.raises(ValueError, match="QUANTILES normalization mode requires q01 and q99 stats"): + _ = normalizer(transition) + + +def test_quantile_mixed_with_other_modes(): + """Test quantile normalization mixed with other normalization modes.""" + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3,)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, # Standard normalization + FeatureType.STATE: NormalizationMode.QUANTILES, # Quantile normalization + FeatureType.ACTION: NormalizationMode.QUANTILE10, # Different quantile mode + } + stats = { + "observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]}, + "observation.state": {"q01": [0.1, -0.8], "q99": [0.9, 0.8]}, + "action": {"q10": [0.2, -0.6], "q90": [0.8, 0.6]}, + } + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), # Should use QUANTILES + } + action = torch.tensor([0.5, 0.0]) # Should use QUANTILE10 + transition = create_transition(observation=observation, action=action) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + normalized_action = normalized_transition[TransitionKey.ACTION] + + # Image should be mean/std normalized: (0.7 - 0.5) / 0.2 = 1.0, etc. + expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 + assert torch.allclose(normalized_obs["observation.image"], expected_image) + + # State should be quantile normalized: 2 * (0.5 - 0.1) / (0.9 - 0.1) - 1 = 0.0, etc. + expected_state = torch.tensor([0.0, 0.0]) + assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6) + + # Action should be quantile10 normalized: 2 * (0.5 - 0.2) / (0.8 - 0.2) - 1 = 0.0, etc. + expected_action = torch.tensor([0.0, 0.0]) + assert torch.allclose(normalized_action, expected_action, atol=1e-6) + + +def test_quantile_with_missing_stats(): + """Test that quantile normalization handles completely missing stats gracefully.""" + features = { + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + norm_map = { + FeatureType.STATE: NormalizationMode.QUANTILES, + } + stats = {} # No stats provided + + normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + observation = { + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Should pass through unchanged when no stats available + assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"]) + + def test_selective_normalization(observation_stats): features = _create_observation_features() norm_map = _create_observation_norm_map() @@ -547,7 +767,7 @@ def test_empty_stats(): def test_partial_stats(): - """If statistics are incomplete, the value should pass through unchanged.""" + """If statistics are incomplete, we should raise.""" stats = {OBS_IMAGE: {"mean": [0.5]}} # Missing std / (min,max) features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} @@ -555,8 +775,8 @@ def test_partial_stats(): observation = {OBS_IMAGE: torch.tensor([0.7])} transition = create_transition(observation=observation) - processed = normalizer(transition)[TransitionKey.OBSERVATION] - assert torch.allclose(processed[OBS_IMAGE], observation[OBS_IMAGE]) + with pytest.raises(ValueError, match="MEAN_STD normalization mode requires mean and std stats"): + _ = normalizer(transition)[TransitionKey.OBSERVATION] def test_missing_action_stats_no_error(): diff --git a/tests/processor/test_pi0_processor.py b/tests/processor/test_pi0_processor.py deleted file mode 100644 index 24afc648f..000000000 --- a/tests/processor/test_pi0_processor.py +++ /dev/null @@ -1,424 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for PI0 policy processor.""" - -from unittest.mock import patch - -import pytest -import torch - -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.policies.pi0.configuration_pi0 import PI0Config -from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre_post_processors -from lerobot.processor import ( - AddBatchDimensionProcessorStep, - DeviceProcessorStep, - EnvTransition, - NormalizerProcessorStep, - ProcessorStep, - RenameObservationsProcessorStep, - TransitionKey, - UnnormalizerProcessorStep, -) -from lerobot.processor.converters import create_transition, transition_to_batch -from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE - - -class MockTokenizerProcessorStep(ProcessorStep): - """Mock tokenizer processor step for testing.""" - - def __init__(self, *args, **kwargs): - # Accept any arguments to mimic the real TokenizerProcessorStep interface - pass - - def __call__(self, transition: EnvTransition) -> EnvTransition: - # Pass through transition unchanged - return transition - - def transform_features(self, features): - # Pass through features unchanged - return features - - -def create_default_config(): - """Create a default PI0 configuration for testing.""" - config = PI0Config() - config.input_features = { - OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)), - OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), - } - config.output_features = { - ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)), - } - config.normalization_mapping = { - FeatureType.STATE: NormalizationMode.MEAN_STD, - FeatureType.VISUAL: NormalizationMode.IDENTITY, - FeatureType.ACTION: NormalizationMode.MIN_MAX, - } - config.device = "cpu" - config.tokenizer_max_length = 128 - return config - - -def create_default_stats(): - """Create default dataset statistics for testing.""" - return { - OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)}, - OBS_IMAGE: {}, # No normalization for images - ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)}, - } - - -def test_make_pi0_processor_basic(): - """Test basic creation of PI0 processor.""" - config = create_default_config() - stats = create_default_stats() - - with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): - preprocessor, postprocessor = make_pi0_pre_post_processors( - config, - stats, - ) - - # Check processor names - assert preprocessor.name == "policy_preprocessor" - assert postprocessor.name == "policy_postprocessor" - - # Check steps in preprocessor - assert len(preprocessor.steps) == 6 - assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) - assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) - assert isinstance(preprocessor.steps[2], Pi0NewLineProcessor) - # Step 3 would be TokenizerProcessorStep but it's mocked - assert isinstance(preprocessor.steps[4], DeviceProcessorStep) - assert isinstance(preprocessor.steps[5], NormalizerProcessorStep) - - # Check steps in postprocessor - assert len(postprocessor.steps) == 2 - assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) - assert isinstance(postprocessor.steps[1], DeviceProcessorStep) - - -def test_pi0_newline_processor_single_task(): - """Test Pi0NewLineProcessor with single task string.""" - processor = Pi0NewLineProcessor() - - # Test with task that doesn't have newline - transition = create_transition(complementary_data={"task": "test task"}) - result = processor(transition) - assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n" - - # Test with task that already has newline - transition = create_transition(complementary_data={"task": "test task\n"}) - result = processor(transition) - assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n" - - -def test_pi0_newline_processor_list_of_tasks(): - """Test Pi0NewLineProcessor with list of task strings.""" - processor = Pi0NewLineProcessor() - - # Test with list of tasks - tasks = ["task1", "task2\n", "task3"] - transition = create_transition(complementary_data={"task": tasks}) - result = processor(transition) - expected = ["task1\n", "task2\n", "task3\n"] - assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == expected - - -def test_pi0_newline_processor_empty_transition(): - """Test Pi0NewLineProcessor with empty transition.""" - processor = Pi0NewLineProcessor() - - # Test with no complementary_data - transition = create_transition() - result = processor(transition) - assert result == transition - - # Test with complementary_data but no task - transition = create_transition(complementary_data={"other": "data"}) - result = processor(transition) - assert result == transition - - # Test with None task - transition = create_transition(complementary_data={"task": None}) - result = processor(transition) - assert result == transition - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_pi0_processor_cuda(): - """Test PI0 processor with CUDA device.""" - config = create_default_config() - config.device = "cuda" - stats = create_default_stats() - - # Mock the tokenizer processor to act as pass-through - class MockTokenizerProcessorStep(ProcessorStep): - def __init__(self, *args, **kwargs): - pass - - def __call__(self, transition): - return transition - - def state_dict(self): - return {} - - def load_state_dict(self, state): - pass - - def reset(self): - pass - - def get_config(self): - return {"tokenizer_name": "google/paligemma-3b-pt-224"} - - def transform_features(self, features): - return features - - with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): - preprocessor, postprocessor = make_pi0_pre_post_processors( - config, - stats, - ) - - # Create CPU data - observation = { - OBS_STATE: torch.randn(10), - OBS_IMAGE: torch.randn(3, 224, 224), - } - action = torch.randn(6) - transition = create_transition(observation, action, complementary_data={"task": "test task"}) - batch = transition_to_batch(transition) - - # Process through preprocessor - processed = preprocessor(batch) - - # Check that data is on CUDA - assert processed[OBS_STATE].device.type == "cuda" - assert processed[OBS_IMAGE].device.type == "cuda" - assert processed[TransitionKey.ACTION.value].device.type == "cuda" - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_pi0_processor_accelerate_scenario(): - """Test PI0 processor in simulated Accelerate scenario.""" - config = create_default_config() - config.device = "cuda:0" - stats = create_default_stats() - - # Mock the tokenizer processor to act as pass-through - class MockTokenizerProcessorStep(ProcessorStep): - def __init__(self, *args, **kwargs): - pass - - def __call__(self, transition): - return transition - - def state_dict(self): - return {} - - def load_state_dict(self, state): - pass - - def reset(self): - pass - - def get_config(self): - return {"tokenizer_name": "google/paligemma-3b-pt-224"} - - def transform_features(self, features): - return features - - with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): - preprocessor, postprocessor = make_pi0_pre_post_processors( - config, - stats, - ) - - # Simulate Accelerate: data already on GPU and batched - device = torch.device("cuda:0") - observation = { - OBS_STATE: torch.randn(1, 10).to(device), - OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), - } - action = torch.randn(1, 6).to(device) - transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) - batch = transition_to_batch(transition) - - # Process through preprocessor - processed = preprocessor(batch) - - # Check that data stays on same GPU - assert processed[OBS_STATE].device == device - assert processed[OBS_IMAGE].device == device - assert processed[TransitionKey.ACTION.value].device == device - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") -def test_pi0_processor_multi_gpu(): - """Test PI0 processor with multi-GPU setup.""" - config = create_default_config() - config.device = "cuda:0" - stats = create_default_stats() - - # Mock the tokenizer processor to act as pass-through - class MockTokenizerProcessorStep(ProcessorStep): - def __init__(self, *args, **kwargs): - pass - - def __call__(self, transition): - return transition - - def state_dict(self): - return {} - - def load_state_dict(self, state): - pass - - def reset(self): - pass - - def get_config(self): - return {"tokenizer_name": "google/paligemma-3b-pt-224"} - - def transform_features(self, features): - return features - - with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): - preprocessor, postprocessor = make_pi0_pre_post_processors( - config, - stats, - ) - - # Simulate data on different GPU - device = torch.device("cuda:1") - observation = { - OBS_STATE: torch.randn(1, 10).to(device), - OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), - } - action = torch.randn(1, 6).to(device) - transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) - batch = transition_to_batch(transition) - - # Process through preprocessor - processed = preprocessor(batch) - - # Check that data stays on cuda:1 - assert processed[OBS_STATE].device == device - assert processed[OBS_IMAGE].device == device - assert processed[TransitionKey.ACTION.value].device == device - - -def test_pi0_processor_without_stats(): - """Test PI0 processor creation without dataset statistics.""" - config = create_default_config() - - # Mock the tokenizer processor - with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): - preprocessor, postprocessor = make_pi0_pre_post_processors( - config, - dataset_stats=None, - ) - - # Should still create processors - assert preprocessor is not None - assert postprocessor is not None - - -def test_pi0_newline_processor_state_dict(): - """Test Pi0NewLineProcessor state dict methods.""" - processor = Pi0NewLineProcessor() - - # Test state_dict (should be empty) - state = processor.state_dict() - assert state == {} - - # Test load_state_dict (should do nothing) - processor.load_state_dict({}) - - # Test reset (should do nothing) - processor.reset() - - # Test get_config - config = processor.get_config() - assert config == {} - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_pi0_processor_bfloat16_device_float32_normalizer(): - """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" - config = create_default_config() - stats = create_default_stats() - config.device = "cuda" - - with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): - preprocessor, _ = make_pi0_pre_post_processors( - config, - stats, - ) - - # Modify the pipeline to use bfloat16 device processor with float32 normalizer - modified_steps = [] - for step in preprocessor.steps: - if isinstance(step, DeviceProcessorStep): - # Device processor converts to bfloat16 - modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) - elif isinstance(step, NormalizerProcessorStep): - # Normalizer stays configured as float32 (will auto-adapt to bfloat16) - norm_step = step # Now type checker knows this is NormalizerProcessorStep - modified_steps.append( - NormalizerProcessorStep( - features=norm_step.features, - norm_map=norm_step.norm_map, - stats=norm_step.stats, - device=config.device, - dtype=torch.float32, # Deliberately configured as float32 - ) - ) - else: - modified_steps.append(step) - preprocessor.steps = modified_steps - - # Verify initial normalizer configuration (PI0 has NormalizerProcessorStep at index 5) - normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep - assert normalizer_step.dtype == torch.float32 - - # Create test data with both state and visual observations - observation = { - OBS_STATE: torch.randn(10, dtype=torch.float32), # PI0 expects size 10 - OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), - } - action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6 - transition = create_transition( - observation, action, complementary_data={"task": "test bfloat16 adaptation"} - ) - batch = transition_to_batch(transition) - - # Process through full pipeline - processed = preprocessor(batch) - - # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 - assert processed[OBS_STATE].dtype == torch.bfloat16 - assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion - assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 - - # Verify normalizer automatically adapted its internal state - assert normalizer_step.dtype == torch.bfloat16 - # Check state stats (has normalization) - for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): - assert stat_tensor.dtype == torch.bfloat16 - # OBS_IMAGE uses IDENTITY normalization, so no stats to check From 38f6fc816b40cef92bbb32f35539a53a5819ce5d Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Thu, 2 Oct 2025 15:49:18 +0200 Subject: [PATCH 150/158] (chore) improve v3 message, allow converting local datasets to V3 (#1948) Co-authored-by: CarolinePascal --- .../datasets/backward_compatibility.py | 3 + .../v30/convert_dataset_v21_to_v30.py | 111 ++++++++++++++---- 2 files changed, 88 insertions(+), 26 deletions(-) diff --git a/src/lerobot/datasets/backward_compatibility.py b/src/lerobot/datasets/backward_compatibility.py index 1d600434a..ae95c5f7b 100644 --- a/src/lerobot/datasets/backward_compatibility.py +++ b/src/lerobot/datasets/backward_compatibility.py @@ -23,6 +23,9 @@ Please, update your dataset to the new format using this command: python -m lerobot.datasets.v30.convert_dataset_v21_to_v30 --repo-id={repo_id} ``` +If you already have a converted version uploaded to the hub, then this error might be because of +an older version in your local cache. Consider deleting the cached version and retrying. + If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb) or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose). """ diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py index 03d135d7c..42ab2f642 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -26,11 +26,20 @@ This script will help you convert any LeRobot dataset already pushed to the hub Usage: +Convert a dataset from the hub: ```bash python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \ --repo-id=lerobot/pusht ``` +Convert a local dataset (works in place): +```bash +python src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py \ + --repo-id=lerobot/pusht \ + --root=/path/to/local/dataset/directory + --push-to-hub=false +``` + """ import argparse @@ -75,7 +84,7 @@ from lerobot.utils.constants import HF_LEROBOT_HOME from lerobot.utils.utils import init_logging V21 = "v2.1" - +V30 = "v3.0" """ ------------------------- @@ -145,6 +154,17 @@ def legacy_load_tasks(local_dir: Path) -> tuple[dict, dict]: return tasks, task_to_task_index +def validate_local_dataset_version(local_path: Path) -> None: + """Validate that the local dataset has the expected v2.1 version.""" + info = load_info(local_path) + dataset_version = info.get("codebase_version", "unknown") + if dataset_version != V21: + raise ValueError( + f"Local dataset has codebase version '{dataset_version}', expected '{V21}'. " + f"This script is specifically for converting v2.1 datasets to v3.0." + ) + + def convert_tasks(root, new_root): logging.info(f"Converting tasks from {root} to {new_root}") tasks, _ = legacy_load_tasks(root) @@ -407,7 +427,7 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_ def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb): info = load_info(root) - info["codebase_version"] = "v3.0" + info["codebase_version"] = V30 del info["total_chunks"] del info["total_videos"] info["data_files_size_in_mb"] = data_file_size_in_mb @@ -429,16 +449,36 @@ def convert_dataset( branch: str | None = None, data_file_size_in_mb: int | None = None, video_file_size_in_mb: int | None = None, + root: str | Path | None = None, + push_to_hub: bool = True, + force_conversion: bool = False, ): - root = HF_LEROBOT_HOME / repo_id - old_root = HF_LEROBOT_HOME / f"{repo_id}_old" - new_root = HF_LEROBOT_HOME / f"{repo_id}_v30" - if data_file_size_in_mb is None: data_file_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB if video_file_size_in_mb is None: video_file_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB + # First check if the dataset already has a v3.0 version + if root is None and not force_conversion: + try: + print("Trying to download v3.0 version of the dataset from the hub...") + snapshot_download(repo_id, repo_type="dataset", revision=V30, local_dir=HF_LEROBOT_HOME / repo_id) + return + except Exception: + print("Dataset does not have an uploaded v3.0 version. Continuing with conversion.") + + # Set root based on whether local dataset path is provided + use_local_dataset = False + root = HF_LEROBOT_HOME / repo_id if root is None else Path(root) / repo_id + if root.exists(): + validate_local_dataset_version(root) + use_local_dataset = True + print(f"Using local dataset at {root}") + + old_root = root.parent / f"{root.name}_old" + new_root = root.parent / f"{root.name}_v30" + + # Handle old_root cleanup if both old_root and root exist if old_root.is_dir() and root.is_dir(): shutil.rmtree(str(root)) shutil.move(str(old_root), str(root)) @@ -446,12 +486,13 @@ def convert_dataset( if new_root.is_dir(): shutil.rmtree(new_root) - snapshot_download( - repo_id, - repo_type="dataset", - revision=V21, - local_dir=root, - ) + if not use_local_dataset: + snapshot_download( + repo_id, + repo_type="dataset", + revision=V21, + local_dir=root, + ) convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb) convert_tasks(root, new_root) @@ -462,21 +503,22 @@ def convert_dataset( shutil.move(str(root), str(old_root)) shutil.move(str(new_root), str(root)) - hub_api = HfApi() - try: - hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset") - except HTTPError as e: - print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})") - pass - hub_api.delete_files( - delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"], - repo_id=repo_id, - revision=branch, - repo_type="dataset", - ) - hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") + if push_to_hub: + hub_api = HfApi() + try: + hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset") + except HTTPError as e: + print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})") + pass + hub_api.delete_files( + delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"], + repo_id=repo_id, + revision=branch, + repo_type="dataset", + ) + hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") - LeRobotDataset(repo_id).push_to_hub() + LeRobotDataset(repo_id).push_to_hub() if __name__ == "__main__": @@ -507,6 +549,23 @@ if __name__ == "__main__": default=None, help="File size in MB. Defaults to 100 for data and 500 for videos.", ) + parser.add_argument( + "--root", + type=str, + default=None, + help="Local directory to use for downloading/writing the dataset.", + ) + parser.add_argument( + "--push-to-hub", + type=lambda input: input.lower() == "true", + default=True, + help="Push the converted dataset to the hub.", + ) + parser.add_argument( + "--force-conversion", + action="store_true", + help="Force conversion even if the dataset already has a v3.0 version.", + ) args = parser.parse_args() convert_dataset(**vars(args)) From 5c8dd883be7518a18fcee33a06418ea13026f8bf Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Thu, 2 Oct 2025 18:28:44 +0200 Subject: [PATCH 151/158] =?UTF-8?q?fix=20bug=20in=20`augment=5Fdataset=5Fq?= =?UTF-8?q?uantile=5Fstats.py`=20that=20was=20not=20detecting=E2=80=A6=20(?= =?UTF-8?q?#2106)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix bug in `augment_dataset_quantile_stats.py` that was not detecting the image features because we were looping over hf_dataset. Now we loop over the dataset itself * Update src/lerobot/datasets/v30/augment_dataset_quantile_stats.py Signed-off-by: Michel Aractingi --------- Signed-off-by: Michel Aractingi Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../v30/augment_dataset_quantile_stats.py | 79 +++++++++++++------ 1 file changed, 57 insertions(+), 22 deletions(-) diff --git a/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py b/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py index ff4689efa..900a43a4f 100644 --- a/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py +++ b/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py @@ -40,10 +40,12 @@ from pathlib import Path import numpy as np import torch +from huggingface_hub import HfApi +from requests import HTTPError from tqdm import tqdm from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats -from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.datasets.utils import write_stats from lerobot.utils.utils import init_logging @@ -85,13 +87,27 @@ def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict: start_idx = dataset.meta.episodes[episode_idx]["dataset_from_index"] end_idx = dataset.meta.episodes[episode_idx]["dataset_to_index"] + collected_data: dict[str, list] = {} + for idx in range(start_idx, end_idx): + item = dataset[idx] + for key, value in item.items(): + if key not in dataset.features: + continue + + if key not in collected_data: + collected_data[key] = [] + collected_data[key].append(value) + ep_stats = {} - for key, data in dataset.hf_dataset[start_idx:end_idx].items(): + for key, data_list in collected_data.items(): if dataset.features[key]["dtype"] == "string": continue - data = torch.stack(data).cpu().numpy() + data = torch.stack(data_list).cpu().numpy() if dataset.features[key]["dtype"] in ["image", "video"]: + if data.dtype == np.uint8: + data = data.astype(np.float32) / 255.0 + axes_to_reduce = (0, 2, 3) keepdims = True else: @@ -103,12 +119,9 @@ def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict: ) if dataset.features[key]["dtype"] in ["image", "video"]: - for k, v in ep_stats[key].items(): - if dataset.features[key]["dtype"] == "video": - v = v / 255.0 - if k != "count": - v = np.squeeze(v, axis=0) - ep_stats[key][k] = v + ep_stats[key] = { + k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items() + } return ep_stats @@ -121,25 +134,39 @@ def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dic Returns: Dictionary containing aggregated statistics with quantiles + + Note: + Video decoding operations are not thread-safe, so we process episodes sequentially + when video keys are present. For datasets without videos, we use parallel processing + with ThreadPoolExecutor for better performance. """ logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes") episode_stats_list = [] - max_workers = min(dataset.num_episodes, 16) + has_videos = len(dataset.meta.video_keys) > 0 - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_episode = { - executor.submit(process_single_episode, dataset, episode_idx): episode_idx - for episode_idx in range(dataset.num_episodes) - } + if has_videos: + logging.info("Dataset contains video keys - using sequential processing for thread safety") + for episode_idx in tqdm(range(dataset.num_episodes), desc="Processing episodes"): + ep_stats = process_single_episode(dataset, episode_idx) + episode_stats_list.append(ep_stats) + else: + logging.info("Dataset has no video keys - using parallel processing for better performance") + max_workers = min(dataset.num_episodes, 16) - episode_results = {} - with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar: - for future in concurrent.futures.as_completed(future_to_episode): - episode_idx = future_to_episode[future] - ep_stats = future.result() - episode_results[episode_idx] = ep_stats - pbar.update(1) + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_episode = { + executor.submit(process_single_episode, dataset, episode_idx): episode_idx + for episode_idx in range(dataset.num_episodes) + } + + episode_results = {} + with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar: + for future in concurrent.futures.as_completed(future_to_episode): + episode_idx = future_to_episode[future] + ep_stats = future.result() + episode_results[episode_idx] = ep_stats + pbar.update(1) for episode_idx in range(dataset.num_episodes): if episode_idx in episode_results: @@ -186,6 +213,14 @@ def augment_dataset_with_quantile_stats( logging.info("Successfully updated dataset with quantile statistics") dataset.push_to_hub() + hub_api = HfApi() + try: + hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset") + except HTTPError as e: + logging.info(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})") + pass + hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=None, repo_type="dataset") + def main(): """Main function to run the augmentation script.""" From a4bed41132a3f5ffb05579003a8fece26d188e56 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Fri, 3 Oct 2025 12:06:18 +0200 Subject: [PATCH 152/158] Improve docs pi (#2110) * Improve docs and add numpy to pi install requirments * fix formatting * update command * remvoe numpy dep --- docs/source/pi0.mdx | 2 +- docs/source/pi05.mdx | 19 ++++++++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/docs/source/pi0.mdx b/docs/source/pi0.mdx index 10260ee72..d36fe0ce4 100644 --- a/docs/source/pi0.mdx +++ b/docs/source/pi0.mdx @@ -49,7 +49,7 @@ policy.type=pi0 For training π₀, you can use the standard LeRobot training script with the appropriate configuration: ```bash -python src/lerobot/scripts/train.py \ +python src/lerobot/scripts/lerobot_train.py \ --dataset.repo_id=your_dataset \ --policy.type=pi0 \ --output_dir=./outputs/pi0_training \ diff --git a/docs/source/pi05.mdx b/docs/source/pi05.mdx index b777fcd58..b6267fc5e 100644 --- a/docs/source/pi05.mdx +++ b/docs/source/pi05.mdx @@ -51,13 +51,13 @@ policy.type=pi05 Here's a complete training command for finetuning the base π₀.₅ model on your own dataset: ```bash -python src/lerobot/scripts/train.py \ +python src/lerobot/scripts/lerobot_train.py\ --dataset.repo_id=your_dataset \ --policy.type=pi05 \ - --output_dir=./outputs/pi0_training \ - --job_name=pi0_training \ - --policy.repo_id=lerobot/pi05_base \ - --policy.pretrained_path=your_repo_id \ + --output_dir=./outputs/pi05_training \ + --job_name=pi05_training \ + --policy.repo_id=your_repo_id \ + --policy.pretrained_path=lerobot/pi05_base \ --policy.compile_model=true \ --policy.gradient_checkpointing=true \ --wandb.enable=true \ @@ -77,6 +77,15 @@ python src/lerobot/scripts/train.py \ - [lerobot/pi05_base](https://huggingface.co/lerobot/pi05_base) - [lerobot/pi05_libero](https://huggingface.co/lerobot/pi05_libero) (specifically trained on the Libero dataset) +If your dataset is not converted with `quantiles`, you can convert it with the following command: + +```bash +python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \ + --repo-id=your_dataset \ +``` + +Or train pi05 with this normalization mapping: `--policy.normalization_mapping='{"ACTION": "MEAN_STD", "STATE": "MEAN_STD", "VISUAL": "IDENTITY"}'` + ## Performance Results ### Libero Benchmark Results From b74e2a61133b695eca35997334f22389321ed6db Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sun, 5 Oct 2025 17:53:43 +0200 Subject: [PATCH 153/158] feat(deps): ceil dependency versions (#2091) --- .github/workflows/unbound_deps_tests.yml | 183 +++++++++++++++++++++++ docker/Dockerfile.internal | 8 + docker/Dockerfile.user | 8 + pyproject.toml | 64 ++++---- 4 files changed, 231 insertions(+), 32 deletions(-) create mode 100644 .github/workflows/unbound_deps_tests.yml diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml new file mode 100644 index 000000000..902074a83 --- /dev/null +++ b/.github/workflows/unbound_deps_tests.yml @@ -0,0 +1,183 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This workflow handles full testing with unboud dependencies versions. +name: Unbound Dependency Tests + +on: + # Allows running this workflow manually from the Actions tab + workflow_dispatch: + + # Run on the 1st and 15th of every month at 09:00 UTC + schedule: + - cron: '0 2 1,15 * *' + +permissions: + contents: read + +# Sets up the environment variables +env: + UV_VERSION: "0.8.0" + PYTHON_VERSION: "3.10" + DOCKER_IMAGE_NAME: huggingface/lerobot-gpu:unbound + +# Ensures that only the latest action is built, canceling older runs. +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + + # This job runs the E2E tests + pytest with all unbound extras + full-tests: + name: Full Unbound Tests + runs-on: ubuntu-latest + env: + MUJOCO_GL: egl + steps: + - uses: actions/checkout@v4 + with: + lfs: true + persist-credentials: false + + - name: Install apt dependencies + run: | + sudo apt-get update && sudo apt-get install -y build-essential \ + git curl libglib2.0-0 libegl1-mesa-dev ffmpeg libusb-1.0-0-dev \ + speech-dispatcher libgeos-dev portaudio19-dev + + - name: Setup uv and Python + uses: astral-sh/setup-uv@v6 # zizmor: ignore[unpinned-uses] + with: + enable-cache: true + version: ${{ env.UV_VERSION }} + python-version: ${{ env.PYTHON_VERSION }} + + - name: Unbound dependencies + run: | + sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml + echo "Dependencies unbound:" && cat pyproject.toml + + - name: Install lerobot with all extras + run: uv sync --all-extras + + - name: Run pytest (all extras) + run: uv run pytest tests -vv + + - name: Run end-to-end tests + run: uv run make test-end-to-end + + # This job builds a GPU enabled image for testing + build-and-push-docker: + name: Build and Push Docker + runs-on: + group: aws-general-8-plus + outputs: + image_tag: ${{ env.DOCKER_IMAGE_NAME }} + env: + GITHUB_REF: ${{ github.ref }} + steps: + - name: Install Git LFS + run: | + sudo apt-get update + sudo apt-get install git-lfs + git lfs install + - uses: actions/checkout@v4 + with: + lfs: true + persist-credentials: false + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 # zizmor: ignore[unpinned-uses] + with: + cache-binary: false + - name: Login to Docker Hub + uses: docker/login-action@v3 # zizmor: ignore[unpinned-uses] + with: + username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }} + password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }} + - name: Build and push Docker image + uses: docker/build-push-action@v6 # zizmor: ignore[unpinned-uses] + with: + context: . + file: ./docker/Dockerfile.internal + push: true + tags: ${{ env.DOCKER_IMAGE_NAME }} + build-args: | + UNBOUND_DEPS=true + + # This job runs pytest with all unbound extras in a GPU enabled host + # It runs everytime a test image is created + gpu-tests: + name: GPU Unbound Tests + needs: [build-and-push-docker] + runs-on: + group: aws-g6-4xlarge-plus + env: + HF_HOME: /home/user_lerobot/.cache/huggingface + HF_LEROBOT_HOME: /home/user_lerobot/.cache/huggingface/lerobot + TORCH_HOME: /home/user_lerobot/.cache/torch + TRITON_CACHE_DIR: /home/user_lerobot/.cache/triton + container: + image: ${{ needs.build-and-push-docker.outputs.image_tag }} # zizmor: ignore[unpinned-images] + options: --gpus all --shm-size "16gb" + credentials: + username: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }} + password: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }} + defaults: + run: + shell: bash + working-directory: /lerobot + steps: + - name: Run pytest on GPU + run: pytest tests -vv + - name: Run end-to-end tests + run: make test-end-to-end + + # This job deletes the test image recently created + # It runs everytime after the gpu-tests have finished + delete-unbound-image: + name: Delete Unbound Image + needs: [gpu-tests, build-and-push-docker] + if: always() && needs.build-and-push-docker.result == 'success' + runs-on: ubuntu-latest + steps: + - name: Get Docker Hub Token and Delete Image + # zizmor: ignore[template-injection] + run: | + IMAGE_NAME=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f1) + IMAGE_TAG=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f2) + + echo "Attempting to delete image: $IMAGE_NAME:$IMAGE_TAG" + + TOKEN=$(curl -s -H "Content-Type: application/json" \ + -X POST \ + -d '{"username": "${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}", "password": "${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}"}' \ + https://hub.docker.com/v2/users/login/ | jq -r .token) + + if [ "$TOKEN" == "null" ] || [ -z "$TOKEN" ]; then + echo "::error::Failed to get Docker Hub token." + exit 1 + fi + + HTTP_RESPONSE=$(curl -s -o /dev/null -w "%{http_code}" \ + -H "Authorization: JWT ${TOKEN}" \ + -X DELETE \ + https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/${IMAGE_TAG}/) + + if [ "$HTTP_RESPONSE" -eq 204 ]; then + echo "Successfully deleted Docker image tag: $IMAGE_NAME:$IMAGE_TAG" + else + echo "::error::Failed to delete Docker image. HTTP status: $HTTP_RESPONSE" + exit 1 + fi diff --git a/docker/Dockerfile.internal b/docker/Dockerfile.internal index 52becb830..2616cd06c 100644 --- a/docker/Dockerfile.internal +++ b/docker/Dockerfile.internal @@ -75,6 +75,14 @@ RUN uv venv --python python${PYTHON_VERSION} # Install Python dependencies for caching COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./ COPY --chown=user_lerobot:user_lerobot src/ src/ + +ARG UNBOUND_DEPS=false + +RUN if [ "$UNBOUND_DEPS" = "true" ]; then \ + sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml; \ + echo "Dependencies unbound:" && cat pyproject.toml; \ + fi + RUN uv pip install --no-cache ".[all]" # Copy the rest of the application source code diff --git a/docker/Dockerfile.user b/docker/Dockerfile.user index 59fd3e0b3..c1b284453 100644 --- a/docker/Dockerfile.user +++ b/docker/Dockerfile.user @@ -61,6 +61,14 @@ RUN uv venv # Install Python dependencies for caching COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./ COPY --chown=user_lerobot:user_lerobot src/ src/ + +ARG UNBOUND_DEPS=false + +RUN if [ "$UNBOUND_DEPS" = "true" ]; then \ + sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml; \ + echo "Dependencies unbound:" && cat pyproject.toml; \ + fi + RUN uv pip install --no-cache ".[all]" # Copy the rest of the application code diff --git a/pyproject.toml b/pyproject.toml index f350fac0a..c67b481f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,20 +59,20 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici dependencies = [ # Hugging Face dependencies - "datasets>=4.0.0", - "diffusers>=0.27.2", - "huggingface-hub[hf-transfer,cli]>=0.34.2", + "datasets>=4.0.0,<4.2.0", + "diffusers>=0.27.2,<0.36.0", + "huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0", # Core dependencies - "cmake>=3.29.0.1", - "einops>=0.8.0", - "opencv-python-headless>=4.9.0", - "av>=14.2.0", - "jsonlines>=4.0.0", - "packaging>=24.2", - "pynput>=1.7.7", - "pyserial>=3.5", - "wandb>=0.20.0", + "cmake>=3.29.0.1,<4.2.0", + "einops>=0.8.0,<0.9.0", + "opencv-python-headless>=4.9.0,<4.13.0", + "av>=14.2.0,<16.0.0", + "jsonlines>=4.0.0,<5.0.0", + "packaging>=24.2,<26.0", + "pynput>=1.7.7,<1.9.0", + "pyserial>=3.5,<4.0", + "wandb>=0.20.0,<0.23.0", "torch>=2.2.1,<2.8.0", # TODO: Bumb dependency "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency @@ -92,26 +92,26 @@ dependencies = [ [project.optional-dependencies] # Common -pygame-dep = ["pygame>=2.5.1"] -placo-dep = ["placo>=0.9.6"] -transformers-dep = ["transformers>=4.53.0"] +pygame-dep = ["pygame>=2.5.1,<2.7.0"] +placo-dep = ["placo>=0.9.6,<0.10.0"] +transformers-dep = ["transformers>=4.53.0,<5.0.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # Motors -feetech = ["feetech-servo-sdk>=1.0.0"] -dynamixel = ["dynamixel-sdk>=3.7.31"] +feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"] +dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"] # Robots -gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0"] +gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"] hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"] -lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1"] -reachy2 = ["reachy2_sdk>=1.0.14"] +lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"] +reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"] kinematics = ["lerobot[placo-dep]"] intelrealsense = [ - "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'", - "pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'", + "pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'", + "pyrealsense2-macosx>=2.54,<2.55.0 ; sys_platform == 'darwin'", ] -phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"] +phone = ["hebi-py>=2.8.0,<2.12.0", "teleop>=0.1.0,<0.2.0"] # stretch = [ # "hello-robot-stretch-body>=0.7.27 ; sys_platform == 'linux'", # "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'", @@ -120,21 +120,21 @@ phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"] # Policies pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"] -smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"] -hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] +smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] +hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] # Features -async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"] +async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"] # Development -dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"] -test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"] -video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] +dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"] +test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"] +video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"] # Simulation -aloha = ["gym-aloha>=0.1.1"] -pusht = ["gym-pusht>=0.1.5", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead -xarm = ["gym-xarm>=0.1.1"] +aloha = ["gym-aloha>=0.1.1,<0.2.0"] +pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead +xarm = ["gym-xarm>=0.1.1,<0.2.0"] libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"] From 5ac9356135400f38255e772578ef92b150f4de2a Mon Sep 17 00:00:00 2001 From: Iulia Feroli Date: Tue, 7 Oct 2025 09:43:32 +0200 Subject: [PATCH 154/158] Update README.md to fix broken link to example notebook for visuals (#2117) Folder structure of examples seems to have changed with extra `dataset` folder and the notebook has also changed names. Signed-off-by: Iulia Feroli Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a59f96deb..357e62cc1 100644 --- a/README.md +++ b/README.md @@ -197,7 +197,7 @@ wandb login ### Visualize datasets -Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub. +Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/dataset/load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub. You can also locally visualize episodes from a dataset on the hub by executing our script from the command line: From fcaa0ea5f9b4ba3fee3360589a6b7a1ac12d6ff9 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Tue, 7 Oct 2025 14:09:36 +0200 Subject: [PATCH 155/158] remove extra time base set. (#2133) Co-authored-by: CarolinePascal --- src/lerobot/datasets/video_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 2c0e116cb..1d4f07c76 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -451,11 +451,6 @@ def concatenate_video_files( stream_map[input_stream.index] = output_container.add_stream_from_template( template=input_stream, opaque=True ) - stream_map[ - input_stream.index - ].time_base = ( - input_stream.time_base - ) # set the time base to the input stream time base (missing in the codec context) # Demux + remux packets (no re-encode) for packet in input_container.demux(): From 9f32e00f9046ea65cad0bd543866862e27039138 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Tue, 7 Oct 2025 15:10:31 +0200 Subject: [PATCH 156/158] fix(async): Add pre and post processing to async inference and update docs (#2132) * Add pre and post processing to async inference and update docs * precommit fix typo * fix tests * refactor(async): no None branching for processors in _predict_action_chunk --------- Co-authored-by: Steven Palma --- docs/source/async.mdx | 16 +-- src/lerobot/async_inference/constants.py | 2 +- src/lerobot/async_inference/helpers.py | 6 +- src/lerobot/async_inference/policy_server.py | 116 ++++++++++++------- src/lerobot/async_inference/robot_client.py | 5 +- tests/async_inference/test_e2e.py | 3 + tests/async_inference/test_helpers.py | 28 ++--- tests/async_inference/test_policy_server.py | 3 + 8 files changed, 103 insertions(+), 76 deletions(-) diff --git a/docs/source/async.mdx b/docs/source/async.mdx index c66cdb143..be10f8baf 100644 --- a/docs/source/async.mdx +++ b/docs/source/async.mdx @@ -31,15 +31,15 @@ Then, spin up a policy server (in one terminal, or in a separate machine) specif You can spin up a policy server running: ```shell -python src/lerobot/async_inference/policy_server.py \ - --host=127.0.0.1 \ - --port=8080 \ +python -m lerobot.async_inference.policy_server \ + --host=127.0.0.1 \ + --port=8080 ``` This will start a policy server listening on `127.0.0.1:8080` (`localhost`, port 8080). At this stage, the policy server is empty, as all information related to which policy to run and with which parameters are specified during the first handshake with the client. Spin up a client with: ```shell -python src/lerobot/async_inference/robot_client.py \ +python -m lerobot.async_inference.robot_client \ --server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server --robot.type=so100_follower \ # ROBOT: your robot type --robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port @@ -113,9 +113,9 @@ As such, spinning up a policy server is as easy as specifying the host address a ```bash -python -m lerobot.scripts.server.policy_server \ - --host="localhost" \ - --port=8080 +python -m lerobot.async_inference.policy_server \ + --host=127.0.0.1 \ + --port=8080 ``` @@ -148,7 +148,7 @@ The `RobotClient` streams observations to the `PolicyServer`, and receives actio ```bash -python src/lerobot/async_inference/robot_client.py \ +python -m lerobot.async_inference.robot_client \ --server_address=127.0.0.1:8080 \ # SERVER: the host address and port of the policy server --robot.type=so100_follower \ # ROBOT: your robot type --robot.port=/dev/tty.usbmodem585A0076841 \ # ROBOT: your robot port diff --git a/src/lerobot/async_inference/constants.py b/src/lerobot/async_inference/constants.py index 5ebf3780c..1b1dac0f5 100644 --- a/src/lerobot/async_inference/constants.py +++ b/src/lerobot/async_inference/constants.py @@ -26,4 +26,4 @@ DEFAULT_OBS_QUEUE_TIMEOUT = 2 SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"] # TODO: Add all other robots -SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"] +SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower"] diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 88fb00a3f..54fad8c54 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -92,11 +92,11 @@ def resize_robot_observation_image(image: torch.tensor, resize_dims: tuple[int, return resized.squeeze(0) +# TODO(Steven): Consider implementing a pipeline step for this def raw_observation_to_observation( raw_observation: RawObservation, lerobot_features: dict[str, dict], policy_image_features: dict[str, PolicyFeature], - device: str, ) -> Observation: observation = {} @@ -105,9 +105,7 @@ def raw_observation_to_observation( if isinstance(v, torch.Tensor): # VLAs present natural-language instructions in observations if "image" in k: # Policy expects images in shape (B, C, H, W) - observation[k] = prepare_image(v).unsqueeze(0).to(device) - else: - observation[k] = v.to(device) + observation[k] = prepare_image(v).unsqueeze(0) else: observation[k] = v diff --git a/src/lerobot/async_inference/policy_server.py b/src/lerobot/async_inference/policy_server.py index 125727060..f7e00dea4 100644 --- a/src/lerobot/async_inference/policy_server.py +++ b/src/lerobot/async_inference/policy_server.py @@ -15,7 +15,7 @@ """ Example: ```shell -python src/lerobot/async_inference/policy_server.py \ +python -m lerobot.async_inference.policy_server \ --host=127.0.0.1 \ --port=8080 \ --fps=30 \ @@ -32,12 +32,17 @@ from concurrent import futures from dataclasses import asdict from pprint import pformat from queue import Empty, Queue +from typing import Any import draccus import grpc import torch -from lerobot.policies.factory import get_policy_class +from lerobot.policies.factory import get_policy_class, make_pre_post_processors +from lerobot.processor import ( + PolicyAction, + PolicyProcessorPipeline, +) from lerobot.transport import ( services_pb2, # type: ignore services_pb2_grpc, # type: ignore @@ -82,6 +87,8 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): self.lerobot_features = None self.actions_per_chunk = None self.policy = None + self.preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]] | None = None + self.postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction] | None = None @property def running(self): @@ -146,6 +153,16 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): start = time.perf_counter() self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path) self.policy.to(self.device) + + # Load preprocessor and postprocessor, overriding device to match requested device + device_override = {"device": self.device} + self.preprocessor, self.postprocessor = make_pre_post_processors( + self.policy.config, + pretrained_path=policy_specs.pretrained_name_or_path, + preprocessor_overrides={"device_processor": device_override}, + postprocessor_overrides={"device_processor": device_override}, + ) + end = time.perf_counter() self.logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds") @@ -173,7 +190,7 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): # Calculate FPS metrics fps_metrics = self.fps_tracker.calculate_fps_metrics(obs_timestamp) - self.logger.info( + self.logger.debug( f"Received observation #{obs_timestep} | " f"Avg FPS: {fps_metrics['avg_fps']:.2f} | " # fps at which observations are received from client f"Target: {fps_metrics['target_fps']:.2f} | " @@ -189,7 +206,7 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): if not self._enqueue_observation( timed_observation # wrapping a RawObservation ): - self.logger.info(f"Observation #{obs_timestep} has been filtered out") + self.logger.debug(f"Observation #{obs_timestep} has been filtered out") return services_pb2.Empty() @@ -301,23 +318,6 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): for i, action in enumerate(action_chunk) ] - def _prepare_observation(self, observation_t: TimedObservation) -> Observation: - """ - Prepare observation, ready for policy inference. - E.g.: To keep observation sampling rate high (and network packet tiny) we send int8 [0,255] images from the - client and then convert them to float32 [0,1] images here, before running inference. - """ - # RawObservation from robot.get_observation() - wrong keys, wrong dtype, wrong image shape - observation: Observation = raw_observation_to_observation( - observation_t.get_observation(), - self.lerobot_features, - self.policy_image_features, - self.device, - ) - # processed Observation - right keys, right dtype, right image shape - - return observation - def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor: """Get an action chunk from the policy. The chunk contains only""" chunk = self.policy.predict_action_chunk(observation) @@ -327,44 +327,76 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): return chunk[:, : self.actions_per_chunk, :] def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]: - """Predict an action chunk based on an observation""" - inference_starts = time.perf_counter() + """Predict an action chunk based on an observation. + Pipeline: + 1. Convert raw observation to LeRobot format + 2. Apply preprocessor (tokenization, normalization, batching, device placement) + 3. Run policy inference to get action chunk + 4. Apply postprocessor (unnormalization, device movement) + 5. Convert to TimedAction list + """ """1. Prepare observation""" - start_time = time.perf_counter() - observation = self._prepare_observation(observation_t) - preprocessing_time = time.perf_counter() - start_time + start_prepare = time.perf_counter() + observation: Observation = raw_observation_to_observation( + observation_t.get_observation(), + self.lerobot_features, + self.policy_image_features, + ) + prepare_time = time.perf_counter() - start_prepare + """2. Apply preprocessor""" + start_preprocess = time.perf_counter() + observation = self.preprocessor(observation) self.last_processed_obs: TimedObservation = observation_t + preprocessing_time = time.perf_counter() - start_preprocess - """2. Get action chunk""" - start_time = time.perf_counter() + """3. Get action chunk""" + start_inference = time.perf_counter() action_tensor = self._get_action_chunk(observation) - inference_time = time.perf_counter() - start_time + inference_time = time.perf_counter() - start_inference + self.logger.info( + f"Preprocessing and inference took {inference_time:.4f}s, action shape: {action_tensor.shape}" + ) - """3. Post-inference processing""" - start_time = time.perf_counter() - # Move to CPU before serializing - action_tensor = action_tensor.cpu().squeeze(0) + """4. Apply postprocessor""" + # Apply postprocessor (handles unnormalization and device movement) + # Postprocessor expects (B, action_dim) per action, but we have (B, chunk_size, action_dim) + # So we process each action in the chunk individually + start_postprocess = time.perf_counter() + _, chunk_size, _ = action_tensor.shape + # Process each action in the chunk + processed_actions = [] + for i in range(chunk_size): + # Extract action at timestep i: (B, action_dim) + single_action = action_tensor[:, i, :] + processed_action = self.postprocessor(single_action) + processed_actions.append(processed_action) + + # Stack back to (B, chunk_size, action_dim), then remove batch dim + action_tensor = torch.stack(processed_actions, dim=1).squeeze(0) + self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}") + + """5. Convert to TimedAction list""" action_chunk = self._time_action_chunk( observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep() ) - postprocessing_time = time.perf_counter() - start_time - inference_stops = time.perf_counter() + postprocess_stops = time.perf_counter() + postprocessing_time = postprocess_stops - start_postprocess self.logger.info( - f"Observation {observation_t.get_timestep()} |" - f"Inference time: {1000 * (inference_stops - inference_starts):.2f}ms" + f"Observation {observation_t.get_timestep()} | " + f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms" ) - # full-process latency breakdown for debugging purposes self.logger.debug( f"Observation {observation_t.get_timestep()} | " - f"Preprocessing time: {1000 * (preprocessing_time - inference_starts):.2f}ms | " - f"Inference time: {1000 * (inference_time - preprocessing_time):.2f}ms | " - f"Postprocessing time: {1000 * (postprocessing_time - inference_time):.2f}ms | " - f"Total time: {1000 * (postprocessing_time - inference_starts):.2f}ms" + f"Prepare time: {1000 * prepare_time:.2f}ms | " + f"Preprocessing time: {1000 * preprocessing_time:.2f}ms | " + f"Inference time: {1000 * inference_time:.2f}ms | " + f"Postprocessing time: {1000 * postprocessing_time:.2f}ms | " + f"Total time: {1000 * (postprocess_stops - start_prepare):.2f}ms" ) return action_chunk diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index c969bc605..8c4425c6b 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -52,6 +52,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_so100_follower, koch_follower, make_robot_from_config, so100_follower, @@ -214,7 +215,7 @@ class RobotClient: ) _ = self.stub.SendObservations(observation_iterator) obs_timestep = obs.get_timestep() - self.logger.info(f"Sent observation #{obs_timestep} | ") + self.logger.debug(f"Sent observation #{obs_timestep} | ") return True @@ -467,7 +468,7 @@ class RobotClient: if self._ready_to_send_observation(): _captured_observation = self.control_loop_observation(task, verbose) - self.logger.info(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}") + self.logger.debug(f"Control loop (ms): {(time.perf_counter() - control_loop_start) * 1000:.2f}") # Dynamically adjust sleep time to maintain the desired control frequency time.sleep(max(0, self.config.environment_dt - (time.perf_counter() - control_loop_start))) diff --git a/tests/async_inference/test_e2e.py b/tests/async_inference/test_e2e.py index 2689f0618..ebaef2ef1 100644 --- a/tests/async_inference/test_e2e.py +++ b/tests/async_inference/test_e2e.py @@ -91,6 +91,9 @@ def test_async_inference_e2e(monkeypatch): policy_server.policy = MockPolicy() policy_server.actions_per_chunk = 20 policy_server.device = "cpu" + # NOTE(Steven): Smelly tests as the Server is a state machine being partially mocked. Adding these processors as a quick fix. + policy_server.preprocessor = lambda obs: obs + policy_server.postprocessor = lambda tensor: tensor # Set up robot config and features robot_config = MockRobotConfig() diff --git a/tests/async_inference/test_helpers.py b/tests/async_inference/test_helpers.py index acf5870d5..1e2d1e311 100644 --- a/tests/async_inference/test_helpers.py +++ b/tests/async_inference/test_helpers.py @@ -333,9 +333,8 @@ def test_raw_observation_to_observation_basic(): robot_obs = _create_mock_robot_observation() lerobot_features = _create_mock_lerobot_features() policy_image_features = _create_mock_policy_image_features() - device = "cpu" - observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) + observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features) # Check that all expected keys are present assert OBS_STATE in observation @@ -345,7 +344,6 @@ def test_raw_observation_to_observation_basic(): # Check state processing state = observation[OBS_STATE] assert isinstance(state, torch.Tensor) - assert state.device.type == device assert state.shape == (1, 4) # Batched # Check image processing @@ -356,10 +354,6 @@ def test_raw_observation_to_observation_basic(): assert laptop_img.shape == (1, 3, 224, 224) assert phone_img.shape == (1, 3, 160, 160) - # Check device placement - assert laptop_img.device.type == device - assert phone_img.device.type == device - # Check image dtype and range (should be float32 in [0, 1]) assert laptop_img.dtype == torch.float32 assert phone_img.dtype == torch.float32 @@ -374,9 +368,8 @@ def test_raw_observation_to_observation_with_non_tensor_data(): lerobot_features = _create_mock_lerobot_features() policy_image_features = _create_mock_policy_image_features() - device = "cpu" - observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) + observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features) # Check that task string is preserved assert "task" in observation @@ -386,19 +379,17 @@ def test_raw_observation_to_observation_with_non_tensor_data(): @torch.no_grad() def test_raw_observation_to_observation_device_handling(): - """Test that tensors are properly moved to the specified device.""" - device = "mps" if torch.backends.mps.is_available() else "cpu" - + """Test that tensors are created (device placement is handled by preprocessor).""" robot_obs = _create_mock_robot_observation() lerobot_features = _create_mock_lerobot_features() policy_image_features = _create_mock_policy_image_features() - observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) + observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features) - # Check that all tensors are on the correct device + # Check that all expected keys produce tensors (device placement handled by preprocessor later) for key, value in observation.items(): if isinstance(value, torch.Tensor): - assert value.device.type == device, f"Tensor {key} not on {device}" + assert value.device.type in ["cpu", "cuda", "mps"], f"Tensor {key} on unexpected device" def test_raw_observation_to_observation_deterministic(): @@ -406,11 +397,10 @@ def test_raw_observation_to_observation_deterministic(): robot_obs = _create_mock_robot_observation() lerobot_features = _create_mock_lerobot_features() policy_image_features = _create_mock_policy_image_features() - device = "cpu" # Run twice with same input - obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) - obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, device) + obs1 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features) + obs2 = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features) # Results should be identical assert set(obs1.keys()) == set(obs2.keys()) @@ -448,7 +438,7 @@ def test_image_processing_pipeline_preserves_content(): ) } - observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features, "cpu") + observation = raw_observation_to_observation(robot_obs, lerobot_features, policy_image_features) processed_img = observation[f"{OBS_IMAGES}.laptop"].squeeze(0) # Remove batch dim diff --git a/tests/async_inference/test_policy_server.py b/tests/async_inference/test_policy_server.py index de441ff09..29583d4fa 100644 --- a/tests/async_inference/test_policy_server.py +++ b/tests/async_inference/test_policy_server.py @@ -196,6 +196,9 @@ def test_predict_action_chunk(monkeypatch, policy_server): # Force server to act-style policy; patch method to return deterministic tensor policy_server.policy_type = "act" + # NOTE(Steven): Smelly tests as the Server is a state machine being partially mocked. Adding these processors as a quick fix. + policy_server.preprocessor = lambda obs: obs + policy_server.postprocessor = lambda tensor: tensor action_dim = 6 batch_size = 1 actions_per_chunk = policy_server.actions_per_chunk From bf3c8746b72aedcd23c29d4a3d1ee53df810ec22 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 7 Oct 2025 17:46:22 +0200 Subject: [PATCH 157/158] feat(devices): add lazy loading for 3rd party robots cameras and teleoperators (#2123) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(devices): add lazy loading for 3rd party robots cameras and teleoperators Co-authored-by: Darko Lukić * feat(devices): load device class based on assumptions in naming * docs(devices): instructions for using 3rd party devices * docs: address review feedback * chore(docs): add example for 3rd party devices --------- Co-authored-by: Darko Lukić --- docs/source/integrate_hardware.mdx | 128 +++++++++++++++++++++ src/lerobot/cameras/utils.py | 11 +- src/lerobot/robots/utils.py | 10 +- src/lerobot/scripts/lerobot_calibrate.py | 2 + src/lerobot/scripts/lerobot_record.py | 2 + src/lerobot/scripts/lerobot_replay.py | 2 + src/lerobot/scripts/lerobot_teleoperate.py | 2 + src/lerobot/teleoperators/utils.py | 9 +- src/lerobot/utils/import_utils.py | 94 +++++++++++++++ 9 files changed, 255 insertions(+), 5 deletions(-) diff --git a/docs/source/integrate_hardware.mdx b/docs/source/integrate_hardware.mdx index 7b2e3833f..7e7fe0bff 100644 --- a/docs/source/integrate_hardware.mdx +++ b/docs/source/integrate_hardware.mdx @@ -335,6 +335,134 @@ For implementing teleoperation devices, we also provide a [`Teleoperator`](https The main differences are in the I/O functions: a teleoperator allows you to produce action via `get_action` and can receive feedback actions via `send_feedback`. Feedback could be anything controllable on the teleoperation device that could help the person controlling it understand the consequences of the actions sent. Think motion/force feedback on a leader arm, vibrations on a gamepad controller for example. To implement a teleoperator, you can follow this same tutorial and adapt it for these two methods. +## Using Your Own `LeRobot` Devices 🔌 + +You can easily extend `lerobot` with your own custom hardware—be it a camera, robot, or teleoperation device—by creating a separate, installable Python package. If you follow a few simple conventions, the `lerobot` command-line tools (like `lerobot-teleop` and `lerobot-record`) will **automatically discover and integrate your creations** without requiring any changes to the `lerobot` source code. + +This guide outlines the conventions your plugin must follow. + +### The 4 Core Conventions + +To ensure your custom device is discoverable, you must adhere to the following four rules. + +#### 1\. Create an Installable Package with a Specific Prefix + +Your project must be a standard, installable Python package. Crucially, the name of your package (as defined in `pyproject.toml` or `setup.py`) must begin with one of these prefixes: + +- `lerobot_robot_` for a robot. +- `lerobot_camera_` for a camera. +- `lerobot_teleoperator_` for a teleoperation device. + +This prefix system is how `lerobot` automatically finds your plugin in the Python environment. + +#### 2\. Follow the `SomethingConfig`/`Something` Naming Pattern + +Your device's implementation class must be named after its configuration class, simply by removing the `Config` suffix. + +- **Config Class:** `MyAwesomeTeleopConfig` +- **Device Class:** `MyAwesomeTeleop` + +#### 3\. Place Your Files in a Predictable Structure + +The device class (`MyAwesomeTeleop`) must be located in a predictable module relative to its configuration class (`MyAwesomeTeleopConfig`). `lerobot` will automatically search in these locations: + +- In the **same module** as the config class. +- In a **submodule named after the device** (e.g., `my_awesome_teleop.py`). + +The recommended and simplest structure is to place them in separate, clearly named files within the same directory. + +#### 4\. Expose Classes in `__init__.py` + +Your package's `__init__.py` file should import and expose both the configuration and the device classes, making them easily accessible. + +### Putting It All Together: A Complete Example + +Let's create a new teleoperator called `my_awesome_teleop`. + +#### Directory Structure + +Here is what the project folder should look like. The package name, `lerobot_teleoperator_my_awesome_teleop`, follows **Convention \#1**. + +``` +lerobot_teleoperator_my_awesome_teleop/ +├── pyproject.toml # (or setup.py) lists lerobot as a dependency +└── lerobot_teleoperator_my_awesome_teleop/ + ├── __init__.py + ├── config_my_awesome_teleop.py + └── my_awesome_teleop.py +``` + +#### File Contents + +- **`config_my_awesome_teleop.py`**: Defines the configuration class. Note the `Config` suffix (**Convention \#2**). + + ```python + from dataclasses import dataclass + + from lerobot.teleoperators.config import TeleoperatorConfig + + @TeleoperatorConfig.register_subclass("my_awesome_teleop") + @dataclass + class MyAwesomeTeleopConfig(TeleoperatorConfig): + # Your configuration fields go here + port: str = "192.168.1.1" + ``` + +- **`my_awesome_teleop.py`**: Implements the device. The class name `MyAwesomeTeleop` matches its config class name (**Convention \#2**). This file structure adheres to **Convention \#3**. + + ```python + from lerobot.teleoperators.teleoperator import Teleoperator + + from .config_my_awesome_teleop import MyAwesomeTeleopConfig + + class MyAwesomeTeleop(Teleoperator): + config_class = MyAwesomeTeleopConfig + name = "my_awesome_teleop" + + def __init__(self, config: MyAwesomeTeleopConfig): + super().__init__(config) + self.config = config + + # Your device logic (e.g., connect) goes here + ``` + +- **`__init__.py`**: Exposes the key classes (**Convention \#4**). + + ```python + from .config_my_awesome_teleop import MyAwesomeTeleopConfig + from .my_awesome_teleop import MyAwesomeTeleop + ``` + +### Installation and Usage + +1. **Install your new plugin in your Python environment.** You can install your local plugin package using `pip`'s editable mode or from PyPi. + + ```bash + # Locally + # Navigate to your plugin's root directory and install it + cd lerobot_teleoperator_my_awesome_teleop + pip install -e . + + # From PyPi + pip install lerobot_teleoperator_my_awesome_teleop + ``` + +2. **Use it directly from the command line.** Now, you can use your custom device by referencing its type. + + ```bash + lerobot-teleoperate --teleop.type=my_awesome_teleop \ + # other arguments + ``` + +And that's it\! Your custom device is now fully integrated. + +### Looking for an example ? + +Check out these two packages from the community: + +- https://github.com/SpesRobotics/lerobot-robot-xarm +- https://github.com/SpesRobotics/lerobot-teleoperator-teleop + ## Wrapping Up Once your robot class is complete, you can leverage the LeRobot ecosystem: diff --git a/src/lerobot/cameras/utils.py b/src/lerobot/cameras/utils.py index 4a23843b2..aa6ff98b4 100644 --- a/src/lerobot/cameras/utils.py +++ b/src/lerobot/cameras/utils.py @@ -15,15 +15,19 @@ # limitations under the License. import platform +from typing import cast + +from lerobot.utils.import_utils import make_device_from_device_class from .camera import Camera from .configs import CameraConfig, Cv2Rotation def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[str, Camera]: - cameras = {} + cameras: dict[str, Camera] = {} for key, cfg in camera_configs.items(): + # TODO(Steven): Consider just using the make_device_from_device_class for all types if cfg.type == "opencv": from .opencv import OpenCVCamera @@ -40,7 +44,10 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s cameras[key] = Reachy2Camera(cfg) else: - raise ValueError(f"The camera type '{cfg.type}' is not valid.") + try: + cameras[key] = cast(Camera, make_device_from_device_class(cfg)) + except Exception as e: + raise ValueError(f"Error creating camera {key} with config {cfg}: {e}") from e return cameras diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index 0455bce3f..aca5c8716 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -14,13 +14,16 @@ import logging from pprint import pformat +from typing import cast -from lerobot.robots import RobotConfig +from lerobot.utils.import_utils import make_device_from_device_class +from .config import RobotConfig from .robot import Robot def make_robot_from_config(config: RobotConfig) -> Robot: + # TODO(Steven): Consider just using the make_device_from_device_class for all types if config.type == "koch_follower": from .koch_follower import KochFollower @@ -66,7 +69,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot: return MockRobot(config) else: - raise ValueError(config.type) + try: + return cast(Robot, make_device_from_device_class(config)) + except Exception as e: + raise ValueError(f"Error creating robot with config {config}: {e}") from e # TODO(pepijn): Move to pipeline step to make sure we don't have to do this in the robot code and send action to robot is clean for use in dataset diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index 0aa61a2f9..0f247caef 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -52,6 +52,7 @@ from lerobot.teleoperators import ( # noqa: F401 so100_leader, so101_leader, ) +from lerobot.utils.import_utils import register_third_party_devices from lerobot.utils.utils import init_logging @@ -83,6 +84,7 @@ def calibrate(cfg: CalibrateConfig): def main(): + register_third_party_devices() calibrate() diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index ddb21e917..55846ff63 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -117,6 +117,7 @@ from lerobot.utils.control_utils import ( sanity_check_dataset_name, sanity_check_dataset_robot_compatibility, ) +from lerobot.utils.import_utils import register_third_party_devices from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import ( get_safe_torch_device, @@ -513,6 +514,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: def main(): + register_third_party_devices() record() diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index b899745b6..ffd7b2b22 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -61,6 +61,7 @@ from lerobot.robots import ( # noqa: F401 so101_follower, ) from lerobot.utils.constants import ACTION +from lerobot.utils.import_utils import register_third_party_devices from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import ( init_logging, @@ -126,6 +127,7 @@ def replay(cfg: ReplayConfig): def main(): + register_third_party_devices() replay() diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index ab9a6361d..0a418f3bc 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -88,6 +88,7 @@ from lerobot.teleoperators import ( # noqa: F401 so100_leader, so101_leader, ) +from lerobot.utils.import_utils import register_third_party_devices from lerobot.utils.robot_utils import busy_wait from lerobot.utils.utils import init_logging, move_cursor_up from lerobot.utils.visualization_utils import init_rerun, log_rerun_data @@ -215,6 +216,7 @@ def teleoperate(cfg: TeleoperateConfig): def main(): + register_third_party_devices() teleoperate() diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index bad7d9c37..ada7ee8a1 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -13,6 +13,9 @@ # limitations under the License. from enum import Enum +from typing import cast + +from lerobot.utils.import_utils import make_device_from_device_class from .config import TeleoperatorConfig from .teleoperator import Teleoperator @@ -29,6 +32,7 @@ class TeleopEvents(Enum): def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: + # TODO(Steven): Consider just using the make_device_from_device_class for all types if config.type == "keyboard": from .keyboard import KeyboardTeleop @@ -82,4 +86,7 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: return Reachy2Teleoperator(config) else: - raise ValueError(config.type) + try: + return cast(Teleoperator, make_device_from_device_class(config)) + except Exception as e: + raise ValueError(f"Error creating robot with config {config}: {e}") from e diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index 5f41ea3a3..de43e58db 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -15,6 +15,10 @@ # limitations under the License. import importlib import logging +import pkgutil +from typing import Any + +from draccus.choice_types import ChoiceRegistry def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: @@ -58,3 +62,93 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b _transformers_available = is_package_available("transformers") + + +def make_device_from_device_class(config: ChoiceRegistry) -> Any: + """ + Dynamically instantiates an object from its `ChoiceRegistry` configuration. + + This factory uses the module path and class name from the `config` object's + type to locate and instantiate the corresponding device class (not the config). + It derives the device class name by removing a trailing 'Config' from the config + class name and tries a few candidate modules where the device implementation is + commonly located. + """ + if not isinstance(config, ChoiceRegistry): + raise ValueError(f"Config should be an instance of `ChoiceRegistry`, got {type(config)}") + + config_cls = config.__class__ + module_path = config_cls.__module__ # typical: lerobot_teleop_mydevice.config_mydevice + config_name = config_cls.__name__ # typical: MyDeviceConfig + + # Derive device class name (strip "Config") + if not config_name.endswith("Config"): + raise ValueError(f"Config class name '{config_name}' does not end with 'Config'") + + device_class_name = config_name[:-6] # typical: MyDeviceConfig -> MyDevice + + # Build candidate modules to search for the device class + parts = module_path.split(".") + parent_module = ".".join(parts[:-1]) if len(parts) > 1 else module_path + candidates = [ + parent_module, # typical: lerobot_teleop_mydevice + parent_module + "." + device_class_name.lower(), # typical: lerobot_teleop_mydevice.mydevice + ] + + # handle modules named like "config_xxx" -> try replacing that piece with "xxx" + last = parts[-1] if parts else "" + if last.startswith("config_"): + candidates.append(".".join(parts[:-1] + [last.replace("config_", "")])) + + # de-duplicate while preserving order + seen: set[str] = set() + candidates = [c for c in candidates if not (c in seen or seen.add(c))] + + tried: list[str] = [] + for candidate in candidates: + tried.append(candidate) + try: + module = importlib.import_module(candidate) + except ImportError: + continue + + if hasattr(module, device_class_name): + cls = getattr(module, device_class_name) + if callable(cls): + try: + return cls(config) + except TypeError as e: + raise TypeError( + f"Failed to instantiate '{device_class_name}' from module '{candidate}': {e}" + ) from e + + raise ImportError( + f"Could not locate device class '{device_class_name}' for config '{config_name}'. " + f"Tried modules: {tried}. Ensure your device class name is the config class name without " + f"'Config' and that it's importable from one of those modules." + ) + + +def register_third_party_devices() -> None: + """ + Discover and import third-party lerobot_* plugins so they can register themselves. + + Scans top-level modules on sys.path for packages starting with + 'lerobot_robot_', 'lerobot_camera_' or 'lerobot_teleoperator_' and imports them. + """ + prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_") + imported: list[str] = [] + failed: list[str] = [] + + for module_info in pkgutil.iter_modules(): + name = module_info.name + if name.startswith(prefixes): + try: + importlib.import_module(name) + imported.append(name) + logging.info("Imported third-party plugin: %s", name) + except Exception: + logging.exception("Could not import third-party plugin: %s", name) + failed.append(name) + + logging.debug("Third-party plugin import summary: imported=%s failed=%s", imported, failed) From 6c28ef894af215bfaaf665aa3015c5645d91e53f Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 8 Oct 2025 14:27:52 +0200 Subject: [PATCH 158/158] chore(docs): add missing license headers (#2140) --- src/lerobot/motors/__init__.py | 16 ++++++++++++++++ src/lerobot/processor/policy_robot_bridge.py | 16 ++++++++++++++++ src/lerobot/robots/__init__.py | 16 ++++++++++++++++ tests/plugins/reachy2_sdk.py | 16 ++++++++++++++++ tests/policies/pi0_pi05/test_pi0.py | 14 ++++++++++++++ tests/policies/pi0_pi05/test_pi05.py | 14 ++++++++++++++ .../pi0_pi05/test_pi05_original_vs_lerobot.py | 16 ++++++++++++++++ .../pi0_pi05/test_pi0_original_vs_lerobot.py | 16 ++++++++++++++++ tests/processor/test_batch_conversion.py | 16 ++++++++++++++++ tests/processor/test_converters.py | 16 ++++++++++++++++ tests/processor/test_tokenizer_processor.py | 16 ++++++++++++++++ tests/utils/test_io_utils.py | 5 ++++- tests/utils/test_logging_utils.py | 5 ++++- tests/utils/test_random_utils.py | 5 ++++- tests/utils/test_train_utils.py | 5 ++++- tests/utils/test_visualization_utils.py | 16 ++++++++++++++++ 16 files changed, 204 insertions(+), 4 deletions(-) diff --git a/src/lerobot/motors/__init__.py b/src/lerobot/motors/__init__.py index dfbfbaee8..850ef33d7 100644 --- a/src/lerobot/motors/__init__.py +++ b/src/lerobot/motors/__init__.py @@ -1 +1,17 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus diff --git a/src/lerobot/processor/policy_robot_bridge.py b/src/lerobot/processor/policy_robot_bridge.py index 845ee065a..25887d414 100644 --- a/src/lerobot/processor/policy_robot_bridge.py +++ b/src/lerobot/processor/policy_robot_bridge.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import asdict, dataclass from typing import Any diff --git a/src/lerobot/robots/__init__.py b/src/lerobot/robots/__init__.py index d8fd0de93..1dba0f1b0 100644 --- a/src/lerobot/robots/__init__.py +++ b/src/lerobot/robots/__init__.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .config import RobotConfig from .robot import Robot from .utils import make_robot_from_config diff --git a/tests/plugins/reachy2_sdk.py b/tests/plugins/reachy2_sdk.py index f56b59efb..457fcf0f9 100644 --- a/tests/plugins/reachy2_sdk.py +++ b/tests/plugins/reachy2_sdk.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import sys import types from unittest.mock import MagicMock diff --git a/tests/policies/pi0_pi05/test_pi0.py b/tests/policies/pi0_pi05/test_pi0.py index 65f64e6bc..b580310eb 100644 --- a/tests/policies/pi0_pi05/test_pi0.py +++ b/tests/policies/pi0_pi05/test_pi0.py @@ -1,5 +1,19 @@ #!/usr/bin/env python +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Test script to verify PI0 policy integration with LeRobot, only meant to be run locally!""" import os diff --git a/tests/policies/pi0_pi05/test_pi05.py b/tests/policies/pi0_pi05/test_pi05.py index 72828a02f..964539446 100644 --- a/tests/policies/pi0_pi05/test_pi05.py +++ b/tests/policies/pi0_pi05/test_pi05.py @@ -1,5 +1,19 @@ #!/usr/bin/env python +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Test script to verify PI0.5 (pi05) support in PI0 policy, only meant to be run locally!""" import os diff --git a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py index 7bea89486..0d5244e1c 100644 --- a/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py +++ b/tests/policies/pi0_pi05/test_pi05_original_vs_lerobot.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation, only meant to be run locally!""" import os diff --git a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py index d91f716f1..41db2dceb 100644 --- a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py +++ b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Test script to verify PI0 policy integration with LeRobot vs the original implementation, only meant to be run locally!""" import os diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py index 88b873128..477381618 100644 --- a/tests/processor/test_batch_conversion.py +++ b/tests/processor/test_batch_conversion.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch from lerobot.processor import DataProcessorPipeline, TransitionKey diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py index bc58f7a61..47a6eea18 100644 --- a/tests/processor/test_converters.py +++ b/tests/processor/test_converters.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import numpy as np import pytest import torch diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index b81710db1..d6f87f567 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Tests for the TokenizerProcessorStep class. """ diff --git a/tests/utils/test_io_utils.py b/tests/utils/test_io_utils.py index 9768a5ef9..0beea639d 100644 --- a/tests/utils/test_io_utils.py +++ b/tests/utils/test_io_utils.py @@ -1,4 +1,6 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import json from pathlib import Path from typing import Any diff --git a/tests/utils/test_logging_utils.py b/tests/utils/test_logging_utils.py index 927fdc14d..560ba5701 100644 --- a/tests/utils/test_logging_utils.py +++ b/tests/utils/test_logging_utils.py @@ -1,4 +1,6 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import pytest from lerobot.utils.logging_utils import AverageMeter, MetricsTracker diff --git a/tests/utils/test_random_utils.py b/tests/utils/test_random_utils.py index 5865361d0..e3a5d420f 100644 --- a/tests/utils/test_random_utils.py +++ b/tests/utils/test_random_utils.py @@ -1,4 +1,6 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import random import numpy as np diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index 0eeaf907c..892503e97 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -1,4 +1,6 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from pathlib import Path from unittest.mock import Mock, patch diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py index 65a97c6a3..08a827570 100644 --- a/tests/utils/test_visualization_utils.py +++ b/tests/utils/test_visualization_utils.py @@ -1,3 +1,19 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import importlib import sys from types import SimpleNamespace