Spaces:
Runtime error
Runtime error
import cv2 | |
from utils import read_image,get_valid_augs | |
import torch | |
import torch.nn.functional as F | |
from model import CustomModel | |
CKPT = 'fold_0' | |
Targets = ['Not AI' 'AI Generated'] | |
def predict_one_image(path) : | |
image = read_image(path) | |
image = get_valid_augs()(image=image)['image'] | |
image = torch.tensor(image,dtype=torch.float) | |
image = image.reshape((1,3,512,512)) | |
model = CustomModel() | |
#loading ckpt | |
model.load_state_dict(torch.load(CKPT,map_location=torch.device('cpu'))) | |
with torch.no_grad() : | |
outputs = model(image) | |
proba = F.sigmoid(outputs['label']).detach().numpy()[0] | |
return {'Not AI' : 1-float(proba),'AI' : float(proba)}#(proba>0.5)*1 | |