LPX
commited on
Commit
·
36487f1
1
Parent(s):
76cdea1
♻️ refactor(pipeline): update model_6 initialization
Browse files- remove redundant model_6 pipeline initialization
- change model_6 as an AutoModelForImageClassification instead of Swinv2ForImageClassification
- update predict_image to align with model_6 changes
️ style(prediction): fix function predict
- modified line 116 confusion code replace clf_6 with model_6 on the prediction function
- fix indentation of the 118 line, reduce 1 line from indent at line 118 on the 5 model with prediction function
app.py
CHANGED
@@ -48,11 +48,6 @@ def load_models():
|
|
48 |
model_1 = model_1.to(device)
|
49 |
clf_1 = pipeline(model=model_1, task="image-classification", image_processor=image_processor_1, device=device)
|
50 |
|
51 |
-
image_processor_6 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_6"], use_fast=True)
|
52 |
-
model_6 = Swinv2ForImageClassification.from_pretrained(MODEL_PATHS["model_6"])
|
53 |
-
model_6 = model_6.to(device)
|
54 |
-
clf_6 = pipeline(model=model_6, task="image-classification", image_processor=image_processor_6, device=device)
|
55 |
-
|
56 |
clf_2 = pipeline("image-classification", model=MODEL_PATHS["model_2"], device=device)
|
57 |
|
58 |
feature_extractor_3 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_3"], device=device)
|
@@ -64,9 +59,12 @@ def load_models():
|
|
64 |
clf_5 = pipeline("image-classification", model=MODEL_PATHS["model_5"], device=device)
|
65 |
clf_5b = pipeline("image-classification", model=MODEL_PATHS["model_5b"], device=device)
|
66 |
|
67 |
-
|
|
|
|
|
|
|
68 |
|
69 |
-
clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b,
|
70 |
|
71 |
@spaces.GPU(duration=10)
|
72 |
def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
|
@@ -118,7 +116,7 @@ def predict_image(img, confidence_threshold):
|
|
118 |
label_4, result_4output = predict_with_model(img_pil, model_4, CLASS_NAMES["model_4"], confidence_threshold, "SDXL + FLUX", 4, feature_extractor_4)
|
119 |
label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
|
120 |
label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
|
121 |
-
label_6, result_6output = predict_with_model(img_pilvits,
|
122 |
|
123 |
combined_results = {
|
124 |
"SwinV2/detect": label_1,
|
|
|
48 |
model_1 = model_1.to(device)
|
49 |
clf_1 = pipeline(model=model_1, task="image-classification", image_processor=image_processor_1, device=device)
|
50 |
|
|
|
|
|
|
|
|
|
|
|
51 |
clf_2 = pipeline("image-classification", model=MODEL_PATHS["model_2"], device=device)
|
52 |
|
53 |
feature_extractor_3 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_3"], device=device)
|
|
|
59 |
clf_5 = pipeline("image-classification", model=MODEL_PATHS["model_5"], device=device)
|
60 |
clf_5b = pipeline("image-classification", model=MODEL_PATHS["model_5b"], device=device)
|
61 |
|
62 |
+
feature_extractor_6 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_6"], device=device)
|
63 |
+
model_6 = AutoModelForImageClassification.from_pretrained(MODEL_PATHS["model_6"]).to(device)
|
64 |
+
|
65 |
+
return clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, model_6, feature_extractor_6
|
66 |
|
67 |
+
clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, feature_extractor_6, model_6 = load_models()
|
68 |
|
69 |
@spaces.GPU(duration=10)
|
70 |
def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
|
|
|
116 |
label_4, result_4output = predict_with_model(img_pil, model_4, CLASS_NAMES["model_4"], confidence_threshold, "SDXL + FLUX", 4, feature_extractor_4)
|
117 |
label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
|
118 |
label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
|
119 |
+
label_6, result_6output = predict_with_model(img_pilvits, model_6, CLASS_NAMES["model_6"], confidence_threshold, "Swin Midjourney/SDXL", 6, feature_extractor_6)
|
120 |
|
121 |
combined_results = {
|
122 |
"SwinV2/detect": label_1,
|