chriskok3 commited on
Commit
d264abb
Β·
verified Β·
1 Parent(s): 7519712

Reverted app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -222
app.py CHANGED
@@ -1,266 +1,266 @@
1
- # load up the libraries
2
- import panel as pn
3
- import pandas as pd
4
- import altair as alt
5
- from vega_datasets import data
6
 
7
- # we want to use bootstrap/template, tell Panel to load up what we need
8
- pn.extension(design='bootstrap')
9
 
10
- # we want to use vega, tell Panel to load up what we need
11
- pn.extension('vega')
12
 
13
- # create a basic template using bootstrap
14
- template = pn.template.BootstrapTemplate(
15
- title='SI649 Walkthrough',
16
- )
17
 
18
- # the main column will hold our key content
19
- maincol = pn.Column()
20
 
21
- # add some markdown to the main column
22
- maincol.append("# Markdown Title")
23
- maincol.append("I can format in cool ways. Like **bold** or *italics* or ***both*** or ~~strikethrough~~ or `code` or [links](https://panel.holoviz.org)")
24
- maincol.append("I am writing a link [to the streamlit documentation page](https://docs.streamlit.io/en/stable/api.html)")
25
- maincol.append('![alt text](https://upload.wikimedia.org/wikipedia/commons/thumb/3/3e/Irises-Vincent_van_Gogh.jpg/314px-Irises-Vincent_van_Gogh.jpg)')
26
 
27
- # load up a dataframe and show it in the main column
28
- cars_url = "https://raw.githubusercontent.com/altair-viz/vega_datasets/master/vega_datasets/_data/cars.json"
29
- cars = pd.read_json(cars_url)
30
- temps = data.seattle_weather()
31
 
32
- maincol.append(temps.head(10))
33
 
34
- # create a basic chart
35
- hp_mpg = alt.Chart(cars).mark_circle(size=80).encode(
36
- x='Horsepower:Q',
37
- y='Miles_per_Gallon:Q',
38
- color='Origin:N'
39
- )
40
 
41
- # dispaly it in the main column
42
- # maincol.append(hp_mpg)
43
 
44
- # create a basic slider
45
- simpleslider = pn.widgets.IntSlider(name='Simple Slider', start=0, end=100, value=0)
46
 
47
- # generate text based on slider value
48
- def square(x):
49
- return f'{x} squared is {x**2}'
50
 
51
 
52
- # bind the slider to the function and hold the output in a row
53
- row = pn.Column(pn.bind(square,simpleslider))
54
 
55
- # add both slider and row
56
- maincol.append(simpleslider)
57
- maincol.append(row)
58
 
59
- # variable to track state of visualization
60
- flip = False
61
 
62
- # function to either return the vis or a message
63
- def makeChartVisible(val):
64
- global flip # grab the variable outside the function
65
- if (flip == True):
66
- flip = not flip # flip to False
67
- return pn.pane.Vega(hp_mpg) # return the vis
68
- else:
69
- flip = not flip # flip to true and return text
70
- return pn.panel("Click the button to see the chart")
71
 
72
- # add a button and then create the binding
73
- btn = pn.widgets.Button(name='Click me')
74
- row = pn.Row(pn.bind(makeChartVisible, btn))
75
-
76
- # add button and new row to main column
77
- maincol.append(btn)
78
- maincol.append(row)
79
-
80
- # create a base chart
81
- basechart = alt.Chart(cars).mark_circle(size=80,opacity=0.5).encode(
82
- x='Horsepower:Q',
83
- y='Acceleration:Q',
84
- color="Origin:N"
85
- )
86
 
87
- # create something to hold the base chart
88
- currentoption = pn.panel(basechart)
89
 
90
- # create a selection widget
91
- select = pn.widgets.Select(name='Select', options=['Horsepower','Acceleration','Miles_per_Gallon'])
92
 
93
- # create a function to modify the basechart that is being
94
- # held in currentoption
95
- def changeOption(val):
96
- # grab what's there now
97
- chrt = currentoption.object
98
- # change the encoding based on val
99
- chrt = chrt.encode(
100
- y=val+":Q"
101
- )
102
- # replace old chart in currentoption with new one
103
- currentoption.object = chrt
104
 
