furusu commited on
Commit
90fe5b0
·
verified ·
1 Parent(s): 0d874c0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import timm
3
+ import numpy as np
4
+ import faiss
5
+ import pandas as pd
6
+
7
+
8
+ TITLE = "wd-eva02-large-tagger-v3-vector"
9
+ DESCRIPTION = """
10
+ [model](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3)
11
+ """
12
+
13
+ model = timm.create_model(f"hf_hub:SmilingWolf/wd-eva02-large-tagger-v3", pretrained=True)
14
+ head = model.head.weight.data.cpu().numpy()
15
+ del model
16
+ df = pd.read_csv(f"https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3/resolve/main/selected_tags.csv")
17
+ id2label = df["name"].to_dict()
18
+ label2id = {v:k for k,v in id2label.items()}
19
+
20
+ faiss.normalize_L2(head)
21
+ index = faiss.IndexFlatIP(head.shape[1])
22
+ index.add(head)
23
+
24
+ def predict(target_tag):
25
+ target_id = label2id[target_tag]
26
+ query = head[target_id:target_id+1]
27
+ k = 50
28
+ target_id = label2id[target_tag]
29
+ distances, indices = index.search(query, k)
30
+ return {id2label[indice]:distance for indice, distance in zip(indices[0], distances[0])}
31
+
32
+ demo = gr.Interface(
33
+ fn=predict,
34
+ inputs=[
35
+ gr.Dropdown(list(label2id.keys()), label="Target tag", value="otoko_no_ko"),
36
+ ],
37
+ outputs=gr.Label(num_top_classes=50),
38
+ title=TITLE,
39
+ description=DESCRIPTION
40
+ )
41
+
42
+ demo.launch()