TharunSiva commited on
Commit
8413e92
·
verified ·
1 Parent(s): 321a92c

application file

Browse files
Files changed (1) hide show
  1. app.py +221 -0
app.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ import matplotlib.pyplot as plt
5
+
6
+ import tensorflow as tf
7
+ import tensorflow.keras.backend as K
8
+ from keras.preprocessing import image
9
+
10
+ from ResUNet import *
11
+
12
+ from eff import *
13
+ from vit import *
14
+
15
+ # Define the image transformation
16
+ transform = transforms.Compose([
17
+ transforms.ToTensor(),
18
+ transforms.Resize((224, 224)),
19
+ ])
20
+
21
+ examples1 = [
22
+ f"examples/Eff_ViT/Classification_{i}.jpg" for i in range(0, 4)
23
+ ]
24
+
25
+ def classification(image):
26
+ input_tensor = transform(image).unsqueeze(0).to(CFG.DEVICE) # Add batch dimension
27
+
28
+ input_batch = input_tensor
29
+
30
+ # Perform inference
31
+ with torch.no_grad():
32
+ output1 = efficientnet_model(input_batch).to(CFG.DEVICE)
33
+ output2 = efficientnet_model(input_batch).to(CFG.DEVICE)
34
+ output3 = vit_model(input_batch).to(CFG.DEVICE)
35
+
36
+ # You can now use the 'output' tensor as needed (e.g., get predictions)
37
+ # print(output)
38
+ res1 = torch.softmax(output1, dim=1)
39
+ res2 = torch.softmax(output2, dim=1)
40
+ res3 = torch.softmax(output3, dim=1)
41
+
42
+ probs1 = {class_names[i]: float(res1[0][i]) for i in range(len(class_names))}
43
+ probs2 = {class_names[i]: float(res2[0][i]) for i in range(len(class_names))}
44
+ probs3 = {class_names[i]: float(res3[0][i]) for i in range(len(class_names))}
45
+
46
+ return probs1, probs2, probs3
47
+
48
+
49
+ classify = gr.Interface(
50
+ fn=classification,
51
+ inputs=[
52
+ gr.Image(label="Image"),
53
+ # gr.Radio(["EfficientNetB3", "EfficientNetV2", "ViT"], value="ViT")
54
+ ],
55
+ outputs=[
56
+ gr.Label(num_top_classes = 3, label = "EfficientNet-B3"),
57
+ gr.Label(num_top_classes = 3, label = "EfficientNet-V2"),
58
+ gr.Label(num_top_classes = 3, label = "ViT"),
59
+ ],
60
+ examples=examples1,
61
+ cache_examples=True
62
+ )
63
+
64
+ # ---------------------------------------------------------
65
+
66
+ seg_model = load_model()
67
+ seg_model.load_weights("ResUNet-segModel-weights.hdf5")
68
+
69
+
70
+ examples2 = [
71
+ f"examples/ResUNet/{i}.jpg" for i in range(5)
72
+ ]
73
+
74
+ def detection(img):
75
+ org_img = img
76
+
77
+ img = img *1./255.
78
+
79
+ #reshaping
80
+ img = cv2.resize(img, (256,256))
81
+
82
+ # converting img into array
83
+ img = np.array(img, dtype=np.float64)
84
+
85
+ #reshaping the image from 256,256,3 to 1,256,256,3
86
+ img = np.reshape(img, (1,256,256,3))
87
+
88
+
89
+ #Creating a empty array of shape 1,256,256,1
90
+ X = np.empty((1,256,256,3))
91
+
92
+ # standardising the image
93
+ img -= img.mean()
94
+ img /= img.std()
95
+
96
+ #converting the shape of image from 256,256,3 to 1,256,256,3
97
+ X[0,] = img
98
+
99
+ #make prediction of mask
100
+ predict = seg_model.predict(X)
101
+
102
+
103
+ pred = np.array(predict[0]).squeeze().round()
104
+
105
+
106
+ img_ = cv2.resize(org_img, (256,256))
107
+ img_ = cv2.cvtColor(img_, cv2.COLOR_BGR2RGB)
108
+ img_[pred==1] = (0,255,150)
109
+
110
+ plt.imshow(img_)
111
+ plt.axis("off")
112
+ image_path = "plot.png"
113
+ plt.savefig(image_path)
114
+
115
+ return gr.update(value=image_path, visible=True)
116
+
117
+
118
+ detect = gr.Interface(
119
+ fn=detection,
120
+ inputs=[
121
+ gr.Image(label="Image")
122
+ ],
123
+ outputs=[
124
+ gr.Image(label="Output")
125
+ ],
126
+ examples=examples2,
127
+ cache_examples=True
128
+ )
129
+
130
+ # ##########################################
131
+
132
+ # def data_viewer(label="Pituitary", count=10):
133
+ # results = []
134
+
135
+ # if(label == "Segmentation"):
136
+ # for i in range((count//2)+1):
137
+ # results.append(f"Images/{label}/original_image_{i}.png")
138
+ # results.append(f"Images/{label}/image_with_mask_{i}.png")
139
+
140
+ # else:
141
+
142
+ # for i in range(count):
143
+ # results.append(f"Images/{label}/{i}.jpg")
144
+
145
+ # return results
146
+
147
+
148
+ # view_data = gr.Interface(
149
+ # fn = data_viewer,
150
+ # inputs = [
151
+ # gr.Dropdown(
152
+ # ["Glioma", "Meningioma", "Pituitary", "Segmentation"], label="Category"
153
+ # ),
154
+ # gr.Slider(0, 12, value=4, step=2)
155
+ # ],
156
+ # outputs = [
157
+ # gr.Gallery(columns=2),
158
+ # ]
159
+ # )
160
+
161
+ # ##########################
162
+
163
+ from huggingface_hub import InferenceClient
164
+
165
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
166
+
167
+ def format_prompt(message, history):
168
+ prompt = "<s>"
169
+ for user_prompt, bot_response in history:
170
+ prompt += f"[INST] {user_prompt} [/INST]"
171
+ prompt += f" {bot_response}</s> "
172
+ prompt += f"[INST] {message} [/INST]"
173
+ return prompt
174
+
175
+ def generate(
176
+ prompt, history, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,
177
+ ):
178
+ temperature = float(temperature)
179
+ if temperature < 1e-2:
180
+ temperature = 1e-2
181
+ top_p = float(top_p)
182
+
183
+ generate_kwargs = dict(
184
+ temperature=temperature,
185
+ max_new_tokens=max_new_tokens,
186
+ top_p=top_p,
187
+ repetition_penalty=repetition_penalty,
188
+ do_sample=True,
189
+ seed=42,
190
+ )
191
+
192
+ formatted_prompt = format_prompt(prompt, history)
193
+
194
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
195
+ output = ""
196
+
197
+ for response in stream:
198
+ output += response.token.text
199
+ yield output
200
+ return output
201
+
202
+
203
+ mychatbot = gr.Chatbot(
204
+ avatar_images=["Chatbot/user.png", "Chatbot/botm.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=True,)
205
+
206
+ chatbot = gr.ChatInterface(
207
+ fn=generate,
208
+ chatbot=mychatbot,
209
+ examples=[
210
+ "What is Brain Tumor and its types?",
211
+ "What is a tumor's grade? What does this mean?",
212
+ "What are some of the treatment options for Brain Tumor?",
213
+ "What causes brain tumors?",
214
+ "If I have a brain tumor, can I pass it on to my children?"
215
+ ],
216
+ )
217
+
218
+
219
+ demo = gr.TabbedInterface([classify, detect, chatbot], ["Classification", "Detection", "ChatBot"])
220
+
221
+ demo.launch()