diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index f815ca3b3..c9db0979f 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -14,7 +14,7 @@ import abc from dataclasses import dataclass, field -from typing import Any, Optional +from typing import Any import draccus @@ -30,8 +30,6 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC): fps: int = 30 features: dict[str, PolicyFeature] = field(default_factory=dict) features_map: dict[str, str] = field(default_factory=dict) - multitask_eval: bool = False - max_parallel_tasks: int = 5 @property def type(self) -> str: @@ -46,7 +44,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC): @EnvConfig.register_subclass("aloha") @dataclass class AlohaEnv(EnvConfig): - task: str = "AlohaInsertion-v0" + task: str | None = "AlohaInsertion-v0" fps: int = 50 episode_length: int = 400 obs_type: str = "pixels_agent_pos" @@ -84,7 +82,7 @@ class AlohaEnv(EnvConfig): @EnvConfig.register_subclass("pusht") @dataclass class PushtEnv(EnvConfig): - task: str = "PushT-v0" + task: str | None = "PushT-v0" fps: int = 10 episode_length: int = 300 obs_type: str = "pixels_agent_pos" @@ -126,7 +124,7 @@ class PushtEnv(EnvConfig): @EnvConfig.register_subclass("xarm") @dataclass class XarmEnv(EnvConfig): - task: str = "XarmLift-v0" + task: str | None = "XarmLift-v0" fps: int = 15 episode_length: int = 200 obs_type: str = "pixels_agent_pos" @@ -181,10 +179,10 @@ class EnvTransformConfig: add_joint_velocity_to_observation: bool = False add_current_to_observation: bool = False add_ee_pose_to_observation: bool = False - crop_params_dict: Optional[dict[str, tuple[int, int, int, int]]] = None - resize_size: Optional[tuple[int, int]] = None + crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None + resize_size: tuple[int, int] | None = None control_time_s: float = 20.0 - fixed_reset_joint_positions: Optional[Any] = None + fixed_reset_joint_positions: Any | None = None reset_time_s: float = 5.0 use_gripper: bool = True gripper_quantization_threshold: float | None = 0.8 @@ -197,24 +195,25 @@ class EnvTransformConfig: class HILSerlRobotEnvConfig(EnvConfig): """Configuration for the HILSerlRobotEnv environment.""" - robot: Optional[RobotConfig] = None - teleop: Optional[TeleoperatorConfig] = None - wrapper: Optional[EnvTransformConfig] = None + robot: RobotConfig | None = None + teleop: TeleoperatorConfig | None = None + wrapper: EnvTransformConfig | None = None fps: int = 10 name: str = "real_robot" - mode: str = None # Either "record", "replay", None - repo_id: Optional[str] = None - dataset_root: Optional[str] = None - task: str = "" + mode: str | None = None # Either "record", "replay", None + repo_id: str | None = None + dataset_root: str | None = None + task: str | None = "" num_episodes: int = 10 # only for record mode episode: int = 0 device: str = "cuda" push_to_hub: bool = True - pretrained_policy_name_or_path: Optional[str] = None - reward_classifier_pretrained_path: Optional[str] = None + pretrained_policy_name_or_path: str | None = None + reward_classifier_pretrained_path: str | None = None # For the reward classifier, to record more positive examples after a success number_of_steps_after_success: int = 0 + @property def gym_kwargs(self) -> dict: return {} @@ -224,9 +223,8 @@ class HILSerlRobotEnvConfig(EnvConfig): class HILEnvConfig(EnvConfig): """Configuration for the HIL environment.""" - type: str = "hil" name: str = "PandaPickCube" - task: str = "PandaPickCubeKeyboard-v0" + task: str | None = "PandaPickCubeKeyboard-v0" use_viewer: bool = True gripper_penalty: float = 0.0 use_gamepad: bool = True @@ -250,18 +248,18 @@ class HILEnvConfig(EnvConfig): } ) ################# args from hilserlrobotenv - reward_classifier_pretrained_path: Optional[str] = None - robot_config: Optional[RobotConfig] = None - teleop_config: Optional[TeleoperatorConfig] = None - wrapper: Optional[EnvTransformConfig] = None - mode: str = None # Either "record", "replay", None - repo_id: Optional[str] = None - dataset_root: Optional[str] = None + reward_classifier_pretrained_path: str | None = None + robot_config: RobotConfig | None = None + teleop_config: TeleoperatorConfig | None = None + wrapper: EnvTransformConfig | None = None + mode: str | None = None # Either "record", "replay", None + repo_id: str | None = None + dataset_root: str | None = None num_episodes: int = 10 # only for record mode episode: int = 0 device: str = "cuda" push_to_hub: bool = True - pretrained_policy_name_or_path: Optional[str] = None + pretrained_policy_name_or_path: str | None = None # For the reward classifier, to record more positive examples after a success number_of_steps_after_success: int = 0 ############################ @@ -274,7 +272,6 @@ class HILEnvConfig(EnvConfig): "gripper_penalty": self.gripper_penalty, } - @EnvConfig.register_subclass("libero") @dataclass class LiberoEnv(EnvConfig):