LPX commited on
Commit
a7d5234
·
1 Parent(s): c08bf6c

feat(model):

Browse files

- add support for new model (model_7) for image classification task

- extended CLASS_NAMES to support new model class labels

- added resource loading for the new model

- accessing the new model in the predict_image function

- updated the combined_results and combined_outputs to include new model output

- added a new tile in the HTML results for the new model

♻️ style(frontend):
- increase w-24 to w-30 in HTML CSS snippet

Note: Preferred writing "feat" for adding new model, even if it is just adding an URLs, and "chore" for any small tutorials added.

Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -27,7 +27,8 @@ MODEL_PATHS = {
27
  "model_4": "cmckinle/sdxl-flux-detector",
28
  "model_5": "prithivMLmods/Deep-Fake-Detector-v2-Model",
29
  "model_5b": "prithivMLmods/Deepfake-Detection-Exp-02-22",
30
- "model_6": "ideepankarsharma2003/AI_ImageClassification_MidjourneyV6_SDXL"
 
31
  }
32
 
33
  CLASS_NAMES = {
@@ -38,6 +39,7 @@ CLASS_NAMES = {
38
  "model_5": ['Realism', 'Deepfake'],
39
  "model_5b": ['Real', 'Deepfake'],
40
  "model_6": ['ai_gen', 'human'],
 
41
 
42
  }
43
 
@@ -63,9 +65,13 @@ def load_models():
63
  model_6 = SwinForImageClassification.from_pretrained(MODEL_PATHS["model_6"]).to(device)
64
  clf_6 = pipeline(model=model_6, task="image-classification", image_processor=image_processor_6, device=device)
65
 
66
- return clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, clf_6
 
 
67
 
68
- clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, clf_6 = load_models()
 
 
69
 
70
  @spaces.GPU(duration=10)
71
  def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
@@ -118,6 +124,7 @@ def predict_image(img, confidence_threshold):
118
  label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
119
  label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
120
  label_6, result_6output = predict_with_model(img_pilvits, clf_6, CLASS_NAMES["model_6"], confidence_threshold, "Swin Midjourney/SDXL", 7)
 
121
 
122
  combined_results = {
123
  "SwinV2/detect": label_1,
@@ -126,11 +133,12 @@ def predict_image(img, confidence_threshold):
126
  "Swin/SDXL-FLUX": label_4,
127
  "prithivMLmods": label_5,
128
  "prithivMLmods-2-22": label_5b,
129
- "SwinMidSDXL": label_6
 
130
  }
131
  print(combined_results)
132
 
133
- combined_outputs = [result_1output, result_2output, result_3output, result_4output, result_5output, result_5boutput, result_6output]
134
  return img_pil, combined_outputs
135
  # Define a function to generate the HTML content
136
 
@@ -159,7 +167,7 @@ def generate_results_html(results):
159
  class="-m-4 h-24 {header_colors[0]} rounded-sm rounded-b-none transition border group-hover:border-gray-100 group-hover:shadow-lg group-hover:{header_colors[4]}">
160
  <span class="text-gray-300 font-mono tracking-widest p-4 pb-3 block text-xs text-center">MODEL {index + 1}:</span>
161
  <span
162
- class="flex w-24 mx-auto tracking-wide items-center justify-center rounded-full {header_colors[2]} px-1 py-0.5 {header_colors[3]}"
163
  >
164
  <svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="3" stroke="currentColor" class="w-4 h-4 mr-2 -ml-3 group-hover:animate group-hover:animate-pulse">
165
  {'<path stroke-linecap="round" stroke-linejoin="round" d="M9 12.75 11.25 15 15 9.75M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Z" />' if label == 'REAL' else '<path stroke-linecap="round" stroke-linejoin="round" d="m9.75 9.75 4.5 4.5m0-4.5-4.5 4.5M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Z" />'}
@@ -207,7 +215,8 @@ def generate_results_html(results):
207
  {generate_tile_html(3, results[3], "SDXL + FLUX", "cmckinle", MODEL_PATHS["model_4"])}
208
  {generate_tile_html(4, results[4], "Vit Based", "prithivMLmods", MODEL_PATHS["model_5"])}
209
  {generate_tile_html(5, results[5], "Vit Based, Newer Dataset", "prithivMLmods", MODEL_PATHS["model_5b"])}
210
- {generate_tile_html(6, results[6], "Swin, Midjourney+SDXL", "ideepankarsharma2003", MODEL_PATHS["model_6"])}
 
211
  </div>
212
  </div>
213
  """
 
27
  "model_4": "cmckinle/sdxl-flux-detector",
28
  "model_5": "prithivMLmods/Deep-Fake-Detector-v2-Model",
29
  "model_5b": "prithivMLmods/Deepfake-Detection-Exp-02-22",
30
+ "model_6": "ideepankarsharma2003/AI_ImageClassification_MidjourneyV6_SDXL",
31
+ "model_7": "date3k2/vit-real-fake-classification-v4"
32
  }
33
 
34
  CLASS_NAMES = {
 
39
  "model_5": ['Realism', 'Deepfake'],
40
  "model_5b": ['Real', 'Deepfake'],
41
  "model_6": ['ai_gen', 'human'],
42
+ "model_7": ['Fake', 'Real'],
43
 
44
  }
45
 
 
65
  model_6 = SwinForImageClassification.from_pretrained(MODEL_PATHS["model_6"]).to(device)
66
  clf_6 = pipeline(model=model_6, task="image-classification", image_processor=image_processor_6, device=device)
67
 
68
+ image_processor_7 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_7"], use_fast=True)
69
+ model_7 = AutoModelForImageClassification.from_pretrained(MODEL_PATHS["model_7"]).to(device)
70
+ clf_7 = pipeline(model=model_7, task="image-classification", image_processor=image_processor_7, device=device)
71
 
72
+ return clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, clf_6, model_7, clf_7
73
+
74
+ clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, clf_6, model_7, clf_7 = load_models()
75
 
76
  @spaces.GPU(duration=10)
77
  def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
 
124
  label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
125
  label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
126
  label_6, result_6output = predict_with_model(img_pilvits, clf_6, CLASS_NAMES["model_6"], confidence_threshold, "Swin Midjourney/SDXL", 7)
127
+ label_7, result_7output = predict_with_model(img_pilvits, clf_7, CLASS_NAMES["model_7"], confidence_threshold, "Vit", 7)
128
 
129
  combined_results = {
130
  "SwinV2/detect": label_1,
 
133
  "Swin/SDXL-FLUX": label_4,
134
  "prithivMLmods": label_5,
135
  "prithivMLmods-2-22": label_5b,
136
+ "SwinMidSDXL": label_6,
137
+ "Vit": label_7
138
  }
139
  print(combined_results)
140
 
141
+ combined_outputs = [result_1output, result_2output, result_3output, result_4output, result_5output, result_5boutput, result_6output, result_7output]
142
  return img_pil, combined_outputs
143
  # Define a function to generate the HTML content
144
 
 
167
  class="-m-4 h-24 {header_colors[0]} rounded-sm rounded-b-none transition border group-hover:border-gray-100 group-hover:shadow-lg group-hover:{header_colors[4]}">
168
  <span class="text-gray-300 font-mono tracking-widest p-4 pb-3 block text-xs text-center">MODEL {index + 1}:</span>
169
  <span
170
+ class="flex w-30 mx-auto tracking-wide items-center justify-center rounded-full {header_colors[2]} px-1 py-0.5 {header_colors[3]}"
171
  >
172
  <svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="3" stroke="currentColor" class="w-4 h-4 mr-2 -ml-3 group-hover:animate group-hover:animate-pulse">
173
  {'<path stroke-linecap="round" stroke-linejoin="round" d="M9 12.75 11.25 15 15 9.75M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Z" />' if label == 'REAL' else '<path stroke-linecap="round" stroke-linejoin="round" d="m9.75 9.75 4.5 4.5m0-4.5-4.5 4.5M21 12a9 9 0 1 1-18 0 9 9 0 0 1 18 0Z" />'}
 
215
  {generate_tile_html(3, results[3], "SDXL + FLUX", "cmckinle", MODEL_PATHS["model_4"])}
216
  {generate_tile_html(4, results[4], "Vit Based", "prithivMLmods", MODEL_PATHS["model_5"])}
217
  {generate_tile_html(5, results[5], "Vit Based, Newer Dataset", "prithivMLmods", MODEL_PATHS["model_5b"])}
218
+ {generate_tile_html(6, results[6], "Swin, Midj + SDXL", "ideepankarsharma2003", MODEL_PATHS["model_6"])}
219
+ {generate_tile_html(7, results[7], "ViT", "temp", MODEL_PATHS["model_7"])}
220
  </div>
221
  </div>
222
  """