MinhNH
Initial commit
48c5871
import sys
import json
import torch
import numpy as np
import argparse
import torchvision.transforms as transforms
import cv2
from DRL.ddpg import decode
from utils.util import *
from PIL import Image
from torchvision import transforms, utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
aug = transforms.Compose(
[transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
])
width = 128
convas_area = width * width
img_train = []
img_test = []
train_num = 0
test_num = 0
class Paint:
def __init__(self, batch_size, max_step):
self.batch_size = batch_size
self.max_step = max_step
self.action_space = (13)
self.observation_space = (self.batch_size, width, width, 7)
self.test = False
def load_data(self):
# CelebA
global train_num, test_num
for i in range(200000):
img_id = '%06d' % (i + 1)
try:
img = cv2.imread('./data/img_align_celeba/' + img_id + '.jpg', cv2.IMREAD_UNCHANGED)
img = cv2.resize(img, (width, width))
if i > 2000:
train_num += 1
img_train.append(img)
else:
test_num += 1
img_test.append(img)
finally:
if (i + 1) % 10000 == 0:
print('loaded {} images'.format(i + 1))
print('finish loading data, {} training images, {} testing images'.format(str(train_num), str(test_num)))
def pre_data(self, id, test):
if test:
img = img_test[id]
else:
img = img_train[id]
if not test:
img = aug(img)
img = np.asarray(img)
return np.transpose(img, (2, 0, 1))
def reset(self, test=False, begin_num=False):
self.test = test
self.imgid = [0] * self.batch_size
self.gt = torch.zeros([self.batch_size, 3, width, width], dtype=torch.uint8).to(device)
for i in range(self.batch_size):
if test:
id = (i + begin_num) % test_num
else:
id = np.random.randint(train_num)
self.imgid[i] = id
self.gt[i] = torch.tensor(self.pre_data(id, test))
self.tot_reward = ((self.gt.float() / 255) ** 2).mean(1).mean(1).mean(1)
self.stepnum = 0
self.canvas = torch.zeros([self.batch_size, 3, width, width], dtype=torch.uint8).to(device)
self.lastdis = self.ini_dis = self.cal_dis()
return self.observation()
def observation(self):
# canvas B * 3 * width * width
# gt B * 3 * width * width
# T B * 1 * width * width
ob = []
T = torch.ones([self.batch_size, 1, width, width], dtype=torch.uint8) * self.stepnum
return torch.cat((self.canvas, self.gt, T.to(device)), 1) # canvas, img, T
def cal_trans(self, s, t):
return (s.transpose(0, 3) * t).transpose(0, 3)
def step(self, action):
self.canvas = (decode(action, self.canvas.float() / 255) * 255).byte()
self.stepnum += 1
ob = self.observation()
done = (self.stepnum == self.max_step)
reward = self.cal_reward() # np.array([0.] * self.batch_size)
return ob.detach(), reward, np.array([done] * self.batch_size), None
def cal_dis(self):
return (((self.canvas.float() - self.gt.float()) / 255) ** 2).mean(1).mean(1).mean(1)
def cal_reward(self):
dis = self.cal_dis()
reward = (self.lastdis - dis) / (self.ini_dis + 1e-8)
self.lastdis = dis
return to_numpy(reward)