from gui.ui_win import Ui_Form from gui.ui_draw import * from PIL import Image, ImageQt import numpy as np import random, io, os import torch import torch.nn.functional as F import torchvision.transforms as transforms from util import task, util from dataloader.image_folder import make_dataset from dataloader.data_loader import get_transform from model import create_model class ui_model(QtWidgets.QWidget, Ui_Form): """define the class of UI""" shape = 'line' CurrentWidth = 1 def __init__(self, opt): super(ui_model, self).__init__() self.setupUi(self) self.opt = opt self.show_result_flag = False self.mask_type = None self.img_power = None self.model_names = ['celeba', 'ffhq', 'imagenet', 'places2'] self.img_root = './examples/' self.img_files = ['celeba/img', 'ffhq/img', 'imagenet/img', 'places2/img'] self.show_logo() self.comboBox.activated.connect(self.load_model) # select model self.pushButton_2.clicked.connect(self.select_image) # manually select an image self.pushButton_3.clicked.connect(self.random_image) # randomly select an image self.pushButton_4.clicked.connect(self.load_mask) # manually select a mask self.pushButton_5.clicked.connect(self.random_mask) # randomly select a mask # draw/erasure the mask self.radioButton.toggled.connect(lambda: self.draw_mask('line')) # draw the line self.radioButton_2.toggled.connect(lambda: self.draw_mask('rectangle')) # draw the rectangle self.radioButton_3.toggled.connect(lambda: self.draw_mask('center')) # center mask self.spinBox.valueChanged.connect(self.change_thickness) self.pushButton.clicked.connect(self.clear_mask) # fill image self.pushButton_6.clicked.connect(self.fill_image) self.comboBox_2.activated.connect(self.show_result) self.pushButton_7.clicked.connect(self.save_result) opt.preprocess = 'scale_shortside' self.transform_o = get_transform(opt, convert=False, augment=False) self.pil2tensor = transforms.ToTensor() def show_logo(self): """Show the logo of NTU and BTC""" img = QtWidgets.QLabel(self) img.setGeometry(1000, 10, 140, 50) pixmap = QtGui.QPixmap("./gui/logo/NTU_logo.jpg") # read examples pixmap = pixmap.scaled(140, 140, QtCore.Qt.KeepAspectRatio, QtCore.Qt.SmoothTransformation) img.setPixmap(pixmap) img.show() img1 = QtWidgets.QLabel(self) img1.setGeometry(1200, 10, 70, 50) pixmap1 = QtGui.QPixmap("./gui/logo/BTC_logo.png") # read examples pixmap1 = pixmap1.scaled(70, 70, QtCore.Qt.KeepAspectRatio, QtCore.Qt.SmoothTransformation) img1.setPixmap(pixmap1) img1.show() def show_image(self, img): """Show the masked examples""" show_img = img.copy() if self.mask_type == 'center': sub_img = Image.fromarray(np.uint8(255 * np.ones((int(self.pw/2), int(self.pw/2), 3)))) mask = Image.fromarray(np.uint8(255 * np.ones((int(self.pw/2), int(self.pw/2))))) show_img.paste(sub_img, box=(int(self.pw/4), int(self.pw/4)), mask=mask) elif self.mask_type == 'external': mask = Image.open(self.mname).resize(self.img_power.size).convert('RGB') mask_L = Image.open(self.mname).resize(self.img_power.size).convert('L') show_img = Image.composite(mask, show_img, mask_L) self.new_painter(ImageQt.ImageQt(show_img)) def show_result(self): """Show different kind examples""" value = self.comboBox_2.currentIndex() if value == 0: self.new_painter(ImageQt.ImageQt(self.img_power)) elif value == 1: masked_img = torch.where(self.mask > 0, self.img_m, torch.ones_like(self.img_m)) masked_img = Image.fromarray(util.tensor2im(masked_img.detach())) self.new_painter(ImageQt.ImageQt(masked_img)) elif value == 2: if 'refine' in self.opt.coarse_or_refine: img_out = Image.fromarray(util.tensor2im(self.img_ref_out.detach())) else: img_out = Image.fromarray(util.tensor2im(self.img_out.detach())) self.new_painter(ImageQt.ImageQt(img_out)) def save_result(self): """Save the results to the disk""" util.mkdir(self.opt.results_dir) img_name = self.fname.split('/')[-1] data_name = self.opt.img_file.split('/')[-1].split('.')[0] original_name = '%s_%s_%s' % ('original', data_name, img_name) # save the original image original_path = os.path.join(self.opt.results_dir, original_name) img_original = util.tensor2im(self.img_truth) util.save_image(img_original, original_path) mask_name = '%s_%s_%d_%s' % ('mask', data_name, self.PaintPanel.iteration, img_name) mask_path = os.path.join(self.opt.results_dir, mask_name) mask = self.mask.repeat(1, 3, 1, 1) img_mask = util.tensor2im(1-mask) util.save_image(img_mask, mask_path) #save masked image masked_img_name = '%s_%s_%d_%s' % ('masked_img', data_name, self.PaintPanel.iteration, img_name) img_path = os.path.join(self.opt.results_dir, masked_img_name) img = torch.where(self.mask < 0.2, torch.ones_like(self.img_truth), self.img_truth) masked_img = util.tensor2im(img) util.save_image(masked_img, img_path) # save the generated results img_g_name = '%s_%s_%d_%s' % ('g', data_name, self.PaintPanel.iteration, img_name) img_path = os.path.join(self.opt.results_dir, img_g_name) img_g = util.tensor2im(self.img_g) util.save_image(img_g, img_path) # save the results result_name = '%s_%s_%d_%s' % ('out', data_name, self.PaintPanel.iteration, img_name) result_path = os.path.join(self.opt.results_dir, result_name) img_result = util.tensor2im(self.img_out) util.save_image(img_result, result_path) # save the refined results if 'tc' in self.opt.model and 'refine' in self.opt.coarse_or_refine: result_name = '%s_%s_%d_%s' % ('ref', data_name, self.PaintPanel.iteration, img_name) result_path = os.path.join(self.opt.results_dir, result_name) img_result = util.tensor2im(self.img_ref_out) util.save_image(img_result, result_path) def load_model(self): """Load different kind models""" value = self.comboBox.currentIndex() if value == 0: raise NotImplementedError("Please choose a model") else: index = value-1 # define the model type and dataset type self.opt.name = self.model_names[index] self.opt.img_file = self.img_root + self.img_files[index % len(self.img_files)] self.model = create_model(self.opt) self.model.setup(self.opt) def load_image(self, fname): """Load the image""" self.img_o = Image.open(fname).convert('RGB') self.ow, self.oh = self.img_o.size self.img_power = self.transform_o(self.img_o) self.pw, self.ph = self.img_power.size return self.img_power def select_image(self): """Load the image""" self.fname, _ = QtWidgets.QFileDialog.getOpenFileName(self, 'select the image', self.opt.img_file, '*') img = self.load_image(self.fname) self.mask_type = 'none' self.show_image(img) def random_image(self): """Random load the test image""" image_paths, image_size = make_dataset(self.opt.img_file) item = random.randint(0, image_size-1) self.fname = image_paths[item] img = self.load_image(self.fname) self.mask_type = 'none' self.show_image(img) def load_mask(self): """Load a mask""" self.mask_type = 'external' self.mname, _ = QtWidgets.QFileDialog.getOpenFileName(self, 'select the mask', self.opt.mask_file,'*') self.show_image(self.img_power) def random_mask(self): """Random load the test mask""" if self.opt.mask_file == 'none': raise NotImplementedError("Please input the mask path") self.mask_type = 'external' mask_paths, mask_size = make_dataset(self.opt.mask_file) item = random.randint(0, mask_size - 1) self.mname = mask_paths[item] self.show_image(self.img_power) def read_mask(self): """Read the mask from the painted plain""" self.PaintPanel.saveDraw() buffer = QtCore.QBuffer() buffer.open(QtCore.QBuffer.ReadWrite) self.PaintPanel.map.save(buffer, 'PNG') pil_im = Image.open(io.BytesIO(buffer.data())) return pil_im def new_painter(self, image=None): """Build a painter to load and process the image""" # painter self.PaintPanel = painter(self, image) self.PaintPanel.close() if image is not None: w, h = image.size().width(), image.size().height() self.stackedWidget.setGeometry(QtCore.QRect(250+int(512-w/2), 100+int(128-h/8), w, h)) self.stackedWidget.insertWidget(0, self.PaintPanel) self.stackedWidget.setCurrentWidget(self.PaintPanel) def change_thickness(self, num): """Change the width of the painter""" self.CurrentWidth = num self.PaintPanel.CurrentWidth = num def draw_mask(self, masktype): """Draw the mask""" if masktype == 'center': self.mask_type = 'center' if self.img_power is not None: self.show_image(self.img_power) else: self.mask_type = 'draw' self.shape = masktype self.PaintPanel.shape = masktype def clear_mask(self): """Clear the mask""" self.mask_type = 'draw' if self.PaintPanel.Brush: self.PaintPanel.Brush = False else: self.PaintPanel.Brush = True def set_input(self): """Set the input for the network""" img_o = self.pil2tensor(self.img_o).unsqueeze(0) img = self.pil2tensor(self.img_power).unsqueeze(0) if self.mask_type == 'draw': # get the test mask from painter mask = self.read_mask() mask = torch.autograd.Variable(self.pil2tensor(mask)).unsqueeze(0)[:, 0:1, :, :] elif self.mask_type == 'center': mask = torch.zeros_like(img)[:, 0:1, :, :] mask[:, :, int(self.pw/4):int(3*self.pw/4), int(self.ph/4):int(3*self.ph/4)] = 1 elif self.mask_type == 'external': mask = self.pil2tensor(Image.open(self.mname).resize((self.pw, self.ph)).convert('L')).unsqueeze(0) mask = (mask < 0.5).float() if len(self.opt.gpu_ids) > 0: img = img.cuda(self.opt.gpu_ids[0]) mask = mask.cuda(self.opt.gpu_ids[0]) img_o = img_o.cuda(self.opt.gpu_ids[0]) self.mask = mask self.img_org = img_o * 2 - 1 self.img_truth = img * 2 - 1 self.img_m = self.mask * self.img_truth def fill_image(self): """Forward to get the completed results""" self.set_input() if self.PaintPanel.iteration < 1: with torch.no_grad(): fixed_img = F.interpolate(self.img_m, size=[self.opt.fixed_size, self.opt.fixed_size], mode='bicubic', align_corners=True).clamp(-1, 1) fixed_mask = (F.interpolate(self.mask, size=[self.opt.fixed_size, self.opt.fixed_size], mode='bicubic', align_corners=True) > 0.9).type_as(fixed_img) out, mask = self.model.netE(fixed_img, mask=fixed_mask, return_mask=True) out = self.model.netT(out, mask, bool_mask=False) self.img_g = self.model.netG(out) img_g_org = F.interpolate(self.img_g, size=self.img_truth.size()[2:], mode='bicubic', align_corners=True).clamp(-1, 1) self.img_out = self.mask * self.img_truth + (1 - self.mask) * img_g_org if 'refine' in self.opt.coarse_or_refine: img_ref = self.model.netG_Ref(self.img_out, mask=self.mask) self.img_ref_out = self.mask * self.img_truth + (1 - self.mask) * img_ref print('finish the completion') self.show_result_flag = True self.show_result()