mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
chore(format): formatting code
This commit is contained in:
@@ -90,16 +90,16 @@ class TestQuantizeDequantize:
|
||||
@pytest.mark.parametrize("output_unit", ["m", "mm"])
|
||||
def test_numpy_torch_agree(self, use_log, output_unit):
|
||||
"""Batched torch path produces the same values as the numpy path."""
|
||||
T = 3
|
||||
batch_size = 3
|
||||
per_frame = np.linspace(0, DEPTH_QMAX, H * W, dtype=np.uint16).reshape(H, W)
|
||||
batch_np = np.broadcast_to(per_frame[None, None, ...], (T, 1, H, W)).copy()
|
||||
batch_np = np.broadcast_to(per_frame[None, None, ...], (batch_size, 1, H, W)).copy()
|
||||
batch_t = torch.from_numpy(batch_np.astype(np.int32)) # torch.uint16 support is patchy.
|
||||
|
||||
ref = dequantize_depth(batch_np, use_log=use_log, output_unit=output_unit, output_tensor=False)
|
||||
out = dequantize_depth(batch_t, use_log=use_log, output_unit=output_unit, output_tensor=True)
|
||||
|
||||
assert isinstance(out, torch.Tensor)
|
||||
assert out.shape == (T, 1, H, W)
|
||||
assert out.shape == (batch_size, 1, H, W)
|
||||
# ``m``: float32 noise (~10 µm in log mode, after ``exp``) — still 200× below the ~2 mm quant step.
|
||||
# ``mm`` + tensor stays in float32 (no uint16 round-trip), so allow 1 mm slop.
|
||||
atol = 1e-5 if output_unit == "m" else 1.0
|
||||
|
||||
@@ -66,7 +66,7 @@ require_qsv = _require_encoder("h264_qsv")
|
||||
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "encoded_videos"
|
||||
|
||||
|
||||
def _write_RGB_frames(imgs_dir: Path, num_frames: int = 4, height: int = 64, width: int = 96) -> None:
|
||||
def _write_color_frames(imgs_dir: Path, num_frames: int = 4, height: int = 64, width: int = 96) -> None:
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i in range(num_frames):
|
||||
arr = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
||||
@@ -107,7 +107,7 @@ def _encode_video(
|
||||
_write_depth_frames(imgs_dir, num_frames=num_frames)
|
||||
cfg = cfg or DepthEncoderConfig()
|
||||
else:
|
||||
_write_RGB_frames(imgs_dir, num_frames=num_frames)
|
||||
_write_color_frames(imgs_dir, num_frames=num_frames)
|
||||
encode_video_frames(imgs_dir, path, fps=fps, video_encoder=cfg, overwrite=True)
|
||||
return path
|
||||
|
||||
@@ -449,7 +449,7 @@ class TestEncodeVideoFrames:
|
||||
|
||||
def test_overwrite_false_skips_existing_file(self, tmp_path):
|
||||
imgs_dir = tmp_path / "imgs"
|
||||
_write_RGB_frames(imgs_dir)
|
||||
_write_color_frames(imgs_dir)
|
||||
video_path = tmp_path / "out.mp4"
|
||||
sentinel = b"pre-existing content"
|
||||
video_path.write_bytes(sentinel)
|
||||
@@ -461,7 +461,7 @@ class TestEncodeVideoFrames:
|
||||
@require_libsvtav1
|
||||
def test_overwrite_true_replaces_existing_file(self, tmp_path):
|
||||
imgs_dir = tmp_path / "imgs"
|
||||
_write_RGB_frames(imgs_dir)
|
||||
_write_color_frames(imgs_dir)
|
||||
video_path = tmp_path / "out.mp4"
|
||||
video_path.write_bytes(b"stale content")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user