LPX55 commited on
Commit
25f90cb
·
1 Parent(s): 0e2cdc4

feat: implement universal image loader to support various input types and update preprocessing functions accordingly

Browse files
Files changed (1) hide show
  1. app.py +52 -8
app.py CHANGED
@@ -5,6 +5,9 @@ import numpy as np
5
  import os
6
  import time
7
  import logging
 
 
 
8
 
9
  # Assuming these are available from your utils and agents directories
10
  # You might need to adjust paths or copy these functions/classes if they are not directly importable.
@@ -66,16 +69,57 @@ CLASS_NAMES = {
66
  "model_7": ['Fake', 'Real'],
67
  }
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def preprocess_resize_256(image):
 
70
  if image.mode != 'RGB':
71
  image = image.convert('RGB')
72
  return transforms.Resize((256, 256))(image)
73
 
74
  def preprocess_resize_224(image):
 
75
  if image.mode != 'RGB':
76
  image = image.convert('RGB')
77
  return transforms.Resize((224, 224))(image)
78
 
 
 
 
 
 
 
79
  def postprocess_pipeline(prediction, class_names):
80
  # Assumes HuggingFace pipeline output
81
  return {pred['label']: pred['score'] for pred in prediction}
@@ -109,10 +153,6 @@ register_model_with_metadata(
109
 
110
  feature_extractor_3 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_3"], device=device)
111
  model_3 = AutoModelForImageClassification.from_pretrained(MODEL_PATHS["model_3"]).to(device)
112
- def preprocess_256(image):
113
- if image.mode != 'RGB':
114
- image = image.convert('RGB')
115
- return transforms.Resize((256, 256))(image)
116
  def postprocess_logits_model3(outputs, class_names):
117
  logits = outputs.logits.cpu().numpy()[0]
118
  probabilities = softmax(logits)
@@ -171,8 +211,7 @@ register_model_with_metadata(
171
  )
172
 
173
  def preprocess_simple_prediction(image):
174
- # The simple_prediction function expects a PIL image (filepath is handled internally)
175
- return image
176
 
177
  def postprocess_simple_prediction(result, class_names):
178
  scores = {name: 0.0 for name in class_names}
@@ -184,10 +223,15 @@ def postprocess_simple_prediction(result, class_names):
184
  return scores
185
 
186
  def simple_prediction(img):
 
 
 
 
 
187
  client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
188
  result = client.predict(
189
- input_image=handle_file(img),
190
- api_name="/simple_predict"
191
  )
192
  return result
193
 
 
5
  import os
6
  import time
7
  import logging
8
+ import requests
9
+ import io
10
+ import tempfile
11
 
12
  # Assuming these are available from your utils and agents directories
13
  # You might need to adjust paths or copy these functions/classes if they are not directly importable.
 
69
  "model_7": ['Fake', 'Real'],
70
  }
71
 
72
+ # Universal image loader
73
+ def universal_image_loader(img_input):
74
+ """
75
+ Accepts a PIL Image, NumPy array, file path, or URL.
76
+ Returns a PIL Image.
77
+ """
78
+ if isinstance(img_input, Image.Image):
79
+ return img_input
80
+ if isinstance(img_input, np.ndarray):
81
+ return Image.fromarray(img_input)
82
+ if isinstance(img_input, str):
83
+ if img_input.startswith('http://') or img_input.startswith('https://'):
84
+ try:
85
+ response = requests.get(img_input)
86
+ response.raise_for_status()
87
+ return Image.open(io.BytesIO(response.content)).convert('RGB')
88
+ except Exception as e:
89
+ logger.error(f"Failed to load image from URL: {img_input} | Error: {e}")
90
+ raise
91
+ elif os.path.exists(img_input):
92
+ try:
93
+ return Image.open(img_input).convert('RGB')
94
+ except Exception as e:
95
+ logger.error(f"Failed to load image from file: {img_input} | Error: {e}")
96
+ raise
97
+ else:
98
+ logger.error(f"String input is not a valid file path or URL: {img_input}")
99
+ raise ValueError(f"Invalid image input: {img_input}")
100
+ logger.error(f"Unsupported image input type: {type(img_input)}")
101
+ raise TypeError(f"Unsupported image input type: {type(img_input)}")
102
+
103
+ # Update all preprocessors to use universal_image_loader
104
+
105
  def preprocess_resize_256(image):
106
+ image = universal_image_loader(image)
107
  if image.mode != 'RGB':
108
  image = image.convert('RGB')
109
  return transforms.Resize((256, 256))(image)
110
 
111
  def preprocess_resize_224(image):
112
+ image = universal_image_loader(image)
113
  if image.mode != 'RGB':
114
  image = image.convert('RGB')
115
  return transforms.Resize((224, 224))(image)
116
 
117
+ def preprocess_256(image):
118
+ image = universal_image_loader(image)
119
+ if image.mode != 'RGB':
120
+ image = image.convert('RGB')
121
+ return transforms.Resize((256, 256))(image)
122
+
123
  def postprocess_pipeline(prediction, class_names):
124
  # Assumes HuggingFace pipeline output
125
  return {pred['label']: pred['score'] for pred in prediction}
 
153
 
154
  feature_extractor_3 = AutoFeatureExtractor.from_pretrained(MODEL_PATHS["model_3"], device=device)
155
  model_3 = AutoModelForImageClassification.from_pretrained(MODEL_PATHS["model_3"]).to(device)
 
 
 
 
156
  def postprocess_logits_model3(outputs, class_names):
157
  logits = outputs.logits.cpu().numpy()[0]
158
  probabilities = softmax(logits)
 
211
  )
212
 
213
  def preprocess_simple_prediction(image):
214
+ return universal_image_loader(image)
 
215
 
216
  def postprocess_simple_prediction(result, class_names):
217
  scores = {name: 0.0 for name in class_names}
 
223
  return scores
224
 
225
  def simple_prediction(img):
226
+ img = universal_image_loader(img)
227
+ # Save PIL image to a temporary file
228
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
229
+ img.save(tmp, format="JPEG")
230
+ tmp_path = tmp.name
231
  client = Client("aiwithoutborders-xyz/OpenSight-Community-Forensics-Preview")
232
  result = client.predict(
233
+ input_image=tmp_path,
234
+ api_name="/simple_predict"
235
  )
236
  return result
237