mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 07:49:48 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6407a244c0 | |||
| 0511c12b8f |
@@ -180,32 +180,24 @@ class WandBLogger:
|
||||
self._wandb_custom_step_key.add(new_custom_key)
|
||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||
|
||||
batch_data = {}
|
||||
for k, v in d.items():
|
||||
# Skip the custom step key here, it's added to the batch below.
|
||||
if custom_step_key is not None and k == custom_step_key:
|
||||
continue
|
||||
|
||||
if isinstance(v, list):
|
||||
for i, elem in enumerate(v):
|
||||
if isinstance(elem, (int | float)):
|
||||
batch_data[f"{mode}/{k}_{i}"] = elem
|
||||
continue
|
||||
|
||||
if not isinstance(v, (int | float | str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
|
||||
batch_data[f"{mode}/{k}"] = v
|
||||
# Do not log the custom step key itself.
|
||||
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||
continue
|
||||
|
||||
if batch_data:
|
||||
if custom_step_key is not None:
|
||||
batch_data[f"{mode}/{custom_step_key}"] = d[custom_step_key]
|
||||
self._wandb.log(batch_data)
|
||||
else:
|
||||
self._wandb.log(data=batch_data, step=step)
|
||||
value_custom_step = d[custom_step_key]
|
||||
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
|
||||
self._wandb.log(data)
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
if mode not in {"train", "eval"}:
|
||||
|
||||
@@ -153,7 +153,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
|
||||
Returns:
|
||||
dict: The statistics dictionary with values cast to numpy arrays.
|
||||
"""
|
||||
stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()}
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
|
||||
@@ -126,6 +126,26 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
if "camera_obs" in observations:
|
||||
return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"]
|
||||
|
||||
# Pass through any remaining ndarray/tensor keys not already handled above,
|
||||
# so env plugins can expose extra observation keys via get_env_processors().
|
||||
_handled = {"pixels", "environment_state", "agent_pos", "robot_state", "policy", "camera_obs"}
|
||||
for key, value in observations.items():
|
||||
if key in _handled:
|
||||
continue
|
||||
target = f"{OBS_STR}.{key}"
|
||||
if target in return_observations:
|
||||
continue
|
||||
if isinstance(value, np.ndarray):
|
||||
val = torch.from_numpy(value).float()
|
||||
if val.dim() == 1:
|
||||
val = val.unsqueeze(0)
|
||||
return_observations[target] = val
|
||||
elif isinstance(value, Tensor):
|
||||
val = value.float()
|
||||
if val.dim() == 1:
|
||||
val = val.unsqueeze(0)
|
||||
return_observations[target] = val
|
||||
|
||||
return return_observations
|
||||
|
||||
|
||||
|
||||
@@ -216,9 +216,15 @@ def register_third_party_plugins() -> None:
|
||||
|
||||
This function uses `importlib.metadata` to find packages installed in the environment
|
||||
(including editable installs) starting with 'lerobot_robot_', 'lerobot_camera_',
|
||||
'lerobot_teleoperator_', or 'lerobot_policy_' and imports them.
|
||||
'lerobot_teleoperator_', 'lerobot_policy_', or 'lerobot_env_' and imports them.
|
||||
"""
|
||||
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_")
|
||||
prefixes = (
|
||||
"lerobot_robot_",
|
||||
"lerobot_camera_",
|
||||
"lerobot_teleoperator_",
|
||||
"lerobot_policy_",
|
||||
"lerobot_env_",
|
||||
)
|
||||
imported: list[str] = []
|
||||
failed: list[str] = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user