LPX
commited on
Commit
·
57b2083
1
Parent(s):
c65deae
fix(bug): feature extractor added back in
Browse files
app.py
CHANGED
@@ -61,10 +61,19 @@ def load_models():
|
|
61 |
clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b = load_models()
|
62 |
|
63 |
@spaces.GPU(duration=10)
|
64 |
-
def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id):
|
65 |
try:
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
result_output = [model_id, model_name, result.get(class_names[1], 0.0), result.get(class_names[0], 0.0)]
|
69 |
logger.info(result_output)
|
70 |
for class_name in class_names:
|
@@ -97,8 +106,8 @@ def predict_image(img, confidence_threshold):
|
|
97 |
|
98 |
label_1, result_1output = predict_with_model(img_pil, clf_1, CLASS_NAMES["model_1"], confidence_threshold, "SwinV2-base", 1)
|
99 |
label_2, result_2output = predict_with_model(img_pilvits, clf_2, CLASS_NAMES["model_2"], confidence_threshold, "ViT-base Classifier", 2)
|
100 |
-
label_3, result_3output = predict_with_model(img_pil, model_3, CLASS_NAMES["model_3"], confidence_threshold, "SDXL-Trained", 3)
|
101 |
-
label_4, result_4output = predict_with_model(img_pil, model_4, CLASS_NAMES["model_4"], confidence_threshold, "SDXL + FLUX", 4)
|
102 |
label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
|
103 |
label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
|
104 |
|
@@ -113,7 +122,6 @@ def predict_image(img, confidence_threshold):
|
|
113 |
|
114 |
combined_outputs = [result_1output, result_2output, result_3output, result_4output, result_5output, result_5boutput]
|
115 |
return img_pil, combined_outputs
|
116 |
-
|
117 |
# Define a function to generate the HTML content
|
118 |
def generate_results_html(results):
|
119 |
def get_header_color(label):
|
|
|
61 |
clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b = load_models()
|
62 |
|
63 |
@spaces.GPU(duration=10)
|
64 |
+
def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
|
65 |
try:
|
66 |
+
if feature_extractor:
|
67 |
+
inputs = feature_extractor(img_pil, return_tensors="pt").to(device)
|
68 |
+
with torch.no_grad():
|
69 |
+
outputs = clf(**inputs)
|
70 |
+
logits = outputs.logits
|
71 |
+
probabilities = softmax(logits.cpu().numpy()[0])
|
72 |
+
result = {class_names[i]: probabilities[i] for i in range(len(class_names))}
|
73 |
+
else:
|
74 |
+
prediction = clf(img_pil)
|
75 |
+
result = {pred['label']: pred['score'] for pred in prediction}
|
76 |
+
|
77 |
result_output = [model_id, model_name, result.get(class_names[1], 0.0), result.get(class_names[0], 0.0)]
|
78 |
logger.info(result_output)
|
79 |
for class_name in class_names:
|
|
|
106 |
|
107 |
label_1, result_1output = predict_with_model(img_pil, clf_1, CLASS_NAMES["model_1"], confidence_threshold, "SwinV2-base", 1)
|
108 |
label_2, result_2output = predict_with_model(img_pilvits, clf_2, CLASS_NAMES["model_2"], confidence_threshold, "ViT-base Classifier", 2)
|
109 |
+
label_3, result_3output = predict_with_model(img_pil, model_3, CLASS_NAMES["model_3"], confidence_threshold, "SDXL-Trained", 3, feature_extractor_3)
|
110 |
+
label_4, result_4output = predict_with_model(img_pil, model_4, CLASS_NAMES["model_4"], confidence_threshold, "SDXL + FLUX", 4, feature_extractor_4)
|
111 |
label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
|
112 |
label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
|
113 |
|
|
|
122 |
|
123 |
combined_outputs = [result_1output, result_2output, result_3output, result_4output, result_5output, result_5boutput]
|
124 |
return img_pil, combined_outputs
|
|
|
125 |
# Define a function to generate the HTML content
|
126 |
def generate_results_html(results):
|
127 |
def get_header_color(label):
|