LPX commited on
Commit
36487f1
·
1 Parent(s): 76cdea1

♻️ refactor(pipeline): update model_6 initialization

Browse files

- remove redundant model_6 pipeline initialization
- change model_6 as an AutoModelForImageClassification instead of Swinv2ForImageClassification
- update predict_image to align with model_6 changes

️ style(prediction): fix function predict

- modified line 116 confusion code replace clf_6 with model_6 on the prediction function
- fix indentation of the 118 line, reduce 1 line from indent at line 118 on the 5 model with prediction function

Files changed (1) hide show
  1. app.py +6 -8
app.py CHANGED
@@ -48,11 +48,6 @@ def load_models():
48
  model_1 = model_1.to(device)
49
  clf_1 = pipeline(model=model_1, task="image-classification", image_processor=image_processor_1, device=device)
50
 
51
- image_processor_6 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_6"], use_fast=True)
52
- model_6 = Swinv2ForImageClassification.from_pretrained(MODEL_PATHS["model_6"])
53
- model_6 = model_6.to(device)
54
- clf_6 = pipeline(model=model_6, task="image-classification", image_processor=image_processor_6, device=device)
55
-
56
  clf_2 = pipeline("image-classification", model=MODEL_PATHS["model_2"], device=device)
57
 
58
  feature_extractor_3 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_3"], device=device)
@@ -64,9 +59,12 @@ def load_models():
64
  clf_5 = pipeline("image-classification", model=MODEL_PATHS["model_5"], device=device)
65
  clf_5b = pipeline("image-classification", model=MODEL_PATHS["model_5b"], device=device)
66
 
67
- return clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, clf_6
 
 
 
68
 
69
- clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, clf_6 = load_models()
70
 
71
  @spaces.GPU(duration=10)
72
  def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
@@ -118,7 +116,7 @@ def predict_image(img, confidence_threshold):
118
  label_4, result_4output = predict_with_model(img_pil, model_4, CLASS_NAMES["model_4"], confidence_threshold, "SDXL + FLUX", 4, feature_extractor_4)
119
  label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
120
  label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
121
- label_6, result_6output = predict_with_model(img_pilvits, clf_6, CLASS_NAMES["model_6"], confidence_threshold, "Swin Midjourney/SDXL", 6)
122
 
123
  combined_results = {
124
  "SwinV2/detect": label_1,
 
48
  model_1 = model_1.to(device)
49
  clf_1 = pipeline(model=model_1, task="image-classification", image_processor=image_processor_1, device=device)
50
 
 
 
 
 
 
51
  clf_2 = pipeline("image-classification", model=MODEL_PATHS["model_2"], device=device)
52
 
53
  feature_extractor_3 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_3"], device=device)
 
59
  clf_5 = pipeline("image-classification", model=MODEL_PATHS["model_5"], device=device)
60
  clf_5b = pipeline("image-classification", model=MODEL_PATHS["model_5b"], device=device)
61
 
62
+ feature_extractor_6 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_6"], device=device)
63
+ model_6 = AutoModelForImageClassification.from_pretrained(MODEL_PATHS["model_6"]).to(device)
64
+
65
+ return clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, model_6, feature_extractor_6
66
 
67
+ clf_1, clf_2, feature_extractor_3, model_3, feature_extractor_4, model_4, clf_5, clf_5b, feature_extractor_6, model_6 = load_models()
68
 
69
  @spaces.GPU(duration=10)
70
  def predict_with_model(img_pil, clf, class_names, confidence_threshold, model_name, model_id, feature_extractor=None):
 
116
  label_4, result_4output = predict_with_model(img_pil, model_4, CLASS_NAMES["model_4"], confidence_threshold, "SDXL + FLUX", 4, feature_extractor_4)
117
  label_5, result_5output = predict_with_model(img_pilvits, clf_5, CLASS_NAMES["model_5"], confidence_threshold, "ViT-base Newcomer", 5)
118
  label_5b, result_5boutput = predict_with_model(img_pilvits, clf_5b, CLASS_NAMES["model_5b"], confidence_threshold, "ViT-base Newcomer", 6)
119
+ label_6, result_6output = predict_with_model(img_pilvits, model_6, CLASS_NAMES["model_6"], confidence_threshold, "Swin Midjourney/SDXL", 6, feature_extractor_6)
120
 
121
  combined_results = {
122
  "SwinV2/detect": label_1,