howard-hou commited on
Commit
1d251b2
1 Parent(s): 45cab51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -34,7 +34,6 @@ image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
34
  ##########################################################################
35
  def generate_prompt(instruction):
36
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
37
- input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
38
  return f"\n{instruction}\n\nAssistant:"
39
 
40
  def generate(
@@ -104,7 +103,7 @@ examples = [
104
  def chatbot(image, question):
105
  image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
106
  image_features = visual_encoder.encode_images(image.unsqueeze(0))
107
- emb_mixer.set_image_embeddings(image_features)
108
  model.w["emb.weight"] = emb_mixer.get_input_embeddings()
109
  image_ids = [i for i in range(emb_mixer.image_start_index, emb_mixer.image_start_index + len(image_features))]
110
  input_text = generate_prompt(question)
@@ -116,13 +115,13 @@ with gr.Blocks(title=title) as demo:
116
  with gr.Column():
117
  image = gr.Image(type='pil', label="Image")
118
  with gr.Column():
119
- prompt = gr.Textbox(lines=3, label="Prompt",
120
- value="Assistant: Please upload an image and ask a question.")
121
  with gr.Row():
122
  submit = gr.Button("Submit", variant="primary")
123
  clear = gr.Button("Clear", variant="secondary")
124
  with gr.Column():
125
- output = gr.Textbox(label="Output", lines=5)
126
  data = gr.Dataset(components=[image, prompt], samples=examples, label="Examples", headers=["Image", "Prompt"])
127
  submit.click(chatbot, [image, prompt], [output])
128
  clear.click(lambda: None, [], [output])
 
34
  ##########################################################################
35
  def generate_prompt(instruction):
36
  instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
 
37
  return f"\n{instruction}\n\nAssistant:"
38
 
39
  def generate(
 
103
  def chatbot(image, question):
104
  image = image_processor(images=image.convert('RGB'), return_tensors='pt')['pixel_values']
105
  image_features = visual_encoder.encode_images(image.unsqueeze(0))
106
+ emb_mixer.set_image_embeddings(image_features.squeeze(0))
107
  model.w["emb.weight"] = emb_mixer.get_input_embeddings()
108
  image_ids = [i for i in range(emb_mixer.image_start_index, emb_mixer.image_start_index + len(image_features))]
109
  input_text = generate_prompt(question)
 
115
  with gr.Column():
116
  image = gr.Image(type='pil', label="Image")
117
  with gr.Column():
118
+ prompt = gr.Textbox(lines=5, label="Prompt",
119
+ value="Please upload an image and ask a question.")
120
  with gr.Row():
121
  submit = gr.Button("Submit", variant="primary")
122
  clear = gr.Button("Clear", variant="secondary")
123
  with gr.Column():
124
+ output = gr.Textbox(label="Output", lines=7)
125
  data = gr.Dataset(components=[image, prompt], samples=examples, label="Examples", headers=["Image", "Prompt"])
126
  submit.click(chatbot, [image, prompt], [output])
127
  clear.click(lambda: None, [], [output])