Upload 13 files
Browse files- Dockerfile +11 -0
- README.md +54 -3
- __init__.py +0 -0
- app/__init__.py +0 -0
- app/main.py +34 -0
- app/model.py +41 -0
- app/saved_model.pb +3 -0
- convert_model.py +33 -0
- main.py +34 -0
- model.py +41 -0
- requirements.txt +7 -0
- saved_model.pb +3 -0
- streamlit_viz.py +71 -0
Dockerfile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10
|
2 |
+
|
3 |
+
WORKDIR /code
|
4 |
+
|
5 |
+
COPY ./requirements.txt /code/requirements.txt
|
6 |
+
|
7 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
8 |
+
|
9 |
+
COPY ./app /code/app
|
10 |
+
|
11 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8080"]
|
README.md
CHANGED
@@ -1,3 +1,54 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Saliency Inference API Template
|
2 |
+
|
3 |
+
This is an API and Streamlit app to interact with a saliency model. The API is built using FastAPI and the Streamlit app is built using Streamlit. The API is built to be run in a Docker container.
|
4 |
+
|
5 |
+
## Setup
|
6 |
+
|
7 |
+
### Install dependencies
|
8 |
+
|
9 |
+
```bash
|
10 |
+
pip install -r requirements.txt
|
11 |
+
```
|
12 |
+
|
13 |
+
### Run the API
|
14 |
+
|
15 |
+
```bash
|
16 |
+
uvicorn main:app --reload --workers 1 --host 0.0.0.0 --port 8080
|
17 |
+
```
|
18 |
+
|
19 |
+
This will run the FastAPI server on port 8080.
|
20 |
+
|
21 |
+
### (Alternative) Run the API in a Docker container
|
22 |
+
|
23 |
+
|
24 |
+
```bash
|
25 |
+
docker build -t ds-api-template .
|
26 |
+
docker run -p 8080:8080 ds-api-template
|
27 |
+
```
|
28 |
+
|
29 |
+
You can test this is running by executing the same `curl` command as above, which should return the same response.
|
30 |
+
|
31 |
+
NOTE: You will need to have Docker installed on your machine. To install Docker, follow the instructions [here](https://docs.docker.com/get-docker/).
|
32 |
+
|
33 |
+
## Run the Streamlit App
|
34 |
+
|
35 |
+
Once you've set up the API, you can run the Streamlit app to interact with the API.
|
36 |
+
|
37 |
+
To run the Streamlit app, run the following command:
|
38 |
+
|
39 |
+
```bash
|
40 |
+
streamlit run app.py
|
41 |
+
```
|
42 |
+
|
43 |
+
You will need to have Streamlit installed on your machine. To install Streamlit, run the following command:
|
44 |
+
|
45 |
+
```bash
|
46 |
+
pip install streamlit
|
47 |
+
```
|
48 |
+
|
49 |
+
You will also need to update a `secrets.toml` file in a `.streamlit` directory at the root of the repo. This file should contain the following:
|
50 |
+
|
51 |
+
```toml
|
52 |
+
api_host = "http://localhost:8501"
|
53 |
+
password = "<INSERT DESIRED PASSWORD HERE>"
|
54 |
+
```
|
__init__.py
ADDED
File without changes
|
app/__init__.py
ADDED
File without changes
|
app/main.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, File
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
|
4 |
+
from .model import predict
|
5 |
+
import json
|
6 |
+
|
7 |
+
app = FastAPI()
|
8 |
+
|
9 |
+
# CORS
|
10 |
+
origins = [
|
11 |
+
"http://localhost:8080",
|
12 |
+
"http://localhost"
|
13 |
+
]
|
14 |
+
|
15 |
+
app.add_middleware(
|
16 |
+
CORSMiddleware,
|
17 |
+
allow_origins=origins,
|
18 |
+
allow_credentials=True,
|
19 |
+
allow_methods=["POST"],
|
20 |
+
allow_headers=["*"],
|
21 |
+
)
|
22 |
+
|
23 |
+
@app.post("/predict")
|
24 |
+
def img_object_detection_to_img(file: bytes = File(...)):
|
25 |
+
"""
|
26 |
+
Object Detection from an image plot bbox on image
|
27 |
+
|
28 |
+
Args:
|
29 |
+
file (bytes): The image file in bytes format.
|
30 |
+
Returns:
|
31 |
+
The json representation of the prediction
|
32 |
+
"""
|
33 |
+
prediction = predict(file)
|
34 |
+
return json.dumps(prediction.tolist())
|
app/model.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from PIL import Image
|
3 |
+
import io
|
4 |
+
|
5 |
+
imported = tf.saved_model.load("./app")
|
6 |
+
imported = imported.signatures["serving_default"]
|
7 |
+
|
8 |
+
def get_image_from_bytes(binary_image: bytes) -> Image:
|
9 |
+
"""Convert image from bytes to PIL RGB format
|
10 |
+
|
11 |
+
Args:
|
12 |
+
binary_image (bytes): The binary representation of the image
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
PIL.Image: The image in PIL RGB format
|
16 |
+
"""
|
17 |
+
input_image = Image.open(io.BytesIO(binary_image)).convert("RGB")
|
18 |
+
return input_image
|
19 |
+
|
20 |
+
def predict(input_image):
|
21 |
+
"""Reads file and returns prediction
|
22 |
+
|
23 |
+
Args:
|
24 |
+
x (_type_): _description_
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
_type_: _description_
|
28 |
+
"""
|
29 |
+
tensor = tf.io.decode_image(input_image, channels=3)
|
30 |
+
|
31 |
+
inference_shape = (240, 320)
|
32 |
+
original_shape = tensor.shape[:2]
|
33 |
+
|
34 |
+
input_tensor = tf.expand_dims(tensor, axis=0)
|
35 |
+
|
36 |
+
input_tensor = tf.image.resize(input_tensor, inference_shape,
|
37 |
+
preserve_aspect_ratio=True)
|
38 |
+
saliency = imported(input_tensor)["output"]
|
39 |
+
|
40 |
+
saliency = tf.image.resize(saliency, original_shape)
|
41 |
+
return saliency.numpy()[0]
|
app/saved_model.pb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:646e0f343c4357e828f2569bef2f2bf288449fe68f7e4fb43e076f2e3b094e3d
|
3 |
+
size 99858975
|
convert_model.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# use this script to convert any of the models saved to be
|
2 |
+
# compatible with tf2: https://drive.google.com/drive/folders/1GI7i6GpfI-FoklP3vCc6vxe3T9nk3V2n
|
3 |
+
|
4 |
+
import tensorflow as tf
|
5 |
+
from tensorflow.python.saved_model import signature_constants, tag_constants
|
6 |
+
|
7 |
+
export_dir = "./app/"
|
8 |
+
# update the below line to point at the desired model downloaded
|
9 |
+
# from the above google drive link
|
10 |
+
graph_pb = "./app/model_salicon_cpu.pb"
|
11 |
+
|
12 |
+
with tf.io.gfile.GFile(graph_pb, "rb") as f:
|
13 |
+
graph_def = tf.compat.v1.GraphDef()
|
14 |
+
graph_def.ParseFromString(f.read())
|
15 |
+
|
16 |
+
sig = {}
|
17 |
+
|
18 |
+
builder = tf.compat.v1.saved_model.Builder(export_dir)
|
19 |
+
|
20 |
+
with tf.compat.v1.Session(graph=tf.Graph()) as sess:
|
21 |
+
tf.import_graph_def(graph_def, name="")
|
22 |
+
g = tf.compat.v1.get_default_graph()
|
23 |
+
|
24 |
+
input = g.get_tensor_by_name("input:0")
|
25 |
+
output = g.get_tensor_by_name("output:0")
|
26 |
+
|
27 |
+
sig_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
28 |
+
sig[sig_key] = tf.compat.v1.saved_model.predict_signature_def({"input": input},
|
29 |
+
{"output": output})
|
30 |
+
builder.add_meta_graph_and_variables(sess,
|
31 |
+
[tag_constants.SERVING],
|
32 |
+
signature_def_map=sig)
|
33 |
+
builder.save()
|
main.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, File
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
|
4 |
+
from .model import predict
|
5 |
+
import json
|
6 |
+
|
7 |
+
app = FastAPI()
|
8 |
+
|
9 |
+
# CORS
|
10 |
+
origins = [
|
11 |
+
"http://localhost:8080",
|
12 |
+
"http://localhost"
|
13 |
+
]
|
14 |
+
|
15 |
+
app.add_middleware(
|
16 |
+
CORSMiddleware,
|
17 |
+
allow_origins=origins,
|
18 |
+
allow_credentials=True,
|
19 |
+
allow_methods=["POST"],
|
20 |
+
allow_headers=["*"],
|
21 |
+
)
|
22 |
+
|
23 |
+
@app.post("/predict")
|
24 |
+
def img_object_detection_to_img(file: bytes = File(...)):
|
25 |
+
"""
|
26 |
+
Object Detection from an image plot bbox on image
|
27 |
+
|
28 |
+
Args:
|
29 |
+
file (bytes): The image file in bytes format.
|
30 |
+
Returns:
|
31 |
+
The json representation of the prediction
|
32 |
+
"""
|
33 |
+
prediction = predict(file)
|
34 |
+
return json.dumps(prediction.tolist())
|
model.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from PIL import Image
|
3 |
+
import io
|
4 |
+
|
5 |
+
imported = tf.saved_model.load("./app")
|
6 |
+
imported = imported.signatures["serving_default"]
|
7 |
+
|
8 |
+
def get_image_from_bytes(binary_image: bytes) -> Image:
|
9 |
+
"""Convert image from bytes to PIL RGB format
|
10 |
+
|
11 |
+
Args:
|
12 |
+
binary_image (bytes): The binary representation of the image
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
PIL.Image: The image in PIL RGB format
|
16 |
+
"""
|
17 |
+
input_image = Image.open(io.BytesIO(binary_image)).convert("RGB")
|
18 |
+
return input_image
|
19 |
+
|
20 |
+
def predict(input_image):
|
21 |
+
"""Reads file and returns prediction
|
22 |
+
|
23 |
+
Args:
|
24 |
+
x (_type_): _description_
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
_type_: _description_
|
28 |
+
"""
|
29 |
+
tensor = tf.io.decode_image(input_image, channels=3)
|
30 |
+
|
31 |
+
inference_shape = (240, 320)
|
32 |
+
original_shape = tensor.shape[:2]
|
33 |
+
|
34 |
+
input_tensor = tf.expand_dims(tensor, axis=0)
|
35 |
+
|
36 |
+
input_tensor = tf.image.resize(input_tensor, inference_shape,
|
37 |
+
preserve_aspect_ratio=True)
|
38 |
+
saliency = imported(input_tensor)["output"]
|
39 |
+
|
40 |
+
saliency = tf.image.resize(saliency, original_shape)
|
41 |
+
return saliency.numpy()[0]
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi==0.103.2
|
2 |
+
uvicorn==0.23.2
|
3 |
+
tensorflow
|
4 |
+
python-multipart
|
5 |
+
Pillow
|
6 |
+
streamlit
|
7 |
+
matplotlib
|
saved_model.pb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:646e0f343c4357e828f2569bef2f2bf288449fe68f7e4fb43e076f2e3b094e3d
|
3 |
+
size 99858975
|
streamlit_viz.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""App to visualize saliency maps for images.
|
2 |
+
To run, use:
|
3 |
+
streamlit run streamlit_viz.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
import streamlit as st
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
import requests
|
10 |
+
import hmac
|
11 |
+
import json
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import matplotlib.image as mpimg
|
14 |
+
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
st.set_option('deprecation.showPyplotGlobalUse', False)
|
18 |
+
|
19 |
+
def check_password():
|
20 |
+
"""Returns `True` if the user had the correct password."""
|
21 |
+
|
22 |
+
def password_entered():
|
23 |
+
"""Checks whether a password entered by the user is correct."""
|
24 |
+
if hmac.compare_digest(st.session_state["password"], st.secrets["password"]):
|
25 |
+
st.session_state["password_correct"] = True
|
26 |
+
del st.session_state["password"] # Don't store the password.
|
27 |
+
else:
|
28 |
+
st.session_state["password_correct"] = False
|
29 |
+
|
30 |
+
# Return True if the passward is validated.
|
31 |
+
if st.session_state.get("password_correct", False):
|
32 |
+
return True
|
33 |
+
|
34 |
+
# Show input for password.
|
35 |
+
st.text_input(
|
36 |
+
"Password", type="password", on_change=password_entered, key="password"
|
37 |
+
)
|
38 |
+
if "password_correct" in st.session_state:
|
39 |
+
st.error("😕 Password incorrect")
|
40 |
+
return False
|
41 |
+
|
42 |
+
|
43 |
+
if not check_password():
|
44 |
+
st.stop() # Do not continue if check_password is not True.
|
45 |
+
|
46 |
+
st.title("Saliency Map Visualizer")
|
47 |
+
|
48 |
+
st.markdown(
|
49 |
+
"""
|
50 |
+
This is a demo of the Saliency Map Visualizer. To use it, upload an image
|
51 |
+
and click the button below. Please note, it may take up to 20 seconds to visualise.
|
52 |
+
"""
|
53 |
+
)
|
54 |
+
|
55 |
+
# get host from secrets
|
56 |
+
api_host = st.secrets["api_host"]
|
57 |
+
|
58 |
+
uploaded_file = st.file_uploader("Choose an image...", type=(["jpg", "jpeg", "png"]))
|
59 |
+
|
60 |
+
if uploaded_file is not None:
|
61 |
+
file = {'file': uploaded_file.read()}
|
62 |
+
st.write("")
|
63 |
+
st.write("Classifying...")
|
64 |
+
response = requests.post(api_host, files=file)
|
65 |
+
arr = np.asarray(json.loads(response.json()))
|
66 |
+
st.write("Done!")
|
67 |
+
# Show plt plots
|
68 |
+
plt.imshow(Image.open(uploaded_file))
|
69 |
+
plt.imshow(arr, alpha=0.6)
|
70 |
+
plt.axis('off')
|
71 |
+
st.pyplot()
|