Spaces:
Runtime error
Runtime error
Commit
·
794ada2
1
Parent(s):
898a24b
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import os, gc
|
|
|
|
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
from transformers import CLIPImageProcessor
|
@@ -103,17 +105,35 @@ examples = [
|
|
103 |
]
|
104 |
]
|
105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
def chatbot(image, question):
|
107 |
if image is None:
|
108 |
yield "Please upload an image."
|
109 |
return
|
110 |
-
|
111 |
-
image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
|
112 |
-
# apply layer norm to image feature, very important
|
113 |
-
image_features = F.layer_norm(image_features,
|
114 |
-
(image_features.shape[-1],),
|
115 |
-
weight=model.w['blocks.0.ln0.weight'],
|
116 |
-
bias=model.w['blocks.0.ln0.bias'])
|
117 |
input_text = generate_prompt(question)
|
118 |
for output in generate(input_text, image_features):
|
119 |
yield output
|
|
|
1 |
import gradio as gr
|
2 |
import os, gc
|
3 |
+
import base64
|
4 |
+
from io import BytesIO
|
5 |
import torch
|
6 |
import torch.nn.functional as F
|
7 |
from transformers import CLIPImageProcessor
|
|
|
105 |
]
|
106 |
]
|
107 |
|
108 |
+
|
109 |
+
def pil_image_to_base64(pil_image):
|
110 |
+
buffered = BytesIO()
|
111 |
+
pil_image.save(buffered, format="JPEG") # You can change the format as needed (JPEG, PNG, etc.)
|
112 |
+
# Encodes the image data into base64 format as a bytes object
|
113 |
+
base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
114 |
+
return base64_image
|
115 |
+
|
116 |
+
image_cache = {}
|
117 |
+
def get_image_features(image):
|
118 |
+
base64_image = pil_image_to_base64(image)
|
119 |
+
if base64_image in image_cache:
|
120 |
+
image_features = image_cache[base64_image]
|
121 |
+
else:
|
122 |
+
image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
|
123 |
+
image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
|
124 |
+
# apply layer norm to image feature, very important
|
125 |
+
image_features = F.layer_norm(image_features,
|
126 |
+
(image_features.shape[-1],),
|
127 |
+
weight=model.w['blocks.0.ln0.weight'],
|
128 |
+
bias=model.w['blocks.0.ln0.bias'])
|
129 |
+
image_cache[base64_image] = image_features
|
130 |
+
return image_features
|
131 |
+
|
132 |
def chatbot(image, question):
|
133 |
if image is None:
|
134 |
yield "Please upload an image."
|
135 |
return
|
136 |
+
image_features = get_image_features(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
input_text = generate_prompt(question)
|
138 |
for output in generate(input_text, image_features):
|
139 |
yield output
|