Spaces:
Runtime error
Runtime error
File size: 8,469 Bytes
d264abb 70cf329 d264abb 70cf329 d264abb b2fb08d d264abb b2fb08d d264abb b2fb08d d264abb b2fb08d d264abb b2fb08d d264abb b2fb08d d264abb b2fb08d d264abb b2fb08d d264abb b2fb08d d264abb b2fb08d 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb b2fb08d d264abb b2fb08d d264abb b2fb08d d264abb b2fb08d d264abb 70cf329 d264abb 70cf329 d264abb 70cf329 d264abb b2fb08d d264abb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 |
# # load up the libraries
# import panel as pn
# import pandas as pd
# import altair as alt
# from vega_datasets import data
# # we want to use bootstrap/template, tell Panel to load up what we need
# pn.extension(design='bootstrap')
# # we want to use vega, tell Panel to load up what we need
# pn.extension('vega')
# # create a basic template using bootstrap
# template = pn.template.BootstrapTemplate(
# title='SI649 Walkthrough',
# )
# # the main column will hold our key content
# maincol = pn.Column()
# # add some markdown to the main column
# maincol.append("# Markdown Title")
# maincol.append("I can format in cool ways. Like **bold** or *italics* or ***both*** or ~~strikethrough~~ or `code` or [links](https://panel.holoviz.org)")
# maincol.append("I am writing a link [to the streamlit documentation page](https://docs.streamlit.io/en/stable/api.html)")
# maincol.append('')
# # load up a dataframe and show it in the main column
# cars_url = "https://raw.githubusercontent.com/altair-viz/vega_datasets/master/vega_datasets/_data/cars.json"
# cars = pd.read_json(cars_url)
# temps = data.seattle_weather()
# maincol.append(temps.head(10))
# # create a basic chart
# hp_mpg = alt.Chart(cars).mark_circle(size=80).encode(
# x='Horsepower:Q',
# y='Miles_per_Gallon:Q',
# color='Origin:N'
# )
# # dispaly it in the main column
# # maincol.append(hp_mpg)
# # create a basic slider
# simpleslider = pn.widgets.IntSlider(name='Simple Slider', start=0, end=100, value=0)
# # generate text based on slider value
# def square(x):
# return f'{x} squared is {x**2}'
# # bind the slider to the function and hold the output in a row
# row = pn.Column(pn.bind(square,simpleslider))
# # add both slider and row
# maincol.append(simpleslider)
# maincol.append(row)
# # variable to track state of visualization
# flip = False
# # function to either return the vis or a message
# def makeChartVisible(val):
# global flip # grab the variable outside the function
# if (flip == True):
# flip = not flip # flip to False
# return pn.pane.Vega(hp_mpg) # return the vis
# else:
# flip = not flip # flip to true and return text
# return pn.panel("Click the button to see the chart")
# # add a button and then create the binding
# btn = pn.widgets.Button(name='Click me')
# row = pn.Row(pn.bind(makeChartVisible, btn))
# # add button and new row to main column
# maincol.append(btn)
# maincol.append(row)
# # create a base chart
# basechart = alt.Chart(cars).mark_circle(size=80,opacity=0.5).encode(
# x='Horsepower:Q',
# y='Acceleration:Q',
# color="Origin:N"
# )
# # create something to hold the base chart
# currentoption = pn.panel(basechart)
# # create a selection widget
# select = pn.widgets.Select(name='Select', options=['Horsepower','Acceleration','Miles_per_Gallon'])
# # create a function to modify the basechart that is being
# # held in currentoption
# def changeOption(val):
# # grab what's there now
# chrt = currentoption.object
# # change the encoding based on val
# chrt = chrt.encode(
# y=val+":Q"
# )
# # replace old chart in currentoption with new one
# currentoption.object = chrt
# # append the selection
# maincol.append(select)
# # append the binding (in thise case nothing is being returned by changeOption, so...)
# chartchange = pn.Row(pn.bind(changeOption, select))
# # ... we need to also add the chart
# maincol.append(chartchange)
# maincol.append(currentoption)
# # add the main column to the template
# template.main.append(maincol)
# # Indicate that the template object is the "application" and serve it
# template.servable(title="SI649 Walkthrough")
import io
import random
from typing import List, Tuple
import aiohttp
import panel as pn
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
pn.extension(design="bootstrap", sizing_mode="stretch_width")
ICON_URLS = {
"brand-github": "https://github.com/holoviz/panel",
"brand-twitter": "https://twitter.com/Panel_Org",
"brand-linkedin": "https://www.linkedin.com/company/panel-org",
"message-circle": "https://discourse.holoviz.org/",
"brand-discord": "https://discord.gg/AXRHnJU6sP",
}
async def random_url(_):
pet = random.choice(["cat", "dog"])
api_url = f"https://api.the{pet}api.com/v1/images/search"
async with aiohttp.ClientSession() as session:
async with session.get(api_url) as resp:
return (await resp.json())[0]["url"]
@pn.cache
def load_processor_model(
processor_name: str, model_name: str
) -> Tuple[CLIPProcessor, CLIPModel]:
processor = CLIPProcessor.from_pretrained(processor_name)
model = CLIPModel.from_pretrained(model_name)
return processor, model
async def open_image_url(image_url: str) -> Image:
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as resp:
return Image.open(io.BytesIO(await resp.read()))
def get_similarity_scores(class_items: List[str], image: Image) -> List[float]:
processor, model = load_processor_model(
"openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32"
)
inputs = processor(
text=class_items,
images=[image],
return_tensors="pt", # pytorch tensors
)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy()
return class_likelihoods[0]
async def process_inputs(class_names: List[str], image_url: str):
"""
High level function that takes in the user inputs and returns the
classification results as panel objects.
"""
try:
main.disabled = True
if not image_url:
yield "##### β οΈ Provide an image URL"
return
yield "##### β Fetching image and running model..."
try:
pil_img = await open_image_url(image_url)
img = pn.pane.Image(pil_img, height=400, align="center")
except Exception as e:
yield f"##### π Something went wrong, please try a different URL!"
return
class_items = class_names.split(",")
class_likelihoods = get_similarity_scores(class_items, pil_img)
# build the results column
results = pn.Column("##### π Here are the results!", img)
for class_item, class_likelihood in zip(class_items, class_likelihoods):
row_label = pn.widgets.StaticText(
name=class_item.strip(), value=f"{class_likelihood:.2%}", align="center"
)
row_bar = pn.indicators.Progress(
value=int(class_likelihood * 100),
sizing_mode="stretch_width",
bar_color="secondary",
margin=(0, 10),
design=pn.theme.Material,
)
results.append(pn.Column(row_label, row_bar))
yield results
finally:
main.disabled = False
# create widgets
randomize_url = pn.widgets.Button(name="Randomize URL", align="end")
image_url = pn.widgets.TextInput(
name="Image URL to classify",
value=pn.bind(random_url, randomize_url),
)
class_names = pn.widgets.TextInput(
name="Comma separated class names",
placeholder="Enter possible class names, e.g. cat, dog",
value="cat, dog, parrot",
)
input_widgets = pn.Column(
"##### π Click randomize or paste a URL to start classifying!",
pn.Row(image_url, randomize_url),
class_names,
)
# add interactivity
interactive_result = pn.panel(
pn.bind(process_inputs, image_url=image_url, class_names=class_names),
height=600,
)
# add footer
footer_row = pn.Row(pn.Spacer(), align="center")
for icon, url in ICON_URLS.items():
href_button = pn.widgets.Button(icon=icon, width=35, height=35)
href_button.js_on_click(code=f"window.open('{url}')")
footer_row.append(href_button)
footer_row.append(pn.Spacer())
# create dashboard
main = pn.WidgetBox(
input_widgets,
interactive_result,
footer_row,
)
title = "Panel Demo - Image Classification"
pn.template.BootstrapTemplate(
title=title,
main=main,
main_max_width="min(50%, 698px)",
header_background="#F08080",
).servable(title=title) |