diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index f9c251a02..9032d2667 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -24,72 +24,78 @@ import torch from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.utils import dataset_to_policy_features -from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig -from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy +from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig +from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy from lerobot.configs.types import FeatureType +def inject_normalization_stats(policy, dataset): + """Manually loads normalization stats from the dataset into the policy's state dictionary.""" + stats = dataset.meta.stats + pol_state_dict = policy.state_dict() + + keys_to_update = { + "normalize_inputs.buffer_observation_state.mean": ("observation.state", "mean"), + "normalize_inputs.buffer_observation_state.std": ("observation.state", "std"), + "normalize_targets.buffer_action.mean": ("action", "mean"), + "normalize_targets.buffer_action.std": ("action", "std"), + "unnormalize_outputs.buffer_action.mean": ("action", "mean"), + "unnormalize_outputs.buffer_action.std": ("action", "std"), + } + + for pol_key, (stat_key, stat_type) in keys_to_update.items(): + pol_state_dict[pol_key] = torch.from_numpy(stats[stat_key][stat_type]) + + policy.load_state_dict(pol_state_dict) + print("Normalization stats injected into the policy.") + +def prepare_batch(batch, device): + """ + Prepares a batch of samples from the dataset for inference. + This involves moving tensors to the correct device, + and remapping image keys to match the policy's expectations. + """ + batch = { + "observation.state": batch["observation.state"].to(device), + "observation.image": batch["observation.images.top"].to(device), + "observation.image2": batch["observation.images.wrist"].to(device), + "action": batch["action"].to(device), + "task": batch["task"], + } + return batch def main(): # Create a directory to store the training checkpoint. - output_directory = Path("outputs/train/example_pusht_diffusion") + output_directory = Path("outputs/train/smolvlaplus_training") output_directory.mkdir(parents=True, exist_ok=True) # # Select your device - device = torch.device("cuda") + device = torch.device("mps") # Number of offline training steps (we'll only do offline training for this example.) # Adjust as you prefer. 5000 steps are needed to get something worth evaluating. - training_steps = 5000 + training_steps = 10_000 log_freq = 1 + batch_size = 32 # When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before # creating the policy: # - input/output shapes: to properly size the policy # - dataset stats: for normalization and denormalization of input/outputs - dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht") - features = dataset_to_policy_features(dataset_metadata.features) - output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} - input_features = {key: ft for key, ft in features.items() if key not in output_features} + dataset = LeRobotDataset("lerobot/svla_so100_stacking") + policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base") - # Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example, - # we'll just use the defaults and so no arguments other than input/output features need to be passed. - cfg = DiffusionConfig(input_features=input_features, output_features=output_features) - - # We can now instantiate our policy with this config and the dataset stats. - policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats) + # fix absence of normalization stats in the policy + inject_normalization_stats(policy, dataset) policy.train() policy.to(device) - # Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames - # which can differ for inputs, outputs and rewards (if there are some). - delta_timestamps = { - "observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices], - "observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices], - "action": [i / dataset_metadata.fps for i in cfg.action_delta_indices], - } - - # In this case with the standard configuration for Diffusion Policy, it is equivalent to this: - delta_timestamps = { - # Load the previous image and state at -0.1 seconds before current frame, - # then load current image and state corresponding to 0.0 second. - "observation.image": [-0.1, 0.0], - "observation.state": [-0.1, 0.0], - # Load the previous action (-0.1), the next action to be executed (0.0), - # and 14 future actions with a 0.1 seconds spacing. All these actions will be - # used to supervise the policy. - "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], - } - - # We can then instantiate the dataset with these delta_timestamps configuration. - dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps) - # Then we create our optimizer and dataloader for offline training. - optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4) + optimizer = torch.optim.AdamW(policy.parameters(), lr=3e-4) dataloader = torch.utils.data.DataLoader( dataset, num_workers=4, - batch_size=64, - shuffle=True, + batch_size=batch_size, + shuffle=False, pin_memory=device.type != "cpu", drop_last=True, ) @@ -99,7 +105,7 @@ def main(): done = False while not done: for batch in dataloader: - batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + batch = prepare_batch(batch, device) loss, _ = policy.forward(batch) loss.backward() optimizer.step()