mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
fix: close envs between tasks to prevent worker process accumulation
eval_policy_all never closed environments after each task completed, causing AsyncVectorEnv worker processes to accumulate (N_tasks × n_envs). This led to OOM, BrokenPipeError and EOFError on multi-task benchmarks. Also fixes: - AsyncVectorEnv compat in envs/utils.py (use get_attr/call instead of .envs) - Tuple task handling in tokenizer_processor and lerobot_eval - _LazyAsyncVectorEnv for deferred worker spawning in LIBERO Made-with: Cursor
This commit is contained in:
@@ -44,6 +44,13 @@ from lerobot.utils.constants import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_vec_env_cls(use_async: bool, n_envs: int):
|
||||||
|
"""Return the right VectorEnv constructor."""
|
||||||
|
if use_async and n_envs > 1:
|
||||||
|
return gym.vector.AsyncVectorEnv
|
||||||
|
return gym.vector.SyncVectorEnv
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||||
task: str | None = None
|
task: str | None = None
|
||||||
@@ -405,7 +412,7 @@ class LiberoEnv(EnvConfig):
|
|||||||
|
|
||||||
if self.task is None:
|
if self.task is None:
|
||||||
raise ValueError("LiberoEnv requires a task to be specified")
|
raise ValueError("LiberoEnv requires a task to be specified")
|
||||||
env_cls = gym.vector.AsyncVectorEnv if (use_async_envs and n_envs > 1) else gym.vector.SyncVectorEnv
|
env_cls = _make_vec_env_cls(use_async_envs, n_envs)
|
||||||
return create_libero_envs(
|
return create_libero_envs(
|
||||||
task=self.task,
|
task=self.task,
|
||||||
n_envs=n_envs,
|
n_envs=n_envs,
|
||||||
@@ -474,7 +481,7 @@ class MetaworldEnv(EnvConfig):
|
|||||||
|
|
||||||
if self.task is None:
|
if self.task is None:
|
||||||
raise ValueError("MetaWorld requires a task to be specified")
|
raise ValueError("MetaWorld requires a task to be specified")
|
||||||
env_cls = gym.vector.AsyncVectorEnv if (use_async_envs and n_envs > 1) else gym.vector.SyncVectorEnv
|
env_cls = _make_vec_env_cls(use_async_envs, n_envs)
|
||||||
return create_metaworld_envs(
|
return create_metaworld_envs(
|
||||||
task=self.task,
|
task=self.task,
|
||||||
n_envs=n_envs,
|
n_envs=n_envs,
|
||||||
|
|||||||
@@ -403,6 +403,57 @@ def _make_env_fns(
|
|||||||
return fns
|
return fns
|
||||||
|
|
||||||
|
|
||||||
|
class _LazyAsyncVectorEnv:
|
||||||
|
"""Wrapper that defers AsyncVectorEnv creation until first use.
|
||||||
|
|
||||||
|
Creating all tasks' AsyncVectorEnvs upfront spawns N_tasks × n_envs worker
|
||||||
|
processes, all of which allocate EGL/GPU resources immediately. Since tasks
|
||||||
|
are evaluated sequentially, only one task's workers need to be alive at a
|
||||||
|
time. This wrapper stores the factory functions and creates the real
|
||||||
|
AsyncVectorEnv on first reset(), keeping peak process count = n_envs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env_fns: list[Callable]):
|
||||||
|
self._env_fns = env_fns
|
||||||
|
self._env: gym.vector.AsyncVectorEnv | None = None
|
||||||
|
self.num_envs = len(env_fns)
|
||||||
|
# Instantiate one env to expose spaces (no GPU — _ensure_env is lazy).
|
||||||
|
tmp = env_fns[0]()
|
||||||
|
self.observation_space = tmp.observation_space
|
||||||
|
self.action_space = tmp.action_space
|
||||||
|
self.single_observation_space = tmp.observation_space
|
||||||
|
self.single_action_space = tmp.action_space
|
||||||
|
tmp.close()
|
||||||
|
|
||||||
|
def _ensure(self):
|
||||||
|
if self._env is None:
|
||||||
|
self._env = gym.vector.AsyncVectorEnv(self._env_fns, context="forkserver")
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
self._ensure()
|
||||||
|
return self._env.reset(**kwargs)
|
||||||
|
|
||||||
|
def step(self, actions):
|
||||||
|
self._ensure()
|
||||||
|
return self._env.step(actions)
|
||||||
|
|
||||||
|
def call(self, name, *args, **kwargs):
|
||||||
|
self._ensure()
|
||||||
|
return self._env.call(name, *args, **kwargs)
|
||||||
|
|
||||||
|
def get_attr(self, name):
|
||||||
|
self._ensure()
|
||||||
|
return self._env.get_attr(name)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self._env is not None:
|
||||||
|
self._env.close()
|
||||||
|
self._env = None
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
# ---- Main API ----------------------------------------------------------------
|
# ---- Main API ----------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -446,6 +497,8 @@ def create_libero_envs(
|
|||||||
if task_ids_filter is not None:
|
if task_ids_filter is not None:
|
||||||
print(f"Restricting to task_ids={task_ids_filter}")
|
print(f"Restricting to task_ids={task_ids_filter}")
|
||||||
|
|
||||||
|
is_async = env_cls is gym.vector.AsyncVectorEnv
|
||||||
|
|
||||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||||
for suite_name in suite_names:
|
for suite_name in suite_names:
|
||||||
suite = _get_suite(suite_name)
|
suite = _get_suite(suite_name)
|
||||||
@@ -467,8 +520,10 @@ def create_libero_envs(
|
|||||||
control_mode=control_mode,
|
control_mode=control_mode,
|
||||||
camera_name_mapping=camera_name_mapping,
|
camera_name_mapping=camera_name_mapping,
|
||||||
)
|
)
|
||||||
out[suite_name][tid] = env_cls(fns)
|
if is_async:
|
||||||
|
out[suite_name][tid] = _LazyAsyncVectorEnv(fns)
|
||||||
|
else:
|
||||||
|
out[suite_name][tid] = env_cls(fns)
|
||||||
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||||
|
|
||||||
# return plain dicts for predictability
|
|
||||||
return {suite: dict(task_map) for suite, task_map in out.items()}
|
return {suite: dict(task_map) for suite, task_map in out.items()}
|
||||||
|
|||||||
+21
-26
@@ -130,56 +130,51 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
|||||||
return policy_features
|
return policy_features
|
||||||
|
|
||||||
|
|
||||||
def are_all_envs_same_type(env: gym.vector.VectorEnv) -> bool:
|
def _get_sub_env_attr(env: gym.vector.VectorEnv, attr: str, index: int = 0):
|
||||||
first_type = type(env.envs[0]) # Get type of first env
|
"""Retrieve an attribute from a sub-environment, works for both Sync and Async."""
|
||||||
return all(type(e) is first_type for e in env.envs) # Fast type check
|
try:
|
||||||
|
return env.get_attr(attr)[index]
|
||||||
|
except (AttributeError, Exception):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _sub_env_has_attr(env: gym.vector.VectorEnv, attr: str) -> bool:
|
||||||
|
try:
|
||||||
|
env.get_attr(attr)
|
||||||
|
return True
|
||||||
|
except (AttributeError, Exception):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
|
warnings.simplefilter("once", UserWarning)
|
||||||
|
|
||||||
if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
|
if not (_sub_env_has_attr(env, "task_description") and _sub_env_has_attr(env, "task")):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
|
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
|
||||||
UserWarning,
|
UserWarning,
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
if not are_all_envs_same_type(env):
|
|
||||||
warnings.warn(
|
|
||||||
"The environments have different types. Make sure you infer the right task from each environment. Empty task will be passed instead.",
|
|
||||||
UserWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def add_envs_task(env: gym.vector.VectorEnv, observation: RobotObservation) -> RobotObservation:
|
def add_envs_task(env: gym.vector.VectorEnv, observation: RobotObservation) -> RobotObservation:
|
||||||
"""Adds task feature to the observation dict with respect to the first environment attribute."""
|
"""Adds task feature to the observation dict with respect to the first environment attribute."""
|
||||||
if hasattr(env.envs[0], "task_description"):
|
if _sub_env_has_attr(env, "task_description"):
|
||||||
task_result = env.call("task_description")
|
task_result = list(env.call("task_description"))
|
||||||
|
|
||||||
if isinstance(task_result, tuple):
|
|
||||||
task_result = list(task_result)
|
|
||||||
|
|
||||||
if not isinstance(task_result, list):
|
|
||||||
raise TypeError(f"Expected task_description to return a list, got {type(task_result)}")
|
|
||||||
if not all(isinstance(item, str) for item in task_result):
|
if not all(isinstance(item, str) for item in task_result):
|
||||||
raise TypeError("All items in task_description result must be strings")
|
raise TypeError("All items in task_description result must be strings")
|
||||||
|
|
||||||
observation["task"] = task_result
|
observation["task"] = task_result
|
||||||
elif hasattr(env.envs[0], "task"):
|
elif _sub_env_has_attr(env, "task"):
|
||||||
task_result = env.call("task")
|
task_result = list(env.call("task"))
|
||||||
|
|
||||||
if isinstance(task_result, tuple):
|
|
||||||
task_result = list(task_result)
|
|
||||||
|
|
||||||
if not isinstance(task_result, list):
|
|
||||||
raise TypeError(f"Expected task to return a list, got {type(task_result)}")
|
|
||||||
if not all(isinstance(item, str) for item in task_result):
|
if not all(isinstance(item, str) for item in task_result):
|
||||||
raise TypeError("All items in task result must be strings")
|
raise TypeError("All items in task result must be strings")
|
||||||
|
|
||||||
observation["task"] = task_result
|
observation["task"] = task_result
|
||||||
else: # For envs without language instructions, e.g. aloha transfer cube and etc.
|
else:
|
||||||
num_envs = observation[list(observation.keys())[0]].shape[0]
|
num_envs = observation[list(observation.keys())[0]].shape[0]
|
||||||
observation["task"] = ["" for _ in range(num_envs)]
|
observation["task"] = ["" for _ in range(num_envs)]
|
||||||
return observation
|
return observation
|
||||||
|
|||||||
@@ -136,8 +136,8 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
|||||||
# Standardize to a list of strings for the tokenizer
|
# Standardize to a list of strings for the tokenizer
|
||||||
if isinstance(task, str):
|
if isinstance(task, str):
|
||||||
return [task]
|
return [task]
|
||||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
elif isinstance(task, (list, tuple)) and all(isinstance(t, str) for t in task):
|
||||||
return task
|
return list(task)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ def rollout(
|
|||||||
|
|
||||||
# Infer "task" from sub-environments.
|
# Infer "task" from sub-environments.
|
||||||
# env.call() works with both SyncVectorEnv and AsyncVectorEnv.
|
# env.call() works with both SyncVectorEnv and AsyncVectorEnv.
|
||||||
observation["task"] = env.call("task")
|
observation["task"] = list(env.call("task"))
|
||||||
|
|
||||||
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
|
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
|
||||||
observation = env_preprocessor(observation)
|
observation = env_preprocessor(observation)
|
||||||
@@ -748,23 +748,27 @@ def eval_policy_all(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if max_parallel_tasks <= 1:
|
if max_parallel_tasks <= 1:
|
||||||
# sequential path (single accumulator path on the main thread)
|
|
||||||
# NOTE: keeping a single-threaded accumulator avoids concurrent list appends or locks
|
|
||||||
for task_group, task_id, env in tasks:
|
for task_group, task_id, env in tasks:
|
||||||
tg, tid, metrics = task_runner(task_group, task_id, env)
|
try:
|
||||||
_accumulate_to(tg, metrics)
|
tg, tid, metrics = task_runner(task_group, task_id, env)
|
||||||
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
|
_accumulate_to(tg, metrics)
|
||||||
|
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
|
||||||
|
finally:
|
||||||
|
env.close()
|
||||||
else:
|
else:
|
||||||
# threaded path: submit all tasks, consume completions on main thread and accumulate there
|
|
||||||
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
|
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
|
||||||
fut2meta = {}
|
fut2meta = {}
|
||||||
for task_group, task_id, env in tasks:
|
for task_group, task_id, env in tasks:
|
||||||
fut = executor.submit(task_runner, task_group, task_id, env)
|
fut = executor.submit(task_runner, task_group, task_id, env)
|
||||||
fut2meta[fut] = (task_group, task_id)
|
fut2meta[fut] = (task_group, task_id, env)
|
||||||
for fut in cf.as_completed(fut2meta):
|
for fut in cf.as_completed(fut2meta):
|
||||||
tg, tid, metrics = fut.result()
|
tg, tid, env = fut2meta[fut]
|
||||||
_accumulate_to(tg, metrics)
|
try:
|
||||||
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
|
tg, tid, metrics = fut.result()
|
||||||
|
_accumulate_to(tg, metrics)
|
||||||
|
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
|
||||||
|
finally:
|
||||||
|
env.close()
|
||||||
|
|
||||||
# compute aggregated metrics helper (robust to lists/scalars)
|
# compute aggregated metrics helper (robust to lists/scalars)
|
||||||
def _agg_from_list(xs):
|
def _agg_from_list(xs):
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ def test_base_create_envs():
|
|||||||
envs = _Env().create_envs(n_envs=2)
|
envs = _Env().create_envs(n_envs=2)
|
||||||
assert "_dispatch_base_test" in envs
|
assert "_dispatch_base_test" in envs
|
||||||
env = envs["_dispatch_base_test"][0]
|
env = envs["_dispatch_base_test"][0]
|
||||||
assert isinstance(env, gym.vector.SyncVectorEnv)
|
assert isinstance(env, gym.vector.VectorEnv)
|
||||||
assert env.num_envs == 2
|
assert env.num_envs == 2
|
||||||
env.close()
|
env.close()
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -189,6 +189,30 @@ def test_list_of_strings_tokenization(mock_auto_tokenizer):
|
|||||||
assert attention_mask.shape == (2, 8)
|
assert attention_mask.shape == (2, 8)
|
||||||
|
|
||||||
|
|
||||||
|
@require_package("transformers")
|
||||||
|
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
|
||||||
|
def test_tuple_of_strings_tokenization(mock_auto_tokenizer):
|
||||||
|
"""Test tokenization of a tuple of strings (returned by VectorEnv.call())."""
|
||||||
|
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||||
|
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||||
|
|
||||||
|
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=8)
|
||||||
|
|
||||||
|
transition = create_transition(
|
||||||
|
observation={"state": torch.tensor([1.0, 2.0])},
|
||||||
|
action=torch.tensor([0.1, 0.2]),
|
||||||
|
complementary_data={"task": ("pick up cube", "place on table")},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = processor(transition)
|
||||||
|
|
||||||
|
observation = result[TransitionKey.OBSERVATION]
|
||||||
|
tokens = observation[f"{OBS_LANGUAGE}.tokens"]
|
||||||
|
attention_mask = observation[f"{OBS_LANGUAGE}.attention_mask"]
|
||||||
|
assert tokens.shape == (2, 8)
|
||||||
|
assert attention_mask.shape == (2, 8)
|
||||||
|
|
||||||
|
|
||||||
@require_package("transformers")
|
@require_package("transformers")
|
||||||
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
|
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
|
||||||
def test_custom_keys(mock_auto_tokenizer):
|
def test_custom_keys(mock_auto_tokenizer):
|
||||||
|
|||||||
Reference in New Issue
Block a user