mayhug commited on
Commit
6aca538
·
1 Parent(s): 4951d64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -31
app.py CHANGED
@@ -6,44 +6,45 @@ import tensorflow.keras as keras
6
  from gradio import inputs, outputs
7
 
8
  SIZE = 256
 
 
9
 
10
  with open("./tags.json", "rt", encoding="utf-8") as f:
11
  tags = json.load(f)
12
 
13
 
14
- base_model = keras.applications.resnet.ResNet50(
15
- include_top=False, weights=None, input_shape=(SIZE, SIZE, 3)
16
- )
17
- model = keras.Sequential(
18
- [
19
- base_model,
20
- keras.layers.Conv2D(filters=len(tags), kernel_size=(1, 1), padding="same"),
21
- keras.layers.BatchNormalization(epsilon=1.001e-5),
22
- keras.layers.GlobalAveragePooling2D(name="avg_pool"),
23
- keras.layers.Activation("sigmoid"),
24
- ]
25
- )
26
- model.load_weights("tf_model.h5")
27
-
28
- @tf.function
29
- def process_data(content):
30
- img = tf.io.decode_jpeg(content, channels=3)
31
- img = tf.image.resize_with_pad(img, SIZE, SIZE)
32
- img = tf.image.per_image_standardization(img)
33
- return img
34
-
35
-
36
- def predict(img, size):
37
- img = tf.image.resize_with_pad(img, size, size)
38
- img = tf.image.per_image_standardization(img)
39
- data = tf.expand_dims(img, 0)
40
- out,*_ = model(data)
41
- return dict((tags[i], out[i].numpy()) for i in range(len(tags)))
42
 
43
 
44
  image = inputs.Image(label="Upload your image here!")
45
- size = inputs.Number(label="Image resize", default=SIZE)
 
 
46
 
47
- labels = outputs.Label(label="Tags")
48
 
49
- gr.Interface(predict, inputs=[image, size], outputs=[labels])
 
 
6
  from gradio import inputs, outputs
7
 
8
  SIZE = 256
9
+ DEVICE = "/cpu:0"
10
+
11
 
12
  with open("./tags.json", "rt", encoding="utf-8") as f:
13
  tags = json.load(f)
14
 
15
 
16
+ with tf.device(DEVICE):
17
+ base_model = keras.applications.resnet.ResNet50(
18
+ include_top=False, weights=None, input_shape=(SIZE, SIZE, 3)
19
+ )
20
+ model = keras.Sequential(
21
+ [
22
+ base_model,
23
+ keras.layers.Conv2D(filters=len(tags), kernel_size=(1, 1), padding="same"),
24
+ keras.layers.BatchNormalization(epsilon=1.001e-5),
25
+ keras.layers.GlobalAveragePooling2D(name="avg_pool"),
26
+ keras.layers.Activation("sigmoid"),
27
+ ]
28
+ )
29
+ model.load_weights("tf_model.h5")
30
+
31
+
32
+ def predict(img, hide: float):
33
+ with tf.device(DEVICE):
34
+ img = tf.image.resize_with_pad(img, SIZE, SIZE)
35
+ img = tf.image.per_image_standardization(img)
36
+ data = tf.expand_dims(img, 0)
37
+ out, *_ = model(data)
38
+ labels = {tag: float(out[i].numpy()) for i, tag in enumerate(tags)}
39
+ return {k: v for k, v in labels.items() if v >= hide}
 
 
 
 
40
 
41
 
42
  image = inputs.Image(label="Upload your image here!")
43
+ hide_threshold = inputs.Slider(
44
+ label="Hide confidence lower than", default=0.5, maximum=1, minimum=0
45
+ )
46
 
47
+ labels = outputs.Label(label="Tags", type="confidences")
48
 
49
+ interface = gr.Interface(predict, inputs=[image, hide_threshold], outputs=[labels])
50
+ interface.launch()