Files
lerobot/tests/artifacts/policies/save_policy_to_safetensors.py
T
Steven Palma ce665160ae feat(processor): multiple improvements to the pipeline porting (#1749)
* [Port codebase pipeline] General fixes for RL and scripts (#1748)

* Refactor dataset configuration in documentation and codebase

- Updated dataset configuration keys from `dataset_root` to `root` and `num_episodes` to `num_episodes_to_record` for consistency.
- Adjusted replay episode handling by renaming `episode` to `replay_episode`.
- Enhanced documentation
- added specific processor to transform from policy actions to delta actions

* Added Robot action to tensor processor
Added new processor script for dealing with gym specific action processing

* removed RobotAction2Tensor processor; imrpoved choosing observations in actor

* nit in delta action

* added missing reset functions to kinematics

* Adapt teleoperate and replay to pipeline similar to record

* refactor(processors): move to inheritance (#1750)

* fix(teleoperator): improvements phone implementation (#1752)

* fix(teleoperator): protect shared state in phone implementation

* refactor(teleop): separate classes in phone

* fix: solve breaking changes (#1753)

* refactor(policies): multiple improvements (#1754)

* refactor(processor): simpler logic in device processor (#1755)

* refactor(processor): euclidean distance in delta action processor (#1757)

* refactor(processor): improvements to joint observations processor migration (#1758)

* refactor(processor): improvements to tokenizer migration (#1759)

* refactor(processor): improvements to tokenizer migration

* fix(tests): tokenizer tests regression from #1750

* fix(processors): fix float comparison and config in hil processors (#1760)

* chore(teleop): remove unnecessary callbacks in KeyboardEndEffectorTeleop (#1761)

* refactor(processor): improvements normalize pipeline migration (#1756)

* refactor(processor): several improvements normalize processor step

* refactor(processor): more improvements normalize processor

* refactor(processor): more changes to normalizer

* refactor(processor): take a different approach to DRY

* refactor(processor): final design

* chore(record): revert comment and continue deleted (#1764)

* refactor(examples): pipeline phone examples (#1769)

* refactor(examples): phone teleop + teleop script

* refactor(examples): phone replay + replay

* chore(examples): rename phone example files & folders

* feat(processor): fix improvements to the pipeline porting (#1796)

* refactor(processor): enhance tensor device handling in normalization process (#1795)

* refactor(tests): remove unsupported device detection test for complementary data (#1797)

* chore(tests): update ToBatchProcessor test (#1798)

* refactor(tests): remove in-place mutation tests for actions and complementary data in batch processor

* test(tests): add tests for action and task processing in batch processor

* add names for android and ios phone (#1799)

* use _tensor_stats in normalize processor (#1800)

* fix(normalize_processor): correct device reference for tensor epsilon handling (#1801)

* add point 5 add missing feature contracts (#1806)

* Fix PR comments 1452 (#1807)

* use key to determine image

* Address rest of PR comments

* use PolicyFeatures in transform_features

---------

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
2025-08-31 20:38:52 +02:00

156 lines
5.9 KiB
Python

#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from pathlib import Path
import torch
from safetensors.torch import save_file
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies.factory import make_policy, make_policy_config, make_pre_post_processors
from lerobot.processor import TransitionKey
from lerobot.utils.random_utils import set_seed
def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
set_seed(1337)
train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs),
)
train_cfg.validate() # Needed for auto-setting some parameters
dataset = make_dataset(train_cfg)
dataset_stats = dataset.meta.stats
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
preprocessor, postprocessor = make_pre_post_processors(train_cfg.policy, dataset_stats=dataset_stats)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=train_cfg.batch_size,
shuffle=False,
)
batch = next(iter(dataloader))
batch = preprocessor(batch)
loss, output_dict = policy.forward(batch)
if output_dict is not None:
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
output_dict["loss"] = loss
else:
output_dict = {"loss": loss}
loss.backward()
grad_stats = {}
for key, param in policy.named_parameters():
if param.requires_grad:
grad_stats[f"{key}_mean"] = param.grad.mean()
grad_stats[f"{key}_std"] = (
param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0))
)
optimizer.step()
param_stats = {}
for key, param in policy.named_parameters():
param_stats[f"{key}_mean"] = param.mean()
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
optimizer.zero_grad()
policy.reset()
# HACK: We reload a batch with no delta_indices as `select_action` won't expect a timestamps dimension
# We simulate having an environment using a dataset by setting delta_indices to None and dropping tensors
# indicating padding (those ending with "_is_pad")
dataset.delta_indices = None
batch = next(iter(dataloader))
obs = {}
for k in batch:
# TODO: regenerate the safetensors
# for backward compatibility
if k.endswith("_is_pad"):
continue
# for backward compatibility
if k == "task":
continue
if k.startswith("observation"):
obs[k] = batch[k]
if hasattr(train_cfg.policy, "n_action_steps"):
actions_queue = train_cfg.policy.n_action_steps
else:
actions_queue = train_cfg.policy.n_action_repeats
actions = {}
for i in range(actions_queue):
unnormalized_action = policy.select_action(obs).contiguous()
action_robot = postprocessor({TransitionKey.ACTION: unnormalized_action}).get(TransitionKey.ACTION)
actions[str(i)] = action_robot
return output_dict, grad_stats, param_stats, actions
def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict):
if output_dir.exists():
print(f"Overwrite existing safetensors in '{output_dir}':")
print(f" - Validate with: `git add {output_dir}`")
print(f" - Revert with: `git checkout -- {output_dir}`")
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
save_file(output_dict, output_dir / "output_dict.safetensors")
save_file(grad_stats, output_dir / "grad_stats.safetensors")
save_file(param_stats, output_dir / "param_stats.safetensors")
save_file(actions, output_dir / "actions.safetensors")
if __name__ == "__main__":
artifacts_cfg = [
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
(
"lerobot/pusht",
"diffusion",
{
"n_action_steps": 8,
"num_inference_steps": 10,
"down_dims": [128, 256, 512],
},
"",
),
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
(
"lerobot/aloha_sim_insertion_human",
"act",
{"n_action_steps": 1000, "chunk_size": 1000},
"1000_steps",
),
]
if len(artifacts_cfg) == 0:
raise RuntimeError("No policies were provided!")
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
ds_name = ds_repo_id.split("/")[-1]
output_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)