mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
test(update): update tests
This commit is contained in:
@@ -30,25 +30,46 @@ from lerobot.utils.constants import OBS_STATE
|
||||
@pytest.fixture
|
||||
def mock_rerun(monkeypatch):
|
||||
"""
|
||||
Provide a mock `rerun` module so tests don't depend on the real library.
|
||||
Also reload the module-under-test so it binds to this mock `rr`.
|
||||
Provide a mock `rerun` module (and `rerun.blueprint` submodule) so tests don't
|
||||
depend on the real library. Also reload the module-under-test so it binds to
|
||||
this mock `rr`.
|
||||
"""
|
||||
calls = []
|
||||
blueprints = []
|
||||
|
||||
class DummyScalar:
|
||||
def __init__(self, value):
|
||||
self.value = float(value)
|
||||
# Scalars may be built from a single float or from a 1D array batch.
|
||||
self.value = value
|
||||
|
||||
class DummyImage:
|
||||
def __init__(self, arr):
|
||||
self.arr = arr
|
||||
|
||||
def compress(self, *a, **k):
|
||||
return self
|
||||
|
||||
def dummy_log(key, obj=None, **kwargs):
|
||||
# Accept either positional `obj` or keyword `entity` and record remaining kwargs.
|
||||
if obj is None and "entity" in kwargs:
|
||||
obj = kwargs.pop("entity")
|
||||
calls.append((key, obj, kwargs))
|
||||
|
||||
def dummy_send_blueprint(blueprint, *a, **k):
|
||||
blueprints.append(blueprint)
|
||||
|
||||
# Mock the `rerun.blueprint` submodule used to build the layout.
|
||||
dummy_rrb = SimpleNamespace(
|
||||
Spatial2DView=lambda origin=None, name=None: SimpleNamespace(
|
||||
kind="Spatial2DView", origin=origin, name=name
|
||||
),
|
||||
TimeSeriesView=lambda name=None, contents=None: SimpleNamespace(
|
||||
kind="TimeSeriesView", name=name, contents=contents
|
||||
),
|
||||
Grid=lambda *views: SimpleNamespace(kind="Grid", views=list(views)),
|
||||
Blueprint=lambda root: SimpleNamespace(kind="Blueprint", root=root),
|
||||
)
|
||||
|
||||
dummy_rr = SimpleNamespace(
|
||||
__name__="rerun",
|
||||
__package__="rerun",
|
||||
@@ -56,20 +77,23 @@ def mock_rerun(monkeypatch):
|
||||
Scalars=DummyScalar,
|
||||
Image=DummyImage,
|
||||
log=dummy_log,
|
||||
send_blueprint=dummy_send_blueprint,
|
||||
init=lambda *a, **k: None,
|
||||
spawn=lambda *a, **k: None,
|
||||
blueprint=dummy_rrb,
|
||||
)
|
||||
|
||||
# Inject fake module into sys.modules
|
||||
# Inject fake modules into sys.modules (both `rerun` and `rerun.blueprint`).
|
||||
monkeypatch.setitem(sys.modules, "rerun", dummy_rr)
|
||||
monkeypatch.setitem(sys.modules, "rerun.blueprint", dummy_rrb)
|
||||
|
||||
# Now import and reload the module under test, to bind to our rerun mock
|
||||
import lerobot.utils.visualization_utils as vu
|
||||
|
||||
importlib.reload(vu)
|
||||
|
||||
# Expose both the reloaded module and the call recorder
|
||||
yield vu, calls
|
||||
# Expose the reloaded module, the call recorder and the captured blueprints
|
||||
yield vu, calls, blueprints
|
||||
|
||||
|
||||
def _keys(calls):
|
||||
@@ -92,8 +116,13 @@ def _kwargs_for(calls, key):
|
||||
raise KeyError(f"Key {key} not found in calls: {calls}")
|
||||
|
||||
|
||||
def _views_by_kind(blueprint, kind):
|
||||
"""Return the views of a given kind from the (single) blueprint's grid."""
|
||||
return [v for v in blueprint.root.views if v.kind == kind]
|
||||
|
||||
|
||||
def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
vu, calls = mock_rerun
|
||||
vu, calls, blueprints = mock_rerun
|
||||
|
||||
# Build EnvTransition dict
|
||||
obs = {
|
||||
@@ -103,7 +132,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
}
|
||||
act = {
|
||||
"action.throttle": 0.7,
|
||||
# 1D array should log individual Scalars with suffix _i
|
||||
# 1D array should be logged as a single Scalars batch under one entity path
|
||||
"action.vector": np.array([1.0, 2.0], dtype=np.float32),
|
||||
}
|
||||
transition = {
|
||||
@@ -120,31 +149,28 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
# - observation.state.temperature -> Scalars
|
||||
# - observation.camera -> Image (HWC) with static=True
|
||||
# - action.throttle -> Scalars
|
||||
# - action.vector_0, action.vector_1 -> Scalars
|
||||
# - action.vector -> single Scalars batch (no per-element suffix)
|
||||
expected_keys = {
|
||||
f"{OBS_STATE}.temperature",
|
||||
"observation.camera",
|
||||
"action.throttle",
|
||||
"action.vector_0",
|
||||
"action.vector_1",
|
||||
"action.vector",
|
||||
}
|
||||
assert set(_keys(calls)) == expected_keys
|
||||
|
||||
# Check scalar types and values
|
||||
temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature")
|
||||
assert type(temp_obj).__name__ == "DummyScalar"
|
||||
assert temp_obj.value == pytest.approx(25.0)
|
||||
assert float(temp_obj.value) == pytest.approx(25.0)
|
||||
|
||||
throttle_obj = _obj_for(calls, "action.throttle")
|
||||
assert type(throttle_obj).__name__ == "DummyScalar"
|
||||
assert throttle_obj.value == pytest.approx(0.7)
|
||||
assert float(throttle_obj.value) == pytest.approx(0.7)
|
||||
|
||||
v0 = _obj_for(calls, "action.vector_0")
|
||||
v1 = _obj_for(calls, "action.vector_1")
|
||||
assert type(v0).__name__ == "DummyScalar"
|
||||
assert type(v1).__name__ == "DummyScalar"
|
||||
assert v0.value == pytest.approx(1.0)
|
||||
assert v1.value == pytest.approx(2.0)
|
||||
# 1D vector logged as a single batched Scalars under one entity path
|
||||
vec = _obj_for(calls, "action.vector")
|
||||
assert type(vec).__name__ == "DummyScalar"
|
||||
np.testing.assert_allclose(np.asarray(vec.value), [1.0, 2.0])
|
||||
|
||||
# Check image handling: CHW -> HWC
|
||||
img_obj = _obj_for(calls, "observation.camera")
|
||||
@@ -152,9 +178,24 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
assert img_obj.arr.shape == (10, 20, 3) # transposed
|
||||
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
|
||||
|
||||
# A blueprint should have been built and sent exactly once, and cached on the function.
|
||||
assert len(blueprints) == 1
|
||||
assert vu.log_rerun_data.blueprint is blueprints[0]
|
||||
|
||||
bp = blueprints[0]
|
||||
# One spatial view per image path
|
||||
spatial_views = _views_by_kind(bp, "Spatial2DView")
|
||||
assert {v.origin for v in spatial_views} == {"observation.camera"}
|
||||
|
||||
# One time-series view each for observation and action scalars
|
||||
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
|
||||
assert set(ts_views) == {"observation", "action"}
|
||||
assert ts_views["observation"].contents == [f"{OBS_STATE}.temperature"]
|
||||
assert ts_views["action"].contents == ["action.throttle", "action.vector"]
|
||||
|
||||
|
||||
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
vu, calls = mock_rerun
|
||||
vu, calls, blueprints = mock_rerun
|
||||
|
||||
# First dict without prefixes treated as observation
|
||||
# Second dict without prefixes treated as action
|
||||
@@ -173,14 +214,12 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
# First dict was treated as observation, second as action
|
||||
vu.log_rerun_data(observation=obs_plain, action=act_plain)
|
||||
|
||||
# Expected keys with auto-prefixes
|
||||
# Expected keys with auto-prefixes. The 1D vector is a single batched Scalars.
|
||||
expected = {
|
||||
"observation.temp",
|
||||
"observation.img",
|
||||
"action.throttle",
|
||||
"action.vec_0",
|
||||
"action.vec_1",
|
||||
"action.vec_2",
|
||||
"action.vec",
|
||||
}
|
||||
logged = set(_keys(calls))
|
||||
assert logged == expected
|
||||
@@ -188,11 +227,11 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
# Scalars
|
||||
t = _obj_for(calls, "observation.temp")
|
||||
assert type(t).__name__ == "DummyScalar"
|
||||
assert t.value == pytest.approx(1.5)
|
||||
assert float(t.value) == pytest.approx(1.5)
|
||||
|
||||
throttle = _obj_for(calls, "action.throttle")
|
||||
assert type(throttle).__name__ == "DummyScalar"
|
||||
assert throttle.value == pytest.approx(0.3)
|
||||
assert float(throttle.value) == pytest.approx(0.3)
|
||||
|
||||
# Image stays HWC
|
||||
img = _obj_for(calls, "observation.img")
|
||||
@@ -200,15 +239,23 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
assert img.arr.shape == (5, 6, 3)
|
||||
assert _kwargs_for(calls, "observation.img").get("static", False) is True
|
||||
|
||||
# Vectors
|
||||
for i, val in enumerate([9, 8, 7]):
|
||||
o = _obj_for(calls, f"action.vec_{i}")
|
||||
assert type(o).__name__ == "DummyScalar"
|
||||
assert o.value == pytest.approx(val)
|
||||
# Vector logged as a single batched Scalars under one entity path
|
||||
vec = _obj_for(calls, "action.vec")
|
||||
assert type(vec).__name__ == "DummyScalar"
|
||||
np.testing.assert_allclose(np.asarray(vec.value), [9, 8, 7])
|
||||
|
||||
# Blueprint sent once with the expected view layout
|
||||
assert len(blueprints) == 1
|
||||
bp = blueprints[0]
|
||||
spatial_views = _views_by_kind(bp, "Spatial2DView")
|
||||
assert {v.origin for v in spatial_views} == {"observation.img"}
|
||||
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
|
||||
assert ts_views["observation"].contents == ["observation.temp"]
|
||||
assert ts_views["action"].contents == ["action.throttle", "action.vec"]
|
||||
|
||||
|
||||
def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
vu, calls = mock_rerun
|
||||
vu, calls, blueprints = mock_rerun
|
||||
|
||||
vu.log_rerun_data(
|
||||
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
|
||||
@@ -222,7 +269,7 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
|
||||
temp = _obj_for(calls, "observation.temp")
|
||||
assert type(temp).__name__ == "DummyScalar"
|
||||
assert temp.value == pytest.approx(10.0)
|
||||
assert float(temp.value) == pytest.approx(10.0)
|
||||
|
||||
img = _obj_for(calls, "observation.gray")
|
||||
assert type(img).__name__ == "DummyImage"
|
||||
@@ -231,4 +278,26 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
|
||||
a = _obj_for(calls, "action.a")
|
||||
assert type(a).__name__ == "DummyScalar"
|
||||
assert a.value == pytest.approx(1.0)
|
||||
assert float(a.value) == pytest.approx(1.0)
|
||||
|
||||
# Blueprint sent once, with a spatial view for the image and time-series views for scalars
|
||||
assert len(blueprints) == 1
|
||||
bp = blueprints[0]
|
||||
assert {v.origin for v in _views_by_kind(bp, "Spatial2DView")} == {"observation.gray"}
|
||||
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
|
||||
assert ts_views["observation"].contents == ["observation.temp"]
|
||||
assert ts_views["action"].contents == ["action.a"]
|
||||
|
||||
|
||||
def test_log_rerun_data_blueprint_sent_only_once(mock_rerun):
|
||||
"""The blueprint is built from the first call and not resent on subsequent calls."""
|
||||
vu, calls, blueprints = mock_rerun
|
||||
|
||||
vu.log_rerun_data(observation={"temp": 1.0}, action={"a": 2.0})
|
||||
assert len(blueprints) == 1
|
||||
first_blueprint = vu.log_rerun_data.blueprint
|
||||
|
||||
vu.log_rerun_data(observation={"temp": 3.0}, action={"a": 4.0})
|
||||
# Still only one blueprint, and the cached one is unchanged.
|
||||
assert len(blueprints) == 1
|
||||
assert vu.log_rerun_data.blueprint is first_blueprint
|
||||
|
||||
Reference in New Issue
Block a user