refactor: several fixes

This commit is contained in:
Steven Palma
2026-04-10 15:35:31 +02:00
parent e2381633cd
commit 882a6b0965
21 changed files with 407 additions and 356 deletions
+2 -1
View File
@@ -70,10 +70,11 @@ def is_package_available(
def get_safe_default_codec():
logger = logging.getLogger(__name__)
if importlib.util.find_spec("torchcodec"):
return "torchcodec"
else:
logging.warning(
logger.warning(
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
)
return "pyav"
+15 -4
View File
@@ -14,9 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
@@ -58,10 +61,18 @@ def write_video(video_path: str | Path, stacked_frames: list, fps: int) -> None:
import av
with av.open(str(video_path), mode="w") as container:
height, width = stacked_frames[0].shape[:2]
# Ensure dimensions are even for yuv420p compatibility
height = height if height % 2 == 0 else height - 1
width = width if width % 2 == 0 else width - 1
orig_height, orig_width = stacked_frames[0].shape[:2]
# yuv420p requires even dimensions; crop by one pixel if needed
height = orig_height if orig_height % 2 == 0 else orig_height - 1
width = orig_width if orig_width % 2 == 0 else orig_width - 1
if height != orig_height or width != orig_width:
logger.warning(
"Frame dimensions %dx%d are not even; cropping to %dx%d for yuv420p compatibility.",
orig_width,
orig_height,
width,
height,
)
stream = container.add_stream("libx264", rate=fps)
stream.width = width
stream.height = height
+23 -1
View File
@@ -22,11 +22,12 @@ import select
import subprocess
import sys
import time
from collections.abc import Iterator
from copy import copy, deepcopy
from datetime import datetime
from pathlib import Path
from statistics import mean
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import numpy as np
@@ -252,6 +253,27 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
return outdict
def cycle(iterable: Any) -> Iterator[Any]:
"""Create a dataloader-safe cyclical iterator.
This is an equivalent of `itertools.cycle` but is safe for use with
PyTorch DataLoaders with multiple workers.
See https://github.com/pytorch/pytorch/issues/23900 for details.
Args:
iterable: The iterable to cycle over.
Yields:
Items from the iterable, restarting from the beginning when exhausted.
"""
iterator = iter(iterable)
while True:
try:
yield next(iterator)
except StopIteration:
iterator = iter(iterable)
class SuppressProgressBars:
"""
Context manager to suppress progress bars.