qnguyen3 commited on
Commit
3533245
1 Parent(s): c8dd2ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  from threading import Thread
4
  import re
5
  import time
@@ -50,15 +50,14 @@ def bot_streaming(message, history):
50
 
51
  # if image is None:
52
  # gr.Error("You need to upload an image for LLaVA to work.")
53
- prompt=f"[INST] <image>\n{message['text']} [/INST]"
54
  image = Image.open(image).convert("RGB")
55
  text = tokenizer.apply_chat_template(
56
- messages,
57
- tokenize=False,
58
- add_generation_prompt=True)
59
  text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
60
  input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
61
- streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True})
62
  image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
63
  generation_kwargs = dict(inputs=input_ids, images=image_tensor, streamer=streamer, max_new_tokens=100)
64
  generated_text = ""
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
3
  from threading import Thread
4
  import re
5
  import time
 
50
 
51
  # if image is None:
52
  # gr.Error("You need to upload an image for LLaVA to work.")
 
53
  image = Image.open(image).convert("RGB")
54
  text = tokenizer.apply_chat_template(
55
+ messages,
56
+ tokenize=False,
57
+ add_generation_prompt=True)
58
  text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
59
  input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
60
+ streamer = TextStreamer(tokenizer, **{"skip_special_tokens": True})
61
  image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
62
  generation_kwargs = dict(inputs=input_ids, images=image_tensor, streamer=streamer, max_new_tokens=100)
63
  generated_text = ""