diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index 62d31478e..4a2f95f93 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -258,6 +258,11 @@ class RTCEvaluator: policy_class = get_policy_class(self.cfg.policy.type) config = PreTrainedConfig.from_pretrained(self.cfg.policy.pretrained_path) + + if self.cfg.policy.type == "pi05" or self.cfg.policy.type == "pi0": + config.compile_model = self.cfg.torch_compile_mode + config.compile_mode = self.cfg.torch_compile_mode + policy = policy_class.from_pretrained(self.cfg.policy.pretrained_path, config=config) policy = policy.to(self.device) policy.eval()