mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
test(rerun): update tests
This commit is contained in:
@@ -37,7 +37,14 @@ def mock_rerun(monkeypatch):
|
||||
|
||||
class DummyScalar:
|
||||
def __init__(self, value):
|
||||
self.value = float(value)
|
||||
arr = np.asarray(value, dtype=float)
|
||||
# Keep a flat list of values plus a convenience `value` for scalar inputs.
|
||||
self.values = arr.reshape(-1).tolist()
|
||||
self.value = self.values[0] if arr.ndim == 0 else None
|
||||
|
||||
class DummySeriesLines:
|
||||
def __init__(self, names=None):
|
||||
self.names = names
|
||||
|
||||
class DummyImage:
|
||||
def __init__(self, arr):
|
||||
@@ -54,6 +61,7 @@ def mock_rerun(monkeypatch):
|
||||
__package__="rerun",
|
||||
__spec__=SimpleNamespace(name="rerun", submodule_search_locations=None),
|
||||
Scalars=DummyScalar,
|
||||
SeriesLines=DummySeriesLines,
|
||||
Image=DummyImage,
|
||||
log=dummy_log,
|
||||
init=lambda *a, **k: None,
|
||||
@@ -85,6 +93,14 @@ def _obj_for(calls, key):
|
||||
raise KeyError(f"Key {key} not found in calls: {calls}")
|
||||
|
||||
|
||||
def _obj_for_type(calls, key, type_name):
|
||||
"""Find the first object of a given type name logged under a given key."""
|
||||
for k, obj, _kw in calls:
|
||||
if k == key and type(obj).__name__ == type_name:
|
||||
return obj
|
||||
raise KeyError(f"Key {key} with type {type_name} not found in calls: {calls}")
|
||||
|
||||
|
||||
def _kwargs_for(calls, key):
|
||||
for k, _obj, kw in calls:
|
||||
if k == key:
|
||||
@@ -103,7 +119,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
}
|
||||
act = {
|
||||
"action.throttle": 0.7,
|
||||
# 1D array should log individual Scalars with suffix _i
|
||||
# 1D array should log a single Scalars batch under one entity path
|
||||
"action.vector": np.array([1.0, 2.0], dtype=np.float32),
|
||||
}
|
||||
transition = {
|
||||
@@ -120,13 +136,12 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
# - observation.state.temperature -> Scalars
|
||||
# - observation.camera -> Image (HWC) with static=True
|
||||
# - action.throttle -> Scalars
|
||||
# - action.vector_0, action.vector_1 -> Scalars
|
||||
# - action.vector -> single Scalars batch under one entity path
|
||||
expected_keys = {
|
||||
f"{OBS_STATE}.temperature",
|
||||
"observation.camera",
|
||||
"action.throttle",
|
||||
"action.vector_0",
|
||||
"action.vector_1",
|
||||
"action.vector",
|
||||
}
|
||||
assert set(_keys(calls)) == expected_keys
|
||||
|
||||
@@ -139,12 +154,13 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
assert type(throttle_obj).__name__ == "DummyScalar"
|
||||
assert throttle_obj.value == pytest.approx(0.7)
|
||||
|
||||
v0 = _obj_for(calls, "action.vector_0")
|
||||
v1 = _obj_for(calls, "action.vector_1")
|
||||
assert type(v0).__name__ == "DummyScalar"
|
||||
assert type(v1).__name__ == "DummyScalar"
|
||||
assert v0.value == pytest.approx(1.0)
|
||||
assert v1.value == pytest.approx(2.0)
|
||||
# The full vector is logged as one batch so all dimensions share a single view
|
||||
vector_obj = _obj_for_type(calls, "action.vector", "DummyScalar")
|
||||
assert vector_obj.values == pytest.approx([1.0, 2.0])
|
||||
|
||||
# Series keep their `{key}_{i}` names via SeriesLines
|
||||
vector_names = _obj_for_type(calls, "action.vector", "DummySeriesLines")
|
||||
assert vector_names.names == ["action.vector_0", "action.vector_1"]
|
||||
|
||||
# Check image handling: CHW -> HWC
|
||||
img_obj = _obj_for(calls, "observation.camera")
|
||||
@@ -178,9 +194,7 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
"observation.temp",
|
||||
"observation.img",
|
||||
"action.throttle",
|
||||
"action.vec_0",
|
||||
"action.vec_1",
|
||||
"action.vec_2",
|
||||
"action.vec",
|
||||
}
|
||||
logged = set(_keys(calls))
|
||||
assert logged == expected
|
||||
@@ -200,11 +214,11 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
assert img.arr.shape == (5, 6, 3)
|
||||
assert _kwargs_for(calls, "observation.img").get("static", False) is True
|
||||
|
||||
# Vectors
|
||||
for i, val in enumerate([9, 8, 7]):
|
||||
o = _obj_for(calls, f"action.vec_{i}")
|
||||
assert type(o).__name__ == "DummyScalar"
|
||||
assert o.value == pytest.approx(val)
|
||||
# Vector logged as a single batch under one entity path, keeping per-dimension names
|
||||
vec = _obj_for_type(calls, "action.vec", "DummyScalar")
|
||||
assert vec.values == pytest.approx([9, 8, 7])
|
||||
vec_names = _obj_for_type(calls, "action.vec", "DummySeriesLines")
|
||||
assert vec_names.names == ["action.vec_0", "action.vec_1", "action.vec_2"]
|
||||
|
||||
|
||||
def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
|
||||
Reference in New Issue
Block a user