From 4eb7694d471e96a8abcd4a75121fc49722144ae7 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Fri, 31 Oct 2025 18:39:40 +0100 Subject: [PATCH] test(rerun audio): adding tests for audio visualization with rerun --- tests/utils/test_visualization_utils.py | 123 +++++++++++++++++++----- 1 file changed, 101 insertions(+), 22 deletions(-) diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py index 408f636cb..7c855d037 100644 --- a/tests/utils/test_visualization_utils.py +++ b/tests/utils/test_visualization_utils.py @@ -37,6 +37,14 @@ def mock_rerun(monkeypatch): def __init__(self, value): self.value = float(value) + @staticmethod + def columns(scalars): + return DummyScalarsColumn(scalars) + + class DummyScalarsColumn: + def __init__(self, values): + self.values = values + class DummyImage: def __init__(self, arr): self.arr = arr @@ -47,12 +55,46 @@ def mock_rerun(monkeypatch): obj = kwargs.pop("entity") calls.append((key, obj, kwargs)) + def dummy_send_columns(key, indexes, columns, **kwargs): + calls.append((key, columns, kwargs)) + + def dummy_time_column(timeline, timestamp): + return timestamp + + def dummy_set_time(timeline, timestamp): + return None + + class DummyTimeSeriesView: + def __call__(self, origin, plot_legend=None): + return None + + class DummySpatial2DView: + def __call__(self, origin): + return None + + class DummyGrid: + def __call__(self, *args): + return None + + class DummyPlotLegend: + def __call__(self, visible=True): + return None + dummy_rr = SimpleNamespace( Scalars=DummyScalar, Image=DummyImage, log=dummy_log, + TimeColumn=dummy_time_column, + send_columns=dummy_send_columns, + set_time=dummy_set_time, init=lambda *a, **k: None, spawn=lambda *a, **k: None, + blueprint=SimpleNamespace( + TimeSeriesView=DummyTimeSeriesView, + Spatial2DView=DummySpatial2DView, + Grid=DummyGrid, + PlotLegend=DummyPlotLegend, + ), ) # Inject fake module into sys.modules @@ -87,7 +129,7 @@ def _kwargs_for(calls, key): raise KeyError(f"Key {key} not found in calls: {calls}") -def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): +def test_log_rerun_data_envtransition_scalars_image_audio(mock_rerun): vu, calls = mock_rerun # Build EnvTransition dict @@ -95,6 +137,8 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): f"{OBS_STATE}.temperature": np.float32(25.0), # CHW image should be converted to HWC for rr.Image "observation.camera": np.zeros((3, 10, 20), dtype=np.uint8), + # Multiple channels audio data should be split into separate channels and logged as rr.Scalars.columns + "observation.audio": np.zeros((100, 2), dtype=np.float32), } act = { "action.throttle": 0.7, @@ -117,25 +161,27 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): # - action.throttle -> Scalars # - action.vector_0, action.vector_1 -> Scalars expected_keys = { - f"{OBS_STATE}.temperature", + "data/" + f"{OBS_STATE}.temperature", "observation.camera", - "action.throttle", - "action.vector_0", - "action.vector_1", + "data/action.throttle", + "data/action.vector_0", + "data/action.vector_1", + "audio/observation.audio_channel_0", + "audio/observation.audio_channel_1", } assert set(_keys(calls)) == expected_keys # Check scalar types and values - temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature") + temp_obj = _obj_for(calls, f"data/{OBS_STATE}.temperature") assert type(temp_obj).__name__ == "DummyScalar" assert temp_obj.value == pytest.approx(25.0) - throttle_obj = _obj_for(calls, "action.throttle") + throttle_obj = _obj_for(calls, "data/action.throttle") assert type(throttle_obj).__name__ == "DummyScalar" assert throttle_obj.value == pytest.approx(0.7) - v0 = _obj_for(calls, "action.vector_0") - v1 = _obj_for(calls, "action.vector_1") + v0 = _obj_for(calls, "data/action.vector_0") + v1 = _obj_for(calls, "data/action.vector_1") assert type(v0).__name__ == "DummyScalar" assert type(v1).__name__ == "DummyScalar" assert v0.value == pytest.approx(1.0) @@ -147,6 +193,14 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun): assert img_obj.arr.shape == (10, 20, 3) # transposed assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images + # Check audio handling: split channels + rr.Scalars.columns + audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0") + audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1") + assert type(audio_obj_0).__name__ == "DummyScalarsColumn" + assert type(audio_obj_1).__name__ == "DummyScalarsColumn" + assert audio_obj_0.values.shape == (100,) + assert audio_obj_1.values.shape == (100,) + def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun): vu, calls = mock_rerun @@ -157,6 +211,8 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun): "temp": 1.5, # Already HWC image => should stay as-is "img": np.zeros((5, 6, 3), dtype=np.uint8), + # Multiple channels audio data should be split into separate channels + "audio": np.zeros((100, 2), dtype=np.float32), "none": None, # should be skipped } act_plain = { @@ -170,22 +226,24 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun): # Expected keys with auto-prefixes expected = { - "observation.temp", + "data/observation.temp", "observation.img", - "action.throttle", - "action.vec_0", - "action.vec_1", - "action.vec_2", + "data/action.throttle", + "data/action.vec_0", + "data/action.vec_1", + "data/action.vec_2", + "audio/observation.audio_channel_0", + "audio/observation.audio_channel_1", } logged = set(_keys(calls)) assert logged == expected # Scalars - t = _obj_for(calls, "observation.temp") + t = _obj_for(calls, "data/observation.temp") assert type(t).__name__ == "DummyScalar" assert t.value == pytest.approx(1.5) - throttle = _obj_for(calls, "action.throttle") + throttle = _obj_for(calls, "data/action.throttle") assert type(throttle).__name__ == "DummyScalar" assert throttle.value == pytest.approx(0.3) @@ -197,25 +255,39 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun): # Vectors for i, val in enumerate([9, 8, 7]): - o = _obj_for(calls, f"action.vec_{i}") + o = _obj_for(calls, f"data/action.vec_{i}") assert type(o).__name__ == "DummyScalar" assert o.value == pytest.approx(val) + # Audio + audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0") + audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1") + assert type(audio_obj_0).__name__ == "DummyScalarsColumn" + assert type(audio_obj_1).__name__ == "DummyScalarsColumn" + assert audio_obj_0.values.shape == (100,) + assert audio_obj_1.values.shape == (100,) + def test_log_rerun_data_kwargs_only(mock_rerun): vu, calls = mock_rerun vu.log_rerun_data( - observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)}, + observation={ + "observation.temp": 10.0, + "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8), + "observation.audio": np.zeros((100, 2), dtype=np.float32), + }, action={"action.a": 1.0}, ) keys = set(_keys(calls)) - assert "observation.temp" in keys + assert "data/observation.temp" in keys assert "observation.gray" in keys - assert "action.a" in keys + assert "data/action.a" in keys + assert "audio/observation.audio_channel_0" in keys + assert "audio/observation.audio_channel_1" in keys - temp = _obj_for(calls, "observation.temp") + temp = _obj_for(calls, "data/observation.temp") assert type(temp).__name__ == "DummyScalar" assert temp.value == pytest.approx(10.0) @@ -224,6 +296,13 @@ def test_log_rerun_data_kwargs_only(mock_rerun): assert img.arr.shape == (8, 8, 1) # remains HWC assert _kwargs_for(calls, "observation.gray").get("static", False) is True - a = _obj_for(calls, "action.a") + a = _obj_for(calls, "data/action.a") assert type(a).__name__ == "DummyScalar" assert a.value == pytest.approx(1.0) + + audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0") + audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1") + assert type(audio_obj_0).__name__ == "DummyScalarsColumn" + assert type(audio_obj_1).__name__ == "DummyScalarsColumn" + assert audio_obj_0.values.shape == (100,) + assert audio_obj_1.values.shape == (100,)