Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -6,28 +6,24 @@ import tensorflow.keras as keras
|
|
6 |
from gradio import inputs, outputs
|
7 |
|
8 |
SIZE = 256
|
9 |
-
DEVICE = "/CPU:0"
|
10 |
-
|
11 |
|
12 |
with open("./tags.json", "rt", encoding="utf-8") as f:
|
13 |
tags = json.load(f)
|
14 |
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
model.load_weights("tf_model.h5")
|
30 |
-
|
31 |
|
32 |
@tf.function
|
33 |
def process_data(content):
|
@@ -38,12 +34,10 @@ def process_data(content):
|
|
38 |
|
39 |
|
40 |
def predict(img, size):
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
data = tf.expand_dims(data, 0)
|
46 |
-
out = model(data)[0]
|
47 |
return dict((tags[i], out[i].numpy()) for i in range(len(tags)))
|
48 |
|
49 |
|
|
|
6 |
from gradio import inputs, outputs
|
7 |
|
8 |
SIZE = 256
|
|
|
|
|
9 |
|
10 |
with open("./tags.json", "rt", encoding="utf-8") as f:
|
11 |
tags = json.load(f)
|
12 |
|
13 |
|
14 |
+
base_model = keras.applications.resnet.ResNet50(
|
15 |
+
include_top=False, weights=None, input_shape=(SIZE, SIZE, 3)
|
16 |
+
)
|
17 |
+
model = keras.Sequential(
|
18 |
+
[
|
19 |
+
base_model,
|
20 |
+
keras.layers.Conv2D(filters=len(tags), kernel_size=(1, 1), padding="same"),
|
21 |
+
keras.layers.BatchNormalization(epsilon=1.001e-5),
|
22 |
+
keras.layers.GlobalAveragePooling2D(name="avg_pool"),
|
23 |
+
keras.layers.Activation("sigmoid"),
|
24 |
+
]
|
25 |
+
)
|
26 |
+
model.load_weights("tf_model.h5")
|
|
|
|
|
27 |
|
28 |
@tf.function
|
29 |
def process_data(content):
|
|
|
34 |
|
35 |
|
36 |
def predict(img, size):
|
37 |
+
img = tf.image.resize_with_pad(img, size, size)
|
38 |
+
img = tf.image.per_image_standardization(img)
|
39 |
+
data = tf.expand_dims(img, 0)
|
40 |
+
out,*_ = model(data)
|
|
|
|
|
41 |
return dict((tags[i], out[i].numpy()) for i in range(len(tags)))
|
42 |
|
43 |
|