ready for review

This commit is contained in:
Alexander Soare
2024-03-21 10:18:50 +00:00
parent d323993569
commit acf1174447
12 changed files with 282 additions and 85 deletions
+1 -1
View File
@@ -192,7 +192,7 @@ class AlohaEnv(AbstractEnv):
{
"observation": TensorDict(obs, batch_size=[]),
"reward": torch.tensor([reward], dtype=torch.float32),
# succes and done are true when coverage > self.success_threshold in env
# success and done are true when coverage > self.success_threshold in env
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([success], dtype=torch.bool),
},
-24
View File
@@ -62,27 +62,3 @@ def make_env(cfg, transform=None):
{"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
],
)
# def make_env(env_name, frame_skip, device, is_test=False):
# env = GymEnv(
# env_name,
# frame_skip=frame_skip,
# from_pixels=True,
# pixels_only=False,
# device=device,
# )
# env = TransformedEnv(env)
# env.append_transform(NoopResetEnv(noops=30, random=True))
# if not is_test:
# env.append_transform(EndOfLifeTransform())
# env.append_transform(RewardClipping(-1, 1))
# env.append_transform(ToTensorImage())
# env.append_transform(GrayScale())
# env.append_transform(Resize(84, 84))
# env.append_transform(CatFrames(N=4, dim=-3))
# env.append_transform(RewardSum())
# env.append_transform(StepCounter(max_steps=4500))
# env.append_transform(DoubleToFloat())
# env.append_transform(VecNorm(in_keys=["pixels"]))
# return env
+23 -3
View File
@@ -3,6 +3,8 @@ import logging
from collections import deque
from typing import Optional
import cv2
import numpy as np
import torch
from tensordict import TensorDict
from torchrl.data.tensor_specs import (
@@ -59,12 +61,30 @@ class PushtEnv(AbstractEnv):
self._env = PushTImageEnv(render_size=self.image_size)
def render(self, mode="rgb_array", width=384, height=384):
def render(self, mode="rgb_array", width=96, height=96, with_marker=True):
"""
with_marker adds a cursor showing the targeted action for the controller.
"""
if width != height:
raise NotImplementedError()
tmp = self._env.render_size
self._env.render_size = width
out = self._env.render(mode)
if width != self._env.render_size:
self._env.render_cache = None
self._env.render_size = width
out = self._env.render(mode).copy()
if with_marker and self._env.latest_action is not None:
action = np.array(self._env.latest_action)
coord = (action / 512 * self._env.render_size).astype(np.int32)
marker_size = int(8 / 96 * self._env.render_size)
thickness = int(1 / 96 * self._env.render_size)
cv2.drawMarker(
out,
coord,
color=(255, 0, 0),
markerType=cv2.MARKER_CROSS,
markerSize=marker_size,
thickness=thickness,
)
self._env.render_size = tmp
return out
@@ -27,20 +27,6 @@ class PushTImageEnv(PushTEnv):
img_obs = np.moveaxis(img, -1, 0)
obs = {"image": img_obs, "agent_pos": agent_pos}
# draw action
if self.latest_action is not None:
action = np.array(self.latest_action)
coord = (action / 512 * 96).astype(np.int32)
marker_size = int(8 / 96 * self.render_size)
thickness = int(1 / 96 * self.render_size)
# cv2.drawMarker(
# img,
# coord,
# color=(255, 0, 0),
# markerType=cv2.MARKER_CROSS,
# markerSize=marker_size,
# thickness=thickness,
# )
self.render_cache = img
return obs