Spaces:
Runtime error
Runtime error
File size: 701 Bytes
3a9d31e 8693834 3a9d31e d383e61 3a9d31e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import cv2
from utils import read_image,get_valid_augs
import torch
import torch.nn.functional as F
from model import CustomModel
CKPT = 'fold_0'
Targets = ['Not AI' 'AI Generated']
def predict_one_image(path) :
image = read_image(path)
image = get_valid_augs()(image=image)['image']
image = torch.tensor(image,dtype=torch.float)
image = image.reshape((1,3,512,512))
model = CustomModel()
#loading ckpt
model.load_state_dict(torch.load(CKPT,map_location=torch.device('cpu')))
with torch.no_grad() :
outputs = model(image)
proba = F.sigmoid(outputs['label']).detach().numpy()[0]
return {'Not AI' : 1-float(proba),'AI' : float(proba)}#(proba>0.5)*1
|