Spaces:
Runtime error
Runtime error
File size: 3,645 Bytes
307197c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import aiohttp
import io
import random
import panel as pn
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from typing import List, Tuple
pn.extension(design='bootstrap', sizing_mode="stretch_width")
async def random_url(_):
api_url = random.choice([
"https://api.thecatapi.com/v1/images/search",
"https://api.thedogapi.com/v1/images/search"
])
async with aiohttp.ClientSession() as session:
async with session.get(api_url) as resp:
return (await resp.json())[0]["url"]
@pn.cache
def load_processor_model(
processor_name: str, model_name: str
) -> Tuple[CLIPProcessor, CLIPModel]:
processor = CLIPProcessor.from_pretrained(processor_name)
model = CLIPModel.from_pretrained(model_name)
return processor, model
async def open_image_url(image_url: str) -> Image:
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as resp:
return Image.open(io.BytesIO(await resp.read()))
def get_similarity_scores(class_items: List[str], image: Image) -> List[float]:
processor, model = load_processor_model(
"openai/clip-vit-base-patch32", "openai/clip-vit-base-patch32"
)
inputs = processor(
text=class_items,
images=[image],
return_tensors="pt", # pytorch tensors
)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
class_likelihoods = logits_per_image.softmax(dim=1).detach().numpy()
return class_likelihoods[0]
async def process_inputs(class_names: List[str], image_url: str):
"""
High level function that takes in the user inputs and returns the
classification results as panel objects.
"""
if not image_url:
yield '## Provide an image URL'
return
yield '## Fetching image and running model β'
pil_img = await open_image_url(image_url)
img = pn.pane.Image(pil_img, height=400, align='center')
class_items = class_names.split(",")
class_likelihoods = get_similarity_scores(class_items, pil_img)
# build the results column
results = pn.Column("## π Here are the results!", img)
for class_item, class_likelihood in zip(class_items, class_likelihoods):
row_label = pn.widgets.StaticText(
name=class_item.strip(), value=f"{class_likelihood:.2%}", align='center'
)
row_bar = pn.indicators.Progress(
value=int(class_likelihood * 100),
sizing_mode="stretch_width",
bar_color="secondary",
margin=(0, 10),
design=pn.theme.Material
)
results.append(pn.Column(row_label, row_bar))
yield results
# create widgets
randomize_url = pn.widgets.Button(name="Randomize URL", align="end")
image_url = pn.widgets.TextInput(
name="Image URL to classify",
value=pn.bind(random_url, randomize_url),
)
class_names = pn.widgets.TextInput(
name="Comma separated class names",
placeholder="Enter possible class names, e.g. cat, dog",
value="cat, dog, parrot",
)
input_widgets = pn.Column(
"## π Click randomize or paste a URL to start classifying!",
pn.Row(image_url, randomize_url),
class_names,
)
# add interactivity
interactive_result = pn.bind(
process_inputs, image_url=image_url, class_names=class_names
)
# create dashboard
main = pn.WidgetBox(
input_widgets,
interactive_result,
)
pn.template.BootstrapTemplate(
title="Panel Image Classification Demo",
main=main,
main_max_width="min(50%, 698px)",
header_background="#F08080",
).servable(title="Panel Image Classification Demo"); |