LPX commited on
Commit
b8128c0
·
1 Parent(s): d0dfcb4

🛠️ chore(model): update Swin model and refactor model loading

Browse files

-improved Swin model import
-replaced redundant variable: model_6 with clf_6

🐛 fix(model): fix prediction for Swin model in `predict_image` method

-removed redundant feature_extractor usage in `predict_with_model`

Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import spaces
2
  import gradio as gr
3
- from transformers import pipeline, AutoImageProcessor, Swinv2ForImageClassification, AutoFeatureExtractor, AutoModelForImageClassification
4
  from torchvision import transforms
5
  import torch
6
  from PIL import Image
@@ -59,12 +59,13 @@ def load_models():
59
  clf_5 = pipeline("image-classification", model=MODEL_PATHS["model_5"], device=device)
60
  clf_5b = pipeline("image-classification", model=MODEL_PATHS["model_5b"], device=device)
61
 
62
- feature_extractor_6 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_6"], device=device)
63
- model_6 = AutoModelForImageClassification.from_pretrained(MODEL_PATHS["model_6"]).to(device)
 
64
 
65
- return clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, model_6, feature_extractor_6
66
 
67
- clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, feature_extractor_6, model_6 = load_models()
68
 
69
  @spaces.GPU(duration=10)
70
  def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
@@ -116,7 +117,7 @@ def predict_image(img, confidence_threshold):
116
  label_4, result_4output = predict_with_model(img_pil, model_4, CLASS_NAMES["model_4"], confidence_threshold, "SDXL + FLUX", 4, feature_extractor_4)
117
  label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
118
  label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
119
- label_6, result_6output = predict_with_model(img_pilvits, model_6, CLASS_NAMES["model_6"], confidence_threshold, "Swin Midjourney/SDXL", 7, feature_extractor_6)
120
 
121
  combined_results = {
122
  "SwinV2/detect": label_1,
 
1
  import spaces
2
  import gradio as gr
3
+ from transformers import pipeline, AutoImageProcessor, SwinForImageClassification, Swinv2ForImageClassification, AutoFeatureExtractor, AutoModelForImageClassification
4
  from torchvision import transforms
5
  import torch
6
  from PIL import Image
 
59
  clf_5 = pipeline("image-classification", model=MODEL_PATHS["model_5"], device=device)
60
  clf_5b = pipeline("image-classification", model=MODEL_PATHS["model_5b"], device=device)
61
 
62
+ image_processor_6 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_6"], use_fast=True)
63
+ model_6 = SwinForImageClassification.from_pretrained(MODEL_PATHS["model_6"]).to(device)
64
+ clf_6 = pipeline(model=model_6, task="image-classification", image_processor=image_processor_6, device=device)
65
 
66
+ return clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, clf_6
67
 
68
+ clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, clf_6 = load_models()
69
 
70
  @spaces.GPU(duration=10)
71
  def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
 
117
  label_4, result_4output = predict_with_model(img_pil, model_4, CLASS_NAMES["model_4"], confidence_threshold, "SDXL + FLUX", 4, feature_extractor_4)
118
  label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
119
  label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
120
+ label_6, result_6output = predict_with_model(img_pilvits, clf_6, CLASS_NAMES["model_6"], confidence_threshold, "Swin Midjourney/SDXL", 7)
121
 
122
  combined_results = {
123
  "SwinV2/detect": label_1,