Update app.py
Browse files
app.py
CHANGED
@@ -27,19 +27,24 @@ clf_2 = pipeline("image-classification", model=model_2_path)
|
|
27 |
|
28 |
# Load additional models
|
29 |
models = ["Organika/sdxl-detector", "cmckinle/sdxl-flux-detector"]
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
# Define class names for all models
|
34 |
class_names_1 = ['artificial', 'real']
|
35 |
class_names_2 = ['AI Image', 'Real Image']
|
36 |
-
|
37 |
-
|
38 |
|
39 |
def softmax(vector):
|
40 |
e = np.exp(vector - np.max(vector)) # for numerical stability
|
41 |
return e / e.sum()
|
42 |
-
|
43 |
@spaces.GPU(duration=10)
|
44 |
def predict_image(img, confidence_threshold):
|
45 |
# Ensure the image is a PIL Image
|
@@ -97,14 +102,19 @@ def predict_image(img, confidence_threshold):
|
|
97 |
|
98 |
# Predict using the third model with softmax
|
99 |
try:
|
|
|
100 |
with torch.no_grad():
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
105 |
|
106 |
# Ensure the result dictionary contains all class names
|
107 |
-
for class_name in
|
108 |
if class_name not in result_3:
|
109 |
result_3[class_name] = 0.0
|
110 |
|
@@ -120,14 +130,19 @@ def predict_image(img, confidence_threshold):
|
|
120 |
|
121 |
# Predict using the fourth model with softmax
|
122 |
try:
|
|
|
123 |
with torch.no_grad():
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
128 |
|
129 |
# Ensure the result dictionary contains all class names
|
130 |
-
for class_name in
|
131 |
if class_name not in result_4:
|
132 |
result_4[class_name] = 0.0
|
133 |
|
@@ -163,4 +178,4 @@ iface = gr.Interface(
|
|
163 |
outputs=label,
|
164 |
title="AI Generated Classification"
|
165 |
)
|
166 |
-
iface.launch()
|
|
|
27 |
|
28 |
# Load additional models
|
29 |
models = ["Organika/sdxl-detector", "cmckinle/sdxl-flux-detector"]
|
30 |
+
|
31 |
+
# Load the third and fourth models
|
32 |
+
feature_extractor_3 = AutoFeatureExtractor.from_pretrained(models[0])
|
33 |
+
model_3 = AutoModelForImageClassification.from_pretrained(models[0]).to(device)
|
34 |
+
|
35 |
+
feature_extractor_4 = AutoFeatureExtractor.from_pretrained(models[1])
|
36 |
+
model_4 = AutoModelForImageClassification.from_pretrained(models[1]).to(device)
|
37 |
|
38 |
# Define class names for all models
|
39 |
class_names_1 = ['artificial', 'real']
|
40 |
class_names_2 = ['AI Image', 'Real Image']
|
41 |
+
labels_3 = ['AI', 'Real']
|
42 |
+
labels_4 = ['AI', 'Real']
|
43 |
|
44 |
def softmax(vector):
|
45 |
e = np.exp(vector - np.max(vector)) # for numerical stability
|
46 |
return e / e.sum()
|
47 |
+
|
48 |
@spaces.GPU(duration=10)
|
49 |
def predict_image(img, confidence_threshold):
|
50 |
# Ensure the image is a PIL Image
|
|
|
102 |
|
103 |
# Predict using the third model with softmax
|
104 |
try:
|
105 |
+
inputs_3 = feature_extractor_3(img_pil, return_tensors="pt").to(device)
|
106 |
with torch.no_grad():
|
107 |
+
outputs_3 = model_3(**inputs_3)
|
108 |
+
logits_3 = outputs_3.logits
|
109 |
+
probabilities_3 = softmax(logits_3.cpu().numpy()[0])
|
110 |
+
|
111 |
+
result_3 = {
|
112 |
+
labels_3[0]: float(probabilities_3[0]), # AI
|
113 |
+
labels_3[1]: float(probabilities_3[1]) # Real
|
114 |
+
}
|
115 |
|
116 |
# Ensure the result dictionary contains all class names
|
117 |
+
for class_name in labels_3:
|
118 |
if class_name not in result_3:
|
119 |
result_3[class_name] = 0.0
|
120 |
|
|
|
130 |
|
131 |
# Predict using the fourth model with softmax
|
132 |
try:
|
133 |
+
inputs_4 = feature_extractor_4(img_pil, return_tensors="pt").to(device)
|
134 |
with torch.no_grad():
|
135 |
+
outputs_4 = model_4(**inputs_4)
|
136 |
+
logits_4 = outputs_4.logits
|
137 |
+
probabilities_4 = softmax(logits_4.cpu().numpy()[0])
|
138 |
+
|
139 |
+
result_4 = {
|
140 |
+
labels_4[0]: float(probabilities_4[0]), # AI
|
141 |
+
labels_4[1]: float(probabilities_4[1]) # Real
|
142 |
+
}
|
143 |
|
144 |
# Ensure the result dictionary contains all class names
|
145 |
+
for class_name in labels_4:
|
146 |
if class_name not in result_4:
|
147 |
result_4[class_name] = 0.0
|
148 |
|
|
|
178 |
outputs=label,
|
179 |
title="AI Generated Classification"
|
180 |
)
|
181 |
+
iface.launch()
|