vchiang001 commited on
Commit
ed73811
·
1 Parent(s): dffdb39

Create new file

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import from_pretrained_keras
2
+ import gradio as gr
3
+ import tensorflow as tf
4
+ import numpy as np
5
+
6
+ model = from_pretrained_keras("keras-io/semantic-segmentation")
7
+
8
+ inputs = gr.inputs.Image()
9
+ output = gr.output.Image()
10
+
11
+
12
+ def predict(image_input):
13
+ pass
14
+
15
+ class PreTrainedPipeline():
16
+ def __init__(self, path: str):
17
+ # load the model
18
+ self.model = keras.models.load_model(os.path.join(path, "tf_model.h5"))
19
+
20
+ def __call__(self, inputs: "Image.Image")-> List[Dict[str, Any]]:
21
+
22
+ # convert img to numpy array, resize and normalize to make the prediction
23
+ img = np.array(inputs)
24
+
25
+ im = tf.image.resize(img, (128, 128))
26
+ im = tf.cast(im, tf.float32) / 255.0
27
+ pred_mask = self.model.predict(im[tf.newaxis, ...])
28
+
29
+ # take the best performing class for each pixel
30
+ # the output of argmax looks like this [[1, 2, 0], ...]
31
+ pred_mask_arg = tf.argmax(pred_mask, axis=-1)
32
+
33
+ labels = []
34
+
35
+ # convert the prediction mask into binary masks for each class
36
+ binary_masks = {}
37
+ mask_codes = {}
38
+
39
+ # when we take tf.argmax() over pred_mask, it becomes a tensor object
40
+ # the shape becomes TensorShape object, looking like this TensorShape([128])
41
+ # we need to take get shape, convert to list and take the best one
42
+
43
+ rows = pred_mask_arg[0][1].get_shape().as_list()[0]
44
+ cols = pred_mask_arg[0][2].get_shape().as_list()[0]
45
+
46
+ for cls in range(pred_mask.shape[-1]):
47
+
48
+ binary_masks[f"mask_{cls}"] = np.zeros(shape = (pred_mask.shape[1], pred_mask.shape[2])) #create masks for each class
49
+
50
+ for row in range(rows):
51
+
52
+ for col in range(cols):
53
+
54
+ if pred_mask_arg[0][row][col] == cls:
55
+
56
+ binary_masks[f"mask_{cls}"][row][col] = 1
57
+ else:
58
+ binary_masks[f"mask_{cls}"][row][col] = 0
59
+
60
+ mask = binary_masks[f"mask_{cls}"]
61
+ mask *= 255
62
+ img = Image.fromarray(mask.astype(np.int8), mode="L")
63
+
64
+ # we need to make it readable for the widget
65
+ with io.BytesIO() as out:
66
+ img.save(out, format="PNG")
67
+ png_string = out.getvalue()
68
+ mask = base64.b64encode(png_string).decode("utf-8")
69
+
70
+ mask_codes[f"mask_{cls}"] = mask
71
+