ariG23498 HF Staff commited on
Commit
a852975
·
1 Parent(s): 0553ee9

run on cpu

Browse files
Files changed (1) hide show
  1. app.py +14 -44
app.py CHANGED
@@ -1,63 +1,33 @@
1
  """This space is taken and modified from https://huggingface.co/spaces/merve/compare_clip_siglip"""
2
- import torch
3
- from transformers import (
4
- AutoModel,
5
- AutoProcessor
6
- )
7
  import gradio as gr
8
 
9
  ################################################################################
10
  # Load the models
11
  ################################################################################
12
  sg1_ckpt = "google/siglip-so400m-patch14-384"
13
- siglip1_model = AutoModel.from_pretrained(sg1_ckpt, device_map="auto").eval()
14
- siglip1_processor = AutoProcessor.from_pretrained(sg1_ckpt)
15
 
16
  sg2_ckpt = "google/siglip2-so400m-patch14-384"
17
- siglip2_model = AutoModel.from_pretrained(sg2_ckpt, device_map="auto").eval()
18
- siglip2_processor = AutoProcessor.from_pretrained(sg2_ckpt)
19
 
20
  ################################################################################
21
- # Utilities
22
  ################################################################################
23
- def postprocess(output):
24
- return {out["label"]: float(out["score"]) for out in output}
25
-
26
-
27
- def postprocess_siglip(sg1_probs, sg2_probs, labels):
28
- sg1_output = {labels[i]: float(sg1_probs[0].cpu().numpy()[i]) for i in range(len(labels))}
29
- sg2_output = {labels[i]: float(sg2_probs[0].cpu().numpy()[i]) for i in range(len(labels))}
30
- return sg1_output, sg2_output
31
-
32
- def siglip_detector(image, texts):
33
- sg1_inputs = siglip1_processor(
34
- text=texts, images=image, return_tensors="pt", padding="max_length", max_length=64
35
- ).to(siglip1_model.device)
36
-
37
- sg2_inputs = siglip2_processor(
38
- text=texts, images=image, return_tensors="pt", padding="max_length", max_length=64
39
- ).to(siglip2_model.device)
40
-
41
- with torch.no_grad():
42
- sg1_outputs = siglip1_model(**sg1_inputs)
43
- sg2_outputs = siglip2_model(**sg2_inputs)
44
-
45
- sg1_logits_per_image = sg1_outputs.logits_per_image
46
- sg2_logits_per_image = sg2_outputs.logits_per_image
47
-
48
- sg1_probs = torch.sigmoid(sg1_logits_per_image)
49
- sg2_probs = torch.sigmoid(sg2_logits_per_image)
50
- return sg1_probs, sg2_probs
51
-
52
-
53
  def infer(image, candidate_labels):
54
  candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
55
- sg1_probs, sg2_probs = siglip_detector(image, candidate_labels)
56
- return postprocess_siglip(
57
- sg1_probs, sg2_probs, labels=candidate_labels
58
- )
59
 
 
 
 
 
 
60
 
 
 
 
 
 
61
  with gr.Blocks() as demo:
62
  gr.Markdown("# Compare SigLIP 1 and SigLIP 2")
63
  gr.Markdown(
 
1
  """This space is taken and modified from https://huggingface.co/spaces/merve/compare_clip_siglip"""
2
+ from transformers import pipeline
 
 
 
 
3
  import gradio as gr
4
 
5
  ################################################################################
6
  # Load the models
7
  ################################################################################
8
  sg1_ckpt = "google/siglip-so400m-patch14-384"
9
+ sg1_pipe = pipeline(task="zero-shot-image-classification", model=sg1_ckpt, device="cpu")
 
10
 
11
  sg2_ckpt = "google/siglip2-so400m-patch14-384"
12
+ sg2_pipe = pipeline(task="zero-shot-image-classification", model=sg2_ckpt, device="cpu")
 
13
 
14
  ################################################################################
15
+ # Run inference
16
  ################################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def infer(image, candidate_labels):
18
  candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
 
 
 
 
19
 
20
+ sg1_socres = sg1_pipe(image, candidate_labels=candidate_labels)
21
+ sg2_socres = sg2_pipe(image, candidate_labels=candidate_labels)
22
+
23
+ sg1_outputs = {element["label"]:element["score"] for element in sg1_socres}
24
+ sg2_outputs = {element["label"]:element["score"] for element in sg2_socres}
25
 
26
+ return sg1_outputs, sg2_outputs
27
+
28
+ ################################################################################
29
+ # Gradio App
30
+ ################################################################################
31
  with gr.Blocks() as demo:
32
  gr.Markdown("# Compare SigLIP 1 and SigLIP 2")
33
  gr.Markdown(