Compare commits

..

2 Commits

Author SHA1 Message Date
Khalil Meftah 6407a244c0 feat(envs): add generic observation passthrough
- Add generic observation passthrough in preprocess_observation() for
unhandled ndarray/tensor keys, replacing the pattern of adding per-env
hardcoded key handlers. Extra keys are forwarded as observation.<key>
and can be shaped by env-specific ProcessorSteps via get_env_processors().
2026-06-15 14:17:59 +02:00
Khalil Meftah 0511c12b8f feat(envs): add env plugin discovery
- Add 'lerobot_env_' to third-party plugin discovery prefixes, completing
the plugin system for all component types (robots, cameras, teleoperators,
policies, and now environments). External packages named lerobot_env_*
can self-register EnvConfig subclasses on import, enabling --env.type=
resolution without lerobot code changes.
2026-06-15 14:13:12 +02:00
4 changed files with 38 additions and 20 deletions
+9 -17
View File
@@ -180,32 +180,24 @@ class WandBLogger:
self._wandb_custom_step_key.add(new_custom_key)
self._wandb.define_metric(new_custom_key, hidden=True)
batch_data = {}
for k, v in d.items():
# Skip the custom step key here, it's added to the batch below.
if custom_step_key is not None and k == custom_step_key:
continue
if isinstance(v, list):
for i, elem in enumerate(v):
if isinstance(elem, (int | float)):
batch_data[f"{mode}/{k}_{i}"] = elem
continue
if not isinstance(v, (int | float | str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
)
continue
batch_data[f"{mode}/{k}"] = v
# Do not log the custom step key itself.
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
continue
if batch_data:
if custom_step_key is not None:
batch_data[f"{mode}/{custom_step_key}"] = d[custom_step_key]
self._wandb.log(batch_data)
else:
self._wandb.log(data=batch_data, step=step)
value_custom_step = d[custom_step_key]
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
self._wandb.log(data)
continue
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode not in {"train", "eval"}:
+1 -1
View File
@@ -153,7 +153,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
Returns:
dict: The statistics dictionary with values cast to numpy arrays.
"""
stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()}
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
+20
View File
@@ -126,6 +126,26 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
if "camera_obs" in observations:
return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"]
# Pass through any remaining ndarray/tensor keys not already handled above,
# so env plugins can expose extra observation keys via get_env_processors().
_handled = {"pixels", "environment_state", "agent_pos", "robot_state", "policy", "camera_obs"}
for key, value in observations.items():
if key in _handled:
continue
target = f"{OBS_STR}.{key}"
if target in return_observations:
continue
if isinstance(value, np.ndarray):
val = torch.from_numpy(value).float()
if val.dim() == 1:
val = val.unsqueeze(0)
return_observations[target] = val
elif isinstance(value, Tensor):
val = value.float()
if val.dim() == 1:
val = val.unsqueeze(0)
return_observations[target] = val
return return_observations
+8 -2
View File
@@ -216,9 +216,15 @@ def register_third_party_plugins() -> None:
This function uses `importlib.metadata` to find packages installed in the environment
(including editable installs) starting with 'lerobot_robot_', 'lerobot_camera_',
'lerobot_teleoperator_', or 'lerobot_policy_' and imports them.
'lerobot_teleoperator_', 'lerobot_policy_', or 'lerobot_env_' and imports them.
"""
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_")
prefixes = (
"lerobot_robot_",
"lerobot_camera_",
"lerobot_teleoperator_",
"lerobot_policy_",
"lerobot_env_",
)
imported: list[str] = []
failed: list[str] = []