Spaces:
Paused
Paused
Commit
·
9a81d74
1
Parent(s):
1abd311
adding text streaming
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
|
5 |
|
6 |
token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
7 |
|
@@ -51,15 +51,17 @@ def get_prompt_with_template(message: str) -> str:
|
|
51 |
def generate_model_response(message: str) -> str:
|
52 |
prompt = get_prompt_with_template(message)
|
53 |
inputs = tokenizer(prompt, return_tensors='pt')
|
|
|
54 |
if torch.cuda.is_available():
|
55 |
inputs = inputs.to('cuda')
|
56 |
# Include **generate_kwargs to include the user-defined options
|
57 |
output = model.generate(**inputs,
|
58 |
max_new_tokens=4096,
|
59 |
do_sample=True,
|
60 |
-
temperature=0.1
|
|
|
61 |
)
|
62 |
-
return tokenizer.decode(output[0], skip_special_tokens=True)
|
63 |
|
64 |
def extract_response_content(full_response: str) -> str:
|
65 |
response_start_index = full_response.find("### Assistant:")
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline, TextStreamer
|
5 |
|
6 |
token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
|
7 |
|
|
|
51 |
def generate_model_response(message: str) -> str:
|
52 |
prompt = get_prompt_with_template(message)
|
53 |
inputs = tokenizer(prompt, return_tensors='pt')
|
54 |
+
streamer = TextStreamer(tokenizer)
|
55 |
if torch.cuda.is_available():
|
56 |
inputs = inputs.to('cuda')
|
57 |
# Include **generate_kwargs to include the user-defined options
|
58 |
output = model.generate(**inputs,
|
59 |
max_new_tokens=4096,
|
60 |
do_sample=True,
|
61 |
+
temperature=0.1,
|
62 |
+
streamer=streamer
|
63 |
)
|
64 |
+
# return tokenizer.decode(output[0], skip_special_tokens=True)
|
65 |
|
66 |
def extract_response_content(full_response: str) -> str:
|
67 |
response_start_index = full_response.find("### Assistant:")
|