andreped's picture
Ignore some files
f1f2177
raw
history blame
1.98 kB
import re
import requests
import gradio as gr
from torch import topk
from torch.nn.functional import softmax
from transformers import ViTImageProcessor, ViTForImageClassification
def load_label_data():
file_url = "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
response = requests.get(file_url)
labels = []
pattern = '["\'](.*?)["\']'
for line in response.text.split('\n'):
try:
tmp = re.findall(pattern, line)[0]
labels.append(tmp)
except IndexError:
pass
return labels
def run_model(image, nb_classes):
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
outputs = softmax(outputs.logits, dim=1)
outputs = topk(outputs, k=nb_classes)
return outputs
def classify_image(image, labels, nb_classes):
top10 = run_model(image, nb_classes=nb_classes)
return {labels[top10[1][0][i]]: float(top10[0][0][i]) for i in range(nb_classes)}
def main():
nb_classes = 10
labels = load_label_data()
examples=[
['https://github.com/andreped/INF1600-ai-workshop/releases/download/Examples/cat.jpg'],
['https://github.com/andreped/INF1600-ai-workshop/releases/download/Examples/dog.jpeg'],
]
# define UI
image = gr.Image(height=512)
label = gr.Label(num_top_classes=nb_classes)
interface = gr.Interface(
fn=lambda x: classify_image(x, labels, nb_classes), inputs=image, outputs=label, title='Vision Transformer Image Classifier', examples=examples,
)
interface.launch(debug=True, share=False, height=600, width=1200) # by setting share=True you can serve the website for others to access
if __name__ == "__main__":
main()