105
- # append the selection
106
- maincol.append(select)
107
- # append the binding (in thise case nothing is being returned by changeOption, so...)
108
- chartchange = pn.Row(pn.bind(changeOption, select))
109
- # ... we need to also add the chart
110
- maincol.append(chartchange)
111
- maincol.append(currentoption)
112
 
113
- # add the main column to the template
114
- template.main.append(maincol)
115
 
116
- # Indicate that the template object is the "application" and serve it
117
- template.servable(title="SI649 Walkthrough")
118
 
119
 
120
- # import io
121
- # import random
122
- # from typing import List, Tuple
123
 
124
- # import aiohttp
125
- # import panel as pn
126
- # from PIL import Image
127
- # from transformers import CLIPModel, CLIPProcessor
128
 
129
- # pn.extension(design="bootstrap", sizing_mode="stretch_width")
130
 
131
- # ICON_URLS = {
132
- # "brand-github": "https://github.com/holoviz/panel",
133
- # "brand-twitter": "https://twitter.com/Panel_Org",
134
- # "brand-linkedin": "https://www.linkedin.com/company/panel-org",
135
- # "message-circle": "https://discourse.holoviz.org/",
136
- # "brand-discord": "https://discord.gg/AXRHnJU6sP",
137
- # }
138
 
139
 
140
- # async def random_url(_):
141
- # pet = random.choice(["cat", "dog"])
142
- # api_url = f"https://api.the{pet}api.com/v1/images/search"
143
- # async with aiohttp.ClientSession() as session:
144
- # async with session.get(api_url) as resp:
145
- # return (await resp.json())[0]["url"]
146
 
147
 
148
- # @pn.cache
149
- # def load_processor_model(
150
- # processor_name: str, model_name: str
151
- # ) -> Tuple[CLIPProcessor, CLIPModel]:
152
- # processor = CLIPProcessor.from_pretrained(processor_name)
153
- # model = CLIPModel.from_pretrained(model_name)
154
- # return processor, model
155
 
156
 
157
- # async def open_image_url(image_url: str) -> Image:
158
- # async with aiohttp.ClientSession() as session:
159
- # async with session.get(image_url) as resp:
160
- # return Image.open(io.BytesIO(await resp.read()))
161
 
162
 
163
- # def get_similarity_scores(class_items: List[str], image: Image) -> List[float]:
164
- # processor, model = load_processor_model(
165
- # "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32"
166
- # )
167
- # inputs = processor(
168
- # text=class_items,
169
- # images=[image],
170
- # return_tensors="pt", # pytorch tensors
171
- # )
172
- # outputs = model(**inputs)
173
- # logits_per_image = outputs.logits_per_image
174
- # class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy()
175
- # return class_likelihoods[0]
176
-
177
-
178
- # async def process_inputs(class_names: List[str], image_url: str):
179
- # """
180
- # High level function that takes in the user inputs and returns the
181
- # classification results as panel objects.
182
- # """
183
- # try:
184
- # main.disabled = True
185
- # if not image_url:
186
- # yield "##### ⚠️ Provide an image URL"
187
- # return
188
 
189
- # yield "##### βš™ Fetching image and running model..."
190
- # try:
191
- # pil_img = await open_image_url(image_url)
192
- # img = pn.pane.Image(pil_img, height=400, align="center")
193
- # except Exception as e:
194
- # yield f"##### πŸ˜” Something went wrong, please try a different URL!"
195
- # return
196
 
197
- # class_items = class_names.split(",")
198
- # class_likelihoods = get_similarity_scores(class_items, pil_img)
199
 
200
- # # build the results column
201
- # results = pn.Column("##### πŸŽ‰ Here are the results!", img)
202
 
