iamrobotbear commited on
Commit
c617ba2
·
1 Parent(s): bb1fcce

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from pathlib import Path
4
+ from PIL import Image
5
+ import pandas as pd
6
+ from lavis.models import load_model_and_preprocess
7
+ from lavis.processors import load_processor
8
+ from transformers import CLIPProcessor, CLIPModel
9
+
10
+ # Load model and preprocessors for Image-Text Matching (LAVIS)
11
+ device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
12
+ model_itm, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True)
13
+
14
+ # Load model and processor for Image Captioning (TextCaps)
15
+ model_caption = CLIPModel.from_pretrained("microsoft/git-large-r-textcaps")
16
+ processor_caption = CLIPProcessor.from_pretrained("microsoft/git-large-r-textcaps")
17
+
18
+ # List of statements for Image-Text Matching
19
+ statements = [
20
+ # (Add actual statements here)
21
+ ]
22
+
23
+ txts = [text_processors["eval"](statement) for statement in statements]
24
+
25
+ # Function to compute Image-Text Matching (ITM) scores for all statements
26
+ def compute_itm_scores(image):
27
+ pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
28
+ img = vis_processors["eval"](pil_image.convert("RGB")).unsqueeze(0).to(device)
29
+ results = []
30
+ for i, statement in enumerate(statements):
31
+ txt = txts[i]
32
+ itm_output = model_itm({"image": img, "text_input": txt}, match_head="itm")
33
+ itm_scores = torch.nn.functional.softmax(itm_output, dim=1)
34
+ score = itm_scores[:, 1].item()
35
+ result_text = f'The image and "{statement}" are matched with a probability of {score:.3%}'
36
+ results.append(result_text)
37
+ output = "\n".join(results)
38
+ return output
39
+
40
+ # Function to generate image captions using TextCaps
41
+ def generate_image_captions(image):
42
+ pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
43
+ inputs = processor_caption(pil_image, return_tensors="pt", padding=True, truncation=True)
44
+ outputs = model_caption.generate(**inputs)
45
+ caption = processor_caption.decode(outputs[0])
46
+ return caption
47
+
48
+ # Main function to perform image captioning and image-text matching
49
+ def process_images_and_statements(image):
50
+ # Generate image captions using TextCaps
51
+ captions = generate_image_captions(image)
52
+
53
+ # Compute ITM scores for predefined statements using LAVIS
54
+ itm_scores = compute_itm_scores(image)
55
+
56
+ # Combine image captions and ITM scores into the output
57
+ output = "Image Captions:\n" + captions + "\n\nITM Scores:\n" + itm_scores
58
+ return output
59
+
60
+ # Gradio interface
61
+ image_input = gr.inputs.Image()
62
+ output = gr.outputs.Textbox(label="Results")
63
+
64
+ iface = gr.Interface(fn=process_images_and_statements, inputs=image_input, outputs=output, title="Image Captioning and Image-Text Matching")
65
+ iface.launch()