PromptIQA / PromptIQA /run_promptIQA.py
Zevin2023's picture
refine
7cd8f31
raw
history blame
2.63 kB
import os
import random
import torchvision
import cv2
import torch
from PromptIQA.models import promptiqa
import numpy as np
from PromptIQA.utils.dataset.process import ToTensor, Normalize
from PromptIQA.utils.toolkit import *
import warnings
warnings.filterwarnings('ignore')
import sys
sys.path.append(os.path.dirname(__file__))
def load_model(pkl_path):
model = promptiqa.PromptIQA()
dict_pkl = {}
for key, value in torch.load(pkl_path, map_location='cpu')['state_dict'].items():
dict_pkl[key[7:]] = value
model.load_state_dict(dict_pkl)
print('Load Model From ', pkl_path)
return model
class PromptIQA():
def __init__(self) -> None:
self.pkl_path = "./PromptIQA/checkpoints/best_model.pth.tar"
self.model = load_model(self.pkl_path).cuda()
self.model.eval()
self.transform = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
def get_an_img_score(self, img_path, target=0):
def load_image(img_path, size=224):
if isinstance(img_path, str):
d_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
else:
d_img = img_path
d_img = cv2.resize(d_img, (size, size), interpolation=cv2.INTER_CUBIC)
d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)
d_img = np.array(d_img).astype('float32') / 255
d_img = np.transpose(d_img, (2, 0, 1))
return d_img
sample = load_image(img_path)
samples = {'img': sample, 'gt': target}
samples = self.transform(samples)
return samples
def run(self, ISPP_I, ISPP_S, image):
img_tensor, gt_tensor = None, None
for isp_i, isp_s in zip(ISPP_I, ISPP_S):
score = np.array(isp_s)
samples = self.get_an_img_score(isp_i, score)
if img_tensor is None:
img_tensor = samples['img'].unsqueeze(0)
gt_tensor = samples['gt'].type(torch.FloatTensor).unsqueeze(0)
else:
img_tensor = torch.cat((img_tensor, samples['img'].unsqueeze(0)), dim=0)
gt_tensor = torch.cat((gt_tensor, samples['gt'].type(torch.FloatTensor).unsqueeze(0)), dim=0)
img = img_tensor.squeeze(0).cuda()
label = gt_tensor.squeeze(0).cuda()
self.model.forward_prompt(img, label.reshape(-1, 1), 'example')
samples = self.get_an_img_score(image)
img = samples['img'].unsqueeze(0).cuda()
pred = self.model.inference(img, 'example')
return round(pred.item(), 4)