mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4bb7281752 |
+48
-42
@@ -24,72 +24,78 @@ import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
def inject_normalization_stats(policy, dataset):
|
||||
"""Manually loads normalization stats from the dataset into the policy's state dictionary."""
|
||||
stats = dataset.meta.stats
|
||||
pol_state_dict = policy.state_dict()
|
||||
|
||||
keys_to_update = {
|
||||
"normalize_inputs.buffer_observation_state.mean": ("observation.state", "mean"),
|
||||
"normalize_inputs.buffer_observation_state.std": ("observation.state", "std"),
|
||||
"normalize_targets.buffer_action.mean": ("action", "mean"),
|
||||
"normalize_targets.buffer_action.std": ("action", "std"),
|
||||
"unnormalize_outputs.buffer_action.mean": ("action", "mean"),
|
||||
"unnormalize_outputs.buffer_action.std": ("action", "std"),
|
||||
}
|
||||
|
||||
for pol_key, (stat_key, stat_type) in keys_to_update.items():
|
||||
pol_state_dict[pol_key] = torch.from_numpy(stats[stat_key][stat_type])
|
||||
|
||||
policy.load_state_dict(pol_state_dict)
|
||||
print("Normalization stats injected into the policy.")
|
||||
|
||||
def prepare_batch(batch, device):
|
||||
"""
|
||||
Prepares a batch of samples from the dataset for inference.
|
||||
This involves moving tensors to the correct device,
|
||||
and remapping image keys to match the policy's expectations.
|
||||
"""
|
||||
batch = {
|
||||
"observation.state": batch["observation.state"].to(device),
|
||||
"observation.image": batch["observation.images.top"].to(device),
|
||||
"observation.image2": batch["observation.images.wrist"].to(device),
|
||||
"action": batch["action"].to(device),
|
||||
"task": batch["task"],
|
||||
}
|
||||
return batch
|
||||
|
||||
def main():
|
||||
# Create a directory to store the training checkpoint.
|
||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||
output_directory = Path("outputs/train/smolvlaplus_training")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# # Select your device
|
||||
device = torch.device("cuda")
|
||||
device = torch.device("mps")
|
||||
|
||||
# Number of offline training steps (we'll only do offline training for this example.)
|
||||
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
||||
training_steps = 5000
|
||||
training_steps = 10_000
|
||||
log_freq = 1
|
||||
batch_size = 32
|
||||
|
||||
# When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before
|
||||
# creating the policy:
|
||||
# - input/output shapes: to properly size the policy
|
||||
# - dataset stats: for normalization and denormalization of input/outputs
|
||||
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
dataset = LeRobotDataset("lerobot/svla_so100_stacking")
|
||||
policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
|
||||
# Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example,
|
||||
# we'll just use the defaults and so no arguments other than input/output features need to be passed.
|
||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
||||
|
||||
# We can now instantiate our policy with this config and the dataset stats.
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)
|
||||
# fix absence of normalization stats in the policy
|
||||
inject_normalization_stats(policy, dataset)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
|
||||
# which can differ for inputs, outputs and rewards (if there are some).
|
||||
delta_timestamps = {
|
||||
"observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
|
||||
"observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
|
||||
"action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
|
||||
}
|
||||
|
||||
# In this case with the standard configuration for Diffusion Policy, it is equivalent to this:
|
||||
delta_timestamps = {
|
||||
# Load the previous image and state at -0.1 seconds before current frame,
|
||||
# then load current image and state corresponding to 0.0 second.
|
||||
"observation.image": [-0.1, 0.0],
|
||||
"observation.state": [-0.1, 0.0],
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to supervise the policy.
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
|
||||
# We can then instantiate the dataset with these delta_timestamps configuration.
|
||||
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
|
||||
|
||||
# Then we create our optimizer and dataloader for offline training.
|
||||
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
||||
optimizer = torch.optim.AdamW(policy.parameters(), lr=3e-4)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=64,
|
||||
shuffle=True,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
@@ -99,7 +105,7 @@ def main():
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||
batch = prepare_batch(batch, device)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
Reference in New Issue
Block a user