mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-19 01:07:18 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6407a244c0 | |||
| 0511c12b8f |
@@ -180,26 +180,24 @@ class WandBLogger:
|
|||||||
self._wandb_custom_step_key.add(new_custom_key)
|
self._wandb_custom_step_key.add(new_custom_key)
|
||||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||||
|
|
||||||
batch_data = {}
|
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
# Skip the custom step key here, it's added to the batch below.
|
|
||||||
if custom_step_key is not None and k == custom_step_key:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not isinstance(v, (int | float | str)):
|
if not isinstance(v, (int | float | str)):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
|
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
batch_data[f"{mode}/{k}"] = v
|
# Do not log the custom step key itself.
|
||||||
|
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||||
|
continue
|
||||||
|
|
||||||
if batch_data:
|
|
||||||
if custom_step_key is not None:
|
if custom_step_key is not None:
|
||||||
batch_data[f"{mode}/{custom_step_key}"] = d[custom_step_key]
|
value_custom_step = d[custom_step_key]
|
||||||
self._wandb.log(batch_data)
|
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
|
||||||
else:
|
self._wandb.log(data)
|
||||||
self._wandb.log(data=batch_data, step=step)
|
continue
|
||||||
|
|
||||||
|
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||||
|
|
||||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||||
if mode not in {"train", "eval"}:
|
if mode not in {"train", "eval"}:
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
|
|||||||
Returns:
|
Returns:
|
||||||
dict: The statistics dictionary with values cast to numpy arrays.
|
dict: The statistics dictionary with values cast to numpy arrays.
|
||||||
"""
|
"""
|
||||||
stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()}
|
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||||
return unflatten_dict(stats)
|
return unflatten_dict(stats)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -126,6 +126,26 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||||||
if "camera_obs" in observations:
|
if "camera_obs" in observations:
|
||||||
return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"]
|
return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"]
|
||||||
|
|
||||||
|
# Pass through any remaining ndarray/tensor keys not already handled above,
|
||||||
|
# so env plugins can expose extra observation keys via get_env_processors().
|
||||||
|
_handled = {"pixels", "environment_state", "agent_pos", "robot_state", "policy", "camera_obs"}
|
||||||
|
for key, value in observations.items():
|
||||||
|
if key in _handled:
|
||||||
|
continue
|
||||||
|
target = f"{OBS_STR}.{key}"
|
||||||
|
if target in return_observations:
|
||||||
|
continue
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
val = torch.from_numpy(value).float()
|
||||||
|
if val.dim() == 1:
|
||||||
|
val = val.unsqueeze(0)
|
||||||
|
return_observations[target] = val
|
||||||
|
elif isinstance(value, Tensor):
|
||||||
|
val = value.float()
|
||||||
|
if val.dim() == 1:
|
||||||
|
val = val.unsqueeze(0)
|
||||||
|
return_observations[target] = val
|
||||||
|
|
||||||
return return_observations
|
return return_observations
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -216,9 +216,15 @@ def register_third_party_plugins() -> None:
|
|||||||
|
|
||||||
This function uses `importlib.metadata` to find packages installed in the environment
|
This function uses `importlib.metadata` to find packages installed in the environment
|
||||||
(including editable installs) starting with 'lerobot_robot_', 'lerobot_camera_',
|
(including editable installs) starting with 'lerobot_robot_', 'lerobot_camera_',
|
||||||
'lerobot_teleoperator_', or 'lerobot_policy_' and imports them.
|
'lerobot_teleoperator_', 'lerobot_policy_', or 'lerobot_env_' and imports them.
|
||||||
"""
|
"""
|
||||||
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_")
|
prefixes = (
|
||||||
|
"lerobot_robot_",
|
||||||
|
"lerobot_camera_",
|
||||||
|
"lerobot_teleoperator_",
|
||||||
|
"lerobot_policy_",
|
||||||
|
"lerobot_env_",
|
||||||
|
)
|
||||||
imported: list[str] = []
|
imported: list[str] = []
|
||||||
failed: list[str] = []
|
failed: list[str] = []
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user