LPX
commited on
Commit
·
b8128c0
1
Parent(s):
d0dfcb4
🛠️ chore(model): update Swin model and refactor model loading
Browse files-improved Swin model import
-replaced redundant variable: model_6 with clf_6
🐛 fix(model): fix prediction for Swin model in `predict_image` method
-removed redundant feature_extractor usage in `predict_with_model`
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
-
from transformers import pipeline, AutoImageProcessor, Swinv2ForImageClassification, AutoFeatureExtractor, AutoModelForImageClassification
|
4 |
from torchvision import transforms
|
5 |
import torch
|
6 |
from PIL import Image
|
@@ -59,12 +59,13 @@ def load_models():
|
|
59 |
clf_5 = pipeline("image-classification", model=MODEL_PATHS["model_5"], device=device)
|
60 |
clf_5b = pipeline("image-classification", model=MODEL_PATHS["model_5b"], device=device)
|
61 |
|
62 |
-
|
63 |
-
model_6 =
|
|
|
64 |
|
65 |
-
return clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b,
|
66 |
|
67 |
-
clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b,
|
68 |
|
69 |
@spaces.GPU(duration=10)
|
70 |
def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
|
@@ -116,7 +117,7 @@ def predict_image(img, confidence_threshold):
|
|
116 |
label_4, result_4output = predict_with_model(img_pil, model_4, CLASS_NAMES["model_4"], confidence_threshold, "SDXL + FLUX", 4, feature_extractor_4)
|
117 |
label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
|
118 |
label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
|
119 |
-
label_6, result_6output = predict_with_model(img_pilvits,
|
120 |
|
121 |
combined_results = {
|
122 |
"SwinV2/detect": label_1,
|
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
+
from transformers import pipeline, AutoImageProcessor, SwinForImageClassification, Swinv2ForImageClassification, AutoFeatureExtractor, AutoModelForImageClassification
|
4 |
from torchvision import transforms
|
5 |
import torch
|
6 |
from PIL import Image
|
|
|
59 |
clf_5 = pipeline("image-classification", model=MODEL_PATHS["model_5"], device=device)
|
60 |
clf_5b = pipeline("image-classification", model=MODEL_PATHS["model_5b"], device=device)
|
61 |
|
62 |
+
image_processor_6 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_6"], use_fast=True)
|
63 |
+
model_6 = SwinForImageClassification.from_pretrained(MODEL_PATHS["model_6"]).to(device)
|
64 |
+
clf_6 = pipeline(model=model_6, task="image-classification", image_processor=image_processor_6, device=device)
|
65 |
|
66 |
+
return clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, clf_6
|
67 |
|
68 |
+
clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, clf_6 = load_models()
|
69 |
|
70 |
@spaces.GPU(duration=10)
|
71 |
def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
|
|
|
117 |
label_4, result_4output = predict_with_model(img_pil, model_4, CLASS_NAMES["model_4"], confidence_threshold, "SDXL + FLUX", 4, feature_extractor_4)
|
118 |
label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
|
119 |
label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
|
120 |
+
label_6, result_6output = predict_with_model(img_pilvits, clf_6, CLASS_NAMES["model_6"], confidence_threshold, "Swin Midjourney/SDXL", 7)
|
121 |
|
122 |
combined_results = {
|
123 |
"SwinV2/detect": label_1,
|