mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 17:50:09 +00:00
fix rac data collection with rtc by disabling compile
This commit is contained in:
@@ -130,6 +130,10 @@ class RaCRTCConfig:
|
||||
resume: bool = False
|
||||
device: str = "cuda"
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
# Torch compile is disabled by default for real-time inference
|
||||
# First inference with compile takes minutes to compile kernels
|
||||
use_torch_compile: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
@@ -735,7 +739,14 @@ def rac_rtc_collect(cfg: RaCRTCConfig) -> LeRobotDataset:
|
||||
# Load policy
|
||||
logger.info(f"Loading policy from: {cfg.policy.pretrained_path}")
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
# Override compile_model for real-time inference (first compile takes minutes)
|
||||
policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
if cfg.policy.type in ["pi05", "pi0"]:
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
logger.info(f"Set compile_model={cfg.use_torch_compile} for real-time inference")
|
||||
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config)
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
policy.init_rtc_processor()
|
||||
policy = policy.to(cfg.device)
|
||||
|
||||
Reference in New Issue
Block a user