chore(format): formatting code

This commit is contained in:
CarolinePascal
2026-06-12 19:18:47 +02:00
parent 6b2d5cec4c
commit bf2a3f3a39
9 changed files with 45 additions and 35 deletions
+1 -1
View File
@@ -139,7 +139,7 @@ class VideoEncoderConfig:
def from_video_info(cls, video_info: dict | None) -> Self:
"""Reconstruct an encoder config from a video feature's ``info`` block.
Missing or ``None`` values fall back to the class defaults.
Missing or ``None`` values fall back to the class defaults.
"""
return cls(**cls._kwargs_from_video_info(video_info))
+5 -2
View File
@@ -529,9 +529,12 @@ def compute_episode_stats(
)
if features[key]["dtype"] in ["image", "video"]:
normalization_factor = 255.0 if not (features[key].get("info") or {}).get("is_depth_map", False) else 1.0
normalization_factor = (
255.0 if not (features[key].get("info") or {}).get("is_depth_map", False) else 1.0
)
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / normalization_factor, axis=0) for k, v in ep_stats[key].items()
k: v if k == "count" else np.squeeze(v / normalization_factor, axis=0)
for k, v in ep_stats[key].items()
}
return ep_stats
+15 -9
View File
@@ -36,7 +36,13 @@ import pyarrow.parquet as pq
import torch
from tqdm import tqdm
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults, DepthEncoderConfig, encoder_config_from_video_info, depth_encoder_defaults
from lerobot.configs import (
DepthEncoderConfig,
VideoEncoderConfig,
camera_encoder_defaults,
depth_encoder_defaults,
encoder_config_from_video_info,
)
from lerobot.configs.video import DEPTH_ENCODER_INFO_FIELD_NAMES
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
from lerobot.utils.utils import flatten_dict
@@ -1608,7 +1614,7 @@ def recompute_stats(
raise ValueError(f"No parquet files found in {data_dir}")
all_episode_stats = []
#TODO: enable image and video stats re-computation
# TODO: enable image and video stats re-computation
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
@@ -1708,9 +1714,7 @@ def convert_image_to_video_dataset(
logging.info(
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
)
logging.info(
f"RGB video encoder: {camera_encoder}, depth video encoder: {depth_encoder}"
)
logging.info(f"RGB video encoder: {camera_encoder}, depth video encoder: {depth_encoder}")
# Create new features dict, converting image features to video features
new_features = {}
@@ -1964,7 +1968,11 @@ def reencode_dataset(
return dataset
logging.info(f"Re-encoding {sum(len(paths) for paths in video_keys_paths_dict.values())} video file(s).")
worker_args = [(path, encoder, encoder_threads) for video_key, encoder in video_keys_encoders_dict.items() for path in video_keys_paths_dict[video_key]]
worker_args = [
(path, encoder, encoder_threads)
for video_key, encoder in video_keys_encoders_dict.items()
for path in video_keys_paths_dict[video_key]
]
if num_workers and num_workers > 1:
with ProcessPoolExecutor(max_workers=num_workers) as pool:
futures = [pool.submit(_reencode_video_worker, args) for args in worker_args]
@@ -1983,9 +1991,7 @@ def reencode_dataset(
depth_preserve_keys = {"is_depth_map", *(f"video.{n}" for n in DEPTH_ENCODER_INFO_FIELD_NAMES)}
for video_key, encoder in video_keys_encoders_dict.items():
preserve_keys = depth_preserve_keys if video_key in meta.depth_keys else None
meta.update_video_info(
video_key=video_key, video_encoder=encoder, preserve_keys=preserve_keys
)
meta.update_video_info(video_key=video_key, video_encoder=encoder, preserve_keys=preserve_keys)
write_info(meta.info, meta.root)
logging.info("Dataset metadata updated.")
+3 -1
View File
@@ -303,7 +303,9 @@ class DatasetWriter:
temp_path, video_stats = streaming_results[video_key]
if video_stats is not None:
ep_stats[video_key] = {
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / normalization_factor, axis=0)
k: v
if k == "count"
else np.squeeze(v.reshape(1, -1, 1, 1) / normalization_factor, axis=0)
for k, v in video_stats.items()
}
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
-1
View File
@@ -207,7 +207,6 @@ def dequantize_depth(
# ── Torch path: stay on the input device, single fp32 allocation. ────────
if isinstance(quantized, torch.Tensor):
if quantized.ndim >= 3:
# Drop the single-channel dimension so the math runs on (..., H, W).
quantized = quantized.squeeze(-3) if quantized.shape[-3] == 1 else quantized.squeeze(-1)
+12 -6
View File
@@ -82,17 +82,23 @@ def decode_video_frames(
if backend != "pyav" and is_depth:
logger.warning("Decoding depth maps is only supported with the 'pyav' backend.")
# We do not actually return uint8 here, but we avoid the 255 normalization step.
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=False, is_depth=True)
return decode_video_frames_pyav(
video_path, timestamps, tolerance_s, return_uint8=False, is_depth=True
)
if backend is None:
backend = get_safe_default_video_backend()
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
elif backend == "pyav":
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth)
return decode_video_frames_pyav(
video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth
)
elif backend == "video_reader":
logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.")
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth)
return decode_video_frames_pyav(
video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth
)
else:
raise ValueError(f"Unsupported video backend: {backend}")
@@ -122,7 +128,7 @@ def decode_video_frames_pyav(
tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest
decoded frame.
log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level.
return_uint8: For RGB videos, if True return raw uint8 frames (C, H, W).
return_uint8: For RGB videos, if True return raw uint8 frames (C, H, W).
Otherwise, return float32 in [0, 1] range.
is_depth: Set to True if the video is a depth map (1 channel, uint12).
@@ -477,13 +483,13 @@ def encode_video_frames(
with Image.open(input_data) as input_image:
if is_depth:
input_frame = quantize_depth(
np.array(input_image),
np.array(input_image),
depth_min=video_encoder.depth_min,
depth_max=video_encoder.depth_max,
shift=video_encoder.shift,
use_log=video_encoder.use_log,
pix_fmt=video_encoder.pix_fmt,
video_backend="pyav"
video_backend="pyav",
)
else:
input_image = input_image.convert("RGB")
+3 -6
View File
@@ -30,12 +30,11 @@ from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.feature_utils import features_equal_for_merge
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from tests.fixtures.constants import (
DUMMY_CAMERA_FEATURES,
DUMMY_DEPTH_CAMERA_FEATURES,
DUMMY_REPO_ID,
DUMMY_CAMERA_FEATURES_WITH_DEPTH,
DUMMY_REPO_ID,
)
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
"""Test that total number of episodes and frames are correctly aggregated."""
assert aggr_ds.num_episodes == expected_episodes, (
@@ -212,9 +211,7 @@ def assert_depth_keys_preserved(aggr_ds, ds_0, ds_1):
)
for key in expected_depth_keys:
info = aggr_ds.meta.info.features[key].get("info") or {}
assert info.get("is_depth_map") is True, (
f"Depth marker lost on feature {key!r} after aggregation"
)
assert info.get("is_depth_map") is True, f"Depth marker lost on feature {key!r} after aggregation"
def assert_video_timestamps_within_bounds(aggr_ds):
+1 -1
View File
@@ -1524,7 +1524,7 @@ def test_valid_video_codecs_constant():
assert "h264_qsv" in VALID_VIDEO_CODECS
assert "hevc_videotoolbox" in VALID_VIDEO_CODECS
assert "hevc_nvenc" in VALID_VIDEO_CODECS
assert len(VALID_VIDEO_CODECS) == 11
assert len(VALID_VIDEO_CODECS) == 11
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
+5 -8
View File
@@ -26,15 +26,16 @@ from lerobot.datasets.image_writer import image_array_to_pil_image, write_image
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
DUMMY_CAMERA_FEATURES_WITH_DEPTH,
DUMMY_CHW,
DUMMY_DEPTH_CAMERA_FEATURES,
DUMMY_REPO_ID,
DUMMY_CAMERA_FEATURES_WITH_DEPTH,
DUMMY_CHW
)
from tests.fixtures.dataset_factories import add_frames
_, H, W = DUMMY_CHW
def _depth_metres_ramp() -> np.ndarray:
"""Linearly-spaced float32 depth in metres covering the default range."""
return np.linspace(DEFAULT_DEPTH_MIN, DEFAULT_DEPTH_MAX, H * W, dtype=np.float32).reshape(H, W)
@@ -98,9 +99,7 @@ class TestQuantizeDequantize:
# ``m``: float32 noise (~10 µm in log mode, after ``exp``) — still 200× below the ~2 mm quant step.
# ``mm`` + tensor stays in float32 (no uint16 round-trip), so allow 1 mm slop.
atol = 1e-5 if output_unit == "m" else 1.0
np.testing.assert_allclose(
out.cpu().numpy().astype(np.float64), ref.astype(np.float64), atol=atol
)
np.testing.assert_allclose(out.cpu().numpy().astype(np.float64), ref.astype(np.float64), atol=atol)
@pytest.mark.parametrize(
"input_shape,output_shape",
@@ -131,9 +130,7 @@ class TestQuantizeDequantize:
def test_invalid_log_params_raises(self):
with pytest.raises(ValueError, match=r"depth_min \+ shift must be positive"):
quantize_depth(
_depth_metres_ramp(), depth_min=1.0, shift=-2.0, use_log=True, video_backend=None
)
quantize_depth(_depth_metres_ramp(), depth_min=1.0, shift=-2.0, use_log=True, video_backend=None)
# ── 2. Image writer depth support ─────────────────────────────────────