TobyYang7 commited on
Commit
5b853cd
1 Parent(s): 1da21d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -29
app.py CHANGED
@@ -60,43 +60,50 @@ def bot_streaming(message, history):
60
 
61
  # Generate the prompt for the model
62
  prompt = message['text']
 
63
 
64
- # Use a streamer to generate the output in a streaming fashion
65
- streamer = []
66
-
67
- # Define a function to call chat_llava in a separate thread
68
- def generate_output():
69
- output = chat_llava(
70
- args=args,
71
- image_file=image_path,
72
- text=prompt,
73
- tokenizer=tokenizer,
74
- model=llava_model,
75
- image_processor=image_processor,
76
- context_len=context_len
77
- )
78
- for new_text in output:
79
- streamer.append(new_text)
80
 
81
  # Start the generation in a separate thread
82
- thread = Thread(target=generate_output)
83
  thread.start()
84
 
85
- # Stream the output
86
  buffer = ""
87
- while thread.is_alive() or streamer:
88
- while streamer:
89
- new_text = streamer.pop(0)
90
- buffer += new_text
91
- yield buffer
92
- time.sleep(0.1)
93
 
94
- # Ensure any remaining text is yielded after the thread completes
95
- while streamer:
96
- new_text = streamer.pop(0)
 
 
 
 
 
 
 
97
  buffer += new_text
98
- yield buffer
99
-
 
 
 
 
 
 
 
100
 
101
  chatbot = gr.Chatbot(scale=1)
102
  chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
 
60
 
61
  # Generate the prompt for the model
62
  prompt = message['text']
63
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
64
 
65
+ # Set up the generation arguments, including the streamer
66
+ generation_kwargs = dict(
67
+ args=args,
68
+ image_file=image_path,
69
+ text=prompt,
70
+ tokenizer=tokenizer,
71
+ model=llava_model,
72
+ streamer=streamer
73
+ image_processor=image_processor, # todo: input model name or path
74
+ context_len=context_len)
75
+
76
+ # Define the function to call `chat_llava` with the given arguments
77
+ def generate_output(generation_kwargs):
78
+ chat_llava(**generation_kwargs)
 
 
79
 
80
  # Start the generation in a separate thread
81
+ thread = Thread(target=generate_output, kwargs=generation_kwargs)
82
  thread.start()
83
 
84
+ # Initialize a buffer to accumulate the generated text
85
  buffer = ""
 
 
 
 
 
 
86
 
87
+ # Allow the generation to start
88
+ time.sleep(0.5)
89
+
90
+ # Iterate over the streamer to handle the incoming text in chunks
91
+ for new_text in streamer:
92
+ # Look for the end of text token and remove it
93
+ if "<|eot_id|>" in new_text:
94
+ new_text = new_text.split("<|eot_id|>")[0]
95
+
96
+ # Add the new text to the buffer
97
  buffer += new_text
98
+
99
+ # Remove the prompt from the generated text (if necessary)
100
+ generated_text_without_prompt = buffer[len(prompt):]
101
+
102
+ # Simulate processing time (optional)
103
+ time.sleep(0.06)
104
+
105
+ # Yield the current generated text for further processing or display
106
+ yield generated_text_without_prompt
107
 
108
  chatbot = gr.Chatbot(scale=1)
109
  chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)