LPX55 commited on
Commit
2be263d
·
1 Parent(s): 48e0567

refactor: remove unnecessary preprocessing function and streamline image handling in simple_prediction

Browse files
Files changed (1) hide show
  1. app.py +10 -22
app.py CHANGED
@@ -172,13 +172,6 @@ register_model_with_metadata(
172
  architecture="VIT", dataset="TBA"
173
  )
174
 
175
- def preprocess_simple_prediction(image):
176
- print(type(image))
177
- im = load_image(image)
178
- print(type(im))
179
- # The simple_prediction function expects a PIL image (filepath is handled internally)
180
- return image
181
-
182
  def postprocess_simple_prediction(result, class_names):
183
  scores = {name: 0.0 for name in class_names}
184
  fake_prob = result.get("Fake Probability")
@@ -190,26 +183,21 @@ def postprocess_simple_prediction(result, class_names):
190
 
191
  def simple_prediction(img):
192
  client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
193
-
194
- # Convert PIL Image to a file-like object in memory
195
- img_byte_arr = io.BytesIO()
196
- img.save(img_byte_arr, format='PNG') # Using PNG for lossless conversion, can be JPEG if preferred
197
- img_byte_arr.seek(0) # Rewind to the beginning of the stream
198
- im = load_image(img)
199
-
200
  result = client.predict(
201
- input_image=handle_file(img),
202
- api_name="/simple_predict"
203
  )
204
  return result
205
 
206
 
207
  register_model_with_metadata(
208
- "simple_prediction",
209
- simple_prediction,
210
- preprocess_simple_prediction,
211
- postprocess_simple_prediction,
212
- ["AI", "REAL"],
213
  display_name="Community Forensics",
214
  contributor="Jeongsoo Park",
215
  model_path="aiwithoutborders-xyz/CommunityForensics-DeepfakeDet-ViT",
@@ -227,7 +215,7 @@ def infer(image: Image.Image, model_id: str, confidence_threshold: float = 0.75)
227
  dict: A dictionary containing the model details, classification scores, and label.
228
  """
229
  entry = MODEL_REGISTRY[model_id]
230
- img = entry.preprocess(image)
231
  try:
232
  result = entry.model(img)
233
  scores = entry.postprocess(result, entry.class_names)
 
172
  architecture="VIT", dataset="TBA"
173
  )
174
 
 
 
 
 
 
 
 
175
  def postprocess_simple_prediction(result, class_names):
176
  scores = {name: 0.0 for name in class_names}
177
  fake_prob = result.get("Fake Probability")
 
183
 
184
  def simple_prediction(img):
185
  client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
186
+ client.view_api()
187
+ print(type(img))
 
 
 
 
 
188
  result = client.predict(
189
+ handle_file(img),
190
+ api_name="simple_predict"
191
  )
192
  return result
193
 
194
 
195
  register_model_with_metadata(
196
+ model_id="simple_prediction",
197
+ model=simple_prediction,
198
+ preprocess=None,
199
+ postprocess=postprocess_simple_prediction,
200
+ class_names=["AI", "REAL"],
201
  display_name="Community Forensics",
202
  contributor="Jeongsoo Park",
203
  model_path="aiwithoutborders-xyz/CommunityForensics-DeepfakeDet-ViT",
 
215
  dict: A dictionary containing the model details, classification scores, and label.
216
  """
217
  entry = MODEL_REGISTRY[model_id]
218
+ img = entry.preprocess(image) if entry.preprocess else image
219
  try:
220
  result = entry.model(img)
221
  scores = entry.postprocess(result, entry.class_names)