siriuszeina commited on
Commit
65ea629
·
verified ·
1 Parent(s): 1ed6dd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -1
app.py CHANGED
@@ -43,11 +43,38 @@ def get_image(url) -> PIL.Image:
43
  response = requests.get(url)
44
  image = PIL.Image.open(BytesIO(response.content))
45
  return image
 
 
46
 
47
 
48
  model = load_model()
49
  labels = load_labels()
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  def predict(image: PIL.Image.Image, score_threshold: float) -> tuple[dict[str, float], dict[str, float], str]:
53
  _, height, width, _ = model.input_shape
@@ -104,7 +131,7 @@ with gr.Blocks(css="style.css") as demo:
104
  fn=predict,
105
  inputs=[url, score_threshold],
106
  outputs=[result, result_json, result_text],
107
- api_name="predict",
108
  )
109
 
110
  if __name__ == "__main__":
 
43
  response = requests.get(url)
44
  image = PIL.Image.open(BytesIO(response.content))
45
  return image
46
+
47
+
48
 
49
 
50
  model = load_model()
51
  labels = load_labels()
52
 
53
+ def predictx(url: str, score_threshold: float) -> tuple[dict[str, float], dict[str, float], str]:
54
+ _, height, width, _ = model.input_shape
55
+ response = requests.get(url)
56
+ image = PIL.Image.open(BytesIO(response.content))
57
+
58
+ image = np.asarray(image)
59
+ image = tf.image.resize(image, size=(height, width), method=tf.image.ResizeMethod.AREA, preserve_aspect_ratio=True)
60
+ image = image.numpy()
61
+ image = dd.image.transform_and_pad_image(image, width, height)
62
+ image = image / 255.0
63
+ probs = model.predict(image[None, ...])[0]
64
+ probs = probs.astype(float)
65
+
66
+ indices = np.argsort(probs)[::-1]
67
+ result_all = dict()
68
+ result_threshold = dict()
69
+ for index in indices:
70
+ label = labels[index]
71
+ prob = probs[index]
72
+ result_all[label] = prob
73
+ if prob < score_threshold:
74
+ break
75
+ result_threshold[label] = prob
76
+ result_text = ", ".join(result_all.keys())
77
+ return result_threshold, result_all, result_text
78
 
79
  def predict(image: PIL.Image.Image, score_threshold: float) -> tuple[dict[str, float], dict[str, float], str]:
80
  _, height, width, _ = model.input_shape
 
131
  fn=predict,
132
  inputs=[url, score_threshold],
133
  outputs=[result, result_json, result_text],
134
+ api_name="predictx",
135
  )
136
 
137
  if __name__ == "__main__":