LPX commited on
Commit
57b2083
·
1 Parent(s): c65deae

fix(bug): feature extractor added back in

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -61,10 +61,19 @@ def load_models():
61
  clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b = load_models()
62
 
63
  @spaces.GPU(duration=10)
64
- def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id):
65
  try:
66
- prediction = clf(img_pil)
67
- result = {pred['label']: pred['score'] for pred in prediction}
 
 
 
 
 
 
 
 
 
68
  result_output = [model_id, model_name, result.get(class_names[1], 0.0), result.get(class_names[0], 0.0)]
69
  logger.info(result_output)
70
  for class_name in class_names:
@@ -97,8 +106,8 @@ def predict_image(img, confidence_threshold):
97
 
98
  label_1, result_1output = predict_with_model(img_pil, clf_1, CLASS_NAMES["model_1"], confidence_threshold, "SwinV2-base", 1)
99
  label_2, result_2output = predict_with_model(img_pilvits, clf_2, CLASS_NAMES["model_2"], confidence_threshold, "ViT-base Classifier", 2)
100
- label_3, result_3output = predict_with_model(img_pil, model_3, CLASS_NAMES["model_3"], confidence_threshold, "SDXL-Trained", 3)
101
- label_4, result_4output = predict_with_model(img_pil, model_4, CLASS_NAMES["model_4"], confidence_threshold, "SDXL + FLUX", 4)
102
  label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
103
  label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
104
 
@@ -113,7 +122,6 @@ def predict_image(img, confidence_threshold):
113
 
114
  combined_outputs = [result_1output, result_2output, result_3output, result_4output, result_5output, result_5boutput]
115
  return img_pil, combined_outputs
116
-
117
  # Define a function to generate the HTML content
118
  def generate_results_html(results):
119
  def get_header_color(label):
 
61
  clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b = load_models()
62
 
63
  @spaces.GPU(duration=10)
64
+ def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
65
  try:
66
+ if feature_extractor:
67
+ inputs = feature_extractor(img_pil, return_tensors="pt").to(device)
68
+ with torch.no_grad():
69
+ outputs = clf(**inputs)
70
+ logits = outputs.logits
71
+ probabilities = softmax(logits.cpu().numpy()[0])
72
+ result = {class_names[i]: probabilities[i] for i in range(len(class_names))}
73
+ else:
74
+ prediction = clf(img_pil)
75
+ result = {pred['label']: pred['score'] for pred in prediction}
76
+
77
  result_output = [model_id, model_name, result.get(class_names[1], 0.0), result.get(class_names[0], 0.0)]
78
  logger.info(result_output)
79
  for class_name in class_names:
 
106
 
107
  label_1, result_1output = predict_with_model(img_pil, clf_1, CLASS_NAMES["model_1"], confidence_threshold, "SwinV2-base", 1)
108
  label_2, result_2output = predict_with_model(img_pilvits, clf_2, CLASS_NAMES["model_2"], confidence_threshold, "ViT-base Classifier", 2)
109
+ label_3, result_3output = predict_with_model(img_pil, model_3, CLASS_NAMES["model_3"], confidence_threshold, "SDXL-Trained", 3, feature_extractor_3)
110
+ label_4, result_4output = predict_with_model(img_pil, model_4, CLASS_NAMES["model_4"], confidence_threshold, "SDXL + FLUX", 4, feature_extractor_4)
111
  label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
112
  label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
113
 
 
122
 
123
  combined_outputs = [result_1output, result_2output, result_3output, result_4output, result_5output, result_5boutput]
124
  return img_pil, combined_outputs
 
125
  # Define a function to generate the HTML content
126
  def generate_results_html(results):
127
  def get_header_color(label):