wodesq commited on
Commit
f39c6eb
·
1 Parent(s): 5614963
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ from transformers import BertForSequenceClassification, BertConfig, BertTokenizer
4
+ from transformers import CLIPProcessor, CLIPModel
5
+ import numpy as np
6
+ import time
7
+ import gradio as gr
8
+ import re
9
+
10
+ # 加载Taiyi 中文 word encoder
11
+ text_tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese")
12
+ text_encoder = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese").eval()
13
+ # 加载CLIP的image encoder
14
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
15
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
16
+
17
+
18
+ def imgclassfiy(query_texts,img_url):
19
+ start_time = time.time()
20
+ query_texts =re.split(",|,",query_texts)
21
+ text = text_tokenizer(query_texts, return_tensors='pt', padding=True)['input_ids']
22
+ url = img_url
23
+
24
+ image = processor(images=Image.open(url), return_tensors="pt")
25
+
26
+ with torch.no_grad():
27
+ image_features = clip_model.get_image_features(**image)
28
+ text_features = text_encoder(text).logits
29
+
30
+ # 归一化
31
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
32
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
33
+
34
+ # 计算余弦相似度 logit_scale是尺度系数
35
+ logit_scale = clip_model.logit_scale.exp()
36
+ logits_per_image = logit_scale * image_features @ text_features.t()
37
+ logits_per_text = logits_per_image.t()
38
+ probs = logits_per_image.softmax(dim=-1).cpu().numpy()
39
+
40
+ #res = np.around(probs, 3)[0]
41
+ res = query_texts[np.argmax(probs)]
42
+
43
+ end_time = time.time()
44
+ print('用时:', end_time - start_time)
45
+ return res
46
+
47
+ if __name__ =="__main__":
48
+
49
+ with gr.Blocks(title="自定义类别的图像分类") as demo:
50
+ # 标题
51
+ gr.HTML('<br>')
52
+ gr.HTML(
53
+ f'<center><p style="color:#4377ec;font-size:42px;font-weight:bold;text-shadow: #FDEDB7 2px 0 0, #FDEDB7 0 2px 0, #FDEDB7 -2px 0 0, #FDEDB7 0 -2px 0;">自定义类别的图像分类</p></center>')
54
+ gr.HTML('<br>')
55
+ with gr.Row() as row:
56
+ with gr.Column():
57
+ img_input = gr.Image(type="filepath")
58
+ out_input = gr.Textbox(lable='自定义类别')
59
+ text_btn = gr.Button("提交")
60
+
61
+ with gr.Column(scale=5):
62
+ img_out = gr.Textbox(lable='输出类别')
63
+
64
+ text_btn.click(fn=imgclassfiy, inputs=[out_input,img_input], outputs=[img_out])
65
+
66
+ demo.launch(show_api=False,inbrowser=True)