mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-25 04:07:02 +00:00
chore(format): format code
This commit is contained in:
@@ -128,7 +128,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
for key in dataset.meta.camera_keys:
|
||||
if key in dataset.meta.depth_keys:
|
||||
continue # Exclude depth keys from ImageNet stats
|
||||
continue # Exclude depth keys from ImageNet stats
|
||||
for stats_type, stats in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ def check_chw_float32(frame: torch.Tensor) -> None:
|
||||
c, h, w = frame.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {frame.shape}"
|
||||
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
check_chw_float32(chw_float32_torch)
|
||||
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
|
||||
|
||||
@@ -108,9 +108,7 @@ def log_rerun_data(
|
||||
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
else:
|
||||
if arr.shape[-1] == 1:
|
||||
img_entity = (
|
||||
rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis)
|
||||
)
|
||||
img_entity = rr.DepthImage(arr, colormap=rr.components.Colormap.Viridis)
|
||||
else:
|
||||
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
|
||||
rr.log(key, entity=img_entity, static=True)
|
||||
|
||||
Reference in New Issue
Block a user