mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
chore(format): formatting code
This commit is contained in:
@@ -139,7 +139,7 @@ class VideoEncoderConfig:
|
||||
def from_video_info(cls, video_info: dict | None) -> Self:
|
||||
"""Reconstruct an encoder config from a video feature's ``info`` block.
|
||||
|
||||
Missing or ``None`` values fall back to the class defaults.
|
||||
Missing or ``None`` values fall back to the class defaults.
|
||||
"""
|
||||
return cls(**cls._kwargs_from_video_info(video_info))
|
||||
|
||||
|
||||
@@ -531,9 +531,12 @@ def compute_episode_stats(
|
||||
)
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
normalization_factor = 255.0 if not (features[key].get("info") or {}).get("is_depth_map", False) else 1.0
|
||||
normalization_factor = (
|
||||
255.0 if not (features[key].get("info") or {}).get("is_depth_map", False) else 1.0
|
||||
)
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / normalization_factor, axis=0) for k, v in ep_stats[key].items()
|
||||
k: v if k == "count" else np.squeeze(v / normalization_factor, axis=0)
|
||||
for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
|
||||
@@ -36,7 +36,13 @@ import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import VideoEncoderConfig, camera_encoder_defaults, DepthEncoderConfig, encoder_config_from_video_info, depth_encoder_defaults
|
||||
from lerobot.configs import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
encoder_config_from_video_info,
|
||||
)
|
||||
from lerobot.configs.video import DEPTH_ENCODER_INFO_FIELD_NAMES
|
||||
from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.utils import flatten_dict
|
||||
@@ -1614,7 +1620,7 @@ def recompute_stats(
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
all_episode_stats = []
|
||||
#TODO: enable image and video stats re-computation
|
||||
# TODO: enable image and video stats re-computation
|
||||
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
|
||||
|
||||
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
|
||||
@@ -1714,9 +1720,7 @@ def convert_image_to_video_dataset(
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(
|
||||
f"RGB video encoder: {camera_encoder}, depth video encoder: {depth_encoder}"
|
||||
)
|
||||
logging.info(f"RGB video encoder: {camera_encoder}, depth video encoder: {depth_encoder}")
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
@@ -1970,7 +1974,11 @@ def reencode_dataset(
|
||||
return dataset
|
||||
logging.info(f"Re-encoding {sum(len(paths) for paths in video_keys_paths_dict.values())} video file(s).")
|
||||
|
||||
worker_args = [(path, encoder, encoder_threads) for video_key, encoder in video_keys_encoders_dict.items() for path in video_keys_paths_dict[video_key]]
|
||||
worker_args = [
|
||||
(path, encoder, encoder_threads)
|
||||
for video_key, encoder in video_keys_encoders_dict.items()
|
||||
for path in video_keys_paths_dict[video_key]
|
||||
]
|
||||
if num_workers and num_workers > 1:
|
||||
with ProcessPoolExecutor(max_workers=num_workers) as pool:
|
||||
futures = [pool.submit(_reencode_video_worker, args) for args in worker_args]
|
||||
@@ -1989,9 +1997,7 @@ def reencode_dataset(
|
||||
depth_preserve_keys = {"is_depth_map", *(f"video.{n}" for n in DEPTH_ENCODER_INFO_FIELD_NAMES)}
|
||||
for video_key, encoder in video_keys_encoders_dict.items():
|
||||
preserve_keys = depth_preserve_keys if video_key in meta.depth_keys else None
|
||||
meta.update_video_info(
|
||||
video_key=video_key, video_encoder=encoder, preserve_keys=preserve_keys
|
||||
)
|
||||
meta.update_video_info(video_key=video_key, video_encoder=encoder, preserve_keys=preserve_keys)
|
||||
|
||||
write_info(meta.info, meta.root)
|
||||
logging.info("Dataset metadata updated.")
|
||||
|
||||
@@ -303,7 +303,9 @@ class DatasetWriter:
|
||||
temp_path, video_stats = streaming_results[video_key]
|
||||
if video_stats is not None:
|
||||
ep_stats[video_key] = {
|
||||
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / normalization_factor, axis=0)
|
||||
k: v
|
||||
if k == "count"
|
||||
else np.squeeze(v.reshape(1, -1, 1, 1) / normalization_factor, axis=0)
|
||||
for k, v in video_stats.items()
|
||||
}
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
|
||||
|
||||
@@ -207,7 +207,6 @@ def dequantize_depth(
|
||||
|
||||
# ── Torch path: stay on the input device, single fp32 allocation. ────────
|
||||
if isinstance(quantized, torch.Tensor):
|
||||
|
||||
if quantized.ndim >= 3:
|
||||
# Drop the single-channel dimension so the math runs on (..., H, W).
|
||||
quantized = quantized.squeeze(-3) if quantized.shape[-3] == 1 else quantized.squeeze(-1)
|
||||
|
||||
@@ -82,17 +82,23 @@ def decode_video_frames(
|
||||
if backend != "pyav" and is_depth:
|
||||
logger.warning("Decoding depth maps is only supported with the 'pyav' backend.")
|
||||
# We do not actually return uint8 here, but we avoid the 255 normalization step.
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=False, is_depth=True)
|
||||
return decode_video_frames_pyav(
|
||||
video_path, timestamps, tolerance_s, return_uint8=False, is_depth=True
|
||||
)
|
||||
|
||||
if backend is None:
|
||||
backend = get_safe_default_video_backend()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend == "pyav":
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth)
|
||||
return decode_video_frames_pyav(
|
||||
video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth
|
||||
)
|
||||
elif backend == "video_reader":
|
||||
logger.warning("backend='video_reader' is deprecated and now aliases to 'pyav'.")
|
||||
return decode_video_frames_pyav(video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth)
|
||||
return decode_video_frames_pyav(
|
||||
video_path, timestamps, tolerance_s, return_uint8=return_uint8, is_depth=is_depth
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
@@ -122,7 +128,7 @@ def decode_video_frames_pyav(
|
||||
tolerance_s: Allowed deviation in seconds between a queried timestamp and the closest
|
||||
decoded frame.
|
||||
log_loaded_timestamps: When True, log every decoded frame's timestamp at INFO level.
|
||||
return_uint8: For RGB videos, if True return raw uint8 frames (C, H, W).
|
||||
return_uint8: For RGB videos, if True return raw uint8 frames (C, H, W).
|
||||
Otherwise, return float32 in [0, 1] range.
|
||||
is_depth: Set to True if the video is a depth map (1 channel, uint12).
|
||||
|
||||
@@ -477,13 +483,13 @@ def encode_video_frames(
|
||||
with Image.open(input_data) as input_image:
|
||||
if is_depth:
|
||||
input_frame = quantize_depth(
|
||||
np.array(input_image),
|
||||
np.array(input_image),
|
||||
depth_min=video_encoder.depth_min,
|
||||
depth_max=video_encoder.depth_max,
|
||||
shift=video_encoder.shift,
|
||||
use_log=video_encoder.use_log,
|
||||
pix_fmt=video_encoder.pix_fmt,
|
||||
video_backend="pyav"
|
||||
video_backend="pyav",
|
||||
)
|
||||
else:
|
||||
input_image = input_image.convert("RGB")
|
||||
|
||||
@@ -30,12 +30,11 @@ from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.feature_utils import features_equal_for_merge
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from tests.fixtures.constants import (
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
DUMMY_DEPTH_CAMERA_FEATURES,
|
||||
DUMMY_REPO_ID,
|
||||
DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
DUMMY_REPO_ID,
|
||||
)
|
||||
|
||||
|
||||
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
|
||||
"""Test that total number of episodes and frames are correctly aggregated."""
|
||||
assert aggr_ds.num_episodes == expected_episodes, (
|
||||
@@ -212,9 +211,7 @@ def assert_depth_keys_preserved(aggr_ds, ds_0, ds_1):
|
||||
)
|
||||
for key in expected_depth_keys:
|
||||
info = aggr_ds.meta.info.features[key].get("info") or {}
|
||||
assert info.get("is_depth_map") is True, (
|
||||
f"Depth marker lost on feature {key!r} after aggregation"
|
||||
)
|
||||
assert info.get("is_depth_map") is True, f"Depth marker lost on feature {key!r} after aggregation"
|
||||
|
||||
|
||||
def assert_video_timestamps_within_bounds(aggr_ds):
|
||||
|
||||
@@ -1524,7 +1524,7 @@ def test_valid_video_codecs_constant():
|
||||
assert "h264_qsv" in VALID_VIDEO_CODECS
|
||||
assert "hevc_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "hevc_nvenc" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 11
|
||||
assert len(VALID_VIDEO_CODECS) == 11
|
||||
|
||||
|
||||
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
|
||||
|
||||
@@ -26,15 +26,16 @@ from lerobot.datasets.image_writer import image_array_to_pil_image, write_image
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
DUMMY_CHW,
|
||||
DUMMY_DEPTH_CAMERA_FEATURES,
|
||||
DUMMY_REPO_ID,
|
||||
DUMMY_CAMERA_FEATURES_WITH_DEPTH,
|
||||
DUMMY_CHW
|
||||
)
|
||||
from tests.fixtures.dataset_factories import add_frames
|
||||
|
||||
_, H, W = DUMMY_CHW
|
||||
|
||||
|
||||
def _depth_metres_ramp() -> np.ndarray:
|
||||
"""Linearly-spaced float32 depth in metres covering the default range."""
|
||||
return np.linspace(DEFAULT_DEPTH_MIN, DEFAULT_DEPTH_MAX, H * W, dtype=np.float32).reshape(H, W)
|
||||
@@ -98,9 +99,7 @@ class TestQuantizeDequantize:
|
||||
# ``m``: float32 noise (~10 µm in log mode, after ``exp``) — still 200× below the ~2 mm quant step.
|
||||
# ``mm`` + tensor stays in float32 (no uint16 round-trip), so allow 1 mm slop.
|
||||
atol = 1e-5 if output_unit == "m" else 1.0
|
||||
np.testing.assert_allclose(
|
||||
out.cpu().numpy().astype(np.float64), ref.astype(np.float64), atol=atol
|
||||
)
|
||||
np.testing.assert_allclose(out.cpu().numpy().astype(np.float64), ref.astype(np.float64), atol=atol)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape,output_shape",
|
||||
@@ -131,9 +130,7 @@ class TestQuantizeDequantize:
|
||||
|
||||
def test_invalid_log_params_raises(self):
|
||||
with pytest.raises(ValueError, match=r"depth_min \+ shift must be positive"):
|
||||
quantize_depth(
|
||||
_depth_metres_ramp(), depth_min=1.0, shift=-2.0, use_log=True, video_backend=None
|
||||
)
|
||||
quantize_depth(_depth_metres_ramp(), depth_min=1.0, shift=-2.0, use_log=True, video_backend=None)
|
||||
|
||||
|
||||
# ── 2. Image writer depth support ─────────────────────────────────────
|
||||
|
||||
Reference in New Issue
Block a user