mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-13 15:49:53 +00:00
68 lines
2.1 KiB
Python
68 lines
2.1 KiB
Python
import torch
|
|
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
|
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
|
|
|
|
|
def main():
|
|
# Device to use for training
|
|
device = "mps" # or "cuda", or "cpu"
|
|
|
|
# Load the dataset used for training
|
|
repo_id = "lerobot/example_hil_serl_dataset"
|
|
dataset = LeRobotDataset(repo_id)
|
|
|
|
# Configure the policy to extract features from the image frames
|
|
camera_keys = dataset.meta.camera_keys
|
|
|
|
config = RewardClassifierConfig(
|
|
num_cameras=len(camera_keys),
|
|
device=device,
|
|
# backbone model to extract features from the image frames
|
|
model_name="microsoft/resnet-18",
|
|
)
|
|
|
|
# Make policy, preprocessor, and optimizer
|
|
policy = make_policy(config, ds_meta=dataset.meta)
|
|
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
|
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
|
|
|
classifier_id = "<user>/reward_classifier_hil_serl_example"
|
|
|
|
# Instantiate a dataloader
|
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
|
|
|
|
# Training loop
|
|
num_epochs = 5
|
|
for epoch in range(num_epochs):
|
|
total_loss = 0
|
|
total_accuracy = 0
|
|
for batch in dataloader:
|
|
# Preprocess the batch and move it to the correct device.
|
|
batch = preprocessor(batch)
|
|
|
|
# Forward pass
|
|
loss, output_dict = policy.forward(batch)
|
|
|
|
# Backward pass and optimization
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
total_accuracy += output_dict["accuracy"]
|
|
|
|
avg_loss = total_loss / len(dataloader)
|
|
avg_accuracy = total_accuracy / len(dataloader)
|
|
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
|
|
|
|
print("Training finished!")
|
|
|
|
# You can now save the trained policy.
|
|
policy.push_to_hub(classifier_id)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|