xvla works on libero

This commit is contained in:
Jade Choghari
2025-11-17 11:02:20 +01:00
parent 818c75713b
commit ab763abff3
5 changed files with 12 additions and 6 deletions
+2 -1
View File
@@ -3,5 +3,6 @@ lerobot-eval \
--env.type=libero \ --env.type=libero \
--env.task=libero_spatial \ --env.task=libero_spatial \
--eval.batch_size=1 \ --eval.batch_size=1 \
--eval.n_episodes=1 --eval.n_episodes=1 \
--seed=142
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

+9 -4
View File
@@ -145,11 +145,9 @@ class LiberoEnv(gym.Env):
# Load once and keep # Load once and keep
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
self._env = self._make_envs_task(task_suite, self.task_id) self._env = self._make_envs_task(task_suite, self.task_id)
default_steps = 500 default_steps = 500
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps) self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
images = {} images = {}
for cam in self.camera_name: for cam in self.camera_name:
images[self.camera_name_mapping[cam]] = spaces.Box( images[self.camera_name_mapping[cam]] = spaces.Box(
@@ -205,6 +203,7 @@ class LiberoEnv(gym.Env):
"camera_widths": self.observation_width, "camera_widths": self.observation_width,
} }
env = OffScreenRenderEnv(**env_args) env = OffScreenRenderEnv(**env_args)
env.seed(142)
env.reset() env.reset()
return env return env
@@ -212,7 +211,8 @@ class LiberoEnv(gym.Env):
images = {} images = {}
for camera_name in self.camera_name: for camera_name in self.camera_name:
image = raw_obs[camera_name] image = raw_obs[camera_name]
image = image[::-1, ::-1] # rotate 180 degrees if camera_name == "agentview_image":
image = image[::-1, ::-1] # rotate 180 degrees
images[self.camera_name_mapping[camera_name]] = image images[self.camera_name_mapping[camera_name]] = image
state = np.concatenate( state = np.concatenate(
( (
@@ -244,14 +244,18 @@ class LiberoEnv(gym.Env):
self._env.seed(seed) self._env.seed(seed)
if self.init_states and self._init_states is not None: if self.init_states and self._init_states is not None:
self._env.set_init_state(self._init_states[self._init_state_id]) self._env.set_init_state(self._init_states[self._init_state_id])
raw_obs = self._env.reset() raw_obs = self._env.reset()
# After reset, objects may be unstable (slightly floating, intersecting, etc.). # After reset, objects may be unstable (slightly floating, intersecting, etc.).
# Step the simulator with a no-op action for a few frames so everything settles. # Step the simulator with a no-op action for a few frames so everything settles.
# Increasing this value can improve determinism and reproducibility across resets. # Increasing this value can improve determinism and reproducibility across resets.
for _ in range(self.num_steps_wait): for _ in range(self.num_steps_wait):
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action()) action = np.array([0., 0., 0., 0., 0., 0., -1.0])
raw_obs, _, _, _ = self._env.step(action)
observation = self._format_raw_obs(raw_obs) observation = self._format_raw_obs(raw_obs)
for robot in self._env.robots:
robot.controller.use_delta = False
info = {"is_success": False} info = {"is_success": False}
return observation, info return observation, info
@@ -261,6 +265,7 @@ class LiberoEnv(gym.Env):
f"Expected action to be 1-D (shape (action_dim,)), " f"Expected action to be 1-D (shape (action_dim,)), "
f"but got shape {action.shape} with ndim={action.ndim}" f"but got shape {action.shape} with ndim={action.ndim}"
) )
action[-1] = 1 if action[-1] > 0.5 else -1
raw_obs, reward, done, info = self._env.step(action) raw_obs, reward, done, info = self._env.step(action)
is_success = self._env.check_success() is_success = self._env.check_success()
+1 -1
View File
@@ -432,7 +432,7 @@ class XVLAPolicy(PreTrainedPolicy):
print(f"Missing keys: {missing}") print(f"Missing keys: {missing}")
if unexpected: if unexpected:
print(f"Unexpected keys: {unexpected}") print(f"Unexpected keys: {unexpected}")
# --- Step 6: Finalize --- # --- Step 6: Finalize ---
instance.to(config.device) instance.to(config.device)
instance.eval() instance.eval()