antonovmaxim commited on
Commit
c6876c4
·
1 Parent(s): 3a9a9e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -1
app.py CHANGED
@@ -1,6 +1,21 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  def classify(input_img):
3
- return 1
4
  def img_classify(input_img):
5
  s = "Вероятность того, что изображение сгенерировано нейросетью равна: " + str(classify(input_img))
6
  return s
 
1
  import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ model = torch.hub.load("pytorch/vision", "resnet101", pretrained=False)
5
+ model.fc = nn.Sequential(nn.Linear(2048, 500), nn.ReLU(), nn.Linear(500, 2), nn.Softmax(1))
6
+ state_dict = torch.load('model.pth')
7
+ model.load_state_dict(state_dict)
8
+
9
+ transform = transforms.Compose([
10
+ transforms.RandomHorizontalFlip(p=0.5),
11
+ transforms.Resize(256), # Resize the image to 256x256 pixels
12
+ transforms.CenterCrop(224), # Crop the center 224x224 pixels
13
+ transforms.ToTensor(), # Convert the image to a PyTorch tensor
14
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize the image
15
+ ])
16
+
17
  def classify(input_img):
18
+ return model(transform(input_img))[0][1].item()
19
  def img_classify(input_img):
20
  s = "Вероятность того, что изображение сгенерировано нейросетью равна: " + str(classify(input_img))
21
  return s