test(rerun): update tests

This commit is contained in:
CarolinePascal
2026-06-10 15:12:58 +02:00
parent c4b5ef8eaf
commit 854bc8d48f
+33 -19
View File
@@ -37,7 +37,14 @@ def mock_rerun(monkeypatch):
class DummyScalar:
def __init__(self, value):
self.value = float(value)
arr = np.asarray(value, dtype=float)
# Keep a flat list of values plus a convenience `value` for scalar inputs.
self.values = arr.reshape(-1).tolist()
self.value = self.values[0] if arr.ndim == 0 else None
class DummySeriesLines:
def __init__(self, names=None):
self.names = names
class DummyImage:
def __init__(self, arr):
@@ -54,6 +61,7 @@ def mock_rerun(monkeypatch):
__package__="rerun",
__spec__=SimpleNamespace(name="rerun", submodule_search_locations=None),
Scalars=DummyScalar,
SeriesLines=DummySeriesLines,
Image=DummyImage,
log=dummy_log,
init=lambda *a, **k: None,
@@ -85,6 +93,14 @@ def _obj_for(calls, key):
raise KeyError(f"Key {key} not found in calls: {calls}")
def _obj_for_type(calls, key, type_name):
"""Find the first object of a given type name logged under a given key."""
for k, obj, _kw in calls:
if k == key and type(obj).__name__ == type_name:
return obj
raise KeyError(f"Key {key} with type {type_name} not found in calls: {calls}")
def _kwargs_for(calls, key):
for k, _obj, kw in calls:
if k == key:
@@ -103,7 +119,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
}
act = {
"action.throttle": 0.7,
# 1D array should log individual Scalars with suffix _i
# 1D array should log a single Scalars batch under one entity path
"action.vector": np.array([1.0, 2.0], dtype=np.float32),
}
transition = {
@@ -120,13 +136,12 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
# - observation.state.temperature -> Scalars
# - observation.camera -> Image (HWC) with static=True
# - action.throttle -> Scalars
# - action.vector_0, action.vector_1 -> Scalars
# - action.vector -> single Scalars batch under one entity path
expected_keys = {
f"{OBS_STATE}.temperature",
"observation.camera",
"action.throttle",
"action.vector_0",
"action.vector_1",
"action.vector",
}
assert set(_keys(calls)) == expected_keys
@@ -139,12 +154,13 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
assert type(throttle_obj).__name__ == "DummyScalar"
assert throttle_obj.value == pytest.approx(0.7)
v0 = _obj_for(calls, "action.vector_0")
v1 = _obj_for(calls, "action.vector_1")
assert type(v0).__name__ == "DummyScalar"
assert type(v1).__name__ == "DummyScalar"
assert v0.value == pytest.approx(1.0)
assert v1.value == pytest.approx(2.0)
# The full vector is logged as one batch so all dimensions share a single view
vector_obj = _obj_for_type(calls, "action.vector", "DummyScalar")
assert vector_obj.values == pytest.approx([1.0, 2.0])
# Series keep their `{key}_{i}` names via SeriesLines
vector_names = _obj_for_type(calls, "action.vector", "DummySeriesLines")
assert vector_names.names == ["action.vector_0", "action.vector_1"]
# Check image handling: CHW -> HWC
img_obj = _obj_for(calls, "observation.camera")
@@ -178,9 +194,7 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
"observation.temp",
"observation.img",
"action.throttle",
"action.vec_0",
"action.vec_1",
"action.vec_2",
"action.vec",
}
logged = set(_keys(calls))
assert logged == expected
@@ -200,11 +214,11 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
assert img.arr.shape == (5, 6, 3)
assert _kwargs_for(calls, "observation.img").get("static", False) is True
# Vectors
for i, val in enumerate([9, 8, 7]):
o = _obj_for(calls, f"action.vec_{i}")
assert type(o).__name__ == "DummyScalar"
assert o.value == pytest.approx(val)
# Vector logged as a single batch under one entity path, keeping per-dimension names
vec = _obj_for_type(calls, "action.vec", "DummyScalar")
assert vec.values == pytest.approx([9, 8, 7])
vec_names = _obj_for_type(calls, "action.vec", "DummySeriesLines")
assert vec_names.names == ["action.vec_0", "action.vec_1", "action.vec_2"]
def test_log_rerun_data_kwargs_only(mock_rerun):