romadanskiy commited on
Commit
ee2df8c
·
verified ·
1 Parent(s): 69e1902

Create app.py with Zero-Shot Image Classification

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPModel, CLIPProcessor
2
+ from PIL import Image
3
+ import time
4
+ import gradio as gr
5
+
6
+
7
+ openai_model_name = "openai/clip-vit-large-patch14"
8
+ openai_model = CLIPModel.from_pretrained(openai_model_name)
9
+ openai_processor = CLIPProcessor.from_pretrained(openai_model_name)
10
+
11
+ patrickjohncyh_model_name = "patrickjohncyh/fashion-clip"
12
+ patrickjohncyh_model = CLIPModel.from_pretrained(patrickjohncyh_model_name)
13
+ patrickjohncyh_processor = CLIPProcessor.from_pretrained(patrickjohncyh_model_name)
14
+
15
+ model_map = {
16
+ openai_model_name: (openai_model, openai_processor),
17
+ patrickjohncyh_model_name: (patrickjohncyh_model, patrickjohncyh_processor)
18
+ }
19
+
20
+
21
+ def gradio_process(model_name, image, text):
22
+ (model, processor) = model_map[model_name]
23
+ labels = text.split(", ")
24
+ print (labels)
25
+ start = time.time()
26
+ inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
27
+ outputs = model(**inputs)
28
+ probs = outputs.logits_per_image.softmax(dim=1)[0]
29
+ end = time.time()
30
+ time_spent = end - start
31
+ probs = list(probs)
32
+ results = []
33
+ for i in range(len(labels)):
34
+ results.append(f"{labels[i]} - {probs[i].item():.4f}")
35
+ result = "\n".join(results)
36
+
37
+ return [result, time_spent]
38
+
39
+
40
+ with gr.Blocks() as zero_shot_image_classification_tab:
41
+ gr.Markdown("# Zero-Shot Image Classification")
42
+
43
+ with gr.Row():
44
+ with gr.Column():
45
+ # Input components
46
+ input_image = gr.Image(label="Upload Image", type="pil")
47
+ input_text = gr.Textbox(label="Labels (comma separated)")
48
+ model_selector = gr.Dropdown([openai_model_name, patrickjohncyh_model_name],
49
+ label = "Select Model")
50
+
51
+ # Process button
52
+ process_btn = gr.Button("Classificate")
53
+
54
+ with gr.Column():
55
+ # Output components
56
+ elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1)
57
+ output_text = gr.Textbox(label="Classification")
58
+
59
+ # Connect the input components to the processing function
60
+ process_btn.click(
61
+ fn=gradio_process,
62
+ inputs=[
63
+ model_selector,
64
+ input_image,
65
+ input_text
66
+ ],
67
+ outputs=[output_text, elapsed_result]
68
+ )
69
+
70
+
71
+ with gr.Blocks() as app:
72
+ gr.TabbedInterface([zero_shot_image_classification_tab], ["Zero-Shot Classification"])
73
+
74
+
75
+ app.launch()