Spaces:
Sleeping
Sleeping
howard-hou
commited on
Commit
•
b0d85ba
1
Parent(s):
21aea4b
Update app.py
Browse files
app.py
CHANGED
@@ -35,11 +35,11 @@ image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
|
|
35 |
##########################################################################
|
36 |
def generate_prompt(instruction):
|
37 |
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
|
38 |
-
return f"{instruction}\n\nAssistant:"
|
39 |
|
40 |
def generate(
|
41 |
ctx,
|
42 |
-
|
43 |
token_count=128,
|
44 |
temperature=0.2,
|
45 |
top_p=0.3,
|
@@ -58,10 +58,8 @@ def generate(
|
|
58 |
occurrence = {}
|
59 |
for i in range(int(token_count)):
|
60 |
if i == 0:
|
61 |
-
input_ids = pipeline.encode(ctx)
|
62 |
-
|
63 |
-
input_embs = torch.cat((image_features, text_embs), dim=0)[-ctx_limit:]
|
64 |
-
out, state = model.forward(embs=input_embs, state=None)
|
65 |
else:
|
66 |
input_ids = [token]
|
67 |
out, state = model.forward(tokens=input_ids, state=state)
|
@@ -113,11 +111,10 @@ def pil_image_to_base64(pil_image):
|
|
113 |
return base64_image
|
114 |
|
115 |
image_cache = {}
|
116 |
-
def
|
117 |
base64_image = pil_image_to_base64(image)
|
118 |
if base64_image in image_cache:
|
119 |
-
|
120 |
-
print(f"use cache {base64_image[:10]}")
|
121 |
else:
|
122 |
image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
|
123 |
image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
|
@@ -126,16 +123,17 @@ def get_image_features(image):
|
|
126 |
(image_features.shape[-1],),
|
127 |
weight=model.w['blocks.0.ln0.weight'],
|
128 |
bias=model.w['blocks.0.ln0.bias'])
|
129 |
-
|
130 |
-
|
|
|
131 |
|
132 |
def chatbot(image, question):
|
133 |
if image is None:
|
134 |
yield "Please upload an image."
|
135 |
return
|
136 |
-
|
137 |
input_text = generate_prompt(question)
|
138 |
-
for output in generate(input_text,
|
139 |
yield output
|
140 |
|
141 |
with gr.Blocks(title=title) as demo:
|
|
|
35 |
##########################################################################
|
36 |
def generate_prompt(instruction):
|
37 |
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
|
38 |
+
return f"\n{instruction}\n\nAssistant:"
|
39 |
|
40 |
def generate(
|
41 |
ctx,
|
42 |
+
image_state,
|
43 |
token_count=128,
|
44 |
temperature=0.2,
|
45 |
top_p=0.3,
|
|
|
58 |
occurrence = {}
|
59 |
for i in range(int(token_count)):
|
60 |
if i == 0:
|
61 |
+
input_ids = pipeline.encode(ctx)[-ctx_limit:]
|
62 |
+
out, state = model.forward(tokens=input_ids, state=image_state)
|
|
|
|
|
63 |
else:
|
64 |
input_ids = [token]
|
65 |
out, state = model.forward(tokens=input_ids, state=state)
|
|
|
111 |
return base64_image
|
112 |
|
113 |
image_cache = {}
|
114 |
+
def compute_image_state(image):
|
115 |
base64_image = pil_image_to_base64(image)
|
116 |
if base64_image in image_cache:
|
117 |
+
image_state = image_cache[base64_image]
|
|
|
118 |
else:
|
119 |
image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
|
120 |
image_features = visual_encoder.encode_images(image.unsqueeze(0)).squeeze(0) # [L, D]
|
|
|
123 |
(image_features.shape[-1],),
|
124 |
weight=model.w['blocks.0.ln0.weight'],
|
125 |
bias=model.w['blocks.0.ln0.bias'])
|
126 |
+
_, image_state = model.forward(embs=image_features, state=None)
|
127 |
+
image_cache[base64_image] = image_state
|
128 |
+
return image_state
|
129 |
|
130 |
def chatbot(image, question):
|
131 |
if image is None:
|
132 |
yield "Please upload an image."
|
133 |
return
|
134 |
+
image_state = compute_image_state(image)
|
135 |
input_text = generate_prompt(question)
|
136 |
+
for output in generate(input_text, image_state):
|
137 |
yield output
|
138 |
|
139 |
with gr.Blocks(title=title) as demo:
|