From c11d8f1bb6ab08acc1773b53491fb8c499731f12 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Tue, 11 Nov 2025 00:17:31 +0100 Subject: [PATCH] fix style --- src/lerobot/envs/libero.py | 2 +- tests/envs/test_libero.py | 56 ++++++++++++++++++-------------------- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 2df9e8a66..ee7b214a3 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -241,7 +241,7 @@ class LiberoEnv(gym.Env): if self.init_states and self._init_states is not None: self._env.set_init_state(self._init_states[self._init_state_id]) raw_obs = self._env.env._get_observations() - + # After reset, objects may be unstable (slightly floating, intersecting, etc.). # Step the simulator with a no-op action for a few frames so everything settles. # Increasing this value can improve determinism and reproducibility across resets. diff --git a/tests/envs/test_libero.py b/tests/envs/test_libero.py index eeb34ced5..f28d9dbbd 100644 --- a/tests/envs/test_libero.py +++ b/tests/envs/test_libero.py @@ -28,7 +28,7 @@ os.environ["MUJOCO_GL"] = "egl" def assert_observations_equal(obs1, obs2, path="", atol=1e-8): """ Recursively compare two observations and assert they are equal. - + Args: obs1: First observation (dict or numpy array) obs2: Second observation (dict or numpy array) @@ -37,7 +37,7 @@ def assert_observations_equal(obs1, obs2, path="", atol=1e-8): """ if isinstance(obs1, dict) and isinstance(obs2, dict): assert obs1.keys() == obs2.keys(), f"Keys differ at {path}: {obs1.keys()} != {obs2.keys()}" - for key in obs1.keys(): + for key in obs1: assert_observations_equal(obs1[key], obs2[key], path=f"{path}.{key}" if path else key, atol=atol) elif isinstance(obs1, np.ndarray) and isinstance(obs2, np.ndarray): assert obs1.shape == obs2.shape, f"Shape mismatch at {path}: {obs1.shape} != {obs2.shape}" @@ -46,29 +46,27 @@ def assert_observations_equal(obs1, obs2, path="", atol=1e-8): f"Array values differ at {path}: max abs diff = {np.abs(obs1 - obs2).max()}" ) else: - assert type(obs1) == type(obs2), f"Type mismatch at {path}: {type(obs1)} != {type(obs2)}" + assert type(obs1) is type(obs2), f"Type mismatch at {path}: {type(obs1)} != {type(obs2)}" assert obs1 == obs2, f"Values differ at {path}: {obs1} != {obs2}" - def test_libero_env_creation(): """Test that the libero environment can be created successfully.""" config = make_env_config("libero", task="libero_spatial") envs_dict = make_env(config) - + assert "libero_spatial" in envs_dict assert 0 in envs_dict["libero_spatial"] - + env = envs_dict["libero_spatial"][0] assert env is not None - + # Test basic reset observation, info = env.reset(seed=42) assert observation is not None assert info is not None - - env.close() + env.close() def test_libero_reset_determinism(): @@ -76,19 +74,18 @@ def test_libero_reset_determinism(): config = make_env_config("libero", task="libero_spatial") envs_dict = make_env(config) env = envs_dict["libero_spatial"][0] - + # Reset multiple times with the same seed obs1, info1 = env.reset(seed=42) obs2, info2 = env.reset(seed=42) obs3, info3 = env.reset(seed=42) - + # All observations should be identical assert_observations_equal(obs1, obs2) assert_observations_equal(obs1, obs3) assert_observations_equal(obs2, obs3) - - env.close() + env.close() def test_libero_step_determinism(): @@ -96,31 +93,32 @@ def test_libero_step_determinism(): config = make_env_config("libero", task="libero_spatial") envs_dict = make_env(config) env = envs_dict["libero_spatial"][0] - + seed = 42 - + # First rollout obs1, info1 = env.reset(seed=seed) action = env.action_space.sample() obs_after_step1, reward1, terminated1, truncated1, info_step1 = env.step(action) - + # Second rollout with identical seed and action obs2, info2 = env.reset(seed=seed) obs_after_step2, reward2, terminated2, truncated2, info_step2 = env.step(action) - + # Initial observations should be identical assert_observations_equal(obs1, obs2) - + # Post-step observations should be identical assert_observations_equal(obs_after_step1, obs_after_step2) - + # Rewards and termination flags should be identical assert np.allclose(reward1, reward2), f"Rewards differ: {reward1} != {reward2}" - assert np.array_equal(terminated1, terminated2), f"Terminated flags differ: {terminated1} != {terminated2}" + assert np.array_equal(terminated1, terminated2), ( + f"Terminated flags differ: {terminated1} != {terminated2}" + ) assert np.array_equal(truncated1, truncated2), f"Truncated flags differ: {truncated1} != {truncated2}" - - env.close() + env.close() @pytest.mark.parametrize("task", ["libero_spatial", "libero_object", "libero_goal", "libero_10"]) @@ -128,23 +126,23 @@ def test_libero_tasks(task): """Test that different libero tasks can be created and used.""" config = make_env_config("libero", task=task) envs_dict = make_env(config) - + assert task in envs_dict assert 0 in envs_dict[task] - + env = envs_dict[task][0] observation, info = env.reset(seed=42) - + assert observation is not None assert info is not None - + # Take a step action = env.action_space.sample() obs, reward, terminated, truncated, info = env.step(action) - + assert obs is not None assert reward is not None assert isinstance(terminated, (bool, np.ndarray)) assert isinstance(truncated, (bool, np.ndarray)) - - env.close() \ No newline at end of file + + env.close()