Files
lerobot/lerobot/configs/policy/hilserl_classifier.yaml
T
s1lent4gnt 83b2dc1219 [Port HIL-SERL] Balanced sampler function speed up and refactor to align with train.py (#715)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-03-28 17:18:48 +00:00

62 lines
1.6 KiB
YAML

# @package _global_
defaults:
- _self_
hydra:
run:
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
dir: outputs/train_hilserl_classifier/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${hydra.job.name}
job:
name: default
seed: 13
dataset_repo_id: aractingi/push_cube_square_light_reward_cropped_resized
# aractingi/push_cube_square_reward_1_cropped_resized
dataset_root: data/aractingi/push_cube_square_light_reward_cropped_resized
local_files_only: true
train_split_proportion: 0.8
# Required by logger
env:
name: "classifier"
task: "binary_classification"
training:
num_epochs: 6
batch_size: 16
learning_rate: 1e-4
num_workers: 4
grad_clip_norm: 10
use_amp: true
log_freq: 1
eval_freq: 1 # How often to run validation (in epochs)
save_freq: 1 # How often to save checkpoints (in epochs)
save_checkpoint: true
image_keys: ["observation.images.front", "observation.images.side"]
label_key: "next.reward"
profile_inference_time: false
profile_inference_time_iters: 20
eval:
batch_size: 16
num_samples_to_log: 30 # Number of validation samples to log in the table
policy:
name: "hilserl/classifier"
model_name: "helper2424/resnet10" # "facebook/convnext-base-224
model_type: "cnn"
num_cameras: 2 # Has to be len(training.image_keys)
wandb:
enable: false
project: "classifier-training"
job_name: "classifier_training_0"
disable_artifact: false
device: "mps"
resume: false
output_dir: "outputs/classifier/old_trainer_resnet10_frozen"