203
- # for class_item, class_likelihood in zip(class_items, class_likelihoods):
204
- # row_label = pn.widgets.StaticText(
205
- # name=class_item.strip(), value=f"{class_likelihood:.2%}", align="center"
206
- # )
207
- # row_bar = pn.indicators.Progress(
208
- # value=int(class_likelihood * 100),
209
- # sizing_mode="stretch_width",
210
- # bar_color="secondary",
211
- # margin=(0, 10),
212
- # design=pn.theme.Material,
213
- # )
214
- # results.append(pn.Column(row_label, row_bar))
215
- # yield results
216
- # finally:
217
- # main.disabled = False
218
-
219
-
220
- # # create widgets
221
- # randomize_url = pn.widgets.Button(name="Randomize URL", align="end")
222
-
223
- # image_url = pn.widgets.TextInput(
224
- # name="Image URL to classify",
225
- # value=pn.bind(random_url, randomize_url),
226
- # )
227
- # class_names = pn.widgets.TextInput(
228
- # name="Comma separated class names",
229
- # placeholder="Enter possible class names, e.g. cat, dog",
230
- # value="cat, dog, parrot",
231
- # )
232
 
233
- # input_widgets = pn.Column(
234
- # "##### 😊 Click randomize or paste a URL to start classifying!",
235
- # pn.Row(image_url, randomize_url),
236
- # class_names,
237
- # )
238
 
239
- # # add interactivity
240
- # interactive_result = pn.panel(
241
- # pn.bind(process_inputs, image_url=image_url, class_names=class_names),
242
- # height=600,
243
- # )
244
 
245
- # # add footer
246
- # footer_row = pn.Row(pn.Spacer(), align="center")
247
- # for icon, url in ICON_URLS.items():
248
- # href_button = pn.widgets.Button(icon=icon, width=35, height=35)
249
- # href_button.js_on_click(code=f"window.open('{url}')")
250
- # footer_row.append(href_button)
251
- # footer_row.append(pn.Spacer())
252
-
253
- # # create dashboard
254
- # main = pn.WidgetBox(
255
- # input_widgets,
256
- # interactive_result,
257
- # footer_row,
258
- # )
259
 
260
- # title = "Panel Demo - Image Classification"
261
- # pn.template.BootstrapTemplate(
262
- # title=title,
263
- # main=main,
264
- # main_max_width="min(50%, 698px)",
265
- # header_background="#F08080",
266
- # ).servable(title=title)
 
1
+ # # load up the libraries
2
+ # import panel as pn
3
+ # import pandas as pd
4
+ # import altair as alt
5
+ # from vega_datasets import data
6
 
7
+ # # we want to use bootstrap/template, tell Panel to load up what we need
8
+ # pn.extension(design='bootstrap')
9
 
10
+ # # we want to use vega, tell Panel to load up what we need
11
+ # pn.extension('vega')
12
 
13
+ # # create a basic template using bootstrap
14
+ # template = pn.template.BootstrapTemplate(
15
+ # title='SI649 Walkthrough',
16
+ # )
17
 
18
+ # # the main column will hold our key content
19
+ # maincol = pn.Column()
20
 
21
+ # # add some markdown to the main column
22
+ # maincol.append("# Markdown Title")
23
+ # maincol.append("I can format in cool ways. Like **bold** or *italics* or ***both*** or ~~strikethrough~~ or `code` or [links](https://panel.holoviz.org)")
24
+ # maincol.append("I am writing a link [to the streamlit documentation page](https://docs.streamlit.io/en/stable/api.html)")
25
+ # maincol.append('![alt text](https://upload.wikimedia.org/wikipedia/commons/thumb/3/3e/Irises-Vincent_van_Gogh.jpg/314px-Irises-Vincent_van_Gogh.jpg)')
26
 
27
+ # # load up a dataframe and show it in the main column
28
+ # cars_url = "https://raw.githubusercontent.com/altair-viz/vega_datasets/master/vega_datasets/_data/cars.json"
29
+ # cars = pd.read_json(cars_url)
30
+ # temps = data.seattle_weather()
31
 
32
+ # maincol.append(temps.head(10))
33
 
34
+ # # create a basic chart
35
+ # hp_mpg = alt.Chart(cars).mark_circle(size=80).encode(
36
+ # x='Horsepower:Q',
37
+ # y='Miles_per_Gallon:Q',
38
+ # color='Origin:N'
39
+ # )
40
 
41
+ # # dispaly it in the main column
42
+ # # maincol.append(hp_mpg)
43
 
44
+ # # create a basic slider
45
+ # simpleslider = pn.widgets.IntSlider(name='Simple Slider', start=0, end=100, value=0)
46
 
