mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
refactor(config): Move device & amp args to PreTrainedConfig (#812)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
@@ -32,6 +32,7 @@ from termcolor import colored
|
||||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import get_features_from_robot
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
||||
@@ -193,8 +194,6 @@ def record_episode(
|
||||
episode_time_s,
|
||||
display_cameras,
|
||||
policy,
|
||||
device,
|
||||
use_amp,
|
||||
fps,
|
||||
single_task,
|
||||
):
|
||||
@@ -205,8 +204,6 @@ def record_episode(
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
policy=policy,
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
teleoperate=policy is None,
|
||||
single_task=single_task,
|
||||
@@ -221,9 +218,7 @@ def control_loop(
|
||||
display_cameras=False,
|
||||
dataset: LeRobotDataset | None = None,
|
||||
events=None,
|
||||
policy=None,
|
||||
device: torch.device | str | None = None,
|
||||
use_amp: bool | None = None,
|
||||
policy: PreTrainedPolicy = None,
|
||||
fps: int | None = None,
|
||||
single_task: str | None = None,
|
||||
):
|
||||
@@ -246,9 +241,6 @@ def control_loop(
|
||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
|
||||
if isinstance(device, str):
|
||||
device = get_safe_torch_device(device)
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < control_time_s:
|
||||
@@ -260,7 +252,9 @@ def control_loop(
|
||||
observation = robot.capture_observation()
|
||||
|
||||
if policy is not None:
|
||||
pred_action = predict_action(observation, policy, device, use_amp)
|
||||
pred_action = predict_action(
|
||||
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||
)
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
action = robot.send_action(pred_action)
|
||||
|
||||
Reference in New Issue
Block a user