mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
wip
This commit is contained in:
+48
-42
@@ -24,72 +24,78 @@ import torch
|
|||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.common.datasets.utils import dataset_to_policy_features
|
from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||||
from lerobot.configs.types import FeatureType
|
from lerobot.configs.types import FeatureType
|
||||||
|
|
||||||
|
def inject_normalization_stats(policy, dataset):
|
||||||
|
"""Manually loads normalization stats from the dataset into the policy's state dictionary."""
|
||||||
|
stats = dataset.meta.stats
|
||||||
|
pol_state_dict = policy.state_dict()
|
||||||
|
|
||||||
|
keys_to_update = {
|
||||||
|
"normalize_inputs.buffer_observation_state.mean": ("observation.state", "mean"),
|
||||||
|
"normalize_inputs.buffer_observation_state.std": ("observation.state", "std"),
|
||||||
|
"normalize_targets.buffer_action.mean": ("action", "mean"),
|
||||||
|
"normalize_targets.buffer_action.std": ("action", "std"),
|
||||||
|
"unnormalize_outputs.buffer_action.mean": ("action", "mean"),
|
||||||
|
"unnormalize_outputs.buffer_action.std": ("action", "std"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for pol_key, (stat_key, stat_type) in keys_to_update.items():
|
||||||
|
pol_state_dict[pol_key] = torch.from_numpy(stats[stat_key][stat_type])
|
||||||
|
|
||||||
|
policy.load_state_dict(pol_state_dict)
|
||||||
|
print("Normalization stats injected into the policy.")
|
||||||
|
|
||||||
|
def prepare_batch(batch, device):
|
||||||
|
"""
|
||||||
|
Prepares a batch of samples from the dataset for inference.
|
||||||
|
This involves moving tensors to the correct device,
|
||||||
|
and remapping image keys to match the policy's expectations.
|
||||||
|
"""
|
||||||
|
batch = {
|
||||||
|
"observation.state": batch["observation.state"].to(device),
|
||||||
|
"observation.image": batch["observation.images.top"].to(device),
|
||||||
|
"observation.image2": batch["observation.images.wrist"].to(device),
|
||||||
|
"action": batch["action"].to(device),
|
||||||
|
"task": batch["task"],
|
||||||
|
}
|
||||||
|
return batch
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Create a directory to store the training checkpoint.
|
# Create a directory to store the training checkpoint.
|
||||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
output_directory = Path("outputs/train/smolvlaplus_training")
|
||||||
output_directory.mkdir(parents=True, exist_ok=True)
|
output_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# # Select your device
|
# # Select your device
|
||||||
device = torch.device("cuda")
|
device = torch.device("mps")
|
||||||
|
|
||||||
# Number of offline training steps (we'll only do offline training for this example.)
|
# Number of offline training steps (we'll only do offline training for this example.)
|
||||||
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
|
||||||
training_steps = 5000
|
training_steps = 10_000
|
||||||
log_freq = 1
|
log_freq = 1
|
||||||
|
batch_size = 32
|
||||||
|
|
||||||
# When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before
|
# When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before
|
||||||
# creating the policy:
|
# creating the policy:
|
||||||
# - input/output shapes: to properly size the policy
|
# - input/output shapes: to properly size the policy
|
||||||
# - dataset stats: for normalization and denormalization of input/outputs
|
# - dataset stats: for normalization and denormalization of input/outputs
|
||||||
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
|
dataset = LeRobotDataset("lerobot/svla_so100_stacking")
|
||||||
features = dataset_to_policy_features(dataset_metadata.features)
|
policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
|
||||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
|
||||||
|
|
||||||
# Policies are initialized with a configuration class, in this case `DiffusionConfig`. For this example,
|
# fix absence of normalization stats in the policy
|
||||||
# we'll just use the defaults and so no arguments other than input/output features need to be passed.
|
inject_normalization_stats(policy, dataset)
|
||||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
|
||||||
|
|
||||||
# We can now instantiate our policy with this config and the dataset stats.
|
|
||||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)
|
|
||||||
policy.train()
|
policy.train()
|
||||||
policy.to(device)
|
policy.to(device)
|
||||||
|
|
||||||
# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
|
|
||||||
# which can differ for inputs, outputs and rewards (if there are some).
|
|
||||||
delta_timestamps = {
|
|
||||||
"observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
|
|
||||||
"observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
|
|
||||||
"action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
|
|
||||||
}
|
|
||||||
|
|
||||||
# In this case with the standard configuration for Diffusion Policy, it is equivalent to this:
|
|
||||||
delta_timestamps = {
|
|
||||||
# Load the previous image and state at -0.1 seconds before current frame,
|
|
||||||
# then load current image and state corresponding to 0.0 second.
|
|
||||||
"observation.image": [-0.1, 0.0],
|
|
||||||
"observation.state": [-0.1, 0.0],
|
|
||||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
|
||||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
|
||||||
# used to supervise the policy.
|
|
||||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
|
||||||
}
|
|
||||||
|
|
||||||
# We can then instantiate the dataset with these delta_timestamps configuration.
|
|
||||||
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
|
|
||||||
|
|
||||||
# Then we create our optimizer and dataloader for offline training.
|
# Then we create our optimizer and dataloader for offline training.
|
||||||
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
optimizer = torch.optim.AdamW(policy.parameters(), lr=3e-4)
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
batch_size=64,
|
batch_size=batch_size,
|
||||||
shuffle=True,
|
shuffle=False,
|
||||||
pin_memory=device.type != "cpu",
|
pin_memory=device.type != "cpu",
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
)
|
)
|
||||||
@@ -99,7 +105,7 @@ def main():
|
|||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
batch = prepare_batch(batch, device)
|
||||||
loss, _ = policy.forward(batch)
|
loss, _ = policy.forward(batch)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|||||||
Reference in New Issue
Block a user