mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
feat(train): Integrate preprocessor into training pipeline
This commit is contained in:
committed by
Steven Palma
parent
7bd0d62ce5
commit
f14ac5d486
@@ -31,7 +31,7 @@ from lerobot.datasets.sampler import EpisodeAwareSampler
|
|||||||
from lerobot.datasets.utils import cycle
|
from lerobot.datasets.utils import cycle
|
||||||
from lerobot.envs.factory import make_env
|
from lerobot.envs.factory import make_env
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.policies.factory import make_policy
|
from lerobot.policies.factory import make_policy, make_processor
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import get_device_from_parameters
|
from lerobot.policies.utils import get_device_from_parameters
|
||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
@@ -140,6 +140,7 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
ds_meta=dataset.meta,
|
ds_meta=dataset.meta,
|
||||||
)
|
)
|
||||||
|
preprocessor, _ = make_processor(cfg.policy, cfg.policy.pretrained_path)
|
||||||
|
|
||||||
logging.info("Creating optimizer and scheduler")
|
logging.info("Creating optimizer and scheduler")
|
||||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||||
@@ -203,6 +204,7 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
for _ in range(step, cfg.steps):
|
for _ in range(step, cfg.steps):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
|
batch = preprocessor(batch)
|
||||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||||
|
|
||||||
for key in batch:
|
for key in batch:
|
||||||
@@ -284,6 +286,7 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
|
|
||||||
if cfg.policy.push_to_hub:
|
if cfg.policy.push_to_hub:
|
||||||
policy.push_model_to_hub(cfg)
|
policy.push_model_to_hub(cfg)
|
||||||
|
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
Reference in New Issue
Block a user