# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This script demonstrates how to train Diffusion Policy on the PushT environment.""" from pathlib import Path import torch from lerobot.configs import FeatureType from lerobot.datasets import LeRobotDataset, LeRobotDatasetMetadata from lerobot.policies import make_pre_post_processors from lerobot.policies.diffusion import DiffusionConfig, DiffusionPolicy from lerobot.utils.feature_utils import dataset_to_policy_features def main(): # Create a directory to store the training checkpoint. output_directory = Path("outputs/train/example_pusht_diffusion") output_directory.mkdir(parents=True, exist_ok=True) # # Select your device device = torch.device("cuda") # Number of offline training steps (we'll only do offline training for this example.) # Adjust as you prefer. 5000 steps are needed to get something worth evaluating. training_steps = 5000 log_freq = 1 # When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before # creating the policy: # - input/output shapes: to properly size the policy # - dataset stats: for normalization and denormalization of input/outputs dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht") features = dataset_to_policy_features(dataset_metadata.features) output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} input_features = {key: ft for key, ft in features.items() if key not in output_features} # Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example, # we'll just use the defaults and so no arguments other than input/output features need to be passed. cfg = DiffusionConfig(input_features=input_features, output_features=output_features) # We can now instantiate our policy with this config and the dataset stats. policy = DiffusionPolicy(cfg) policy.train() policy.to(device) preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats) # Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames # which can differ for inputs, outputs and rewards (if there are some). delta_timestamps = { "observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices], "observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices], "action": [i / dataset_metadata.fps for i in cfg.action_delta_indices], } # In this case with the standard configuration for Diffusion Policy, it is equivalent to this: delta_timestamps = { # Load the previous image and state at -0.1 seconds before current frame, # then load current image and state corresponding to 0.0 second. "observation.image": [-0.1, 0.0], "observation.state": [-0.1, 0.0], # Load the previous action (-0.1), the next action to be executed (0.0), # and 14 future actions with a 0.1 seconds spacing. All these actions will be # used to supervise the policy. "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], } # We can then instantiate the dataset with these delta_timestamps configuration. dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps) # Then we create our optimizer and dataloader for offline training. optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4) dataloader = torch.utils.data.DataLoader( dataset, num_workers=4, batch_size=64, shuffle=True, pin_memory=device.type != "cpu", drop_last=True, ) # Run training loop. step = 0 done = False while not done: for batch in dataloader: batch = preprocessor(batch) loss, _ = policy.forward(batch) loss.backward() optimizer.step() optimizer.zero_grad() if step % log_freq == 0: print(f"step: {step} loss: {loss.item():.3f}") step += 1 if step >= training_steps: done = True break # Save a policy checkpoint. policy.save_pretrained(output_directory) preprocessor.save_pretrained(output_directory) postprocessor.save_pretrained(output_directory) if __name__ == "__main__": main()