diff --git a/pyproject.toml b/pyproject.toml index c4b1c547e..e5431ada3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,9 +76,9 @@ dependencies = [ "pyserial>=3.5,<4.0", "wandb>=0.24.0,<0.25.0", - "torch>=2.2.1,<2.8.0", # TODO: Bumb dependency - "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency - "torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency + "torch>=2.2.1,<2.11.0", # TODO: Bump dependency + "torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency + "torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency "draccus==0.10.0", # TODO: Remove == "gymnasium>=1.1.1,<2.0.0", diff --git a/tests/optim/test_schedulers.py b/tests/optim/test_schedulers.py index 1e566a6ba..224613416 100644 --- a/tests/optim/test_schedulers.py +++ b/tests/optim/test_schedulers.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch +from packaging.version import Version from torch.optim.lr_scheduler import LambdaLR from lerobot.optim.schedulers import ( @@ -38,6 +40,10 @@ def test_diffuser_scheduler(optimizer): "last_epoch": 1, "lr_lambdas": [None], } + + if Version(torch.__version__) >= Version("2.8"): + expected_state_dict["_is_initial"] = False + assert scheduler.state_dict() == expected_state_dict @@ -56,6 +62,10 @@ def test_vqbet_scheduler(optimizer): "last_epoch": 1, "lr_lambdas": [None], } + + if Version(torch.__version__) >= Version("2.8"): + expected_state_dict["_is_initial"] = False + assert scheduler.state_dict() == expected_state_dict @@ -76,6 +86,10 @@ def test_cosine_decay_with_warmup_scheduler(optimizer): "last_epoch": 1, "lr_lambdas": [None], } + + if Version(torch.__version__) >= Version("2.8"): + expected_state_dict["_is_initial"] = False + assert scheduler.state_dict() == expected_state_dict