refactor(config): Move device & amp args to PreTrainedConfig (#812)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Steven Palma
2025-03-06 17:59:28 +01:00
committed by GitHub
parent 10706ed753
commit 5e9473806c
19 changed files with 62 additions and 136 deletions
@@ -12,17 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from pathlib import Path
import draccus
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
@dataclass
@@ -57,11 +54,6 @@ class RecordControlConfig(ControlConfig):
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
policy: PreTrainedConfig | None = None
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
device: str | None = None # cuda | cpu | mps
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool | None = None
# Limit the frames per second. By default, uses the policy fps.
fps: int | None = None
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
@@ -104,27 +96,6 @@ class RecordControlConfig(ControlConfig):
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
# When no device or use_amp are given, use the one from training config.
if self.device is None or self.use_amp is None:
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
if self.device is None:
self.device = train_cfg.device
if self.use_amp is None:
self.use_amp = train_cfg.use_amp
# Automatically switch to available device if necessary
if not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
@ControlConfig.register_subclass("replay")
@dataclass
+5 -11
View File
@@ -32,6 +32,7 @@ from termcolor import colored
from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_features_from_robot
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, has_method
@@ -193,8 +194,6 @@ def record_episode(
episode_time_s,
display_cameras,
policy,
device,
use_amp,
fps,
single_task,
):
@@ -205,8 +204,6 @@ def record_episode(
dataset=dataset,
events=events,
policy=policy,
device=device,
use_amp=use_amp,
fps=fps,
teleoperate=policy is None,
single_task=single_task,
@@ -221,9 +218,7 @@ def control_loop(
display_cameras=False,
dataset: LeRobotDataset | None = None,
events=None,
policy=None,
device: torch.device | str | None = None,
use_amp: bool | None = None,
policy: PreTrainedPolicy = None,
fps: int | None = None,
single_task: str | None = None,
):
@@ -246,9 +241,6 @@ def control_loop(
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
if isinstance(device, str):
device = get_safe_torch_device(device)
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < control_time_s:
@@ -260,7 +252,9 @@ def control_loop(
observation = robot.capture_observation()
if policy is not None:
pred_action = predict_action(observation, policy, device, use_amp)
pred_action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
)
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
action = robot.send_action(pred_action)