mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
test(rerun audio): adding tests for audio visualization with rerun
This commit is contained in:
@@ -37,6 +37,14 @@ def mock_rerun(monkeypatch):
|
|||||||
def __init__(self, value):
|
def __init__(self, value):
|
||||||
self.value = float(value)
|
self.value = float(value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def columns(scalars):
|
||||||
|
return DummyScalarsColumn(scalars)
|
||||||
|
|
||||||
|
class DummyScalarsColumn:
|
||||||
|
def __init__(self, values):
|
||||||
|
self.values = values
|
||||||
|
|
||||||
class DummyImage:
|
class DummyImage:
|
||||||
def __init__(self, arr):
|
def __init__(self, arr):
|
||||||
self.arr = arr
|
self.arr = arr
|
||||||
@@ -47,12 +55,46 @@ def mock_rerun(monkeypatch):
|
|||||||
obj = kwargs.pop("entity")
|
obj = kwargs.pop("entity")
|
||||||
calls.append((key, obj, kwargs))
|
calls.append((key, obj, kwargs))
|
||||||
|
|
||||||
|
def dummy_send_columns(key, indexes, columns, **kwargs):
|
||||||
|
calls.append((key, columns, kwargs))
|
||||||
|
|
||||||
|
def dummy_time_column(timeline, timestamp):
|
||||||
|
return timestamp
|
||||||
|
|
||||||
|
def dummy_set_time(timeline, timestamp):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class DummyTimeSeriesView:
|
||||||
|
def __call__(self, origin, plot_legend=None):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class DummySpatial2DView:
|
||||||
|
def __call__(self, origin):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class DummyGrid:
|
||||||
|
def __call__(self, *args):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class DummyPlotLegend:
|
||||||
|
def __call__(self, visible=True):
|
||||||
|
return None
|
||||||
|
|
||||||
dummy_rr = SimpleNamespace(
|
dummy_rr = SimpleNamespace(
|
||||||
Scalars=DummyScalar,
|
Scalars=DummyScalar,
|
||||||
Image=DummyImage,
|
Image=DummyImage,
|
||||||
log=dummy_log,
|
log=dummy_log,
|
||||||
|
TimeColumn=dummy_time_column,
|
||||||
|
send_columns=dummy_send_columns,
|
||||||
|
set_time=dummy_set_time,
|
||||||
init=lambda *a, **k: None,
|
init=lambda *a, **k: None,
|
||||||
spawn=lambda *a, **k: None,
|
spawn=lambda *a, **k: None,
|
||||||
|
blueprint=SimpleNamespace(
|
||||||
|
TimeSeriesView=DummyTimeSeriesView,
|
||||||
|
Spatial2DView=DummySpatial2DView,
|
||||||
|
Grid=DummyGrid,
|
||||||
|
PlotLegend=DummyPlotLegend,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Inject fake module into sys.modules
|
# Inject fake module into sys.modules
|
||||||
@@ -87,7 +129,7 @@ def _kwargs_for(calls, key):
|
|||||||
raise KeyError(f"Key {key} not found in calls: {calls}")
|
raise KeyError(f"Key {key} not found in calls: {calls}")
|
||||||
|
|
||||||
|
|
||||||
def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
def test_log_rerun_data_envtransition_scalars_image_audio(mock_rerun):
|
||||||
vu, calls = mock_rerun
|
vu, calls = mock_rerun
|
||||||
|
|
||||||
# Build EnvTransition dict
|
# Build EnvTransition dict
|
||||||
@@ -95,6 +137,8 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
|||||||
f"{OBS_STATE}.temperature": np.float32(25.0),
|
f"{OBS_STATE}.temperature": np.float32(25.0),
|
||||||
# CHW image should be converted to HWC for rr.Image
|
# CHW image should be converted to HWC for rr.Image
|
||||||
"observation.camera": np.zeros((3, 10, 20), dtype=np.uint8),
|
"observation.camera": np.zeros((3, 10, 20), dtype=np.uint8),
|
||||||
|
# Multiple channels audio data should be split into separate channels and logged as rr.Scalars.columns
|
||||||
|
"observation.audio": np.zeros((100, 2), dtype=np.float32),
|
||||||
}
|
}
|
||||||
act = {
|
act = {
|
||||||
"action.throttle": 0.7,
|
"action.throttle": 0.7,
|
||||||
@@ -117,25 +161,27 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
|||||||
# - action.throttle -> Scalars
|
# - action.throttle -> Scalars
|
||||||
# - action.vector_0, action.vector_1 -> Scalars
|
# - action.vector_0, action.vector_1 -> Scalars
|
||||||
expected_keys = {
|
expected_keys = {
|
||||||
f"{OBS_STATE}.temperature",
|
"data/" + f"{OBS_STATE}.temperature",
|
||||||
"observation.camera",
|
"observation.camera",
|
||||||
"action.throttle",
|
"data/action.throttle",
|
||||||
"action.vector_0",
|
"data/action.vector_0",
|
||||||
"action.vector_1",
|
"data/action.vector_1",
|
||||||
|
"audio/observation.audio_channel_0",
|
||||||
|
"audio/observation.audio_channel_1",
|
||||||
}
|
}
|
||||||
assert set(_keys(calls)) == expected_keys
|
assert set(_keys(calls)) == expected_keys
|
||||||
|
|
||||||
# Check scalar types and values
|
# Check scalar types and values
|
||||||
temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature")
|
temp_obj = _obj_for(calls, f"data/{OBS_STATE}.temperature")
|
||||||
assert type(temp_obj).__name__ == "DummyScalar"
|
assert type(temp_obj).__name__ == "DummyScalar"
|
||||||
assert temp_obj.value == pytest.approx(25.0)
|
assert temp_obj.value == pytest.approx(25.0)
|
||||||
|
|
||||||
throttle_obj = _obj_for(calls, "action.throttle")
|
throttle_obj = _obj_for(calls, "data/action.throttle")
|
||||||
assert type(throttle_obj).__name__ == "DummyScalar"
|
assert type(throttle_obj).__name__ == "DummyScalar"
|
||||||
assert throttle_obj.value == pytest.approx(0.7)
|
assert throttle_obj.value == pytest.approx(0.7)
|
||||||
|
|
||||||
v0 = _obj_for(calls, "action.vector_0")
|
v0 = _obj_for(calls, "data/action.vector_0")
|
||||||
v1 = _obj_for(calls, "action.vector_1")
|
v1 = _obj_for(calls, "data/action.vector_1")
|
||||||
assert type(v0).__name__ == "DummyScalar"
|
assert type(v0).__name__ == "DummyScalar"
|
||||||
assert type(v1).__name__ == "DummyScalar"
|
assert type(v1).__name__ == "DummyScalar"
|
||||||
assert v0.value == pytest.approx(1.0)
|
assert v0.value == pytest.approx(1.0)
|
||||||
@@ -147,6 +193,14 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
|||||||
assert img_obj.arr.shape == (10, 20, 3) # transposed
|
assert img_obj.arr.shape == (10, 20, 3) # transposed
|
||||||
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
|
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
|
||||||
|
|
||||||
|
# Check audio handling: split channels + rr.Scalars.columns
|
||||||
|
audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0")
|
||||||
|
audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1")
|
||||||
|
assert type(audio_obj_0).__name__ == "DummyScalarsColumn"
|
||||||
|
assert type(audio_obj_1).__name__ == "DummyScalarsColumn"
|
||||||
|
assert audio_obj_0.values.shape == (100,)
|
||||||
|
assert audio_obj_1.values.shape == (100,)
|
||||||
|
|
||||||
|
|
||||||
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||||
vu, calls = mock_rerun
|
vu, calls = mock_rerun
|
||||||
@@ -157,6 +211,8 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
|||||||
"temp": 1.5,
|
"temp": 1.5,
|
||||||
# Already HWC image => should stay as-is
|
# Already HWC image => should stay as-is
|
||||||
"img": np.zeros((5, 6, 3), dtype=np.uint8),
|
"img": np.zeros((5, 6, 3), dtype=np.uint8),
|
||||||
|
# Multiple channels audio data should be split into separate channels
|
||||||
|
"audio": np.zeros((100, 2), dtype=np.float32),
|
||||||
"none": None, # should be skipped
|
"none": None, # should be skipped
|
||||||
}
|
}
|
||||||
act_plain = {
|
act_plain = {
|
||||||
@@ -170,22 +226,24 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
|||||||
|
|
||||||
# Expected keys with auto-prefixes
|
# Expected keys with auto-prefixes
|
||||||
expected = {
|
expected = {
|
||||||
"observation.temp",
|
"data/observation.temp",
|
||||||
"observation.img",
|
"observation.img",
|
||||||
"action.throttle",
|
"data/action.throttle",
|
||||||
"action.vec_0",
|
"data/action.vec_0",
|
||||||
"action.vec_1",
|
"data/action.vec_1",
|
||||||
"action.vec_2",
|
"data/action.vec_2",
|
||||||
|
"audio/observation.audio_channel_0",
|
||||||
|
"audio/observation.audio_channel_1",
|
||||||
}
|
}
|
||||||
logged = set(_keys(calls))
|
logged = set(_keys(calls))
|
||||||
assert logged == expected
|
assert logged == expected
|
||||||
|
|
||||||
# Scalars
|
# Scalars
|
||||||
t = _obj_for(calls, "observation.temp")
|
t = _obj_for(calls, "data/observation.temp")
|
||||||
assert type(t).__name__ == "DummyScalar"
|
assert type(t).__name__ == "DummyScalar"
|
||||||
assert t.value == pytest.approx(1.5)
|
assert t.value == pytest.approx(1.5)
|
||||||
|
|
||||||
throttle = _obj_for(calls, "action.throttle")
|
throttle = _obj_for(calls, "data/action.throttle")
|
||||||
assert type(throttle).__name__ == "DummyScalar"
|
assert type(throttle).__name__ == "DummyScalar"
|
||||||
assert throttle.value == pytest.approx(0.3)
|
assert throttle.value == pytest.approx(0.3)
|
||||||
|
|
||||||
@@ -197,25 +255,39 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
|||||||
|
|
||||||
# Vectors
|
# Vectors
|
||||||
for i, val in enumerate([9, 8, 7]):
|
for i, val in enumerate([9, 8, 7]):
|
||||||
o = _obj_for(calls, f"action.vec_{i}")
|
o = _obj_for(calls, f"data/action.vec_{i}")
|
||||||
assert type(o).__name__ == "DummyScalar"
|
assert type(o).__name__ == "DummyScalar"
|
||||||
assert o.value == pytest.approx(val)
|
assert o.value == pytest.approx(val)
|
||||||
|
|
||||||
|
# Audio
|
||||||
|
audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0")
|
||||||
|
audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1")
|
||||||
|
assert type(audio_obj_0).__name__ == "DummyScalarsColumn"
|
||||||
|
assert type(audio_obj_1).__name__ == "DummyScalarsColumn"
|
||||||
|
assert audio_obj_0.values.shape == (100,)
|
||||||
|
assert audio_obj_1.values.shape == (100,)
|
||||||
|
|
||||||
|
|
||||||
def test_log_rerun_data_kwargs_only(mock_rerun):
|
def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||||
vu, calls = mock_rerun
|
vu, calls = mock_rerun
|
||||||
|
|
||||||
vu.log_rerun_data(
|
vu.log_rerun_data(
|
||||||
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
|
observation={
|
||||||
|
"observation.temp": 10.0,
|
||||||
|
"observation.gray": np.zeros((8, 8, 1), dtype=np.uint8),
|
||||||
|
"observation.audio": np.zeros((100, 2), dtype=np.float32),
|
||||||
|
},
|
||||||
action={"action.a": 1.0},
|
action={"action.a": 1.0},
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = set(_keys(calls))
|
keys = set(_keys(calls))
|
||||||
assert "observation.temp" in keys
|
assert "data/observation.temp" in keys
|
||||||
assert "observation.gray" in keys
|
assert "observation.gray" in keys
|
||||||
assert "action.a" in keys
|
assert "data/action.a" in keys
|
||||||
|
assert "audio/observation.audio_channel_0" in keys
|
||||||
|
assert "audio/observation.audio_channel_1" in keys
|
||||||
|
|
||||||
temp = _obj_for(calls, "observation.temp")
|
temp = _obj_for(calls, "data/observation.temp")
|
||||||
assert type(temp).__name__ == "DummyScalar"
|
assert type(temp).__name__ == "DummyScalar"
|
||||||
assert temp.value == pytest.approx(10.0)
|
assert temp.value == pytest.approx(10.0)
|
||||||
|
|
||||||
@@ -224,6 +296,13 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
|
|||||||
assert img.arr.shape == (8, 8, 1) # remains HWC
|
assert img.arr.shape == (8, 8, 1) # remains HWC
|
||||||
assert _kwargs_for(calls, "observation.gray").get("static", False) is True
|
assert _kwargs_for(calls, "observation.gray").get("static", False) is True
|
||||||
|
|
||||||
a = _obj_for(calls, "action.a")
|
a = _obj_for(calls, "data/action.a")
|
||||||
assert type(a).__name__ == "DummyScalar"
|
assert type(a).__name__ == "DummyScalar"
|
||||||
assert a.value == pytest.approx(1.0)
|
assert a.value == pytest.approx(1.0)
|
||||||
|
|
||||||
|
audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0")
|
||||||
|
audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1")
|
||||||
|
assert type(audio_obj_0).__name__ == "DummyScalarsColumn"
|
||||||
|
assert type(audio_obj_1).__name__ == "DummyScalarsColumn"
|
||||||
|
assert audio_obj_0.values.shape == (100,)
|
||||||
|
assert audio_obj_1.values.shape == (100,)
|
||||||
|
|||||||
Reference in New Issue
Block a user