diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py index c2e3a440b..5f44649da 100644 --- a/examples/rtc/eval_with_real_robot.py +++ b/examples/rtc/eval_with_real_robot.py @@ -461,7 +461,9 @@ def demo_cli(cfg: RTCDemoConfig): peft_pretrained_path = cfg.policy.pretrained_path peft_config = PeftConfig.from_pretrained(peft_pretrained_path) - policy = policy_class.from_pretrained(pretrained_name_or_path=peft_config.base_model_name_or_path, config=config) + policy = policy_class.from_pretrained( + pretrained_name_or_path=peft_config.base_model_name_or_path, config=config + ) policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config) else: policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 6e436e90c..a9b1280bf 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -493,7 +493,6 @@ def make_policy( # Make a fresh policy. policy = policy_cls(**kwargs) - policy.to(cfg.device) assert isinstance(policy, torch.nn.Module) diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index 829846be0..4b7c6e99a 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -289,7 +289,6 @@ def eval_policy( except ImportError: raise exc - start = time.time() policy.eval() diff --git a/tests/test_cli.py b/tests/test_cli.py index 8e49bbfce..a915d1287 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -121,13 +121,7 @@ def test_peft_training_params_are_fewer(policy_type, tmp_path): model_id = resolve_model_id_for_peft_training(policy_type) def dummy_update_policy( - train_metrics, - policy, - batch, - optimizer, - grad_clip_norm: float, - accelerator, - **kwargs + train_metrics, policy, batch, optimizer, grad_clip_norm: float, accelerator, **kwargs ): params_total = sum(p.numel() for p in policy.parameters()) params_trainable = sum(p.numel() for p in policy.parameters() if p.requires_grad)