Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import tensorflow as tf
|
5 |
+
import tensorflow.keras as keras
|
6 |
+
from gradio import inputs, outputs
|
7 |
+
|
8 |
+
SIZE = 256
|
9 |
+
DEVICE = "/CPU:0"
|
10 |
+
|
11 |
+
|
12 |
+
with open("./tags.json", "rt", encoding="utf-8") as f:
|
13 |
+
tags = json.load(f)
|
14 |
+
|
15 |
+
|
16 |
+
with tf.device(DEVICE):
|
17 |
+
base_model = keras.applications.resnet.ResNet50(
|
18 |
+
include_top=False, weights=None, input_shape=(SIZE, SIZE, 3)
|
19 |
+
)
|
20 |
+
model = keras.Sequential(
|
21 |
+
[
|
22 |
+
base_model,
|
23 |
+
keras.layers.Conv2D(filters=len(tags), kernel_size=(1, 1), padding="same"),
|
24 |
+
keras.layers.BatchNormalization(epsilon=1.001e-5),
|
25 |
+
keras.layers.GlobalAveragePooling2D(name="avg_pool"),
|
26 |
+
keras.layers.Activation("sigmoid"),
|
27 |
+
]
|
28 |
+
)
|
29 |
+
model.load_weights("tf_model.h5")
|
30 |
+
|
31 |
+
|
32 |
+
@tf.function
|
33 |
+
def process_data(content):
|
34 |
+
img = tf.io.decode_jpeg(content, channels=3)
|
35 |
+
img = tf.image.resize_with_pad(img, SIZE, SIZE)
|
36 |
+
img = tf.image.per_image_standardization(img)
|
37 |
+
return img
|
38 |
+
|
39 |
+
|
40 |
+
def predict(img, size):
|
41 |
+
with tf.device(DEVICE):
|
42 |
+
img = tf.image.resize_with_pad(img, size, size)
|
43 |
+
img = tf.image.per_image_standardization(img)
|
44 |
+
data = process_data(image)
|
45 |
+
data = tf.expand_dims(data, 0)
|
46 |
+
out = model(data)[0]
|
47 |
+
return dict((tags[i], out[i].numpy()) for i in range(len(tags)))
|
48 |
+
|
49 |
+
|
50 |
+
image = inputs.Image(label="Upload your image here!")
|
51 |
+
size = inputs.Number(label="Image resize", default=SIZE)
|
52 |
+
|
53 |
+
labels = outputs.Label(label="Tags")
|
54 |
+
|
55 |
+
gr.Interface(predict, inputs=[image, size], outputs=[labels])
|