File size: 2,356 Bytes
ee2df8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
import time
import gradio as gr


openai_model_name = "openai/clip-vit-large-patch14"
openai_model = CLIPModel.from_pretrained(openai_model_name)
openai_processor = CLIPProcessor.from_pretrained(openai_model_name)

patrickjohncyh_model_name = "patrickjohncyh/fashion-clip"
patrickjohncyh_model = CLIPModel.from_pretrained(patrickjohncyh_model_name)
patrickjohncyh_processor = CLIPProcessor.from_pretrained(patrickjohncyh_model_name)

model_map = {
    openai_model_name: (openai_model, openai_processor),
    patrickjohncyh_model_name: (patrickjohncyh_model, patrickjohncyh_processor)
}


def gradio_process(model_name, image, text):
    (model, processor) = model_map[model_name]
    labels = text.split(", ")
    print (labels)
    start = time.time()
    inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    probs = outputs.logits_per_image.softmax(dim=1)[0]
    end = time.time()
    time_spent = end - start
    probs = list(probs)
    results = []
    for i in range(len(labels)):
      results.append(f"{labels[i]} - {probs[i].item():.4f}")
    result = "\n".join(results)

    return [result, time_spent]


with gr.Blocks() as zero_shot_image_classification_tab:
  gr.Markdown("# Zero-Shot Image Classification")

  with gr.Row():
      with gr.Column():
          # Input components
          input_image = gr.Image(label="Upload Image", type="pil")
          input_text = gr.Textbox(label="Labels (comma separated)")
          model_selector = gr.Dropdown([openai_model_name, patrickjohncyh_model_name],
                                        label = "Select Model")

           # Process button
          process_btn = gr.Button("Classificate")

      with gr.Column():
          # Output components
          elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1)
          output_text = gr.Textbox(label="Classification")

  # Connect the input components to the processing function
  process_btn.click(
      fn=gradio_process,
      inputs=[
          model_selector,
          input_image,
          input_text
      ],
      outputs=[output_text, elapsed_result]
  )


with gr.Blocks() as app:
  gr.TabbedInterface([zero_shot_image_classification_tab], ["Zero-Shot Classification"])


app.launch()