mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 10:40:04 +00:00
feat(train): Integrate preprocessor into training pipeline
This commit is contained in:
committed by
Steven Palma
parent
7bd0d62ce5
commit
f14ac5d486
@@ -31,7 +31,7 @@ from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import cycle
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.factory import make_policy, make_processor
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
@@ -140,6 +140,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
preprocessor, _ = make_processor(cfg.policy, cfg.policy.pretrained_path)
|
||||
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
@@ -203,6 +204,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
batch = preprocessor(batch)
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
@@ -284,6 +286,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if cfg.policy.push_to_hub:
|
||||
policy.push_model_to_hub(cfg)
|
||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
Reference in New Issue
Block a user