mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
feat(policies): Add X-VLA (#2405)
* first commit * more fixes * add franka action * update testing script * add changes * update files * logits matching * add imagenet as a norm type * logits matching atol1e-2 * more eval fixes * more changes * xvla works on libero * remove seed * more refactoring * more fixes * more changes * more changes * more fixes * migrate policy revert * major pre-commit cleanup * renaming * revert to self.transformer * refactor * new changes * clean * update libero * more changes * make it work * more changes: * remove imagenet dependency * style * more * more refactor * remove proprio * add loss * more * more * add freeze/unfreeze options * add testing * upgrade transformers version * update testing * add installation * remove .sh file * fix testing * silent linter in xvlatest * fix failing test * upgrade test, fix failing * fix testing * more fixes to testing * require cuda in tests * temp check * add xvla docs * fix styling * update libero doc * remove timm dep * add different dtype support * remove timm skip * remove white lines * Enhance X-VLA finetuning documentation with optimizer details (#2537) Added detailed instructions for implementing a custom optimizer and modifying parameter retrieval for X-VLA finetuning. Signed-off-by: Jinliang Zheng <54488861+2toinf@users.noreply.github.com> * fix style * iterate on review * iterate on cpilot * revert xvla dep * free up ci * test(xvla): remove main test (#2565) * Add xvla custom optim and dtype (#2567) * add custom optim * add custom optim * add auto mode * more changes * add identity to all * add auto * release * add docs * make image smaller docs * smaller image in doc * evan smaller image doc * finalize doc --------- Signed-off-by: Jinliang Zheng <54488861+2toinf@users.noreply.github.com> Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Jinliang Zheng <54488861+2toinf@users.noreply.github.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -80,10 +80,7 @@ def get_libero_dummy_action():
|
||||
return [0, 0, 0, 0, 0, 0, -1]
|
||||
|
||||
|
||||
OBS_STATE_DIM = 8
|
||||
ACTION_DIM = 7
|
||||
AGENT_POS_LOW = -1000.0
|
||||
AGENT_POS_HIGH = 1000.0
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
||||
@@ -103,6 +100,7 @@ class LiberoEnv(gym.Env):
|
||||
task_suite: Any,
|
||||
task_id: int,
|
||||
task_suite_name: str,
|
||||
episode_length: int | None = None,
|
||||
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||
obs_type: str = "pixels",
|
||||
render_mode: str = "rgb_array",
|
||||
@@ -114,6 +112,7 @@ class LiberoEnv(gym.Env):
|
||||
episode_index: int = 0,
|
||||
camera_name_mapping: dict[str, str] | None = None,
|
||||
num_steps_wait: int = 10,
|
||||
control_mode: str = "relative",
|
||||
):
|
||||
super().__init__()
|
||||
self.task_id = task_id
|
||||
@@ -141,14 +140,19 @@ class LiberoEnv(gym.Env):
|
||||
self.camera_name_mapping = camera_name_mapping
|
||||
self.num_steps_wait = num_steps_wait
|
||||
self.episode_index = episode_index
|
||||
self.episode_length = episode_length
|
||||
# Load once and keep
|
||||
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
|
||||
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||
|
||||
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||
default_steps = 500
|
||||
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||
|
||||
self._max_episode_steps = (
|
||||
TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||
if self.episode_length is None
|
||||
else self.episode_length
|
||||
)
|
||||
self.control_mode = control_mode
|
||||
images = {}
|
||||
for cam in self.camera_name:
|
||||
images[self.camera_name_mapping[cam]] = spaces.Box(
|
||||
@@ -296,6 +300,15 @@ class LiberoEnv(gym.Env):
|
||||
# Increasing this value can improve determinism and reproducibility across resets.
|
||||
for _ in range(self.num_steps_wait):
|
||||
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
||||
|
||||
if self.control_mode == "absolute":
|
||||
for robot in self._env.robots:
|
||||
robot.controller.use_delta = False
|
||||
elif self.control_mode == "relative":
|
||||
for robot in self._env.robots:
|
||||
robot.controller.use_delta = True
|
||||
else:
|
||||
raise ValueError(f"Invalid control mode: {self.control_mode}")
|
||||
observation = self._format_raw_obs(raw_obs)
|
||||
info = {"is_success": False}
|
||||
return observation, info
|
||||
@@ -341,8 +354,10 @@ def _make_env_fns(
|
||||
task_id: int,
|
||||
n_envs: int,
|
||||
camera_names: list[str],
|
||||
episode_length: int | None,
|
||||
init_states: bool,
|
||||
gym_kwargs: Mapping[str, Any],
|
||||
control_mode: str,
|
||||
) -> list[Callable[[], LiberoEnv]]:
|
||||
"""Build n_envs factory callables for a single (suite, task_id)."""
|
||||
|
||||
@@ -354,7 +369,9 @@ def _make_env_fns(
|
||||
task_suite_name=suite_name,
|
||||
camera_name=camera_names,
|
||||
init_states=init_states,
|
||||
episode_length=episode_length,
|
||||
episode_index=episode_index,
|
||||
control_mode=control_mode,
|
||||
**local_kwargs,
|
||||
)
|
||||
|
||||
@@ -374,6 +391,8 @@ def create_libero_envs(
|
||||
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||
init_states: bool = True,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
control_mode: str = "relative",
|
||||
episode_length: int | None = None,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""
|
||||
Create vectorized LIBERO environments with a consistent return shape.
|
||||
@@ -415,12 +434,14 @@ def create_libero_envs(
|
||||
for tid in selected:
|
||||
fns = _make_env_fns(
|
||||
suite=suite,
|
||||
episode_length=episode_length,
|
||||
suite_name=suite_name,
|
||||
task_id=tid,
|
||||
n_envs=n_envs,
|
||||
camera_names=camera_names,
|
||||
init_states=init_states,
|
||||
gym_kwargs=gym_kwargs,
|
||||
control_mode=control_mode,
|
||||
)
|
||||
out[suite_name][tid] = env_cls(fns)
|
||||
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||
|
||||
Reference in New Issue
Block a user