more refactor

This commit is contained in:
Jade Choghari
2025-11-21 11:24:54 +01:00
parent a3a5cb1bac
commit e61722fa78
6 changed files with 15 additions and 8 deletions
+1
View File
@@ -5,4 +5,5 @@ lerobot-eval \
--env.control_mode=absolute \ --env.control_mode=absolute \
--eval.batch_size=1 \ --eval.batch_size=1 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--env.episode_length=800 \
--seed=142 --seed=142
-1
View File
@@ -37,7 +37,6 @@ class NormalizationMode(str, Enum):
IDENTITY = "IDENTITY" IDENTITY = "IDENTITY"
QUANTILES = "QUANTILES" QUANTILES = "QUANTILES"
QUANTILE10 = "QUANTILE10" QUANTILE10 = "QUANTILE10"
IMAGENET = "IMAGENET"
@dataclass @dataclass
+1 -1
View File
@@ -245,7 +245,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
class LiberoEnv(EnvConfig): class LiberoEnv(EnvConfig):
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc. task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
fps: int = 30 fps: int = 30
episode_length: int = 520 episode_length: int | None = None
obs_type: str = "pixels_agent_pos" obs_type: str = "pixels_agent_pos"
render_mode: str = "rgb_array" render_mode: str = "rgb_array"
camera_name: str = "agentview_image,robot0_eye_in_hand_image" camera_name: str = "agentview_image,robot0_eye_in_hand_image"
+1
View File
@@ -144,6 +144,7 @@ def make_env(
gym_kwargs=cfg.gym_kwargs, gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls, env_cls=env_cls,
control_mode=cfg.control_mode, control_mode=cfg.control_mode,
episode_length=cfg.episode_length,
) )
elif "metaworld" in cfg.type: elif "metaworld" in cfg.type:
from lerobot.envs.metaworld import create_metaworld_envs from lerobot.envs.metaworld import create_metaworld_envs
+12 -5
View File
@@ -80,14 +80,11 @@ def get_libero_dummy_action():
return [0, 0, 0, 0, 0, 0, -1] return [0, 0, 0, 0, 0, 0, -1]
OBS_STATE_DIM = 8
ACTION_DIM = 7 ACTION_DIM = 7
AGENT_POS_LOW = -1000.0
AGENT_POS_HIGH = 1000.0
ACTION_LOW = -1.0 ACTION_LOW = -1.0
ACTION_HIGH = 1.0 ACTION_HIGH = 1.0
TASK_SUITE_MAX_STEPS: dict[str, int] = { TASK_SUITE_MAX_STEPS: dict[str, int] = {
"libero_spatial": 800, # longest training demo has 193 steps "libero_spatial": 280, # longest training demo has 193 steps
"libero_object": 280, # longest training demo has 254 steps "libero_object": 280, # longest training demo has 254 steps
"libero_goal": 300, # longest training demo has 270 steps "libero_goal": 300, # longest training demo has 270 steps
"libero_10": 520, # longest training demo has 505 steps "libero_10": 520, # longest training demo has 505 steps
@@ -103,6 +100,7 @@ class LiberoEnv(gym.Env):
task_suite: Any, task_suite: Any,
task_id: int, task_id: int,
task_suite_name: str, task_suite_name: str,
episode_length: int | None = None,
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image", camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
obs_type: str = "pixels", obs_type: str = "pixels",
render_mode: str = "rgb_array", render_mode: str = "rgb_array",
@@ -142,13 +140,18 @@ class LiberoEnv(gym.Env):
self.camera_name_mapping = camera_name_mapping self.camera_name_mapping = camera_name_mapping
self.num_steps_wait = num_steps_wait self.num_steps_wait = num_steps_wait
self.episode_index = episode_index self.episode_index = episode_index
self.episode_length = episode_length
# Load once and keep # Load once and keep
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._env = self._make_envs_task(task_suite, self.task_id) self._env = self._make_envs_task(task_suite, self.task_id)
default_steps = 500 default_steps = 500
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps) self._max_episode_steps = (
TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
if self.episode_length is None
else self.episode_length
)
self.control_mode = control_mode self.control_mode = control_mode
images = {} images = {}
for cam in self.camera_name: for cam in self.camera_name:
@@ -351,6 +354,7 @@ def _make_env_fns(
task_id: int, task_id: int,
n_envs: int, n_envs: int,
camera_names: list[str], camera_names: list[str],
episode_length: int | None,
init_states: bool, init_states: bool,
gym_kwargs: Mapping[str, Any], gym_kwargs: Mapping[str, Any],
control_mode: str, control_mode: str,
@@ -365,6 +369,7 @@ def _make_env_fns(
task_suite_name=suite_name, task_suite_name=suite_name,
camera_name=camera_names, camera_name=camera_names,
init_states=init_states, init_states=init_states,
episode_length=episode_length,
episode_index=episode_index, episode_index=episode_index,
control_mode=control_mode, control_mode=control_mode,
**local_kwargs, **local_kwargs,
@@ -387,6 +392,7 @@ def create_libero_envs(
init_states: bool = True, init_states: bool = True,
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None, env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
control_mode: str = "relative", control_mode: str = "relative",
episode_length: int | None = None,
) -> dict[str, dict[int, Any]]: ) -> dict[str, dict[int, Any]]:
""" """
Create vectorized LIBERO environments with a consistent return shape. Create vectorized LIBERO environments with a consistent return shape.
@@ -428,6 +434,7 @@ def create_libero_envs(
for tid in selected: for tid in selected:
fns = _make_env_fns( fns = _make_env_fns(
suite=suite, suite=suite,
episode_length=episode_length,
suite_name=suite_name, suite_name=suite_name,
task_id=tid, task_id=tid,
n_envs=n_envs, n_envs=n_envs,
-1
View File
@@ -175,7 +175,6 @@ def rollout(
with torch.inference_mode(): with torch.inference_mode():
action = policy.select_action(observation) action = policy.select_action(observation)
action = postprocessor(action) action = postprocessor(action)
action_transition = {"action": action} action_transition = {"action": action}
action_transition = env_postprocessor(action_transition) action_transition = env_postprocessor(action_transition)
action = action_transition["action"] action = action_transition["action"]