mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
preparing for training adding some temporary debug code aswell to visualize model output
This commit is contained in:
@@ -49,6 +49,8 @@ def test_fastwam_is_registered_and_publicly_exported():
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
num_video_frames=5,
|
||||
action_video_freq_ratio=1,
|
||||
base_model_id=None,
|
||||
)
|
||||
|
||||
@@ -78,6 +80,8 @@ def test_preprocessor_normalizes_images_and_postprocessor_toggles_actions(tmp_pa
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
num_video_frames=5,
|
||||
action_video_freq_ratio=1,
|
||||
image_size=(2, 2),
|
||||
device="cpu",
|
||||
toggle_action_dimensions=[-1],
|
||||
@@ -154,6 +158,8 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
|
||||
proprio_dim=2,
|
||||
action_horizon=4,
|
||||
n_action_steps=2,
|
||||
num_video_frames=5,
|
||||
action_video_freq_ratio=1,
|
||||
image_size=(16, 16),
|
||||
input_features={
|
||||
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
|
||||
@@ -164,7 +170,7 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
|
||||
)
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
output = policy.forward(
|
||||
loss, metrics = policy.forward(
|
||||
{
|
||||
"observation.images.image": torch.zeros(1, 3, 16, 16),
|
||||
OBS_STATE: torch.zeros(1, 2),
|
||||
@@ -186,8 +192,8 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
|
||||
}
|
||||
)
|
||||
|
||||
assert output["loss"].item() == 1.0
|
||||
assert output["loss_action"].item() == 1.0
|
||||
assert loss.item() == 1.0
|
||||
assert metrics["loss_action"] == 1.0
|
||||
assert action.shape == (2, 4, 3)
|
||||
assert action[:, 0, 0].tolist() == [1.0, 2.0]
|
||||
assert [item["image_shape"] for item in captured] == [(1, 3, 16, 16), (1, 3, 16, 16)]
|
||||
@@ -218,7 +224,7 @@ class CoreWithFrozenComponents(FakeFastWAMCore):
|
||||
|
||||
|
||||
def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tmp_path):
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
|
||||
|
||||
def build_core(self, config):
|
||||
core = CoreWithFrozenComponents()
|
||||
@@ -250,7 +256,7 @@ def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tm
|
||||
|
||||
|
||||
def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
@@ -272,7 +278,7 @@ def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
|
||||
|
||||
|
||||
def test_frozen_components_excluded_from_params_but_follow_device_moves(monkeypatch):
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
|
||||
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
|
||||
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
|
||||
policy = FastWAMPolicy(cfg)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user