diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py index 63ff76c77..760bef1bd 100644 --- a/tests/utils/test_visualization_utils.py +++ b/tests/utils/test_visualization_utils.py @@ -37,7 +37,14 @@ def mock_rerun(monkeypatch): class DummyScalar: def __init__(self, value): - self.value = float(value) + arr = np.asarray(value, dtype=float) + # Keep a flat list of values plus a convenience `value` for scalar inputs. + self.values = arr.reshape(-1).tolist() + self.value = self.values[0] if arr.ndim == 0 else None + + class DummySeriesLines: + def __init__(self, names=None): + self.names = names class DummyImage: def __init__(self, arr): @@ -54,6 +61,7 @@ def mock_rerun(monkeypatch): __package__="rerun", __spec__=SimpleNamespace(name="rerun", submodule_search_locations=None), Scalars=DummyScalar, + SeriesLines=DummySeriesLines, Image=DummyImage, log=dummy_log, init=lambda *a, **k: None, @@ -85,6 +93,14 @@ def _obj_for(calls, key): raise KeyError(f"Key {key} not found in calls: {calls}") +def _obj_for_type(calls, key, type_name): + """Find the first object of a given type name logged under a given key.""" + for k, obj, _kw in calls: + if k == key and type(obj).__name__ == type_name: + return obj + raise KeyError(f"Key {key} with type {type_name} not found in calls: {calls}") + + def _kwargs_for(calls, key): for k, _obj, kw in calls: if k == key: @@ -103,7 +119,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): } act = { "action.throttle": 0.7, - # 1D array should log individual Scalars with suffix _i + # 1D array should log a single Scalars batch under one entity path "action.vector": np.array([1.0, 2.0], dtype=np.float32), } transition = { @@ -120,13 +136,12 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): # - observation.state.temperature -> Scalars # - observation.camera -> Image (HWC) with static=True # - action.throttle -> Scalars - # - action.vector_0, action.vector_1 -> Scalars + # - action.vector -> single Scalars batch under one entity path expected_keys = { f"{OBS_STATE}.temperature", "observation.camera", "action.throttle", - "action.vector_0", - "action.vector_1", + "action.vector", } assert set(_keys(calls)) == expected_keys @@ -139,12 +154,13 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): assert type(throttle_obj).__name__ == "DummyScalar" assert throttle_obj.value == pytest.approx(0.7) - v0 = _obj_for(calls, "action.vector_0") - v1 = _obj_for(calls, "action.vector_1") - assert type(v0).__name__ == "DummyScalar" - assert type(v1).__name__ == "DummyScalar" - assert v0.value == pytest.approx(1.0) - assert v1.value == pytest.approx(2.0) + # The full vector is logged as one batch so all dimensions share a single view + vector_obj = _obj_for_type(calls, "action.vector", "DummyScalar") + assert vector_obj.values == pytest.approx([1.0, 2.0]) + + # Series keep their `{key}_{i}` names via SeriesLines + vector_names = _obj_for_type(calls, "action.vector", "DummySeriesLines") + assert vector_names.names == ["action.vector_0", "action.vector_1"] # Check image handling: CHW -> HWC img_obj = _obj_for(calls, "observation.camera") @@ -178,9 +194,7 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun): "observation.temp", "observation.img", "action.throttle", - "action.vec_0", - "action.vec_1", - "action.vec_2", + "action.vec", } logged = set(_keys(calls)) assert logged == expected @@ -200,11 +214,11 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun): assert img.arr.shape == (5, 6, 3) assert _kwargs_for(calls, "observation.img").get("static", False) is True - # Vectors - for i, val in enumerate([9, 8, 7]): - o = _obj_for(calls, f"action.vec_{i}") - assert type(o).__name__ == "DummyScalar" - assert o.value == pytest.approx(val) + # Vector logged as a single batch under one entity path, keeping per-dimension names + vec = _obj_for_type(calls, "action.vec", "DummyScalar") + assert vec.values == pytest.approx([9, 8, 7]) + vec_names = _obj_for_type(calls, "action.vec", "DummySeriesLines") + assert vec_names.names == ["action.vec_0", "action.vec_1", "action.vec_2"] def test_log_rerun_data_kwargs_only(mock_rerun):