iamrobotbear commited on
Commit
c68c7bb
·
1 Parent(s): ea1e3a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -40
app.py CHANGED
@@ -4,15 +4,15 @@ from PIL import Image
4
  import pandas as pd
5
  from lavis.models import load_model_and_preprocess
6
  from lavis.processors import load_processor
7
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor
8
 
9
  # Load model and preprocessors for Image-Text Matching (LAVIS)
10
  device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
11
  model_itm, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True)
12
 
13
  # Load tokenizer and model for Image Captioning (TextCaps)
14
- git_processor_large_textcaps = AutoProcessor.from_pretrained("microsoft/git-large-r-textcaps")
15
- git_model_large_textcaps = AutoModelForCausalLM.from_pretrained("microsoft/git-large-r-textcaps")
16
 
17
  # List of statements for Image-Text Matching
18
  statements = [
@@ -25,51 +25,41 @@ statements = [
25
  'promotes alcohol use as a "rite of passage" to adulthood',
26
  ]
27
 
28
- # Function to compute ITM scores for the combined text input (caption + statement)
29
- def compute_itm_score(image, combined_text):
 
 
30
  pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
31
  img = vis_processors["eval"](pil_image.convert("RGB")).unsqueeze(0).to(device)
32
- # Pass the combined_text string directly to model_itm
33
- itm_output = model_itm({"image": img, "text_input": combined_text}, match_head="itm")
34
- itm_scores = torch.nn.functional.softmax(itm_output, dim=1)
35
- score = itm_scores[:, 1].item()
36
- return score
37
-
38
-
39
-
40
-
41
-
42
-
43
-
44
 
45
- def generate_caption(processor, model, image):
46
- inputs = processor(images=image, return_tensors="pt").to(device)
47
- generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
48
- generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
49
- return generated_caption
 
 
50
 
51
  # Main function to perform image captioning and image-text matching
52
  def process_images_and_statements(image):
53
- # Generate image caption for the uploaded image using git-large-r-textcaps
54
- caption = generate_caption(git_processor_large_textcaps, git_model_large_textcaps, image)
55
 
56
- # Initialize an empty list to store the results
57
- results = []
58
-
59
- # Loop through each predefined statement
60
- for statement in statements:
61
- # Concatenate the caption with the statement
62
- combined_text = caption + " " + statement
63
 
64
- # Compute ITM score for the combined text and the image
65
- itm_score = compute_itm_score(image, combined_text)
66
-
67
- # Store the result
68
- result_text = f'The image and "{combined_text}" are matched with a probability of {itm_score:.3%}'
69
- results.append(result_text)
70
-
71
- # Combine the results and return them
72
- output = "\n".join(results)
73
  return output
74
 
75
  # Gradio interface
 
4
  import pandas as pd
5
  from lavis.models import load_model_and_preprocess
6
  from lavis.processors import load_processor
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
 
9
  # Load model and preprocessors for Image-Text Matching (LAVIS)
10
  device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
11
  model_itm, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True)
12
 
13
  # Load tokenizer and model for Image Captioning (TextCaps)
14
+ tokenizer_caption = AutoTokenizer.from_pretrained("microsoft/git-large-r-textcaps")
15
+ model_caption = AutoModelForCausalLM.from_pretrained("microsoft/git-large-r-textcaps")
16
 
17
  # List of statements for Image-Text Matching
18
  statements = [
 
25
  'promotes alcohol use as a "rite of passage" to adulthood',
26
  ]
27
 
28
+ txts = [text_processors["eval"](statement) for statement in statements]
29
+
30
+ # Function to compute Image-Text Matching (ITM) scores for all statements
31
+ def compute_itm_scores(image):
32
  pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
33
  img = vis_processors["eval"](pil_image.convert("RGB")).unsqueeze(0).to(device)
34
+ results = []
35
+ for i, statement in enumerate(statements):
36
+ txt = txts[i]
37
+ itm_output = model_itm({"image": img, "text_input": txt}, match_head="itm")
38
+ itm_scores = torch.nn.functional.softmax(itm_output, dim=1)
39
+ score = itm_scores[:, 1].item()
40
+ result_text = f'The image and "{statement}" are matched with a probability of {score:.3%}'
41
+ results.append(result_text)
42
+ output = "\n".join(results)
43
+ return output
 
 
44
 
45
+ # Function to generate image captions using TextCaps
46
+ def generate_image_captions():
47
+ prompt = "A photo of"
48
+ inputs = tokenizer_caption(prompt, return_tensors="pt", padding=True, truncation=True)
49
+ outputs = model_caption.generate(**inputs)
50
+ caption = tokenizer_caption.decode(outputs[0], skip_special_tokens=True)
51
+ return prompt + " " + caption
52
 
53
  # Main function to perform image captioning and image-text matching
54
  def process_images_and_statements(image):
55
+ # Generate image captions using TextCaps
56
+ captions = generate_image_captions()
57
 
58
+ # Compute ITM scores for predefined statements using LAVIS
59
+ itm_scores = compute_itm_scores(image)
 
 
 
 
 
60
 
61
+ # Combine image captions and ITM scores into the output
62
+ output = "Image Captions:\n" + captions + "\n\nITM Scores:\n" + itm_scores
 
 
 
 
 
 
 
63
  return output
64
 
65
  # Gradio interface