howard-hou commited on
Commit
b0d85ba
1 Parent(s): 21aea4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -35,11 +35,11 @@ image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
35
  ##########################################################################
36
  def generate_prompt(instruction):
37
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
38
- return f"{instruction}\n\nAssistant:"
39
 
40
  def generate(
41
  ctx,
42
- image_features,
43
  token_count=128,
44
  temperature=0.2,
45
  top_p=0.3,
@@ -58,10 +58,8 @@ def generate(
58
  occurrence = {}
59
  for i in range(int(token_count)):
60
  if i == 0:
61
- input_ids = pipeline.encode(ctx)
62
- text_embs = model.w['emb.weight'][input_ids]
63
- input_embs = torch.cat((image_features, text_embs), dim=0)[-ctx_limit:]
64
- out, state = model.forward(embs=input_embs, state=None)
65
  else:
66
  input_ids = [token]
67
  out, state = model.forward(tokens=input_ids, state=state)
@@ -113,11 +111,10 @@ def pil_image_to_base64(pil_image):
113
  return base64_image
114
 
115
  image_cache = {}
116
- def get_image_features(image):
117
  base64_image = pil_image_to_base64(image)
118
  if base64_image in image_cache:
119
- image_features = image_cache[base64_image]
120
- print(f"use cache {base64_image[:10]}")
121
  else:
122
  image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
123
  image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
@@ -126,16 +123,17 @@ def get_image_features(image):
126
  (image_features.shape[-1],),
127
  weight=model.w['blocks.0.ln0.weight'],
128
  bias=model.w['blocks.0.ln0.bias'])
129
- image_cache[base64_image] = image_features
130
- return image_features
 
131
 
132
  def chatbot(image, question):
133
  if image is None:
134
  yield "Please upload an image."
135
  return
136
- image_features = get_image_features(image)
137
  input_text = generate_prompt(question)
138
- for output in generate(input_text, image_features):
139
  yield output
140
 
141
  with gr.Blocks(title=title) as demo:
 
35
  ##########################################################################
36
  def generate_prompt(instruction):
37
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
38
+ return f"\n{instruction}\n\nAssistant:"
39
 
40
  def generate(
41
  ctx,
42
+ image_state,
43
  token_count=128,
44
  temperature=0.2,
45
  top_p=0.3,
 
58
  occurrence = {}
59
  for i in range(int(token_count)):
60
  if i == 0:
61
+ input_ids = pipeline.encode(ctx)[-ctx_limit:]
62
+ out, state = model.forward(tokens=input_ids, state=image_state)
 
 
63
  else:
64
  input_ids = [token]
65
  out, state = model.forward(tokens=input_ids, state=state)
 
111
  return base64_image
112
 
113
  image_cache = {}
114
+ def compute_image_state(image):
115
  base64_image = pil_image_to_base64(image)
116
  if base64_image in image_cache:
117
+ image_state = image_cache[base64_image]
 
118
  else:
119
  image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
120
  image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
 
123
  (image_features.shape[-1],),
124
  weight=model.w['blocks.0.ln0.weight'],
125
  bias=model.w['blocks.0.ln0.bias'])
126
+ _, image_state = model.forward(embs=image_features, state=None)
127
+ image_cache[base64_image] = image_state
128
+ return image_state
129
 
130
  def chatbot(image, question):
131
  if image is None:
132
  yield "Please upload an image."
133
  return
134
+ image_state = compute_image_state(image)
135
  input_text = generate_prompt(question)
136
+ for output in generate(input_text, image_state):
137
  yield output
138
 
139
  with gr.Blocks(title=title) as demo: