mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 17:50:09 +00:00
test(rerun audio): adding tests for audio visualization with rerun
This commit is contained in:
@@ -37,6 +37,14 @@ def mock_rerun(monkeypatch):
|
||||
def __init__(self, value):
|
||||
self.value = float(value)
|
||||
|
||||
@staticmethod
|
||||
def columns(scalars):
|
||||
return DummyScalarsColumn(scalars)
|
||||
|
||||
class DummyScalarsColumn:
|
||||
def __init__(self, values):
|
||||
self.values = values
|
||||
|
||||
class DummyImage:
|
||||
def __init__(self, arr):
|
||||
self.arr = arr
|
||||
@@ -47,12 +55,46 @@ def mock_rerun(monkeypatch):
|
||||
obj = kwargs.pop("entity")
|
||||
calls.append((key, obj, kwargs))
|
||||
|
||||
def dummy_send_columns(key, indexes, columns, **kwargs):
|
||||
calls.append((key, columns, kwargs))
|
||||
|
||||
def dummy_time_column(timeline, timestamp):
|
||||
return timestamp
|
||||
|
||||
def dummy_set_time(timeline, timestamp):
|
||||
return None
|
||||
|
||||
class DummyTimeSeriesView:
|
||||
def __call__(self, origin, plot_legend=None):
|
||||
return None
|
||||
|
||||
class DummySpatial2DView:
|
||||
def __call__(self, origin):
|
||||
return None
|
||||
|
||||
class DummyGrid:
|
||||
def __call__(self, *args):
|
||||
return None
|
||||
|
||||
class DummyPlotLegend:
|
||||
def __call__(self, visible=True):
|
||||
return None
|
||||
|
||||
dummy_rr = SimpleNamespace(
|
||||
Scalars=DummyScalar,
|
||||
Image=DummyImage,
|
||||
log=dummy_log,
|
||||
TimeColumn=dummy_time_column,
|
||||
send_columns=dummy_send_columns,
|
||||
set_time=dummy_set_time,
|
||||
init=lambda *a, **k: None,
|
||||
spawn=lambda *a, **k: None,
|
||||
blueprint=SimpleNamespace(
|
||||
TimeSeriesView=DummyTimeSeriesView,
|
||||
Spatial2DView=DummySpatial2DView,
|
||||
Grid=DummyGrid,
|
||||
PlotLegend=DummyPlotLegend,
|
||||
),
|
||||
)
|
||||
|
||||
# Inject fake module into sys.modules
|
||||
@@ -87,7 +129,7 @@ def _kwargs_for(calls, key):
|
||||
raise KeyError(f"Key {key} not found in calls: {calls}")
|
||||
|
||||
|
||||
def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
def test_log_rerun_data_envtransition_scalars_image_audio(mock_rerun):
|
||||
vu, calls = mock_rerun
|
||||
|
||||
# Build EnvTransition dict
|
||||
@@ -95,6 +137,8 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
f"{OBS_STATE}.temperature": np.float32(25.0),
|
||||
# CHW image should be converted to HWC for rr.Image
|
||||
"observation.camera": np.zeros((3, 10, 20), dtype=np.uint8),
|
||||
# Multiple channels audio data should be split into separate channels and logged as rr.Scalars.columns
|
||||
"observation.audio": np.zeros((100, 2), dtype=np.float32),
|
||||
}
|
||||
act = {
|
||||
"action.throttle": 0.7,
|
||||
@@ -117,25 +161,27 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
# - action.throttle -> Scalars
|
||||
# - action.vector_0, action.vector_1 -> Scalars
|
||||
expected_keys = {
|
||||
f"{OBS_STATE}.temperature",
|
||||
"data/" + f"{OBS_STATE}.temperature",
|
||||
"observation.camera",
|
||||
"action.throttle",
|
||||
"action.vector_0",
|
||||
"action.vector_1",
|
||||
"data/action.throttle",
|
||||
"data/action.vector_0",
|
||||
"data/action.vector_1",
|
||||
"audio/observation.audio_channel_0",
|
||||
"audio/observation.audio_channel_1",
|
||||
}
|
||||
assert set(_keys(calls)) == expected_keys
|
||||
|
||||
# Check scalar types and values
|
||||
temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature")
|
||||
temp_obj = _obj_for(calls, f"data/{OBS_STATE}.temperature")
|
||||
assert type(temp_obj).__name__ == "DummyScalar"
|
||||
assert temp_obj.value == pytest.approx(25.0)
|
||||
|
||||
throttle_obj = _obj_for(calls, "action.throttle")
|
||||
throttle_obj = _obj_for(calls, "data/action.throttle")
|
||||
assert type(throttle_obj).__name__ == "DummyScalar"
|
||||
assert throttle_obj.value == pytest.approx(0.7)
|
||||
|
||||
v0 = _obj_for(calls, "action.vector_0")
|
||||
v1 = _obj_for(calls, "action.vector_1")
|
||||
v0 = _obj_for(calls, "data/action.vector_0")
|
||||
v1 = _obj_for(calls, "data/action.vector_1")
|
||||
assert type(v0).__name__ == "DummyScalar"
|
||||
assert type(v1).__name__ == "DummyScalar"
|
||||
assert v0.value == pytest.approx(1.0)
|
||||
@@ -147,6 +193,14 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
assert img_obj.arr.shape == (10, 20, 3) # transposed
|
||||
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
|
||||
|
||||
# Check audio handling: split channels + rr.Scalars.columns
|
||||
audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0")
|
||||
audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1")
|
||||
assert type(audio_obj_0).__name__ == "DummyScalarsColumn"
|
||||
assert type(audio_obj_1).__name__ == "DummyScalarsColumn"
|
||||
assert audio_obj_0.values.shape == (100,)
|
||||
assert audio_obj_1.values.shape == (100,)
|
||||
|
||||
|
||||
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
vu, calls = mock_rerun
|
||||
@@ -157,6 +211,8 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
"temp": 1.5,
|
||||
# Already HWC image => should stay as-is
|
||||
"img": np.zeros((5, 6, 3), dtype=np.uint8),
|
||||
# Multiple channels audio data should be split into separate channels
|
||||
"audio": np.zeros((100, 2), dtype=np.float32),
|
||||
"none": None, # should be skipped
|
||||
}
|
||||
act_plain = {
|
||||
@@ -170,22 +226,24 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
|
||||
# Expected keys with auto-prefixes
|
||||
expected = {
|
||||
"observation.temp",
|
||||
"data/observation.temp",
|
||||
"observation.img",
|
||||
"action.throttle",
|
||||
"action.vec_0",
|
||||
"action.vec_1",
|
||||
"action.vec_2",
|
||||
"data/action.throttle",
|
||||
"data/action.vec_0",
|
||||
"data/action.vec_1",
|
||||
"data/action.vec_2",
|
||||
"audio/observation.audio_channel_0",
|
||||
"audio/observation.audio_channel_1",
|
||||
}
|
||||
logged = set(_keys(calls))
|
||||
assert logged == expected
|
||||
|
||||
# Scalars
|
||||
t = _obj_for(calls, "observation.temp")
|
||||
t = _obj_for(calls, "data/observation.temp")
|
||||
assert type(t).__name__ == "DummyScalar"
|
||||
assert t.value == pytest.approx(1.5)
|
||||
|
||||
throttle = _obj_for(calls, "action.throttle")
|
||||
throttle = _obj_for(calls, "data/action.throttle")
|
||||
assert type(throttle).__name__ == "DummyScalar"
|
||||
assert throttle.value == pytest.approx(0.3)
|
||||
|
||||
@@ -197,25 +255,39 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
|
||||
# Vectors
|
||||
for i, val in enumerate([9, 8, 7]):
|
||||
o = _obj_for(calls, f"action.vec_{i}")
|
||||
o = _obj_for(calls, f"data/action.vec_{i}")
|
||||
assert type(o).__name__ == "DummyScalar"
|
||||
assert o.value == pytest.approx(val)
|
||||
|
||||
# Audio
|
||||
audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0")
|
||||
audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1")
|
||||
assert type(audio_obj_0).__name__ == "DummyScalarsColumn"
|
||||
assert type(audio_obj_1).__name__ == "DummyScalarsColumn"
|
||||
assert audio_obj_0.values.shape == (100,)
|
||||
assert audio_obj_1.values.shape == (100,)
|
||||
|
||||
|
||||
def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
vu, calls = mock_rerun
|
||||
|
||||
vu.log_rerun_data(
|
||||
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
|
||||
observation={
|
||||
"observation.temp": 10.0,
|
||||
"observation.gray": np.zeros((8, 8, 1), dtype=np.uint8),
|
||||
"observation.audio": np.zeros((100, 2), dtype=np.float32),
|
||||
},
|
||||
action={"action.a": 1.0},
|
||||
)
|
||||
|
||||
keys = set(_keys(calls))
|
||||
assert "observation.temp" in keys
|
||||
assert "data/observation.temp" in keys
|
||||
assert "observation.gray" in keys
|
||||
assert "action.a" in keys
|
||||
assert "data/action.a" in keys
|
||||
assert "audio/observation.audio_channel_0" in keys
|
||||
assert "audio/observation.audio_channel_1" in keys
|
||||
|
||||
temp = _obj_for(calls, "observation.temp")
|
||||
temp = _obj_for(calls, "data/observation.temp")
|
||||
assert type(temp).__name__ == "DummyScalar"
|
||||
assert temp.value == pytest.approx(10.0)
|
||||
|
||||
@@ -224,6 +296,13 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
assert img.arr.shape == (8, 8, 1) # remains HWC
|
||||
assert _kwargs_for(calls, "observation.gray").get("static", False) is True
|
||||
|
||||
a = _obj_for(calls, "action.a")
|
||||
a = _obj_for(calls, "data/action.a")
|
||||
assert type(a).__name__ == "DummyScalar"
|
||||
assert a.value == pytest.approx(1.0)
|
||||
|
||||
audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0")
|
||||
audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1")
|
||||
assert type(audio_obj_0).__name__ == "DummyScalarsColumn"
|
||||
assert type(audio_obj_1).__name__ == "DummyScalarsColumn"
|
||||
assert audio_obj_0.values.shape == (100,)
|
||||
assert audio_obj_1.values.shape == (100,)
|
||||
|
||||
Reference in New Issue
Block a user