feat: enhance model registration by adding architecture and dataset metadata to the ModelEntry class
Browse files- app.py +19 -14
- utils/registry.py +6 -3
app.py
CHANGED
@@ -86,11 +86,8 @@ def postprocess_logits(outputs, class_names):
|
|
86 |
probabilities = softmax(logits)
|
87 |
return {class_names[i]: probabilities[i] for i in range(len(class_names))}
|
88 |
|
89 |
-
def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path):
|
90 |
-
entry = ModelEntry(model, preprocess, postprocess, class_names)
|
91 |
-
entry.display_name = display_name
|
92 |
-
entry.contributor = contributor
|
93 |
-
entry.model_path = model_path
|
94 |
MODEL_REGISTRY[model_id] = entry
|
95 |
|
96 |
# Load and register models (copied from app_mcp.py)
|
@@ -99,13 +96,15 @@ model_1 = Swinv2ForImageClassification.from_pretrained(MODEL_PATHS["model_1"]).t
|
|
99 |
clf_1 = pipeline(model=model_1, task="image-classification", image_processor=image_processor_1, device=device)
|
100 |
register_model_with_metadata(
|
101 |
"model_1", clf_1, preprocess_resize_256, postprocess_pipeline, CLASS_NAMES["model_1"],
|
102 |
-
display_name="
|
|
|
103 |
)
|
104 |
|
105 |
clf_2 = pipeline("image-classification", model=MODEL_PATHS["model_2"], device=device)
|
106 |
register_model_with_metadata(
|
107 |
"model_2", clf_2, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_2"],
|
108 |
-
display_name="
|
|
|
109 |
)
|
110 |
|
111 |
feature_extractor_3 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_3"], device=device)
|
@@ -125,7 +124,8 @@ def model3_infer(image):
|
|
125 |
return outputs
|
126 |
register_model_with_metadata(
|
127 |
"model_3", model3_infer, preprocess_256, postprocess_logits_model3, CLASS_NAMES["model_3"],
|
128 |
-
display_name="
|
|
|
129 |
)
|
130 |
|
131 |
feature_extractor_4 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_4"], device=device)
|
@@ -141,13 +141,15 @@ def postprocess_logits_model4(outputs, class_names):
|
|
141 |
return {class_names[i]: probabilities[i] for i in range(len(class_names))}
|
142 |
register_model_with_metadata(
|
143 |
"model_4", model4_infer, preprocess_256, postprocess_logits_model4, CLASS_NAMES["model_4"],
|
144 |
-
display_name="
|
|
|
145 |
)
|
146 |
|
147 |
clf_5 = pipeline("image-classification", model=MODEL_PATHS["model_5"], device=device)
|
148 |
register_model_with_metadata(
|
149 |
"model_5", clf_5, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_5"],
|
150 |
-
display_name="
|
|
|
151 |
)
|
152 |
|
153 |
image_processor_6 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_6"], use_fast=True)
|
@@ -155,7 +157,8 @@ model_6 = SwinForImageClassification.from_pretrained(MODEL_PATHS["model_6"]).to(
|
|
155 |
clf_6 = pipeline(model=model_6, task="image-classification", image_processor=image_processor_6, device=device)
|
156 |
register_model_with_metadata(
|
157 |
"model_6", clf_6, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_6"],
|
158 |
-
display_name="
|
|
|
159 |
)
|
160 |
|
161 |
image_processor_7 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_7"], use_fast=True)
|
@@ -163,7 +166,8 @@ model_7 = AutoModelForImageClassification.from_pretrained(MODEL_PATHS["model_7"]
|
|
163 |
clf_7 = pipeline(model=model_7, task="image-classification", image_processor=image_processor_7, device=device)
|
164 |
register_model_with_metadata(
|
165 |
"model_7", clf_7, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_7"],
|
166 |
-
display_name="
|
|
|
167 |
)
|
168 |
|
169 |
def preprocess_simple_prediction(image):
|
@@ -196,7 +200,8 @@ register_model_with_metadata(
|
|
196 |
["AI", "REAL"],
|
197 |
display_name="Community Forensics",
|
198 |
contributor="Jeongsoo Park",
|
199 |
-
model_path="aiwithoutborders-xyz/CommunityForensics-DeepfakeDet-ViT"
|
|
|
200 |
)
|
201 |
def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75) -> dict:
|
202 |
"""Predict using a specific model.
|
@@ -449,7 +454,7 @@ detection_model_eval_playground = gr.Interface(
|
|
449 |
gr.Gallery(label="Post Processed Images", visible=True, columns=[4], rows=[2], container=False, height="auto", object_fit="contain", elem_id="post-gallery"),
|
450 |
gr.Dataframe(
|
451 |
label="Model Predictions",
|
452 |
-
headers=["
|
453 |
datatype=["str", "str", "number", "number", "str"]
|
454 |
),
|
455 |
gr.JSON(label="Raw Model Results", visible=False),
|
|
|
86 |
probabilities = softmax(logits)
|
87 |
return {class_names[i]: probabilities[i] for i in range(len(class_names))}
|
88 |
|
89 |
+
def register_model_with_metadata(model_id, model, preprocess, postprocess, class_names, display_name, contributor, model_path, architecture=None, dataset=None):
|
90 |
+
entry = ModelEntry(model, preprocess, postprocess, class_names, display_name=display_name, contributor=contributor, model_path=model_path, architecture=architecture, dataset=dataset)
|
|
|
|
|
|
|
91 |
MODEL_REGISTRY[model_id] = entry
|
92 |
|
93 |
# Load and register models (copied from app_mcp.py)
|
|
|
96 |
clf_1 = pipeline(model=model_1, task="image-classification", image_processor=image_processor_1, device=device)
|
97 |
register_model_with_metadata(
|
98 |
"model_1", clf_1, preprocess_resize_256, postprocess_pipeline, CLASS_NAMES["model_1"],
|
99 |
+
display_name="SWIN1", contributor="haywoodsloan", model_path=MODEL_PATHS["model_1"],
|
100 |
+
architecture="SwinV2", dataset="TBA"
|
101 |
)
|
102 |
|
103 |
clf_2 = pipeline("image-classification", model=MODEL_PATHS["model_2"], device=device)
|
104 |
register_model_with_metadata(
|
105 |
"model_2", clf_2, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_2"],
|
106 |
+
display_name="VIT2", contributor="Heem2", model_path=MODEL_PATHS["model_2"],
|
107 |
+
architecture="ViT", dataset="TBA"
|
108 |
)
|
109 |
|
110 |
feature_extractor_3 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_3"], device=device)
|
|
|
124 |
return outputs
|
125 |
register_model_with_metadata(
|
126 |
"model_3", model3_infer, preprocess_256, postprocess_logits_model3, CLASS_NAMES["model_3"],
|
127 |
+
display_name="SDXL3", contributor="Organika", model_path=MODEL_PATHS["model_3"],
|
128 |
+
architecture="VIT", dataset="SDXL"
|
129 |
)
|
130 |
|
131 |
feature_extractor_4 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_4"], device=device)
|
|
|
141 |
return {class_names[i]: probabilities[i] for i in range(len(class_names))}
|
142 |
register_model_with_metadata(
|
143 |
"model_4", model4_infer, preprocess_256, postprocess_logits_model4, CLASS_NAMES["model_4"],
|
144 |
+
display_name="XLFLUX4", contributor="cmckinle", model_path=MODEL_PATHS["model_4"],
|
145 |
+
architecture="VIT", dataset="SDXL, FLUX"
|
146 |
)
|
147 |
|
148 |
clf_5 = pipeline("image-classification", model=MODEL_PATHS["model_5"], device=device)
|
149 |
register_model_with_metadata(
|
150 |
"model_5", clf_5, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_5"],
|
151 |
+
display_name="VIT5", contributor="prithivMLmods", model_path=MODEL_PATHS["model_5"],
|
152 |
+
architecture="VIT", dataset="TBA"
|
153 |
)
|
154 |
|
155 |
image_processor_6 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_6"], use_fast=True)
|
|
|
157 |
clf_6 = pipeline(model=model_6, task="image-classification", image_processor=image_processor_6, device=device)
|
158 |
register_model_with_metadata(
|
159 |
"model_6", clf_6, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_6"],
|
160 |
+
display_name="SWIN6", contributor="ideepankarsharma2003", model_path=MODEL_PATHS["model_6"],
|
161 |
+
architecture="SWINv1", dataset="SDXL, Midjourney"
|
162 |
)
|
163 |
|
164 |
image_processor_7 = AutoImageProcessor.from_pretrained(MODEL_PATHS["model_7"], use_fast=True)
|
|
|
166 |
clf_7 = pipeline(model=model_7, task="image-classification", image_processor=image_processor_7, device=device)
|
167 |
register_model_with_metadata(
|
168 |
"model_7", clf_7, preprocess_resize_224, postprocess_pipeline, CLASS_NAMES["model_7"],
|
169 |
+
display_name="VIT7", contributor="date3k2", model_path=MODEL_PATHS["model_7"],
|
170 |
+
architecture="VIT", dataset="TBA"
|
171 |
)
|
172 |
|
173 |
def preprocess_simple_prediction(image):
|
|
|
200 |
["AI", "REAL"],
|
201 |
display_name="Community Forensics",
|
202 |
contributor="Jeongsoo Park",
|
203 |
+
model_path="aiwithoutborders-xyz/CommunityForensics-DeepfakeDet-ViT",
|
204 |
+
architecture="ViT", dataset="GOAT"
|
205 |
)
|
206 |
def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75) -> dict:
|
207 |
"""Predict using a specific model.
|
|
|
454 |
gr.Gallery(label="Post Processed Images", visible=True, columns=[4], rows=[2], container=False, height="auto", object_fit="contain", elem_id="post-gallery"),
|
455 |
gr.Dataframe(
|
456 |
label="Model Predictions",
|
457 |
+
headers=["Arch / Dataset", "By", "AI", "Real", "Label"],
|
458 |
datatype=["str", "str", "number", "number", "str"]
|
459 |
),
|
460 |
gr.JSON(label="Raw Model Results", visible=False),
|
utils/registry.py
CHANGED
@@ -2,7 +2,8 @@ from typing import Callable, Dict, Any, List, Optional
|
|
2 |
|
3 |
class ModelEntry:
|
4 |
def __init__(self, model: Any, preprocess: Callable, postprocess: Callable, class_names: List[str],
|
5 |
-
display_name: Optional[str] = None, contributor: Optional[str] = None, model_path: Optional[str] = None
|
|
|
6 |
self.model = model
|
7 |
self.preprocess = preprocess
|
8 |
self.postprocess = postprocess
|
@@ -10,8 +11,10 @@ class ModelEntry:
|
|
10 |
self.display_name = display_name
|
11 |
self.contributor = contributor
|
12 |
self.model_path = model_path
|
|
|
|
|
13 |
|
14 |
MODEL_REGISTRY: Dict[str, ModelEntry] = {}
|
15 |
|
16 |
-
def register_model(model_id: str, model: Any, preprocess: Callable, postprocess: Callable, class_names: List[str]):
|
17 |
-
MODEL_REGISTRY[model_id] = ModelEntry(model, preprocess, postprocess, class_names)
|
|
|
2 |
|
3 |
class ModelEntry:
|
4 |
def __init__(self, model: Any, preprocess: Callable, postprocess: Callable, class_names: List[str],
|
5 |
+
display_name: Optional[str] = None, contributor: Optional[str] = None, model_path: Optional[str] = None,
|
6 |
+
architecture: Optional[str] = None, dataset: Optional[str] = None):
|
7 |
self.model = model
|
8 |
self.preprocess = preprocess
|
9 |
self.postprocess = postprocess
|
|
|
11 |
self.display_name = display_name
|
12 |
self.contributor = contributor
|
13 |
self.model_path = model_path
|
14 |
+
self.architecture = architecture
|
15 |
+
self.dataset = dataset
|
16 |
|
17 |
MODEL_REGISTRY: Dict[str, ModelEntry] = {}
|
18 |
|
19 |
+
def register_model(model_id: str, model: Any, preprocess: Callable, postprocess: Callable, class_names: List[str], architecture: Optional[str] = None, dataset: Optional[str] = None):
|
20 |
+
MODEL_REGISTRY[model_id] = ModelEntry(model, preprocess, postprocess, class_names, architecture=architecture, dataset=dataset)
|