Refactor env queue, Training diffusion works (Still not converging)

This commit is contained in:
Remi Cadene
2024-03-04 10:59:43 +00:00
parent fddd9f0311
commit cfc304e870
11 changed files with 96 additions and 111 deletions
+2 -2
View File
@@ -69,7 +69,7 @@ def make_offline_buffer(cfg, sampler=None):
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
prefetch=prefetch,
prefetch=prefetch if isinstance(prefetch, int) else None,
)
elif cfg.env.name == "pusht":
offline_buffer = PushtExperienceReplay(
@@ -79,7 +79,7 @@ def make_offline_buffer(cfg, sampler=None):
sampler=sampler,
batch_size=batch_size,
pin_memory=pin_memory,
prefetch=prefetch,
prefetch=prefetch if isinstance(prefetch, int) else None,
)
else:
raise ValueError(cfg.env.name)
+12 -1
View File
@@ -143,13 +143,24 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
in_keys=[
# ("observation", "image"),
("observation", "state"),
# TODO(rcadene): for tdmpc, we might want image and state
# ("next", "observation", "image"),
("next", "observation", "state"),
# ("next", "observation", "state"),
("action"),
],
mode="min_max",
)
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, min_max_spec
transform.stats["observation", "state", "min"] = torch.tensor(
[13.456424, 32.938293], dtype=torch.float32
)
transform.stats["observation", "state", "max"] = torch.tensor(
[496.14618, 510.9579], dtype=torch.float32
)
transform.stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
transform.stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
if writer is None:
writer = ImmutableDatasetWriter()
if collate_fn is None:
+4
View File
@@ -7,6 +7,8 @@ def make_env(cfg, transform=None):
"from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.env.pixels_only,
"image_size": cfg.env.image_size,
# TODO(rcadene): do we want a specific eval_env_seed?
"seed": cfg.seed,
}
if cfg.env.name == "simxarm":
@@ -17,6 +19,8 @@ def make_env(cfg, transform=None):
elif cfg.env.name == "pusht":
from lerobot.common.envs.pusht import PushtEnv
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
clsfunc = PushtEnv
else:
raise ValueError(cfg.env.name)
+18 -41
View File
@@ -101,14 +101,18 @@ class PushtEnv(EnvBase):
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
# remove all previous observations
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.clear()
self._prev_obs_image_queue = deque(
[obs["image"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.clear()
# copy the current observation n times
obs = self._stack_prev_obs(obs)
self._prev_obs_state_queue = deque(
[obs["state"]] * (self.num_prev_obs + 1), maxlen=(self.num_prev_obs + 1)
)
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict(
{
@@ -121,40 +125,6 @@ class PushtEnv(EnvBase):
raise NotImplementedError()
return td
def _stack_prev_obs(self, obs):
"""When the queue is empty, copy the current observation n times."""
assert self.num_prev_obs > 0
def stack_update_queue(prev_obs_queue, obs, num_prev_obs):
# get n most recent observations
prev_obs = list(prev_obs_queue)[-num_prev_obs:]
# if not enough observations, copy the oldest observation until we obtain n observations
if len(prev_obs) == 0:
prev_obs = [obs] * num_prev_obs # queue is empty when env reset
elif len(prev_obs) < num_prev_obs:
prev_obs = [prev_obs[0] for _ in range(num_prev_obs - len(prev_obs))] + prev_obs
# stack n most recent observations with the current observation
stacked_obs = torch.stack(prev_obs + [obs], dim=0)
# add current observation to the queue
# automatically remove oldest observation when queue is full
prev_obs_queue.appendleft(obs)
return stacked_obs
stacked_obs = {}
if "image" in obs:
stacked_obs["image"] = stack_update_queue(
self._prev_obs_image_queue, obs["image"], self.num_prev_obs
)
if "state" in obs:
stacked_obs["state"] = stack_update_queue(
self._prev_obs_state_queue, obs["state"], self.num_prev_obs
)
return stacked_obs
def _step(self, tensordict: TensorDict):
td = tensordict
action = td["action"].numpy()
@@ -176,7 +146,14 @@ class PushtEnv(EnvBase):
obs = self._format_raw_obs(raw_obs)
if self.num_prev_obs > 0:
obs = self._stack_prev_obs(obs)
stacked_obs = {}
if "image" in obs:
self._prev_obs_image_queue.append(obs["image"])
stacked_obs["image"] = torch.stack(list(self._prev_obs_image_queue))
if "state" in obs:
self._prev_obs_state_queue.append(obs["state"])
stacked_obs["state"] = torch.stack(list(self._prev_obs_state_queue))
obs = stacked_obs
td = TensorDict(
{
+2 -44
View File
@@ -1,51 +1,11 @@
import contextlib
import logging
import os
from pathlib import Path
import numpy as np
from omegaconf import OmegaConf
from termcolor import colored
def make_dir(dir_path):
"""Create directory if it does not already exist."""
with contextlib.suppress(OSError):
dir_path.mkdir(parents=True, exist_ok=True)
return dir_path
def print_run(cfg, reward=None):
"""Pretty-printing of run information. Call at start of training."""
prefix, color, attrs = " ", "green", ["bold"]
def limstr(s, maxlen=32):
return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s
def pprint(k, v):
print(
prefix + colored(f'{k.capitalize() + ":":<16}', color, attrs=attrs),
limstr(v),
)
kvs = [
("task", cfg.env.task),
("offline_steps", f"{cfg.offline_steps}"),
("online_steps", f"{cfg.online_steps}"),
("action_repeat", f"{cfg.env.action_repeat}"),
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
# ('actions', cfg.action_dim),
# ('experiment', cfg.exp_name),
]
if reward is not None:
kvs.append(("episode reward", colored(str(int(reward)), "white", attrs=["bold"])))
w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21
div = "-" * w
print(div)
for k, v in kvs:
pprint(k, v)
print(div)
def cfg_to_group(cfg, return_list=False):
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
# lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
@@ -71,13 +31,12 @@ class Logger:
self._seed = cfg.seed
self._cfg = cfg
self._eval = []
print_run(cfg)
project = cfg.get("wandb", {}).get("project")
entity = cfg.get("wandb", {}).get("entity")
enable_wandb = cfg.get("wandb", {}).get("enable", False)
run_offline = not enable_wandb or not project or not entity
if run_offline:
print(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
self._wandb = None
else:
os.environ["WANDB_SILENT"] = "true"
@@ -134,7 +93,6 @@ class Logger:
self.save_buffer(buffer, identifier="buffer")
if self._wandb:
self._wandb.finish()
print_run(self._cfg, self._eval[-1][-1])
def log_dict(self, d, step, mode="train"):
assert mode in {"train", "eval"}
+26 -8
View File
@@ -4,10 +4,8 @@ import time
import hydra
import torch
import torch.nn as nn
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.common.lr_scheduler import get_scheduler
from diffusion_policy.model.vision.model_getter import get_resnet
from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
from .multi_image_obs_encoder import MultiImageObsEncoder
@@ -39,8 +37,8 @@ class DiffusionPolicy(nn.Module):
super().__init__()
self.cfg = cfg
noise_scheduler = DDPMScheduler(**cfg_noise_scheduler)
rgb_model = get_resnet(**cfg_rgb_model)
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg_obs_encoder,
@@ -127,16 +125,36 @@ class DiffusionPolicy(nn.Module):
# (t h) ... -> t h ...
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
# |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16
# |o|o| observations: 2
# | |a|a|a|a|a|a|a|a| actions executed: 8
# |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16
# note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model
image = batch["observation", "image"]
state = batch["observation", "state"]
action = batch["action"]
assert image.shape[1] == horizon
assert state.shape[1] == horizon
assert action.shape[1] == horizon
if not (horizon == 16 and self.cfg.n_obs_steps == 2):
raise NotImplementedError()
# keep first 2 observations of the slice corresponding to t=[-1,0]
image = image[:, : self.cfg.n_obs_steps]
state = state[:, : self.cfg.n_obs_steps]
out = {
"obs": {
"image": batch["observation", "image"].to(self.device, non_blocking=True),
"agent_pos": batch["observation", "state"].to(self.device, non_blocking=True),
"image": image.to(self.device, non_blocking=True),
"agent_pos": state.to(self.device, non_blocking=True),
},
"action": batch["action"].to(self.device, non_blocking=True),
"action": action.to(self.device, non_blocking=True),
}
return out
batch = replay_buffer.sample(batch_size) if self.cfg.balanced_sampling else replay_buffer.sample()
batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time