iamrobotbear commited on
Commit
ea1e3a1
·
1 Parent(s): 56786fe

Bring up to date with working github copy

Browse files

https://github.com/brianjking/image-caption-textmatch/commit/4ec79f5cccf8269f713977583c4bda88a34c8ca3

Files changed (1) hide show
  1. app.py +41 -31
app.py CHANGED
@@ -4,15 +4,15 @@ from PIL import Image
4
  import pandas as pd
5
  from lavis.models import load_model_and_preprocess
6
  from lavis.processors import load_processor
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
 
9
  # Load model and preprocessors for Image-Text Matching (LAVIS)
10
  device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
11
  model_itm, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True)
12
 
13
  # Load tokenizer and model for Image Captioning (TextCaps)
14
- tokenizer_caption = AutoTokenizer.from_pretrained("microsoft/git-large-r-textcaps")
15
- model_caption = AutoModelForCausalLM.from_pretrained("microsoft/git-large-r-textcaps")
16
 
17
  # List of statements for Image-Text Matching
18
  statements = [
@@ -25,41 +25,51 @@ statements = [
25
  'promotes alcohol use as a "rite of passage" to adulthood',
26
  ]
27
 
28
- txts = [text_processors["eval"](statement) for statement in statements]
29
-
30
- # Function to compute Image-Text Matching (ITM) scores for all statements
31
- def compute_itm_scores(image):
32
  pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
33
  img = vis_processors["eval"](pil_image.convert("RGB")).unsqueeze(0).to(device)
34
- results = []
35
- for i, statement in enumerate(statements):
36
- txt = txts[i]
37
- itm_output = model_itm({"image": img, "text_input": txt}, match_head="itm")
38
- itm_scores = torch.nn.functional.softmax(itm_output, dim=1)
39
- score = itm_scores[:, 1].item()
40
- result_text = f'The image and "{statement}" are matched with a probability of {score:.3%}'
41
- results.append(result_text)
42
- output = "\n".join(results)
43
- return output
44
 
45
- # Function to generate image captions using TextCaps
46
- def generate_image_captions():
47
- prompt = "A photo of"
48
- inputs = tokenizer_caption(prompt, return_tensors="pt", padding=True, truncation=True)
49
- outputs = model_caption.generate(**inputs)
50
- caption = tokenizer_caption.decode(outputs[0], skip_special_tokens=True)
51
- return prompt + " " + caption
 
52
 
53
  # Main function to perform image captioning and image-text matching
54
  def process_images_and_statements(image):
55
- # Generate image captions using TextCaps
56
- captions = generate_image_captions()
 
 
 
57
 
58
- # Compute ITM scores for predefined statements using LAVIS
59
- itm_scores = compute_itm_scores(image)
 
 
 
 
 
 
 
 
 
60
 
61
- # Combine image captions and ITM scores into the output
62
- output = "Image Captions:\n" + captions + "\n\nITM Scores:\n" + itm_scores
63
  return output
64
 
65
  # Gradio interface
@@ -67,4 +77,4 @@ image_input = gr.inputs.Image()
67
  output = gr.outputs.Textbox(label="Results")
68
 
69
  iface = gr.Interface(fn=process_images_and_statements, inputs=image_input, outputs=output, title="Image Captioning and Image-Text Matching")
70
- iface.launch()
 
4
  import pandas as pd
5
  from lavis.models import load_model_and_preprocess
6
  from lavis.processors import load_processor
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor
8
 
9
  # Load model and preprocessors for Image-Text Matching (LAVIS)
10
  device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
11
  model_itm, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True)
12
 
13
  # Load tokenizer and model for Image Captioning (TextCaps)
14
+ git_processor_large_textcaps = AutoProcessor.from_pretrained("microsoft/git-large-r-textcaps")
15
+ git_model_large_textcaps = AutoModelForCausalLM.from_pretrained("microsoft/git-large-r-textcaps")
16
 
17
  # List of statements for Image-Text Matching
18
  statements = [
 
25
  'promotes alcohol use as a "rite of passage" to adulthood',
26
  ]
27
 
28
+ # Function to compute ITM scores for the combined text input (caption + statement)
29
+ def compute_itm_score(image, combined_text):
 
 
30
  pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
31
  img = vis_processors["eval"](pil_image.convert("RGB")).unsqueeze(0).to(device)
32
+ # Pass the combined_text string directly to model_itm
33
+ itm_output = model_itm({"image": img, "text_input": combined_text}, match_head="itm")
34
+ itm_scores = torch.nn.functional.softmax(itm_output, dim=1)
35
+ score = itm_scores[:, 1].item()
36
+ return score
37
+
38
+
39
+
40
+
 
41
 
42
+
43
+
44
+
45
+ def generate_caption(processor, model, image):
46
+ inputs = processor(images=image, return_tensors="pt").to(device)
47
+ generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
48
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
49
+ return generated_caption
50
 
51
  # Main function to perform image captioning and image-text matching
52
  def process_images_and_statements(image):
53
+ # Generate image caption for the uploaded image using git-large-r-textcaps
54
+ caption = generate_caption(git_processor_large_textcaps, git_model_large_textcaps, image)
55
+
56
+ # Initialize an empty list to store the results
57
+ results = []
58
 
59
+ # Loop through each predefined statement
60
+ for statement in statements:
61
+ # Concatenate the caption with the statement
62
+ combined_text = caption + " " + statement
63
+
64
+ # Compute ITM score for the combined text and the image
65
+ itm_score = compute_itm_score(image, combined_text)
66
+
67
+ # Store the result
68
+ result_text = f'The image and "{combined_text}" are matched with a probability of {itm_score:.3%}'
69
+ results.append(result_text)
70
 
71
+ # Combine the results and return them
72
+ output = "\n".join(results)
73
  return output
74
 
75
  # Gradio interface
 
77
  output = gr.outputs.Textbox(label="Results")
78
 
79
  iface = gr.Interface(fn=process_images_and_statements, inputs=image_input, outputs=output, title="Image Captioning and Image-Text Matching")
80
+ iface.launch()