LPX55 commited on
Commit
4a03e59
·
verified ·
1 Parent(s): 6ee8df1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -6
app.py CHANGED
@@ -1,10 +1,12 @@
1
  import spaces
2
  import gradio as gr
3
- from transformers import pipeline, AutoImageProcessor, Swinv2ForImageClassification
4
  from torchvision import transforms
5
  import torch
6
  from PIL import Image
 
7
  import warnings
 
8
 
9
  # Suppress warnings
10
  warnings.filterwarnings("ignore", category=UserWarning, message="Using a slow image processor as `use_fast` is unset")
@@ -20,13 +22,24 @@ clf_1 = pipeline(model=model_1, task="image-classification", image_processor=ima
20
 
21
  # Load the second model
22
  model_2_path = "Heem2/AI-vs-Real-Image-Detection"
23
- clf_2 = pipeline("image-classification", model=model_2_path, device=device)
24
 
25
- # Define class names for both models
 
 
 
 
 
26
  class_names_1 = ['artificial', 'real']
27
- class_names_2 = ['AI Image', 'Real Image'] # Adjust if the second model has different classes
 
 
 
 
 
 
28
 
29
- @spaces.GPU(duration=30)
30
  def predict_image(img, confidence_threshold):
31
  # Ensure the image is a PIL Image
32
  if not isinstance(img, Image.Image):
@@ -81,10 +94,56 @@ def predict_image(img, confidence_threshold):
81
  except Exception as e:
82
  label_2 = f"Error: {str(e)}"
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  # Combine results
85
  combined_results = {
86
  "SwinV2": label_1,
87
- "AI-vs-Real-Image-Detection": label_2
 
 
88
  }
89
 
90
  return combined_results
 
1
  import spaces
2
  import gradio as gr
3
+ from transformers import pipeline, AutoImageProcessor, Swinv2ForImageClassification, AutoFeatureExtractor, AutoModelForImageClassification
4
  from torchvision import transforms
5
  import torch
6
  from PIL import Image
7
+ import pandas as pd
8
  import warnings
9
+ import math
10
 
11
  # Suppress warnings
12
  warnings.filterwarnings("ignore", category=UserWarning, message="Using a slow image processor as `use_fast` is unset")
 
22
 
23
  # Load the second model
24
  model_2_path = "Heem2/AI-vs-Real-Image-Detection"
25
+ clf_2 = pipeline("image-classification", model=model_2_path)
26
 
27
+ # Load additional models
28
+ models = ["Organika/sdxl-detector", "cmckinle/sdxl-flux-detector"]
29
+ pipe0 = pipeline("image-classification", model=models[0])
30
+ pipe1 = pipeline("image-classification", model=models[1])
31
+
32
+ # Define class names for all models
33
  class_names_1 = ['artificial', 'real']
34
+ class_names_2 = ['AI Image', 'Real Image']
35
+ class_names_3 = ['AI', 'Real']
36
+ class_names_4 = ['AI', 'Real']
37
+
38
+ def softmax(vector):
39
+ e = math.exp(vector - vector.max()) # for numerical stability
40
+ return e / e.sum()
41
 
42
+ @spaces.GPU(duration=10)
43
  def predict_image(img, confidence_threshold):
44
  # Ensure the image is a PIL Image
45
  if not isinstance(img, Image.Image):
 
94
  except Exception as e:
95
  label_2 = f"Error: {str(e)}"
96
 
97
+ # Predict using the third model
98
+ try:
99
+ prediction_3 = pipe0(img_pil)
100
+ result_3 = {}
101
+ for idx, result in enumerate(prediction_3):
102
+ result_3[class_names_3[idx]] = float(result['score'])
103
+
104
+ # Ensure the result dictionary contains all class names
105
+ for class_name in class_names_3:
106
+ if class_name not in result_3:
107
+ result_3[class_name] = 0.0
108
+
109
+ # Check if either class meets the confidence threshold
110
+ if result_3['AI'] >= confidence_threshold:
111
+ label_3 = f"Label: AI, Confidence: {result_3['AI']:.4f}"
112
+ elif result_3['Real'] >= confidence_threshold:
113
+ label_3 = f"Label: Real, Confidence: {result_3['Real']:.4f}"
114
+ else:
115
+ label_3 = "Uncertain Classification"
116
+ except Exception as e:
117
+ label_3 = f"Error: {str(e)}"
118
+
119
+ # Predict using the fourth model
120
+ try:
121
+ prediction_4 = pipe1(img_pil)
122
+ result_4 = {}
123
+ for idx, result in enumerate(prediction_4):
124
+ result_4[class_names_4[idx]] = float(result['score'])
125
+
126
+ # Ensure the result dictionary contains all class names
127
+ for class_name in class_names_4:
128
+ if class_name not in result_4:
129
+ result_4[class_name] = 0.0
130
+
131
+ # Check if either class meets the confidence threshold
132
+ if result_4['AI'] >= confidence_threshold:
133
+ label_4 = f"Label: AI, Confidence: {result_4['AI']:.4f}"
134
+ elif result_4['Real'] >= confidence_threshold:
135
+ label_4 = f"Label: Real, Confidence: {result_4['Real']:.4f}"
136
+ else:
137
+ label_4 = "Uncertain Classification"
138
+ except Exception as e:
139
+ label_4 = f"Error: {str(e)}"
140
+
141
  # Combine results
142
  combined_results = {
143
  "SwinV2": label_1,
144
+ "AI-vs-Real-Image-Detection": label_2,
145
+ "Organika/sdxl-detector": label_3,
146
+ "cmckinle/sdxl-flux-detector": label_4
147
  }
148
 
149
  return combined_results