chendl commited on
Commit
020e358
·
1 Parent(s): d263fbc

update chat

Browse files
app.py CHANGED
@@ -237,30 +237,36 @@ def upload_img(gr_img, text_input, chat_state,chatbot):
237
  value="Start Chatting", interactive=False), chat_state, img_list,chatbot
238
 
239
 
240
- def gradio_ask(user_message, chatbot, chat_state):
241
  if len(user_message) == 0:
242
  return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
243
 
244
 
245
- chat.ask(user_message, chat_state)
246
  chatbot = chatbot + [[user_message, None]]
247
  return '', chatbot, chat_state
248
 
249
 
250
- def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
251
  llm_message,image = \
252
  chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
253
- max_length=2000)
254
 
255
  chatbot[-1][1] = llm_message
256
  if image==None:
257
  return chatbot, chat_state, img_list
258
  else:
259
  path = build_image(image)
260
- chatbot = chatbot + [[(path,), None]]
261
  return chatbot, chat_state, img_list
262
 
263
-
 
 
 
 
 
 
264
 
265
  with gr.Blocks() as demo:
266
  gr.Markdown(title)
@@ -273,6 +279,9 @@ with gr.Blocks() as demo:
273
  image = gr.Image(type="pil")
274
  upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
275
  clear = gr.Button("Restart")
 
 
 
276
 
277
  num_beams = gr.Slider(
278
  minimum=1,
@@ -296,13 +305,20 @@ with gr.Blocks() as demo:
296
  chat_state = gr.State()
297
  img_list = gr.State()
298
  chatbot = gr.Chatbot(label='Compositional-VLM')
299
- text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
 
 
 
 
 
 
 
300
 
301
  upload_button.click(upload_img, [image, text_input, chat_state,chatbot],
302
  [image, text_input, upload_button, chat_state, img_list,chatbot])
303
 
304
- text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
305
- gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
306
  )
307
  clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
308
  queue=False)
 
237
  value="Start Chatting", interactive=False), chat_state, img_list,chatbot
238
 
239
 
240
+ def gradio_ask(user_message, chatbot, chat_state,radio):
241
  if len(user_message) == 0:
242
  return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
243
 
244
 
245
+ chat.ask(user_message, chat_state,radio)
246
  chatbot = chatbot + [[user_message, None]]
247
  return '', chatbot, chat_state
248
 
249
 
250
+ def gradio_answer(chatbot, chat_state, img_list, radio, text,num_beams, temperature,radio):
251
  llm_message,image = \
252
  chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
253
+ max_length=2000,radio = radio,text_input = text)
254
 
255
  chatbot[-1][1] = llm_message
256
  if image==None:
257
  return chatbot, chat_state, img_list
258
  else:
259
  path = build_image(image)
260
+ chatbot = chatbot + [[None,(path,)]]
261
  return chatbot, chat_state, img_list
262
 
263
+ task_template = {
264
+ "Cap": "Summarize the content of the photo <image>.",
265
+ "VQA": "For this image <image>, I want a simple and direct answer to my question: <question>",
266
+ "REC": "Can you point out <expr> in the image <image> and provide the coordinates of its location?",
267
+ "GC": "Can you give me a description of the region <boxes> in image <image>?",
268
+ "Advanced": "<question>",
269
+ }
270
 
271
  with gr.Blocks() as demo:
272
  gr.Markdown(title)
 
279
  image = gr.Image(type="pil")
280
  upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
281
  clear = gr.Button("Restart")
282
+ radio = gr.Radio(
283
+ ["Cap", "VQA", "REC", "Advanced"], label="Task Template", value='Cap',
284
+ )
285
 
286
  num_beams = gr.Slider(
287
  minimum=1,
 
305
  chat_state = gr.State()
306
  img_list = gr.State()
307
  chatbot = gr.Chatbot(label='Compositional-VLM')
308
+
309
+
310
+ # template = gr.Textbox(label='Template', show_label=True, lines=1, interactive=False,
311
+ # value='Provide a comprehensive description of the image <image> and specify the positions of any mentioned objects in square brackets.')
312
+ # text_input = gr.Textbox(label='<question>', show_label=True, placeholder="Please upload your image first, then input...", lines=3,
313
+ # value=None, visible=False, interactive=False)
314
+
315
+ text_input = gr.Textbox(label='User', placeholder='Please upload your image first, then input...', interactive=False)
316
 
317
  upload_button.click(upload_img, [image, text_input, chat_state,chatbot],
318
  [image, text_input, upload_button, chat_state, img_list,chatbot])
319
 
320
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state,radio], [text_input, chatbot, chat_state]).then(
321
+ gradio_answer, [chatbot, chat_state, img_list, radio, text_input,num_beams, temperature, radio], [chatbot, chat_state, img_list]
322
  )
