fix style

This commit is contained in:
Jade Choghari
2025-11-11 00:17:31 +01:00
parent 6001b2c3ad
commit c11d8f1bb6
2 changed files with 28 additions and 30 deletions
+1 -1
View File
@@ -241,7 +241,7 @@ class LiberoEnv(gym.Env):
if self.init_states and self._init_states is not None: if self.init_states and self._init_states is not None:
self._env.set_init_state(self._init_states[self._init_state_id]) self._env.set_init_state(self._init_states[self._init_state_id])
raw_obs = self._env.env._get_observations() raw_obs = self._env.env._get_observations()
# After reset, objects may be unstable (slightly floating, intersecting, etc.). # After reset, objects may be unstable (slightly floating, intersecting, etc.).
# Step the simulator with a no-op action for a few frames so everything settles. # Step the simulator with a no-op action for a few frames so everything settles.
# Increasing this value can improve determinism and reproducibility across resets. # Increasing this value can improve determinism and reproducibility across resets.
+27 -29
View File
@@ -28,7 +28,7 @@ os.environ["MUJOCO_GL"] = "egl"
def assert_observations_equal(obs1, obs2, path="", atol=1e-8): def assert_observations_equal(obs1, obs2, path="", atol=1e-8):
""" """
Recursively compare two observations and assert they are equal. Recursively compare two observations and assert they are equal.
Args: Args:
obs1: First observation (dict or numpy array) obs1: First observation (dict or numpy array)
obs2: Second observation (dict or numpy array) obs2: Second observation (dict or numpy array)
@@ -37,7 +37,7 @@ def assert_observations_equal(obs1, obs2, path="", atol=1e-8):
""" """
if isinstance(obs1, dict) and isinstance(obs2, dict): if isinstance(obs1, dict) and isinstance(obs2, dict):
assert obs1.keys() == obs2.keys(), f"Keys differ at {path}: {obs1.keys()} != {obs2.keys()}" assert obs1.keys() == obs2.keys(), f"Keys differ at {path}: {obs1.keys()} != {obs2.keys()}"
for key in obs1.keys(): for key in obs1:
assert_observations_equal(obs1[key], obs2[key], path=f"{path}.{key}" if path else key, atol=atol) assert_observations_equal(obs1[key], obs2[key], path=f"{path}.{key}" if path else key, atol=atol)
elif isinstance(obs1, np.ndarray) and isinstance(obs2, np.ndarray): elif isinstance(obs1, np.ndarray) and isinstance(obs2, np.ndarray):
assert obs1.shape == obs2.shape, f"Shape mismatch at {path}: {obs1.shape} != {obs2.shape}" assert obs1.shape == obs2.shape, f"Shape mismatch at {path}: {obs1.shape} != {obs2.shape}"
@@ -46,29 +46,27 @@ def assert_observations_equal(obs1, obs2, path="", atol=1e-8):
f"Array values differ at {path}: max abs diff = {np.abs(obs1 - obs2).max()}" f"Array values differ at {path}: max abs diff = {np.abs(obs1 - obs2).max()}"
) )
else: else:
assert type(obs1) == type(obs2), f"Type mismatch at {path}: {type(obs1)} != {type(obs2)}" assert type(obs1) is type(obs2), f"Type mismatch at {path}: {type(obs1)} != {type(obs2)}"
assert obs1 == obs2, f"Values differ at {path}: {obs1} != {obs2}" assert obs1 == obs2, f"Values differ at {path}: {obs1} != {obs2}"
def test_libero_env_creation(): def test_libero_env_creation():
"""Test that the libero environment can be created successfully.""" """Test that the libero environment can be created successfully."""
config = make_env_config("libero", task="libero_spatial") config = make_env_config("libero", task="libero_spatial")
envs_dict = make_env(config) envs_dict = make_env(config)
assert "libero_spatial" in envs_dict assert "libero_spatial" in envs_dict
assert 0 in envs_dict["libero_spatial"] assert 0 in envs_dict["libero_spatial"]
env = envs_dict["libero_spatial"][0] env = envs_dict["libero_spatial"][0]
assert env is not None assert env is not None
# Test basic reset # Test basic reset
observation, info = env.reset(seed=42) observation, info = env.reset(seed=42)
assert observation is not None assert observation is not None
assert info is not None assert info is not None
env.close()
env.close()
def test_libero_reset_determinism(): def test_libero_reset_determinism():
@@ -76,19 +74,18 @@ def test_libero_reset_determinism():
config = make_env_config("libero", task="libero_spatial") config = make_env_config("libero", task="libero_spatial")
envs_dict = make_env(config) envs_dict = make_env(config)
env = envs_dict["libero_spatial"][0] env = envs_dict["libero_spatial"][0]
# Reset multiple times with the same seed # Reset multiple times with the same seed
obs1, info1 = env.reset(seed=42) obs1, info1 = env.reset(seed=42)
obs2, info2 = env.reset(seed=42) obs2, info2 = env.reset(seed=42)
obs3, info3 = env.reset(seed=42) obs3, info3 = env.reset(seed=42)
# All observations should be identical # All observations should be identical
assert_observations_equal(obs1, obs2) assert_observations_equal(obs1, obs2)
assert_observations_equal(obs1, obs3) assert_observations_equal(obs1, obs3)
assert_observations_equal(obs2, obs3) assert_observations_equal(obs2, obs3)
env.close()
env.close()
def test_libero_step_determinism(): def test_libero_step_determinism():
@@ -96,31 +93,32 @@ def test_libero_step_determinism():
config = make_env_config("libero", task="libero_spatial") config = make_env_config("libero", task="libero_spatial")
envs_dict = make_env(config) envs_dict = make_env(config)
env = envs_dict["libero_spatial"][0] env = envs_dict["libero_spatial"][0]
seed = 42 seed = 42
# First rollout # First rollout
obs1, info1 = env.reset(seed=seed) obs1, info1 = env.reset(seed=seed)
action = env.action_space.sample() action = env.action_space.sample()
obs_after_step1, reward1, terminated1, truncated1, info_step1 = env.step(action) obs_after_step1, reward1, terminated1, truncated1, info_step1 = env.step(action)
# Second rollout with identical seed and action # Second rollout with identical seed and action
obs2, info2 = env.reset(seed=seed) obs2, info2 = env.reset(seed=seed)
obs_after_step2, reward2, terminated2, truncated2, info_step2 = env.step(action) obs_after_step2, reward2, terminated2, truncated2, info_step2 = env.step(action)
# Initial observations should be identical # Initial observations should be identical
assert_observations_equal(obs1, obs2) assert_observations_equal(obs1, obs2)
# Post-step observations should be identical # Post-step observations should be identical
assert_observations_equal(obs_after_step1, obs_after_step2) assert_observations_equal(obs_after_step1, obs_after_step2)
# Rewards and termination flags should be identical # Rewards and termination flags should be identical
assert np.allclose(reward1, reward2), f"Rewards differ: {reward1} != {reward2}" assert np.allclose(reward1, reward2), f"Rewards differ: {reward1} != {reward2}"
assert np.array_equal(terminated1, terminated2), f"Terminated flags differ: {terminated1} != {terminated2}" assert np.array_equal(terminated1, terminated2), (
f"Terminated flags differ: {terminated1} != {terminated2}"
)
assert np.array_equal(truncated1, truncated2), f"Truncated flags differ: {truncated1} != {truncated2}" assert np.array_equal(truncated1, truncated2), f"Truncated flags differ: {truncated1} != {truncated2}"
env.close()
env.close()
@pytest.mark.parametrize("task", ["libero_spatial", "libero_object", "libero_goal", "libero_10"]) @pytest.mark.parametrize("task", ["libero_spatial", "libero_object", "libero_goal", "libero_10"])
@@ -128,23 +126,23 @@ def test_libero_tasks(task):
"""Test that different libero tasks can be created and used.""" """Test that different libero tasks can be created and used."""
config = make_env_config("libero", task=task) config = make_env_config("libero", task=task)
envs_dict = make_env(config) envs_dict = make_env(config)
assert task in envs_dict assert task in envs_dict
assert 0 in envs_dict[task] assert 0 in envs_dict[task]
env = envs_dict[task][0] env = envs_dict[task][0]
observation, info = env.reset(seed=42) observation, info = env.reset(seed=42)
assert observation is not None assert observation is not None
assert info is not None assert info is not None
# Take a step # Take a step
action = env.action_space.sample() action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action) obs, reward, terminated, truncated, info = env.step(action)
assert obs is not None assert obs is not None
assert reward is not None assert reward is not None
assert isinstance(terminated, (bool, np.ndarray)) assert isinstance(terminated, (bool, np.ndarray))
assert isinstance(truncated, (bool, np.ndarray)) assert isinstance(truncated, (bool, np.ndarray))
env.close() env.close()