Commit Graph

40 Commits

Author SHA1 Message Date
Alexander Soare 342f429f1c Add test to make sure policy dataclass configs match yaml configs (#292) 2024-06-26 09:09:40 +01:00
Radek Osmulski 504d2aaf48 add EpisodeAwareSampler (#217)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
2024-05-31 13:43:47 +01:00
Alexander Soare e3b9f1c19b Add resume training (#205)
Co-authored-by: Remi <re.cadene@gmail.com>
2024-05-28 12:04:23 +01:00
Alexander Soare 473345fdf6 Fix stats override in ACT config (#161) 2024-05-09 15:16:47 +01:00
Akshay Kashyap 460df2ccea Support for DDIMScheduler in Diffusion Policy (#146) 2024-05-08 18:05:16 +01:00
Alexander Soare a8e245fb31 Remove loss masking from diffusion policy (#135) 2024-05-06 07:27:01 +01:00
Alexander Soare f3bba0270d Remove EMA model from Diffusion Policy (#134) 2024-05-05 11:26:12 +01:00
Alexander Soare bccee745c3 Refactor eval.py (#127) 2024-05-03 17:33:16 +01:00
Alexander Soare a4891095e4 Use PytorchModelHubMixin to save models as safetensors (#125)
Co-authored-by: Remi <re.cadene@gmail.com>
2024-05-01 16:17:18 +01:00
Alexander Soare 9d60dce6f3 Tidy up yaml configs (#121) 2024-04-30 16:08:59 +01:00
Simon Alibert 791506dfb8 Remove warnings (#111)
- Replace `use_pretrained_backbone` with `pretrained_backbone_weights`
- Bump diffusers' minimum version `0.26.3` -> `0.27.2`
- Add ignore flags in CI's pytest
- Change Box observation spaces in simulation environments
- Set `version_base="1.2"` in Hydra initializations
- Bump einops' minimum version `0.7.0` -> `0.8.0`
2024-04-29 00:31:33 +02:00
Alexander Soare 45f351c618 Make sure targets are normalized too (#106) 2024-04-26 11:18:39 +01:00
Remi e760e4cd63 Move normalization to policy for act and diffusion (#90)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
2024-04-25 11:47:38 +02:00
Alexander Soare 03b08eb74e backup wip 2024-04-16 12:51:32 +01:00
Alexander Soare 5608e659e6 backup wip 2024-04-15 19:06:44 +01:00
Alexander Soare 976a197f98 backup wip 2024-04-11 17:51:35 +01:00
Cadene 73dfa3c8e3 tests for tdmpc and diffusion policy are passing 2024-04-09 02:50:32 +00:00
Cadene 70aaf1c4cb test_datasets.py are passing! 2024-04-08 14:16:57 +00:00
Cadene 4371a5570d Remove latency, tdmpc policy passes tests (TODO: make it work with online RL) 2024-04-07 16:01:22 +00:00
Alexander Soare dc745e3037 Remove unused part of diffusion policy config 2024-03-27 13:05:13 +00:00
Alexander Soare acf1174447 ready for review 2024-03-21 10:18:50 +00:00
Alexander Soare d323993569 backup wip 2024-03-20 15:01:27 +00:00
Alexander Soare 32e3f71dd1 backup wip 2024-03-20 09:49:16 +00:00
Alexander Soare 896a11f60e backup wip 2024-03-19 18:50:04 +00:00
Alexander Soare ea17f4ce50 backup wip 2024-03-19 16:02:09 +00:00
Alexander Soare 88347965c2 revert dp changes, make act and tdmpc batch friendly 2024-03-18 19:18:21 +00:00
Alexander Soare 98484ac68e ready for review 2024-03-12 21:59:01 +00:00
Alexander Soare 87fcc536f9 wip - still need to verify full training run 2024-03-11 18:45:21 +00:00
Alexander Soare 2a01487494 early training loss as expected 2024-03-11 13:34:04 +00:00
Simon Alibert 6c867d78ef Integrate pusht env from diffusion 2024-03-10 16:33:03 +01:00
Remi Cadene a027f4edfb Add cfg.offline_prioritized_sampler 2024-03-04 23:08:52 +00:00
Remi Cadene e29fbb50e8 Fix grad_clip_norm 0 -> 10, Fix normalization min_max to be per channel 2024-03-04 17:26:34 +00:00
Remi Cadene cfc304e870 Refactor env queue, Training diffusion works (Still not converging) 2024-03-04 11:00:51 +00:00
Remi Cadene 0f2fa4d9ef Add obs queue to pusht, Set n_obs_steps=2 for diffusion (Not fully tested) 2024-03-03 13:21:31 +00:00
Remi Cadene 661bda45ea imagenet_norm: False 2024-03-02 17:18:58 +00:00
Cadene 0b9027f05e Clean logging, Refactor 2024-02-29 23:21:27 +00:00
Cadene ac90b9c3ee Fix diffusion (rm transpose), Add prefetch 2024-02-28 17:45:01 +00:00
Cadene cf5063e50e Add diffusion policy (train and eval works, TODO: reproduce results) 2024-02-28 15:21:42 +00:00
Cadene 7df542445c Small fix and improve logging message 2024-02-27 11:44:26 +00:00
Cadene 21670dce90 Refactor train, eval_policy, logger, Add diffusion.yaml (WIP) 2024-02-26 01:10:09 +00:00