This commit is contained in:
Jade Choghari
2025-08-06 00:00:45 -04:00
parent 21a961ecbb
commit 4bc356b7f3
+26 -29
View File
@@ -14,7 +14,7 @@
import abc
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import Any
import draccus
@@ -30,8 +30,6 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
fps: int = 30
features: dict[str, PolicyFeature] = field(default_factory=dict)
features_map: dict[str, str] = field(default_factory=dict)
multitask_eval: bool = False
max_parallel_tasks: int = 5
@property
def type(self) -> str:
@@ -46,7 +44,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
@EnvConfig.register_subclass("aloha")
@dataclass
class AlohaEnv(EnvConfig):
task: str = "AlohaInsertion-v0"
task: str | None = "AlohaInsertion-v0"
fps: int = 50
episode_length: int = 400
obs_type: str = "pixels_agent_pos"
@@ -84,7 +82,7 @@ class AlohaEnv(EnvConfig):
@EnvConfig.register_subclass("pusht")
@dataclass
class PushtEnv(EnvConfig):
task: str = "PushT-v0"
task: str | None = "PushT-v0"
fps: int = 10
episode_length: int = 300
obs_type: str = "pixels_agent_pos"
@@ -126,7 +124,7 @@ class PushtEnv(EnvConfig):
@EnvConfig.register_subclass("xarm")
@dataclass
class XarmEnv(EnvConfig):
task: str = "XarmLift-v0"
task: str | None = "XarmLift-v0"
fps: int = 15
episode_length: int = 200
obs_type: str = "pixels_agent_pos"
@@ -181,10 +179,10 @@ class EnvTransformConfig:
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
crop_params_dict: Optional[dict[str, tuple[int, int, int, int]]] = None
resize_size: Optional[tuple[int, int]] = None
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
resize_size: tuple[int, int] | None = None
control_time_s: float = 20.0
fixed_reset_joint_positions: Optional[Any] = None
fixed_reset_joint_positions: Any | None = None
reset_time_s: float = 5.0
use_gripper: bool = True
gripper_quantization_threshold: float | None = 0.8
@@ -197,24 +195,25 @@ class EnvTransformConfig:
class HILSerlRobotEnvConfig(EnvConfig):
"""Configuration for the HILSerlRobotEnv environment."""
robot: Optional[RobotConfig] = None
teleop: Optional[TeleoperatorConfig] = None
wrapper: Optional[EnvTransformConfig] = None
robot: RobotConfig | None = None
teleop: TeleoperatorConfig | None = None
wrapper: EnvTransformConfig | None = None
fps: int = 10
name: str = "real_robot"
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
task: str = ""
mode: str | None = None # Either "record", "replay", None
repo_id: str | None = None
dataset_root: str | None = None
task: str | None = ""
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
reward_classifier_pretrained_path: Optional[str] = None
pretrained_policy_name_or_path: str | None = None
reward_classifier_pretrained_path: str | None = None
# For the reward classifier, to record more positive examples after a success
number_of_steps_after_success: int = 0
@property
def gym_kwargs(self) -> dict:
return {}
@@ -224,9 +223,8 @@ class HILSerlRobotEnvConfig(EnvConfig):
class HILEnvConfig(EnvConfig):
"""Configuration for the HIL environment."""
type: str = "hil"
name: str = "PandaPickCube"
task: str = "PandaPickCubeKeyboard-v0"
task: str | None = "PandaPickCubeKeyboard-v0"
use_viewer: bool = True
gripper_penalty: float = 0.0
use_gamepad: bool = True
@@ -250,18 +248,18 @@ class HILEnvConfig(EnvConfig):
}
)
################# args from hilserlrobotenv
reward_classifier_pretrained_path: Optional[str] = None
robot_config: Optional[RobotConfig] = None
teleop_config: Optional[TeleoperatorConfig] = None
wrapper: Optional[EnvTransformConfig] = None
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
reward_classifier_pretrained_path: str | None = None
robot_config: RobotConfig | None = None
teleop_config: TeleoperatorConfig | None = None
wrapper: EnvTransformConfig | None = None
mode: str | None = None # Either "record", "replay", None
repo_id: str | None = None
dataset_root: str | None = None
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
pretrained_policy_name_or_path: str | None = None
# For the reward classifier, to record more positive examples after a success
number_of_steps_after_success: int = 0
############################
@@ -274,7 +272,6 @@ class HILEnvConfig(EnvConfig):
"gripper_penalty": self.gripper_penalty,
}
@EnvConfig.register_subclass("libero")
@dataclass
class LiberoEnv(EnvConfig):