ai_image_gen_detector / predict.py
ReacherTN's picture
Update predict.py
d383e61
raw
history blame contribute delete
701 Bytes
import cv2
from utils import read_image,get_valid_augs
import torch
import torch.nn.functional as F
from model import CustomModel
CKPT = 'fold_0'
Targets = ['Not AI' 'AI Generated']
def predict_one_image(path) :
image = read_image(path)
image = get_valid_augs()(image=image)['image']
image = torch.tensor(image,dtype=torch.float)
image = image.reshape((1,3,512,512))
model = CustomModel()
#loading ckpt
model.load_state_dict(torch.load(CKPT,map_location=torch.device('cpu')))
with torch.no_grad() :
outputs = model(image)
proba = F.sigmoid(outputs['label']).detach().numpy()[0]
return {'Not AI' : 1-float(proba),'AI' : float(proba)}#(proba>0.5)*1