fix(examples): wrap all of them into a main function (#2524)

This commit is contained in:
Steven Palma
2025-11-26 14:28:04 +01:00
committed by GitHub
parent 87bee86640
commit 17581a9449
24 changed files with 1672 additions and 1532 deletions
+68 -62
View File
@@ -19,80 +19,86 @@ def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[flo
return [i / fps for i in delta_indices]
output_directory = Path("outputs/robot_learning_tutorial/act")
output_directory.mkdir(parents=True, exist_ok=True)
def main():
output_directory = Path("outputs/robot_learning_tutorial/act")
output_directory.mkdir(parents=True, exist_ok=True)
# Select your device
device = torch.device("mps") # or "cuda" or "cpu"
# Select your device
device = torch.device("mps") # or "cuda" or "cpu"
dataset_id = "lerobot/svla_so101_pickplace"
dataset_id = "lerobot/svla_so101_pickplace"
# This specifies the inputs the model will be expecting and the outputs it will produce
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
features = dataset_to_policy_features(dataset_metadata.features)
# This specifies the inputs the model will be expecting and the outputs it will produce
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
cfg = ACTConfig(input_features=input_features, output_features=output_features)
policy = ACTPolicy(cfg)
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
cfg = ACTConfig(input_features=input_features, output_features=output_features)
policy = ACTPolicy(cfg)
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
policy.train()
policy.to(device)
policy.train()
policy.to(device)
# To perform action chunking, ACT expects a given number of actions as targets
delta_timestamps = {
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
}
# To perform action chunking, ACT expects a given number of actions as targets
delta_timestamps = {
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
}
# add image features if they are present
delta_timestamps |= {
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
}
# add image features if they are present
delta_timestamps |= {
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps)
for k in cfg.image_features
}
# Instantiate the dataset
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
# Instantiate the dataset
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
# Create the optimizer and dataloader for offline training
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
batch_size = 32
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=True,
)
# Create the optimizer and dataloader for offline training
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
batch_size = 32
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=True,
)
# Number of training steps and logging frequency
training_steps = 1
log_freq = 1
# Number of training steps and logging frequency
training_steps = 1
log_freq = 1
# Run training loop
step = 0
done = False
while not done:
for batch in dataloader:
batch = preprocessor(batch)
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Run training loop
step = 0
done = False
while not done:
for batch in dataloader:
batch = preprocessor(batch)
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break
if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break
# Save the policy checkpoint, alongside the pre/post processors
policy.save_pretrained(output_directory)
preprocessor.save_pretrained(output_directory)
postprocessor.save_pretrained(output_directory)
# Save the policy checkpoint, alongside the pre/post processors
policy.save_pretrained(output_directory)
preprocessor.save_pretrained(output_directory)
postprocessor.save_pretrained(output_directory)
# Save all assets to the Hub
policy.push_to_hub("fracapuano/robot_learning_tutorial_act")
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
# Save all assets to the Hub
policy.push_to_hub("<user>/robot_learning_tutorial_act")
preprocessor.push_to_hub("<user>/robot_learning_tutorial_act")
postprocessor.push_to_hub("<user>/robot_learning_tutorial_act")
if __name__ == "__main__":
main()
+43 -37
View File
@@ -8,50 +8,56 @@ from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "fracapuano/robot_learning_tutorial_act"
model = ACTPolicy.from_pretrained(model_id)
dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
def main():
device = torch.device("mps") # or "cuda" or "cpu"
model_id = "<user>/robot_learning_tutorial_act"
model = ACTPolicy.from_pretrained(model_id)
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_metadata.features, device=device
)
dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
obs = preprocess(obs_frame)
# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
action = model.select_action(obs)
action = postprocess(action)
# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"
action = make_robot_action(action, dataset_metadata.features)
# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}
robot.send_action(action)
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()
print("Episode finished! Starting new episode...")
for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_metadata.features, device=device
)
obs = preprocess(obs_frame)
action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_metadata.features)
robot.send_action(action)
print("Episode finished! Starting new episode...")
if __name__ == "__main__":
main()