Refactor env queue, Training diffusion works (Still not converging)

This commit is contained in:
Remi Cadene
2024-03-04 10:59:43 +00:00
parent fddd9f0311
commit cfc304e870
11 changed files with 96 additions and 111 deletions
+12 -1
View File
@@ -143,13 +143,24 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
in_keys=[
# ("observation", "image"),
("observation", "state"),
# TODO(rcadene): for tdmpc, we might want image and state
# ("next", "observation", "image"),
("next", "observation", "state"),
# ("next", "observation", "state"),
("action"),
],
mode="min_max",
)
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, min_max_spec
transform.stats["observation", "state", "min"] = torch.tensor(
[13.456424, 32.938293], dtype=torch.float32
)
transform.stats["observation", "state", "max"] = torch.tensor(
[496.14618, 510.9579], dtype=torch.float32
)
transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
if writer is None:
writer = ImmutableDatasetWriter()
if collate_fn is None: