xcurvnubaim commited on
Commit
890b83b
·
1 Parent(s): 020a381

fix: fix bytes

Browse files
Files changed (1) hide show
  1. main.py +10 -14
main.py CHANGED
@@ -1,8 +1,8 @@
1
  import numpy as np
2
  from fastapi import FastAPI, File, UploadFile
3
  import tensorflow as tf
4
- from io import StringIO
5
  from PIL import Image
 
6
 
7
  app = FastAPI()
8
 
@@ -12,23 +12,19 @@ with open("labels.txt") as f:
12
  for line in f:
13
  labels.append(line.replace('\n', ''))
14
 
15
- def classify_image(inp):
16
- # Create a copy of the input array to avoid reference issues
17
- inp_copy = np.copy(inp)
18
  # Resize the input image to the expected shape (224, 224)
19
- inp_copy = Image.fromarray(inp_copy)
20
- inp_copy = inp_copy.resize((224, 224))
21
- inp_copy = np.array(inp_copy)
22
- inp_copy = inp_copy.reshape((-1, 224, 224, 3))
23
- inp_copy = tf.keras.applications.efficientnet.preprocess_input(inp_copy)
24
- prediction = model.predict(inp_copy).flatten()
25
  confidences = {labels[i]: float(prediction[i]) for i in range(90)}
26
  return confidences
27
 
28
- @app.post("/predict/")
29
  async def predict(file: UploadFile = File(...)):
30
  contents = await file.read()
31
- img = Image.open(StringIO(contents))
32
- img = np.array(img)
33
  confidences = classify_image(img)
34
- return confidences
 
1
  import numpy as np
2
  from fastapi import FastAPI, File, UploadFile
3
  import tensorflow as tf
 
4
  from PIL import Image
5
+ from io import BytesIO
6
 
7
  app = FastAPI()
8
 
 
12
  for line in f:
13
  labels.append(line.replace('\n', ''))
14
 
15
+ def classify_image(img):
 
 
16
  # Resize the input image to the expected shape (224, 224)
17
+ img_resized = img.resize((224, 224))
18
+ img_array = np.array(img_resized)
19
+ img_array = img_array.reshape((-1, 224, 224, 3))
20
+ img_array = tf.keras.applications.efficientnet.preprocess_input(img_array)
21
+ prediction = model.predict(img_array).flatten()
 
22
  confidences = {labels[i]: float(prediction[i]) for i in range(90)}
23
  return confidences
24
 
25
+ @app.post("/predict")
26
  async def predict(file: UploadFile = File(...)):
27
  contents = await file.read()
28
+ img = Image.open(BytesIO(contents))
 
29
  confidences = classify_image(img)
30
+ return confidences