mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
backup
This commit is contained in:
+26
-29
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import abc
|
import abc
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
@@ -30,8 +30,6 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||||||
fps: int = 30
|
fps: int = 30
|
||||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
features_map: dict[str, str] = field(default_factory=dict)
|
features_map: dict[str, str] = field(default_factory=dict)
|
||||||
multitask_eval: bool = False
|
|
||||||
max_parallel_tasks: int = 5
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
@@ -46,7 +44,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
|||||||
@EnvConfig.register_subclass("aloha")
|
@EnvConfig.register_subclass("aloha")
|
||||||
@dataclass
|
@dataclass
|
||||||
class AlohaEnv(EnvConfig):
|
class AlohaEnv(EnvConfig):
|
||||||
task: str = "AlohaInsertion-v0"
|
task: str | None = "AlohaInsertion-v0"
|
||||||
fps: int = 50
|
fps: int = 50
|
||||||
episode_length: int = 400
|
episode_length: int = 400
|
||||||
obs_type: str = "pixels_agent_pos"
|
obs_type: str = "pixels_agent_pos"
|
||||||
@@ -84,7 +82,7 @@ class AlohaEnv(EnvConfig):
|
|||||||
@EnvConfig.register_subclass("pusht")
|
@EnvConfig.register_subclass("pusht")
|
||||||
@dataclass
|
@dataclass
|
||||||
class PushtEnv(EnvConfig):
|
class PushtEnv(EnvConfig):
|
||||||
task: str = "PushT-v0"
|
task: str | None = "PushT-v0"
|
||||||
fps: int = 10
|
fps: int = 10
|
||||||
episode_length: int = 300
|
episode_length: int = 300
|
||||||
obs_type: str = "pixels_agent_pos"
|
obs_type: str = "pixels_agent_pos"
|
||||||
@@ -126,7 +124,7 @@ class PushtEnv(EnvConfig):
|
|||||||
@EnvConfig.register_subclass("xarm")
|
@EnvConfig.register_subclass("xarm")
|
||||||
@dataclass
|
@dataclass
|
||||||
class XarmEnv(EnvConfig):
|
class XarmEnv(EnvConfig):
|
||||||
task: str = "XarmLift-v0"
|
task: str | None = "XarmLift-v0"
|
||||||
fps: int = 15
|
fps: int = 15
|
||||||
episode_length: int = 200
|
episode_length: int = 200
|
||||||
obs_type: str = "pixels_agent_pos"
|
obs_type: str = "pixels_agent_pos"
|
||||||
@@ -181,10 +179,10 @@ class EnvTransformConfig:
|
|||||||
add_joint_velocity_to_observation: bool = False
|
add_joint_velocity_to_observation: bool = False
|
||||||
add_current_to_observation: bool = False
|
add_current_to_observation: bool = False
|
||||||
add_ee_pose_to_observation: bool = False
|
add_ee_pose_to_observation: bool = False
|
||||||
crop_params_dict: Optional[dict[str, tuple[int, int, int, int]]] = None
|
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||||
resize_size: Optional[tuple[int, int]] = None
|
resize_size: tuple[int, int] | None = None
|
||||||
control_time_s: float = 20.0
|
control_time_s: float = 20.0
|
||||||
fixed_reset_joint_positions: Optional[Any] = None
|
fixed_reset_joint_positions: Any | None = None
|
||||||
reset_time_s: float = 5.0
|
reset_time_s: float = 5.0
|
||||||
use_gripper: bool = True
|
use_gripper: bool = True
|
||||||
gripper_quantization_threshold: float | None = 0.8
|
gripper_quantization_threshold: float | None = 0.8
|
||||||
@@ -197,24 +195,25 @@ class EnvTransformConfig:
|
|||||||
class HILSerlRobotEnvConfig(EnvConfig):
|
class HILSerlRobotEnvConfig(EnvConfig):
|
||||||
"""Configuration for the HILSerlRobotEnv environment."""
|
"""Configuration for the HILSerlRobotEnv environment."""
|
||||||
|
|
||||||
robot: Optional[RobotConfig] = None
|
robot: RobotConfig | None = None
|
||||||
teleop: Optional[TeleoperatorConfig] = None
|
teleop: TeleoperatorConfig | None = None
|
||||||
wrapper: Optional[EnvTransformConfig] = None
|
wrapper: EnvTransformConfig | None = None
|
||||||
fps: int = 10
|
fps: int = 10
|
||||||
name: str = "real_robot"
|
name: str = "real_robot"
|
||||||
mode: str = None # Either "record", "replay", None
|
mode: str | None = None # Either "record", "replay", None
|
||||||
repo_id: Optional[str] = None
|
repo_id: str | None = None
|
||||||
dataset_root: Optional[str] = None
|
dataset_root: str | None = None
|
||||||
task: str = ""
|
task: str | None = ""
|
||||||
num_episodes: int = 10 # only for record mode
|
num_episodes: int = 10 # only for record mode
|
||||||
episode: int = 0
|
episode: int = 0
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
push_to_hub: bool = True
|
push_to_hub: bool = True
|
||||||
pretrained_policy_name_or_path: Optional[str] = None
|
pretrained_policy_name_or_path: str | None = None
|
||||||
reward_classifier_pretrained_path: Optional[str] = None
|
reward_classifier_pretrained_path: str | None = None
|
||||||
# For the reward classifier, to record more positive examples after a success
|
# For the reward classifier, to record more positive examples after a success
|
||||||
number_of_steps_after_success: int = 0
|
number_of_steps_after_success: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
def gym_kwargs(self) -> dict:
|
def gym_kwargs(self) -> dict:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@@ -224,9 +223,8 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
|||||||
class HILEnvConfig(EnvConfig):
|
class HILEnvConfig(EnvConfig):
|
||||||
"""Configuration for the HIL environment."""
|
"""Configuration for the HIL environment."""
|
||||||
|
|
||||||
type: str = "hil"
|
|
||||||
name: str = "PandaPickCube"
|
name: str = "PandaPickCube"
|
||||||
task: str = "PandaPickCubeKeyboard-v0"
|
task: str | None = "PandaPickCubeKeyboard-v0"
|
||||||
use_viewer: bool = True
|
use_viewer: bool = True
|
||||||
gripper_penalty: float = 0.0
|
gripper_penalty: float = 0.0
|
||||||
use_gamepad: bool = True
|
use_gamepad: bool = True
|
||||||
@@ -250,18 +248,18 @@ class HILEnvConfig(EnvConfig):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
################# args from hilserlrobotenv
|
################# args from hilserlrobotenv
|
||||||
reward_classifier_pretrained_path: Optional[str] = None
|
reward_classifier_pretrained_path: str | None = None
|
||||||
robot_config: Optional[RobotConfig] = None
|
robot_config: RobotConfig | None = None
|
||||||
teleop_config: Optional[TeleoperatorConfig] = None
|
teleop_config: TeleoperatorConfig | None = None
|
||||||
wrapper: Optional[EnvTransformConfig] = None
|
wrapper: EnvTransformConfig | None = None
|
||||||
mode: str = None # Either "record", "replay", None
|
mode: str | None = None # Either "record", "replay", None
|
||||||
repo_id: Optional[str] = None
|
repo_id: str | None = None
|
||||||
dataset_root: Optional[str] = None
|
dataset_root: str | None = None
|
||||||
num_episodes: int = 10 # only for record mode
|
num_episodes: int = 10 # only for record mode
|
||||||
episode: int = 0
|
episode: int = 0
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
push_to_hub: bool = True
|
push_to_hub: bool = True
|
||||||
pretrained_policy_name_or_path: Optional[str] = None
|
pretrained_policy_name_or_path: str | None = None
|
||||||
# For the reward classifier, to record more positive examples after a success
|
# For the reward classifier, to record more positive examples after a success
|
||||||
number_of_steps_after_success: int = 0
|
number_of_steps_after_success: int = 0
|
||||||
############################
|
############################
|
||||||
@@ -274,7 +272,6 @@ class HILEnvConfig(EnvConfig):
|
|||||||
"gripper_penalty": self.gripper_penalty,
|
"gripper_penalty": self.gripper_penalty,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass("libero")
|
@EnvConfig.register_subclass("libero")
|
||||||
@dataclass
|
@dataclass
|
||||||
class LiberoEnv(EnvConfig):
|
class LiberoEnv(EnvConfig):
|
||||||
|
|||||||
Reference in New Issue
Block a user