323
  clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
324
  queue=False)
multimodal/open_flamingo/chat/conversation.py CHANGED
@@ -278,18 +278,34 @@ class Chat:
278
  # torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
279
  # self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
280
 
281
- def ask(self, text, conv):
282
- conv.append(({
283
- "from": "human",
284
- "value": text,
285
- }))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  # if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
287
  # and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
288
  # conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
289
  # else:
290
  # conv.append_message(conv.roles[0], text)
291
 
292
- def answer(self, conv, img_list, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
293
  repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
294
  # conv.append_message(conv.roles[1], None)
295
  # embs = self.get_context_emb(conv, img_list)
@@ -315,7 +331,14 @@ class Chat:
315
  # output_text = output_text.split('###')[0] # remove the stop sign '###'
316
  # output_text = output_text.split('Assistant:')[-1].strip()
317
  # conv.messages[-1][1] = output_text
318
-
 
 
 
 
 
 
 
319
  media_token_id = self.tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
320
  box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
321
  endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
@@ -336,10 +359,23 @@ class Chat:
336
 
337
  # conversation = []
338
  human_sentence = None
339
- conv.append({
 
 
 
 
 
 
 
340
  "from": "gpt",
341
- "value": "",
342
- })
 
 
 
 
 
 
343
  # while True:
344
  # human_sentence = input("### Human: ")
345
  # if human_sentence == "#end#":
 
278
  # torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
279
  # self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
280
 
281
+ def ask(self, text, conv,radio):
282
+ if radio in ["Cap"]:
283
+ conv.append({
284
+ "from": "human",
285
+ "value": "",
286
+ })
287
+ elif radio in ["VQA"]:
288
+ conv.append({
289
+ "from": "human",
290
+ "value": f"Answer the question using a single word or phrase.{text}",
291
+ })
292
+ elif radio in ["REC"]:
293
+ conv.append({
294
+ "from": "human",
295
+ "value": f"Please provide the bounding box coordinate of the region this sentence describes: {text}.",
296
+ })
297
+ else:
298
+ conv.append({
299
+ "from": "human",
300
+ "value": text,
301
+ })
302
  # if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
303
  # and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
304
  # conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
305
  # else:
306
  # conv.append_message(conv.roles[0], text)
307
 
308
+ def answer(self, conv, img_list, radio, text_input, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
309
  repetition_penalty=1.0, length_penalty=1, temperature=1, max_length=2000):
310
  # conv.append_message(conv.roles[1], None)
311
  # embs = self.get_context_emb(conv, img_list)
 
331
  # output_text = output_text.split('###')[0] # remove the stop sign '###'
332
  # output_text = output_text.split('Assistant:')[-1].strip()
333
  # conv.messages[-1][1] = output_text
334
+ visual_token = "<|#visual#|>"
335
+ previsual_token = "<|#previsual#|>"
336
+ box_token = "<|#box#|>"
337
+ prebox_token = "<|#prebox#|>"
338
+ end_token = "<|#endofobject#|>"
339
+ object_token = "<|#object#|>"
340
+ end_of_attr_token = "<|#endofattr#|>"
341
+ preend_of_attr_token = "<|#preendofattr#|>"
342
  media_token_id = self.tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
343
  box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
344
  endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
 
359
 
360
  # conversation = []
361
  human_sentence = None
362
+ if radio in ["Cap","VQA"]:
363
+ conv.append({
364
+ "from": "gpt",
365
+ "value": "",
366
+ })
367
+ elif radio in ["REC"]:
368
+ conv.append(
369
+ {
370
  "from": "gpt",
371
+ "value": object_token + text_input + end_token + visual_token
372
+ }
373
+ )
374
+ else:
375
+ conv.append({
376
+ "from": "gpt",
377
+ "value": "",
378
+ })
379
  # while True:
380
  # human_sentence = input("### Human: ")
381
  # if human_sentence == "#end#":