mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 21:19:53 +00:00
disable processor for sac/hilserl
This commit is contained in:
@@ -64,7 +64,7 @@ from lerobot.configs import parser
|
|||||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||||
from lerobot.datasets.factory import make_dataset
|
from lerobot.datasets.factory import make_dataset
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
from lerobot.policies.factory import make_policy
|
||||||
from lerobot.rl.algorithms import make_algorithm
|
from lerobot.rl.algorithms import make_algorithm
|
||||||
from lerobot.rl.buffer import ReplayBuffer
|
from lerobot.rl.buffer import ReplayBuffer
|
||||||
from lerobot.rl.data_sources import OnlineOfflineMixer
|
from lerobot.rl.data_sources import OnlineOfflineMixer
|
||||||
@@ -321,35 +321,8 @@ def add_actor_information_and_train(
|
|||||||
algorithm_name=cfg.algorithm,
|
algorithm_name=cfg.algorithm,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build policy preprocessor for batch normalization during training
|
# TODO: Re-enable processor pipeline once refactoring is validated against main
|
||||||
processor_kwargs = {}
|
preprocessor, postprocessor = None, None
|
||||||
postprocessor_kwargs = {}
|
|
||||||
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
|
|
||||||
processor_kwargs["dataset_stats"] = cfg.policy.dataset_stats
|
|
||||||
|
|
||||||
if cfg.policy.pretrained_path is not None:
|
|
||||||
processor_kwargs["preprocessor_overrides"] = {
|
|
||||||
"device_processor": {"device": device.type},
|
|
||||||
"normalizer_processor": {
|
|
||||||
"stats": cfg.policy.dataset_stats,
|
|
||||||
"features": {**policy.config.input_features, **policy.config.output_features},
|
|
||||||
"norm_map": policy.config.normalization_mapping,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
postprocessor_kwargs["postprocessor_overrides"] = {
|
|
||||||
"unnormalizer_processor": {
|
|
||||||
"stats": cfg.policy.dataset_stats,
|
|
||||||
"features": policy.config.output_features,
|
|
||||||
"norm_map": policy.config.normalization_mapping,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
preprocessor, postprocessor = make_pre_post_processors(
|
|
||||||
policy_cfg=cfg.policy,
|
|
||||||
pretrained_path=cfg.policy.pretrained_path,
|
|
||||||
**processor_kwargs,
|
|
||||||
**postprocessor_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Push initial policy weights to actors (same path as periodic push)
|
# Push initial policy weights to actors (same path as periodic push)
|
||||||
state_bytes = state_to_bytes(algorithm.get_weights())
|
state_bytes = state_to_bytes(algorithm.get_weights())
|
||||||
|
|||||||
Reference in New Issue
Block a user