Merge branch 'main' into chore/bump_transformers_v5

This commit is contained in:
Steven Palma
2026-03-02 15:50:04 +01:00
committed by GitHub
23 changed files with 159 additions and 42 deletions
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:19eaaa85f66ba4aa6388dbb83819ffad6ea4363247208f871a8dc385689f6fc8
oid sha256:54aecbc1af72a4cd5e9261492f5e7601890517516257aacdf2a0ffb3ce281f1b
size 992
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:227296eaeeb54acdc3dae2eb8af3d4d08fb87e245337624447140b1e91cfd002
oid sha256:88a9c3775a2aa1e90a08850521970070a4fcf0f6b82aab43cd8ccc5cf77e0013
size 47424
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:271b00cb2f0cd5fd26b1d53463638e3d1a6e92692ec625fcffb420ca190869e5
oid sha256:91a2635e05a75fe187a5081504c5f35ce3417378813fa2deaf9ca4e8200e1819
size 68
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:778fddbbaa64248cee35cb377c02cc2b6076f7ce5855146de677128900617ddf
oid sha256:645bff922ac7bea63ad018ebf77c303c0e4cd2c1c0dc5ef3192865281bef3dc6
size 47424
+36
View File
@@ -24,6 +24,11 @@ def mock_metrics():
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
class MockAccelerator:
def __init__(self, num_processes: int):
self.num_processes = num_processes
def test_average_meter_initialization():
meter = AverageMeter("loss", ":.2f")
assert meter.name == "loss"
@@ -82,6 +87,37 @@ def test_metrics_tracker_step(mock_metrics):
assert tracker.epochs == tracker.samples / 1000
def test_metrics_tracker_initialization_with_accelerator(mock_metrics):
tracker = MetricsTracker(
batch_size=32,
num_frames=1000,
num_episodes=50,
metrics=mock_metrics,
initial_step=10,
accelerator=MockAccelerator(num_processes=2),
)
assert tracker.steps == 10
assert tracker.samples == 10 * 32 * 2
assert tracker.episodes == tracker.samples / (1000 / 50)
assert tracker.epochs == tracker.samples / 1000
def test_metrics_tracker_step_with_accelerator(mock_metrics):
tracker = MetricsTracker(
batch_size=32,
num_frames=1000,
num_episodes=50,
metrics=mock_metrics,
initial_step=5,
accelerator=MockAccelerator(num_processes=2),
)
tracker.step()
assert tracker.steps == 6
assert tracker.samples == (5 * 32 * 2) + (32 * 2)
assert tracker.episodes == tracker.samples / (1000 / 50)
assert tracker.epochs == tracker.samples / 1000
def test_metrics_tracker_getattr(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
assert tracker.loss == mock_metrics["loss"]