Spaces:
Runtime error
Runtime error
# # load up the libraries | |
# import panel as pn | |
# import pandas as pd | |
# import altair as alt | |
# from vega_datasets import data | |
# # we want to use bootstrap/template, tell Panel to load up what we need | |
# pn.extension(design='bootstrap') | |
# # we want to use vega, tell Panel to load up what we need | |
# pn.extension('vega') | |
# # create a basic template using bootstrap | |
# template = pn.template.BootstrapTemplate( | |
# title='SI649 Walkthrough', | |
# ) | |
# # the main column will hold our key content | |
# maincol = pn.Column() | |
# # add some markdown to the main column | |
# maincol.append("# Markdown Title") | |
# maincol.append("I can format in cool ways. Like **bold** or *italics* or ***both*** or ~~strikethrough~~ or `code` or [links](https://panel.holoviz.org)") | |
# maincol.append("I am writing a link [to the streamlit documentation page](https://docs.streamlit.io/en/stable/api.html)") | |
# maincol.append('') | |
# # load up a dataframe and show it in the main column | |
# cars_url = "https://raw.githubusercontent.com/altair-viz/vega_datasets/master/vega_datasets/_data/cars.json" | |
# cars = pd.read_json(cars_url) | |
# temps = data.seattle_weather() | |
# maincol.append(temps.head(10)) | |
# # create a basic chart | |
# hp_mpg = alt.Chart(cars).mark_circle(size=80).encode( | |
# x='Horsepower:Q', | |
# y='Miles_per_Gallon:Q', | |
# color='Origin:N' | |
# ) | |
# # dispaly it in the main column | |
# # maincol.append(hp_mpg) | |
# # create a basic slider | |
# simpleslider = pn.widgets.IntSlider(name='Simple Slider', start=0, end=100, value=0) | |
# # generate text based on slider value | |
# def square(x): | |
# return f'{x} squared is {x**2}' | |
# # bind the slider to the function and hold the output in a row | |
# row = pn.Column(pn.bind(square,simpleslider)) | |
# # add both slider and row | |
# maincol.append(simpleslider) | |
# maincol.append(row) | |
# # variable to track state of visualization | |
# flip = False | |
# # function to either return the vis or a message | |
# def makeChartVisible(val): | |
# global flip # grab the variable outside the function | |
# if (flip == True): | |
# flip = not flip # flip to False | |
# return pn.pane.Vega(hp_mpg) # return the vis | |
# else: | |
# flip = not flip # flip to true and return text | |
# return pn.panel("Click the button to see the chart") | |
# # add a button and then create the binding | |
# btn = pn.widgets.Button(name='Click me') | |
# row = pn.Row(pn.bind(makeChartVisible, btn)) | |
# # add button and new row to main column | |
# maincol.append(btn) | |
# maincol.append(row) | |
# # create a base chart | |
# basechart = alt.Chart(cars).mark_circle(size=80,opacity=0.5).encode( | |
# x='Horsepower:Q', | |
# y='Acceleration:Q', | |
# color="Origin:N" | |
# ) | |
# # create something to hold the base chart | |
# currentoption = pn.panel(basechart) | |
# # create a selection widget | |
# select = pn.widgets.Select(name='Select', options=['Horsepower','Acceleration','Miles_per_Gallon']) | |
# # create a function to modify the basechart that is being | |
# # held in currentoption | |
# def changeOption(val): | |
# # grab what's there now | |
# chrt = currentoption.object | |
# # change the encoding based on val | |
# chrt = chrt.encode( | |
# y=val+":Q" | |
# ) | |
# # replace old chart in currentoption with new one | |
# currentoption.object = chrt | |
# # append the selection | |
# maincol.append(select) | |
# # append the binding (in thise case nothing is being returned by changeOption, so...) | |
# chartchange = pn.Row(pn.bind(changeOption, select)) | |
# # ... we need to also add the chart | |
# maincol.append(chartchange) | |
# maincol.append(currentoption) | |
# # add the main column to the template | |
# template.main.append(maincol) | |
# # Indicate that the template object is the "application" and serve it | |
# template.servable(title="SI649 Walkthrough") | |
import io | |
import random | |
from typing import List, Tuple | |
import aiohttp | |
import panel as pn | |
from PIL import Image | |
from transformers import CLIPModel, CLIPProcessor | |
pn.extension(design="bootstrap", sizing_mode="stretch_width") | |
ICON_URLS = { | |
"brand-github": "https://github.com/holoviz/panel", | |
"brand-twitter": "https://twitter.com/Panel_Org", | |
"brand-linkedin": "https://www.linkedin.com/company/panel-org", | |
"message-circle": "https://discourse.holoviz.org/", | |
"brand-discord": "https://discord.gg/AXRHnJU6sP", | |
} | |
async def random_url(_): | |
pet = random.choice(["cat", "dog"]) | |
api_url = f"https://api.the{pet}api.com/v1/images/search" | |
async with aiohttp.ClientSession() as session: | |
async with session.get(api_url) as resp: | |
return (await resp.json())[0]["url"] | |
def load_processor_model( | |
processor_name: str, model_name: str | |
) -> Tuple[CLIPProcessor, CLIPModel]: | |
processor = CLIPProcessor.from_pretrained(processor_name) | |
model = CLIPModel.from_pretrained(model_name) | |
return processor, model | |
async def open_image_url(image_url: str) -> Image: | |
async with aiohttp.ClientSession() as session: | |
async with session.get(image_url) as resp: | |
return Image.open(io.BytesIO(await resp.read())) | |
def get_similarity_scores(class_items: List[str], image: Image) -> List[float]: | |
processor, model = load_processor_model( | |
"openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32" | |
) | |
inputs = processor( | |
text=class_items, | |
images=[image], | |
return_tensors="pt", # pytorch tensors | |
) | |
outputs = model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy() | |
return class_likelihoods[0] | |
async def process_inputs(class_names: List[str], image_url: str): | |
""" | |
High level function that takes in the user inputs and returns the | |
classification results as panel objects. | |
""" | |
try: | |
main.disabled = True | |
if not image_url: | |
yield "##### β οΈ Provide an image URL" | |
return | |
yield "##### β Fetching image and running model..." | |
try: | |
pil_img = await open_image_url(image_url) | |
img = pn.pane.Image(pil_img, height=400, align="center") | |
except Exception as e: | |
yield f"##### π Something went wrong, please try a different URL!" | |
return | |
class_items = class_names.split(",") | |
class_likelihoods = get_similarity_scores(class_items, pil_img) | |
# build the results column | |
results = pn.Column("##### π Here are the results!", img) | |
for class_item, class_likelihood in zip(class_items, class_likelihoods): | |
row_label = pn.widgets.StaticText( | |
name=class_item.strip(), value=f"{class_likelihood:.2%}", align="center" | |
) | |
row_bar = pn.indicators.Progress( | |
value=int(class_likelihood * 100), | |
sizing_mode="stretch_width", | |
bar_color="secondary", | |
margin=(0, 10), | |
design=pn.theme.Material, | |
) | |
results.append(pn.Column(row_label, row_bar)) | |
yield results | |
finally: | |
main.disabled = False | |
# create widgets | |
randomize_url = pn.widgets.Button(name="Randomize URL", align="end") | |
image_url = pn.widgets.TextInput( | |
name="Image URL to classify", | |
value=pn.bind(random_url, randomize_url), | |
) | |
class_names = pn.widgets.TextInput( | |
name="Comma separated class names", | |
placeholder="Enter possible class names, e.g. cat, dog", | |
value="cat, dog, parrot", | |
) | |
input_widgets = pn.Column( | |
"##### π Click randomize or paste a URL to start classifying!", | |
pn.Row(image_url, randomize_url), | |
class_names, | |
) | |
# add interactivity | |
interactive_result = pn.panel( | |
pn.bind(process_inputs, image_url=image_url, class_names=class_names), | |
height=600, | |
) | |
# add footer | |
footer_row = pn.Row(pn.Spacer(), align="center") | |
for icon, url in ICON_URLS.items(): | |
href_button = pn.widgets.Button(icon=icon, width=35, height=35) | |
href_button.js_on_click(code=f"window.open('{url}')") | |
footer_row.append(href_button) | |
footer_row.append(pn.Spacer()) | |
# create dashboard | |
main = pn.WidgetBox( | |
input_widgets, | |
interactive_result, | |
footer_row, | |
) | |
title = "Panel Demo - Image Classification" | |
pn.template.BootstrapTemplate( | |
title=title, | |
main=main, | |
main_max_width="min(50%, 698px)", | |
header_background="#F08080", | |
).servable(title=title) |