LPX
commited on
Commit
·
a7d5234
1
Parent(s):
c08bf6c
feat(model):
Browse files- add support for new model (model_7) for image classification task
- extended CLASS_NAMES to support new model class labels
- added resource loading for the new model
- accessing the new model in the predict_image function
- updated the combined_results and combined_outputs to include new model output
- added a new tile in the HTML results for the new model
♻️ style(frontend):
- increase w-24 to w-30 in HTML CSS snippet
Note: Preferred writing "feat" for adding new model, even if it is just adding an URLs, and "chore" for any small tutorials added.
app.py
CHANGED
@@ -27,7 +27,8 @@ MODEL_PATHS = {
|
|
27 |
"model_4": "cmckinle/sdxl-flux-detector",
|
28 |
"model_5": "prithivMLmods/Deep-Fake-Detector-v2-Model",
|
29 |
"model_5b": "prithivMLmods/Deepfake-Detection-Exp-02-22",
|
30 |
-
"model_6": "ideepankarsharma2003/AI_ImageClassification_MidjourneyV6_SDXL"
|
|
|
31 |
}
|
32 |
|
33 |
CLASS_NAMES = {
|
@@ -38,6 +39,7 @@ CLASS_NAMES = {
|
|
38 |
"model_5": ['Realism', 'Deepfake'],
|
39 |
"model_5b": ['Real', 'Deepfake'],
|
40 |
"model_6": ['ai_gen', 'human'],
|
|
|
41 |
|
42 |
}
|
43 |
|
@@ -63,9 +65,13 @@ def load_models():
|
|
63 |
model_6 = SwinForImageClassification.from_pretrained(MODEL_PATHS["model_6"]).to(device)
|
64 |
clf_6 = pipeline(model=model_6, task="image-classification", image_processor=image_processor_6, device=device)
|
65 |
|
66 |
-
|
|
|
|
|
67 |
|
68 |
-
clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, clf_6
|
|
|
|
|
69 |
|
70 |
@spaces.GPU(duration=10)
|
71 |
def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
|
@@ -118,6 +124,7 @@ def predict_image(img, confidence_threshold):
|
|
118 |
label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
|
119 |
label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
|
120 |
label_6, result_6output = predict_with_model(img_pilvits, clf_6, CLASS_NAMES["model_6"], confidence_threshold, "Swin Midjourney/SDXL", 7)
|
|
|
121 |
|
122 |
combined_results = {
|
123 |
"SwinV2/detect": label_1,
|
@@ -126,11 +133,12 @@ def predict_image(img, confidence_threshold):
|
|
126 |
"Swin/SDXL-FLUX": label_4,
|
127 |
"prithivMLmods": label_5,
|
128 |
"prithivMLmods-2-22": label_5b,
|
129 |
-
"SwinMidSDXL": label_6
|
|
|
130 |
}
|
131 |
print(combined_results)
|
132 |
|
133 |
-
combined_outputs = [result_1output, result_2output, result_3output, result_4output, result_5output, result_5boutput, result_6output]
|
134 |
return img_pil, combined_outputs
|
135 |
# Define a function to generate the HTML content
|
136 |
|
@@ -159,7 +167,7 @@ def generate_results_html(results):
|
|
159 |
class="-m-4 h-24 {header_colors[0]} rounded-sm rounded-b-none transition border group-hover:border-gray-100 group-hover:shadow-lg group-hover:{header_colors[4]}">
|
160 |
<span class="text-gray-300 font-mono tracking-widest p-4 pb-3 block text-xs text-center">MODEL {index + 1}:</span>
|
161 |
<span
|
162 |
-
class="flex w-
|
163 |
>
|
164 |
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="3" stroke="currentColor" class="w-4 h-4 mr-2 -ml-3 group-hover:animate group-hover:animate-pulse">
|
165 |
{'<path stroke-linecap="round" stroke-linejoin="round" d="M9 12.75 11.25 15 15 9.75M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Z" />' if label == 'REAL' else '<path stroke-linecap="round" stroke-linejoin="round" d="m9.75 9.75 4.5 4.5m0-4.5-4.5 4.5M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Z" />'}
|
@@ -207,7 +215,8 @@ def generate_results_html(results):
|
|
207 |
{generate_tile_html(3, results[3], "SDXL + FLUX", "cmckinle", MODEL_PATHS["model_4"])}
|
208 |
{generate_tile_html(4, results[4], "Vit Based", "prithivMLmods", MODEL_PATHS["model_5"])}
|
209 |
{generate_tile_html(5, results[5], "Vit Based, Newer Dataset", "prithivMLmods", MODEL_PATHS["model_5b"])}
|
210 |
-
{generate_tile_html(6, results[6], "Swin,
|
|
|
211 |
</div>
|
212 |
</div>
|
213 |
"""
|
|
|
27 |
"model_4": "cmckinle/sdxl-flux-detector",
|
28 |
"model_5": "prithivMLmods/Deep-Fake-Detector-v2-Model",
|
29 |
"model_5b": "prithivMLmods/Deepfake-Detection-Exp-02-22",
|
30 |
+
"model_6": "ideepankarsharma2003/AI_ImageClassification_MidjourneyV6_SDXL",
|
31 |
+
"model_7": "date3k2/vit-real-fake-classification-v4"
|
32 |
}
|
33 |
|
34 |
CLASS_NAMES = {
|
|
|
39 |
"model_5": ['Realism', 'Deepfake'],
|
40 |
"model_5b": ['Real', 'Deepfake'],
|
41 |
"model_6": ['ai_gen', 'human'],
|
42 |
+
"model_7": ['Fake', 'Real'],
|
43 |
|
44 |
}
|
45 |
|
|
|
65 |
model_6 = SwinForImageClassification.from_pretrained(MODEL_PATHS["model_6"]).to(device)
|
66 |
clf_6 = pipeline(model=model_6, task="image-classification", image_processor=image_processor_6, device=device)
|
67 |
|
68 |
+
image_processor_7 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_7"], use_fast=True)
|
69 |
+
model_7 = AutoModelForImageClassification.from_pretrained(MODEL_PATHS["model_7"]).to(device)
|
70 |
+
clf_7 = pipeline(model=model_7, task="image-classification", image_processor=image_processor_7, device=device)
|
71 |
|
72 |
+
return clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, clf_6, model_7, clf_7
|
73 |
+
|
74 |
+
clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, clf_6, model_7, clf_7 = load_models()
|
75 |
|
76 |
@spaces.GPU(duration=10)
|
77 |
def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
|
|
|
124 |
label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
|
125 |
label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
|
126 |
label_6, result_6output = predict_with_model(img_pilvits, clf_6, CLASS_NAMES["model_6"], confidence_threshold, "Swin Midjourney/SDXL", 7)
|
127 |
+
label_7, result_7output = predict_with_model(img_pilvits, clf_7, CLASS_NAMES["model_7"], confidence_threshold, "Vit", 7)
|
128 |
|
129 |
combined_results = {
|
130 |
"SwinV2/detect": label_1,
|
|
|
133 |
"Swin/SDXL-FLUX": label_4,
|
134 |
"prithivMLmods": label_5,
|
135 |
"prithivMLmods-2-22": label_5b,
|
136 |
+
"SwinMidSDXL": label_6,
|
137 |
+
"Vit": label_7
|
138 |
}
|
139 |
print(combined_results)
|
140 |
|
141 |
+
combined_outputs = [result_1output, result_2output, result_3output, result_4output, result_5output, result_5boutput, result_6output, result_7output]
|
142 |
return img_pil, combined_outputs
|
143 |
# Define a function to generate the HTML content
|
144 |
|
|
|
167 |
class="-m-4 h-24 {header_colors[0]} rounded-sm rounded-b-none transition border group-hover:border-gray-100 group-hover:shadow-lg group-hover:{header_colors[4]}">
|
168 |
<span class="text-gray-300 font-mono tracking-widest p-4 pb-3 block text-xs text-center">MODEL {index + 1}:</span>
|
169 |
<span
|
170 |
+
class="flex w-30 mx-auto tracking-wide items-center justify-center rounded-full {header_colors[2]} px-1 py-0.5 {header_colors[3]}"
|
171 |
>
|
172 |
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="3" stroke="currentColor" class="w-4 h-4 mr-2 -ml-3 group-hover:animate group-hover:animate-pulse">
|
173 |
{'<path stroke-linecap="round" stroke-linejoin="round" d="M9 12.75 11.25 15 15 9.75M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Z" />' if label == 'REAL' else '<path stroke-linecap="round" stroke-linejoin="round" d="m9.75 9.75 4.5 4.5m0-4.5-4.5 4.5M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Z" />'}
|
|
|
215 |
{generate_tile_html(3, results[3], "SDXL + FLUX", "cmckinle", MODEL_PATHS["model_4"])}
|
216 |
{generate_tile_html(4, results[4], "Vit Based", "prithivMLmods", MODEL_PATHS["model_5"])}
|
217 |
{generate_tile_html(5, results[5], "Vit Based, Newer Dataset", "prithivMLmods", MODEL_PATHS["model_5b"])}
|
218 |
+
{generate_tile_html(6, results[6], "Swin, Midj + SDXL", "ideepankarsharma2003", MODEL_PATHS["model_6"])}
|
219 |
+
{generate_tile_html(7, results[7], "ViT", "temp", MODEL_PATHS["model_7"])}
|
220 |
</div>
|
221 |
</div>
|
222 |
"""
|