mnh commited on
Commit
08a87c6
·
1 Parent(s): e226f1a

Add application file

Browse files
Files changed (2) hide show
  1. app.py +31 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import clip
4
+
5
+ device = "cuda" if torch.cuda.is_available() else "cpu"
6
+ model, preprocess = clip.load("ViT-B/32", device=device)
7
+
8
+
9
+ def predict(image):
10
+ labels = "Japanese, Chinese, Roman, Greek, Etruscan, Scandinavian, Celtic, Medieval, Victorian, Neoclassic, Romanticism, Art Nouveau, Art deco, Cyberpunk "
11
+ labels = labels.split(',')
12
+
13
+ image = preprocess(image).unsqueeze(0).to(device)
14
+ text = clip.tokenize([f"a character of origin {c}" for c in labels]).to(device)
15
+
16
+ with torch.inference_mode():
17
+ logits_per_image, logits_per_text = model(image, text)
18
+ probs = logits_per_image.softmax(dim=-1).cpu().numpy()
19
+
20
+ return {k: float(v) for k, v in zip(labels, probs[0])}
21
+
22
+ # probs = predict(Image.open("../CLIP/CLIP.png"), "cat, dog, ball")
23
+ # print(probs)
24
+
25
+
26
+ gr.Interface(fn=predict,
27
+ inputs=[
28
+ gr.inputs.Image(label="Image to classify.", type="pil")],
29
+ theme="grass",
30
+ outputs="label",
31
+ description="Zero Shot Image classification..").launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ git+https://github.com/openai/CLIP
2
+ torch
3
+ Jinja2