howard-hou commited on
Commit
794ada2
·
1 Parent(s): 898a24b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -7
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import gradio as gr
2
  import os, gc
 
 
3
  import torch
4
  import torch.nn.functional as F
5
  from transformers import CLIPImageProcessor
@@ -103,17 +105,35 @@ examples = [
103
  ]
104
  ]
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def chatbot(image, question):
107
  if image is None:
108
  yield "Please upload an image."
109
  return
110
- image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
111
- image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
112
- # apply layer norm to image feature, very important
113
- image_features = F.layer_norm(image_features,
114
- (image_features.shape[-1],),
115
- weight=model.w['blocks.0.ln0.weight'],
116
- bias=model.w['blocks.0.ln0.bias'])
117
  input_text = generate_prompt(question)
118
  for output in generate(input_text, image_features):
119
  yield output
 
1
  import gradio as gr
2
  import os, gc
3
+ import base64
4
+ from io import BytesIO
5
  import torch
6
  import torch.nn.functional as F
7
  from transformers import CLIPImageProcessor
 
105
  ]
106
  ]
107
 
108
+
109
+ def pil_image_to_base64(pil_image):
110
+ buffered = BytesIO()
111
+ pil_image.save(buffered, format="JPEG") # You can change the format as needed (JPEG, PNG, etc.)
112
+ # Encodes the image data into base64 format as a bytes object
113
+ base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
114
+ return base64_image
115
+
116
+ image_cache = {}
117
+ def get_image_features(image):
118
+ base64_image = pil_image_to_base64(image)
119
+ if base64_image in image_cache:
120
+ image_features = image_cache[base64_image]
121
+ else:
122
+ image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
123
+ image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
124
+ # apply layer norm to image feature, very important
125
+ image_features = F.layer_norm(image_features,
126
+ (image_features.shape[-1],),
127
+ weight=model.w['blocks.0.ln0.weight'],
128
+ bias=model.w['blocks.0.ln0.bias'])
129
+ image_cache[base64_image] = image_features
130
+ return image_features
131
+
132
  def chatbot(image, question):
133
  if image is None:
134
  yield "Please upload an image."
135
  return
136
+ image_features = get_image_features(image)
 
 
 
 
 
 
137
  input_text = generate_prompt(question)
138
  for output in generate(input_text, image_features):
139
  yield output