47
+ # # generate text based on slider value
48
+ # def square(x):
49
+ # return f'{x} squared is {x**2}'
50
 
51
 
52
+ # # bind the slider to the function and hold the output in a row
53
+ # row = pn.Column(pn.bind(square,simpleslider))
54
 
55
+ # # add both slider and row
56
+ # maincol.append(simpleslider)
57
+ # maincol.append(row)
58
 
59
+ # # variable to track state of visualization
60
+ # flip = False
61
 
62
+ # # function to either return the vis or a message
63
+ # def makeChartVisible(val):
64
+ # global flip # grab the variable outside the function
65
+ # if (flip == True):
66
+ # flip = not flip # flip to False
67
+ # return pn.pane.Vega(hp_mpg) # return the vis
68
+ # else:
69
+ # flip = not flip # flip to true and return text
70
+ # return pn.panel("Click the button to see the chart")
71
 
72
+ # # add a button and then create the binding
73
+ # btn = pn.widgets.Button(name='Click me')
74
+ # row = pn.Row(pn.bind(makeChartVisible, btn))
75
+
76
+ # # add button and new row to main column
77
+ # maincol.append(btn)
78
+ # maincol.append(row)
79
+
80
+ # # create a base chart
81
+ # basechart = alt.Chart(cars).mark_circle(size=80,opacity=0.5).encode(
82
+ # x='Horsepower:Q',
83
+ # y='Acceleration:Q',
84
+ # color="Origin:N"
85
+ # )
86
 
87
+ # # create something to hold the base chart
88
+ # currentoption = pn.panel(basechart)
89
 
90
+ # # create a selection widget
91
+ # select = pn.widgets.Select(name='Select', options=['Horsepower','Acceleration','Miles_per_Gallon'])
92
 
93
+ # # create a function to modify the basechart that is being
94
+ # # held in currentoption
95
+ # def changeOption(val):
96
+ # # grab what's there now
97
+ # chrt = currentoption.object
98
+ # # change the encoding based on val
99
+ # chrt = chrt.encode(
100
+ # y=val+":Q"
101
+ # )
102
+ # # replace old chart in currentoption with new one
103
+ # currentoption.object = chrt
104
 
105
+ # # append the selection
106
+ # maincol.append(select)
107
+ # # append the binding (in thise case nothing is being returned by changeOption, so...)
108
+ # chartchange = pn.Row(pn.bind(changeOption, select))
109
+ # # ... we need to also add the chart
110
+ # maincol.append(chartchange)
111
+ # maincol.append(currentoption)
112
 
113
+ # # add the main column to the template
114
+ # template.main.append(maincol)
115
 
116
+ # # Indicate that the template object is the "application" and serve it
117
+ # template.servable(title="SI649 Walkthrough")
118
 
119
 
120
+ import io
121
+ import random
122
+ from typing import List, Tuple
123
 
124
+ import aiohttp
125
+ import panel as pn
126
+ from PIL import Image
127
+ from transformers import CLIPModel, CLIPProcessor
128
 
129
+ pn.extension(design="bootstrap", sizing_mode="stretch_width")
130
 
131
+ ICON_URLS = {
132
+ "brand-github": "https://github.com/holoviz/panel",
133
+ "brand-twitter": "https://twitter.com/Panel_Org",
134
+ "brand-linkedin": "https://www.linkedin.com/company/panel-org",
135
+ "message-circle": "https://discourse.holoviz.org/",
136
+ "brand-discord": "https://discord.gg/AXRHnJU6sP",
137
+ }
138
 
139
 
140
+ async def random_url(_):
141
+ pet = random.choice(["cat", "dog"])
142
+ api_url = f"https://api.the{pet}api.com/v1/images/search"
143
+ async with aiohttp.ClientSession() as session:
144
+ async with session.get(api_url) as resp:
145
+ return (await resp.json())[0]["url"]
146
 
147
 
148
+ @pn.cache
149
+ def load_processor_model(
150
+ processor_name: str, model_name: str
151
+ ) -> Tuple[CLIPProcessor, CLIPModel]:
152
+ processor = CLIPProcessor.from_pretrained(processor_name)
153
+ model = CLIPModel.from_pretrained(model_name)
154
+ return processor, model
155
 
