mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-29 15:39:56 +00:00
refactor: several fixes
This commit is contained in:
@@ -70,10 +70,11 @@ def is_package_available(
|
||||
|
||||
|
||||
def get_safe_default_codec():
|
||||
logger = logging.getLogger(__name__)
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
return "torchcodec"
|
||||
else:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
||||
)
|
||||
return "pyav"
|
||||
|
||||
@@ -14,9 +14,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
|
||||
|
||||
|
||||
@@ -58,10 +61,18 @@ def write_video(video_path: str | Path, stacked_frames: list, fps: int) -> None:
|
||||
import av
|
||||
|
||||
with av.open(str(video_path), mode="w") as container:
|
||||
height, width = stacked_frames[0].shape[:2]
|
||||
# Ensure dimensions are even for yuv420p compatibility
|
||||
height = height if height % 2 == 0 else height - 1
|
||||
width = width if width % 2 == 0 else width - 1
|
||||
orig_height, orig_width = stacked_frames[0].shape[:2]
|
||||
# yuv420p requires even dimensions; crop by one pixel if needed
|
||||
height = orig_height if orig_height % 2 == 0 else orig_height - 1
|
||||
width = orig_width if orig_width % 2 == 0 else orig_width - 1
|
||||
if height != orig_height or width != orig_width:
|
||||
logger.warning(
|
||||
"Frame dimensions %dx%d are not even; cropping to %dx%d for yuv420p compatibility.",
|
||||
orig_width,
|
||||
orig_height,
|
||||
width,
|
||||
height,
|
||||
)
|
||||
stream = container.add_stream("libx264", rate=fps)
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
|
||||
@@ -22,11 +22,12 @@ import select
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from copy import copy, deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -252,6 +253,27 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
return outdict
|
||||
|
||||
|
||||
def cycle(iterable: Any) -> Iterator[Any]:
|
||||
"""Create a dataloader-safe cyclical iterator.
|
||||
|
||||
This is an equivalent of `itertools.cycle` but is safe for use with
|
||||
PyTorch DataLoaders with multiple workers.
|
||||
See https://github.com/pytorch/pytorch/issues/23900 for details.
|
||||
|
||||
Args:
|
||||
iterable: The iterable to cycle over.
|
||||
|
||||
Yields:
|
||||
Items from the iterable, restarting from the beginning when exhausted.
|
||||
"""
|
||||
iterator = iter(iterable)
|
||||
while True:
|
||||
try:
|
||||
yield next(iterator)
|
||||
except StopIteration:
|
||||
iterator = iter(iterable)
|
||||
|
||||
|
||||
class SuppressProgressBars:
|
||||
"""
|
||||
Context manager to suppress progress bars.
|
||||
|
||||
Reference in New Issue
Block a user