mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
fix(from_video_info): fixing early validation issue in from_video_info
This commit is contained in:
@@ -19,8 +19,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any, ClassVar
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ClassVar, Self
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
@@ -112,9 +112,9 @@ class VideoEncoderConfig:
|
||||
self.validate()
|
||||
|
||||
@classmethod
|
||||
def from_video_info(cls, video_info: dict | None) -> VideoEncoderConfig:
|
||||
"""Reconstruct a :class:`VideoEncoderConfig` from a video feature's ``info`` block.
|
||||
Missing or ``None`` values fall back to the class defaults.
|
||||
def _kwargs_from_video_info(cls, video_info: dict | None) -> dict[str, Any]:
|
||||
"""Parse the ``video.*`` keys of a feature ``info`` block into
|
||||
constructor kwargs.
|
||||
"""
|
||||
video_info = video_info or {}
|
||||
kwargs: dict[str, Any] = {}
|
||||
@@ -133,7 +133,15 @@ class VideoEncoderConfig:
|
||||
continue
|
||||
kwargs[field_name] = value
|
||||
|
||||
return cls(**kwargs)
|
||||
return kwargs
|
||||
|
||||
@classmethod
|
||||
def from_video_info(cls, video_info: dict | None) -> Self:
|
||||
"""Reconstruct an encoder config from a video feature's ``info`` block.
|
||||
|
||||
Missing or ``None`` values fall back to the class defaults.
|
||||
"""
|
||||
return cls(**cls._kwargs_from_video_info(video_info))
|
||||
|
||||
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
|
||||
"""Return the subset of available encoders based on the specified video backend.
|
||||
@@ -291,23 +299,18 @@ class DepthEncoderConfig(VideoEncoderConfig):
|
||||
_DEFAULT_CHANNELS: ClassVar[int] = 1
|
||||
|
||||
@classmethod
|
||||
def from_video_info(cls, video_info: dict | None) -> DepthEncoderConfig:
|
||||
"""Reconstruct a :class:`DepthEncoderConfig` from a depth feature's ``info`` block.
|
||||
|
||||
Reuses :meth:`VideoEncoderConfig.from_video_info` for the base
|
||||
codec/tuning fields and then layers the depth-specific tuning
|
||||
(``depth_min`` / ``depth_max`` / ``shift`` / ``use_log``) on top.
|
||||
Missing keys fall back to the class defaults.
|
||||
def _kwargs_from_video_info(cls, video_info: dict | None) -> dict[str, Any]:
|
||||
"""Layer the depth-specific tuning (``depth_min`` / ``depth_max`` /
|
||||
``shift`` / ``use_log``) on top of the base parser. Missing keys
|
||||
fall back to the class defaults.
|
||||
"""
|
||||
base = VideoEncoderConfig.from_video_info(video_info)
|
||||
kwargs: dict[str, Any] = {f.name: getattr(base, f.name) for f in fields(base) if f.init}
|
||||
|
||||
kwargs = super()._kwargs_from_video_info(video_info)
|
||||
video_info = video_info or {}
|
||||
for name in DEPTH_ENCODER_INFO_FIELD_NAMES:
|
||||
value = video_info.get(f"video.{name}")
|
||||
if value is not None:
|
||||
kwargs[name] = value
|
||||
return cls(**kwargs)
|
||||
return kwargs
|
||||
|
||||
|
||||
def depth_encoder_defaults() -> DepthEncoderConfig:
|
||||
|
||||
Reference in New Issue
Block a user