refactor: remove unnecessary preprocessing function and streamline image handling in simple_prediction
Browse files
app.py
CHANGED
@@ -172,13 +172,6 @@ register_model_with_metadata(
|
|
172 |
architecture="VIT", dataset="TBA"
|
173 |
)
|
174 |
|
175 |
-
def preprocess_simple_prediction(image):
|
176 |
-
print(type(image))
|
177 |
-
im = load_image(image)
|
178 |
-
print(type(im))
|
179 |
-
# The simple_prediction function expects a PIL image (filepath is handled internally)
|
180 |
-
return image
|
181 |
-
|
182 |
def postprocess_simple_prediction(result, class_names):
|
183 |
scores = {name: 0.0 for name in class_names}
|
184 |
fake_prob = result.get("Fake Probability")
|
@@ -190,26 +183,21 @@ def postprocess_simple_prediction(result, class_names):
|
|
190 |
|
191 |
def simple_prediction(img):
|
192 |
client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
|
193 |
-
|
194 |
-
|
195 |
-
img_byte_arr = io.BytesIO()
|
196 |
-
img.save(img_byte_arr, format='PNG') # Using PNG for lossless conversion, can be JPEG if preferred
|
197 |
-
img_byte_arr.seek(0) # Rewind to the beginning of the stream
|
198 |
-
im = load_image(img)
|
199 |
-
|
200 |
result = client.predict(
|
201 |
-
|
202 |
-
|
203 |
)
|
204 |
return result
|
205 |
|
206 |
|
207 |
register_model_with_metadata(
|
208 |
-
"simple_prediction",
|
209 |
-
simple_prediction,
|
210 |
-
|
211 |
-
postprocess_simple_prediction,
|
212 |
-
["AI", "REAL"],
|
213 |
display_name="Community Forensics",
|
214 |
contributor="Jeongsoo Park",
|
215 |
model_path="aiwithoutborders-xyz/CommunityForensics-DeepfakeDet-ViT",
|
@@ -227,7 +215,7 @@ def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75)
|
|
227 |
dict: A dictionary containing the model details, classification scores, and label.
|
228 |
"""
|
229 |
entry = MODEL_REGISTRY[model_id]
|
230 |
-
img = entry.preprocess(image)
|
231 |
try:
|
232 |
result = entry.model(img)
|
233 |
scores = entry.postprocess(result, entry.class_names)
|
|
|
172 |
architecture="VIT", dataset="TBA"
|
173 |
)
|
174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
def postprocess_simple_prediction(result, class_names):
|
176 |
scores = {name: 0.0 for name in class_names}
|
177 |
fake_prob = result.get("Fake Probability")
|
|
|
183 |
|
184 |
def simple_prediction(img):
|
185 |
client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
|
186 |
+
client.view_api()
|
187 |
+
print(type(img))
|
|
|
|
|
|
|
|
|
|
|
188 |
result = client.predict(
|
189 |
+
handle_file(img),
|
190 |
+
api_name="simple_predict"
|
191 |
)
|
192 |
return result
|
193 |
|
194 |
|
195 |
register_model_with_metadata(
|
196 |
+
model_id="simple_prediction",
|
197 |
+
model=simple_prediction,
|
198 |
+
preprocess=None,
|
199 |
+
postprocess=postprocess_simple_prediction,
|
200 |
+
class_names=["AI", "REAL"],
|
201 |
display_name="Community Forensics",
|
202 |
contributor="Jeongsoo Park",
|
203 |
model_path="aiwithoutborders-xyz/CommunityForensics-DeepfakeDet-ViT",
|
|
|
215 |
dict: A dictionary containing the model details, classification scores, and label.
|
216 |
"""
|
217 |
entry = MODEL_REGISTRY[model_id]
|
218 |
+
img = entry.preprocess(image) if entry.preprocess else image
|
219 |
try:
|
220 |
result = entry.model(img)
|
221 |
scores = entry.postprocess(result, entry.class_names)
|