mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
fix diffusion tests ci
This commit is contained in:
@@ -11,6 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
@@ -23,8 +24,10 @@ from lerobot.optim.schedulers import (
|
|||||||
save_scheduler_state,
|
save_scheduler_state,
|
||||||
)
|
)
|
||||||
from lerobot.utils.constants import SCHEDULER_STATE
|
from lerobot.utils.constants import SCHEDULER_STATE
|
||||||
|
from lerobot.utils.import_utils import is_package_available
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_package_available("diffusers"), reason="diffusers not installed")
|
||||||
def test_diffuser_scheduler(optimizer):
|
def test_diffuser_scheduler(optimizer):
|
||||||
config = DiffuserSchedulerConfig(name="cosine", num_warmup_steps=5)
|
config = DiffuserSchedulerConfig(name="cosine", num_warmup_steps=5)
|
||||||
scheduler = config.build(optimizer, num_training_steps=100)
|
scheduler = config.build(optimizer, num_training_steps=100)
|
||||||
|
|||||||
@@ -43,12 +43,21 @@ from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
|||||||
from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead
|
from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead
|
||||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||||
from lerobot.utils.feature_utils import dataset_to_policy_features
|
from lerobot.utils.feature_utils import dataset_to_policy_features
|
||||||
|
from lerobot.utils.import_utils import is_package_available
|
||||||
from lerobot.utils.random_utils import seeded_context
|
from lerobot.utils.random_utils import seeded_context
|
||||||
from lerobot.utils.utils import cycle
|
from lerobot.utils.utils import cycle
|
||||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||||
|
|
||||||
AVAILABLE_POLICIES = ["act", "diffusion", "tdmpc", "vqbet"]
|
# Policies that require optional heavy dependencies to instantiate
|
||||||
|
_POLICY_REQUIRED_PACKAGES: dict[str, tuple[str, ...]] = {
|
||||||
|
"diffusion": ("diffusers",),
|
||||||
|
}
|
||||||
|
|
||||||
|
_ALL_POLICIES = ["act", "diffusion", "tdmpc", "vqbet"]
|
||||||
|
AVAILABLE_POLICIES = [
|
||||||
|
p for p in _ALL_POLICIES if all(is_package_available(pkg) for pkg in _POLICY_REQUIRED_PACKAGES.get(p, ()))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -344,7 +353,7 @@ def test_multikey_construction(multikey: bool):
|
|||||||
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
|
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
|
||||||
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
|
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
|
||||||
# Thus, we deactivate this test for now.
|
# Thus, we deactivate this test for now.
|
||||||
(
|
pytest.param(
|
||||||
"lerobot/pusht",
|
"lerobot/pusht",
|
||||||
"diffusion",
|
"diffusion",
|
||||||
{
|
{
|
||||||
@@ -353,6 +362,7 @@ def test_multikey_construction(multikey: bool):
|
|||||||
"down_dims": [128, 256, 512],
|
"down_dims": [128, 256, 512],
|
||||||
},
|
},
|
||||||
"",
|
"",
|
||||||
|
marks=pytest.mark.skipif(not is_package_available("diffusers"), reason="diffusers not installed"),
|
||||||
),
|
),
|
||||||
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
|
("lerobot/aloha_sim_insertion_human", "act", {"n_action_steps": 10}, ""),
|
||||||
(
|
(
|
||||||
|
|||||||
Reference in New Issue
Block a user