import cv2 from utils import read_image,get_valid_augs import torch import torch.nn.functional as F from model import CustomModel CKPT = 'fold_0' Targets = ['Not AI' 'AI Generated'] def predict_one_image(path) : image = read_image(path) image = get_valid_augs()(image=image)['image'] image = torch.tensor(image,dtype=torch.float) image = image.reshape((1,3,512,512)) model = CustomModel() #loading ckpt model.load_state_dict(torch.load(CKPT,map_location=torch.device('cpu'))) with torch.no_grad() : outputs = model(image) proba = F.sigmoid(outputs['label']).detach().numpy()[0] return {'Not AI' : 1-float(proba),'AI' : float(proba)}#(proba>0.5)*1