156
 
157
+ async def open_image_url(image_url: str) -> Image:
158
+ async with aiohttp.ClientSession() as session:
159
+ async with session.get(image_url) as resp:
160
+ return Image.open(io.BytesIO(await resp.read()))
161
 
162
 
163
+ def get_similarity_scores(class_items: List[str], image: Image) -> List[float]:
164
+ processor, model = load_processor_model(
165
+ "openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32"
166
+ )
167
+ inputs = processor(
168
+ text=class_items,
169
+ images=[image],
170
+ return_tensors="pt", # pytorch tensors
171
+ )
172
+ outputs = model(**inputs)
173
+ logits_per_image = outputs.logits_per_image
174
+ class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy()
175
+ return class_likelihoods[0]
176
+
177
+
178
+ async def process_inputs(class_names: List[str], image_url: str):
179
+ """
180
+ High level function that takes in the user inputs and returns the
181
+ classification results as panel objects.
182
+ """
183
+ try:
184
+ main.disabled = True
185
+ if not image_url:
186
+ yield "##### ⚠️ Provide an image URL"
187
+ return
188
 
189
+ yield "##### βš™ Fetching image and running model..."
190
+ try:
191
+ pil_img = await open_image_url(image_url)
192
+ img = pn.pane.Image(pil_img, height=400, align="center")
193
+ except Exception as e:
194
+ yield f"##### πŸ˜” Something went wrong, please try a different URL!"
195
+ return
196
 
197
+ class_items = class_names.split(",")
198
+ class_likelihoods = get_similarity_scores(class_items, pil_img)
199
 
200
+ # build the results column
201
+ results = pn.Column("##### πŸŽ‰ Here are the results!", img)
202
 
203
+ for class_item, class_likelihood in zip(class_items, class_likelihoods):
204
+ row_label = pn.widgets.StaticText(
205
+ name=class_item.strip(), value=f"{class_likelihood:.2%}", align="center"
206
+ )
207
+ row_bar = pn.indicators.Progress(
208
+ value=int(class_likelihood * 100),
209
+ sizing_mode="stretch_width",
210
+ bar_color="secondary",
211
+ margin=(0, 10),
212
+ design=pn.theme.Material,
213
+ )
214
+ results.append(pn.Column(row_label, row_bar))
215
+ yield results
216
+ finally:
217
+ main.disabled = False
218
+
219
+
220
+ # create widgets
221
+ randomize_url = pn.widgets.Button(name="Randomize URL", align="end")
222
+
223
+ image_url = pn.widgets.TextInput(
224
+ name="Image URL to classify",
225
+ value=pn.bind(random_url, randomize_url),
226
+ )
227
+ class_names = pn.widgets.TextInput(
228
+ name="Comma separated class names",
229
+ placeholder="Enter possible class names, e.g. cat, dog",
230
+ value="cat, dog, parrot",
231
+ )
232
 
233
+ input_widgets = pn.Column(
234
+ "##### 😊 Click randomize or paste a URL to start classifying!",
235
+ pn.Row(image_url, randomize_url),
236
+ class_names,
237
+ )
238
 
239
+ # add interactivity
240
+ interactive_result = pn.panel(
241
+ pn.bind(process_inputs, image_url=image_url, class_names=class_names),
242
+ height=600,
243
+ )
244
 
245
+ # add footer
246
+ footer_row = pn.Row(pn.Spacer(), align="center")
247
+ for icon, url in ICON_URLS.items():
248
+ href_button = pn.widgets.Button(icon=icon, width=35, height=35)
249
+ href_button.js_on_click(code=f"window.open('{url}')")
250
+ footer_row.append(href_button)
251
+ footer_row.append(pn.Spacer())
252
+
253
+ # create dashboard
254
+ main = pn.WidgetBox(
255
+ input_widgets,
256
+ interactive_result,
257
+ footer_row,
258
+ )
259
 
260
+ title = "Panel Demo - Image Classification"
261
+ pn.template.BootstrapTemplate(
262
+ title=title,
263
+ main=main,
264
+ main_max_width="min(50%, 698px)",
265
+ header_background="#F08080",
266
+ ).servable(title=title)