Compare commits

...

1 Commits

Author SHA1 Message Date
Francesco Capuano 4bb7281752 wip 2025-06-15 00:24:03 +02:00
+48 -42
View File
@@ -24,72 +24,78 @@ import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
from lerobot.configs.types import FeatureType
def inject_normalization_stats(policy, dataset):
"""Manually loads normalization stats from the dataset into the policy's state dictionary."""
stats = dataset.meta.stats
pol_state_dict = policy.state_dict()
keys_to_update = {
"normalize_inputs.buffer_observation_state.mean": ("observation.state", "mean"),
"normalize_inputs.buffer_observation_state.std": ("observation.state", "std"),
"normalize_targets.buffer_action.mean": ("action", "mean"),
"normalize_targets.buffer_action.std": ("action", "std"),
"unnormalize_outputs.buffer_action.mean": ("action", "mean"),
"unnormalize_outputs.buffer_action.std": ("action", "std"),
}
for pol_key, (stat_key, stat_type) in keys_to_update.items():
pol_state_dict[pol_key] = torch.from_numpy(stats[stat_key][stat_type])
policy.load_state_dict(pol_state_dict)
print("Normalization stats injected into the policy.")
def prepare_batch(batch, device):
"""
Prepares a batch of samples from the dataset for inference.
This involves moving tensors to the correct device,
and remapping image keys to match the policy's expectations.
"""
batch = {
"observation.state": batch["observation.state"].to(device),
"observation.image": batch["observation.images.top"].to(device),
"observation.image2": batch["observation.images.wrist"].to(device),
"action": batch["action"].to(device),
"task": batch["task"],
}
return batch
def main():
# Create a directory to store the training checkpoint.
output_directory = Path("outputs/train/example_pusht_diffusion")
output_directory = Path("outputs/train/smolvlaplus_training")
output_directory.mkdir(parents=True, exist_ok=True)
# # Select your device
device = torch.device("cuda")
device = torch.device("mps")
# Number of offline training steps (we'll only do offline training for this example.)
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
training_steps = 5000
training_steps = 10_000
log_freq = 1
batch_size = 32
# When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before
# creating the policy:
# - input/output shapes: to properly size the policy
# - dataset stats: for normalization and denormalization of input/outputs
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
dataset = LeRobotDataset("lerobot/svla_so100_stacking")
policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
# Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example,
# we'll just use the defaults and so no arguments other than input/output features need to be passed.
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
# We can now instantiate our policy with this config and the dataset stats.
policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)
# fix absence of normalization stats in the policy
inject_normalization_stats(policy, dataset)
policy.train()
policy.to(device)
# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
# which can differ for inputs, outputs and rewards (if there are some).
delta_timestamps = {
"observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
"observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
"action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
}
# In this case with the standard configuration for Diffusion Policy, it is equivalent to this:
delta_timestamps = {
# Load the previous image and state at -0.1 seconds before current frame,
# then load current image and state corresponding to 0.0 second.
"observation.image": [-0.1, 0.0],
"observation.state": [-0.1, 0.0],
# Load the previous action (-0.1), the next action to be executed (0.0),
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
# used to supervise the policy.
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
# We can then instantiate the dataset with these delta_timestamps configuration.
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
# Then we create our optimizer and dataloader for offline training.
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
optimizer = torch.optim.AdamW(policy.parameters(), lr=3e-4)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=64,
shuffle=True,
batch_size=batch_size,
shuffle=False,
pin_memory=device.type != "cpu",
drop_last=True,
)
@@ -99,7 +105,7 @@ def main():
done = False
while not done:
for batch in dataloader:
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
batch = prepare_batch(batch, device)
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()