MinhNH
Initial commit
48c5871
import cv2
import torch
import numpy as np
from env import Paint
from utils.util import *
from DRL.ddpg import decode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class fastenv():
def __init__(self,
max_episode_length=10, env_batch=64, \
writer=None):
self.max_episode_length = max_episode_length
self.env_batch = env_batch
self.env = Paint(self.env_batch, self.max_episode_length)
self.env.load_data()
self.observation_space = self.env.observation_space
self.action_space = self.env.action_space
self.writer = writer
self.test = False
self.log = 0
def save_image(self, log, step):
for i in range(self.env_batch):
if self.env.imgid[i] <= 10:
canvas = cv2.cvtColor((to_numpy(self.env.canvas[i].permute(1, 2, 0))), cv2.COLOR_BGR2RGB)
self.writer.add_image('{}/canvas_{}.png'.format(str(self.env.imgid[i]), str(step)), canvas, log)
if step == self.max_episode_length:
for i in range(self.env_batch):
if self.env.imgid[i] < 50:
gt = cv2.cvtColor((to_numpy(self.env.gt[i].permute(1, 2, 0))), cv2.COLOR_BGR2RGB)
canvas = cv2.cvtColor((to_numpy(self.env.canvas[i].permute(1, 2, 0))), cv2.COLOR_BGR2RGB)
self.writer.add_image(str(self.env.imgid[i]) + '/_target.png', gt, log)
self.writer.add_image(str(self.env.imgid[i]) + '/_canvas.png', canvas, log)
def step(self, action):
with torch.no_grad():
ob, r, d, _ = self.env.step(torch.tensor(action).to(device))
if d[0]:
if not self.test:
self.dist = self.get_dist()
for i in range(self.env_batch):
self.writer.add_scalar('train/dist', self.dist[i], self.log)
self.log += 1
return ob, r, d, _
def get_dist(self):
return to_numpy((((self.env.gt.float() - self.env.canvas.float()) / 255) ** 2).mean(1).mean(1).mean(1))
def reset(self, test=False, episode=0):
self.test = test
ob = self.env.reset(self.test, episode * self.env_batch)
return ob