Ramji commited on
Commit
a5c4244
·
verified ·
1 Parent(s): 874d6ff

created app

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchvision.transforms as T
4
+ from PIL import Image
5
+ from torchvision.transforms.functional import InterpolationMode
6
+ from transformers import AutoModel, AutoTokenizer
7
+ import matplotlib.pyplot as plt
8
+ import random
9
+ import streamlit as st
10
+ import requests
11
+ from io import BytesIO
12
+
13
+
14
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
15
+ IMAGENET_STD = (0.229, 0.224, 0.225)
16
+
17
+ def build_transform(input_size):
18
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
19
+ transform = T.Compose([
20
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
21
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
22
+ T.ToTensor(),
23
+ T.Normalize(mean=MEAN, std=STD)
24
+ ])
25
+ return transform
26
+
27
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
28
+ best_ratio_diff = float('inf')
29
+ best_ratio = (1, 1)
30
+ area = width * height
31
+ for ratio in target_ratios:
32
+ target_aspect_ratio = ratio[0] / ratio[1]
33
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
34
+ if ratio_diff < best_ratio_diff:
35
+ best_ratio_diff = ratio_diff
36
+ best_ratio = ratio
37
+ elif ratio_diff == best_ratio_diff:
38
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
39
+ best_ratio = ratio
40
+ return best_ratio
41
+
42
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
43
+ orig_width, orig_height = image.size
44
+ aspect_ratio = orig_width / orig_height
45
+
46
+ # calculate the existing image aspect ratio
47
+ target_ratios = set(
48
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
49
+ i * j <= max_num and i * j >= min_num)
50
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
51
+
52
+ # find the closest aspect ratio to the target
53
+ target_aspect_ratio = find_closest_aspect_ratio(
54
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
55
+
56
+ # calculate the target width and height
57
+ target_width = image_size * target_aspect_ratio[0]
58
+ target_height = image_size * target_aspect_ratio[1]
59
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
60
+
61
+ # resize the image
62
+ resized_img = image.resize((target_width, target_height))
63
+ processed_images = []
64
+ for i in range(blocks):
65
+ box = (
66
+ (i % (target_width // image_size)) * image_size,
67
+ (i // (target_width // image_size)) * image_size,
68
+ ((i % (target_width // image_size)) + 1) * image_size,
69
+ ((i // (target_width // image_size)) + 1) * image_size
70
+ )
71
+ # split the image
72
+ split_img = resized_img.crop(box)
73
+ processed_images.append(split_img)
74
+ assert len(processed_images) == blocks
75
+ if use_thumbnail and len(processed_images) != 1:
76
+ thumbnail_img = image.resize((image_size, image_size))
77
+ processed_images.append(thumbnail_img)
78
+ return processed_images
79
+
80
+ def load_image(image_file, input_size=448, max_num=12):
81
+ image = Image.open(image_file).convert('RGB')
82
+
83
+ transform = build_transform(input_size=input_size)
84
+ images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
85
+ pixel_values = [transform(image) for image in images]
86
+ pixel_values = torch.stack(pixel_values)
87
+ return pixel_values
88
+
89
+ def prediction(model, image_file, question):
90
+ question = f"<image>\n{question}"
91
+ # set the max number of tiles in `max_num`
92
+ pixel_values = load_image(image_file, max_num=12).to(torch.bfloat16).cuda()
93
+ generation_config = dict(max_new_tokens=1024, do_sample=False)
94
+
95
+ response = model.chat(tokenizer, pixel_values, question, generation_config)
96
+
97
+ return response
98
+
99
+ # If you want to load a model using multiple GPUs, please refer to the `Multiple GPUs` section.
100
+ path = 'Ramji/slake_vqa_internvl_demo'
101
+ intern_model = AutoModel.from_pretrained(
102
+ path,
103
+ torch_dtype=torch.bfloat16,
104
+ low_cpu_mem_usage=True,
105
+ use_flash_attn=False,
106
+ trust_remote_code=True).eval().cuda()
107
+ tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
108
+
109
+ # Title of the Streamlit app
110
+ st.title("Image VQA")
111
+
112
+ # Step 1: Upload an image
113
+ st.header("Upload an Image")
114
+ uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
115
+
116
+ # Step 2: Input a question
117
+ st.header("Ask a Question")
118
+ question = st.text_input("Type your question here:")
119
+
120
+ # Step 3: Handle the uploaded image by saving it and reading its path
121
+ if uploaded_image is not None:
122
+ # Save the uploaded image to a file
123
+ image_path = os.path.join("uploaded_images", uploaded_image.name)
124
+
125
+ # Make sure the directory exists
126
+ os.makedirs("uploaded_images", exist_ok=True)
127
+
128
+ # Write the image to a file
129
+ with open(image_path, "wb") as f:
130
+ f.write(uploaded_image.getbuffer())
131
+
132
+ # Read the image from the saved file path
133
+ image = Image.open(image_path)
134
+
135
+ # Display the uploaded image
136
+ st.image(image, caption="Uploaded Image", use_column_width=True)
137
+
138
+ st.write(f"Image saved at: {image_path}")
139
+
140
+ # Step 4: Display the typed question
141
+ if question:
142
+ st.write(f"Your question: **{question}**")
143
+
144
+ # Optional: Process the image and question for a VLM (like CLIP or BLIP)
145
+ if uploaded_image and question:
146
+ st.write("Processing the image and question...")
147
+
148
+ output = prediction(intern_model, image_file, question)
149
+
150
+ st.write("Model output: This is where the answer will appear.")