Spaces:
Running
Running
File size: 1,979 Bytes
254fbaf 1e99e66 f1f2177 254fbaf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import re
import requests
import gradio as gr
from torch import topk
from torch.nn.functional import softmax
from transformers import ViTImageProcessor, ViTForImageClassification
def load_label_data():
file_url = "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
response = requests.get(file_url)
labels = []
pattern = '["\'](.*?)["\']'
for line in response.text.split('\n'):
try:
tmp = re.findall(pattern, line)[0]
labels.append(tmp)
except IndexError:
pass
return labels
def run_model(image, nb_classes):
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
outputs = softmax(outputs.logits, dim=1)
outputs = topk(outputs, k=nb_classes)
return outputs
def classify_image(image, labels, nb_classes):
top10 = run_model(image, nb_classes=nb_classes)
return {labels[top10[1][0][i]]: float(top10[0][0][i]) for i in range(nb_classes)}
def main():
nb_classes = 10
labels = load_label_data()
examples=[
['https://github.com/andreped/INF1600-ai-workshop/releases/download/Examples/cat.jpg'],
['https://github.com/andreped/INF1600-ai-workshop/releases/download/Examples/dog.jpeg'],
]
# define UI
image = gr.Image(height=512)
label = gr.Label(num_top_classes=nb_classes)
interface = gr.Interface(
fn=lambda x: classify_image(x, labels, nb_classes), inputs=image, outputs=label, title='Vision Transformer Image Classifier', examples=examples,
)
interface.launch(debug=True, share=False, height=600, width=1200) # by setting share=True you can serve the website for others to access
if __name__ == "__main__":
main()
|