File size: 4,242 Bytes
0694801 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
TOKEN_LIMIT = 2048
TEMPERATURE = 0.7
REPETITION_PENALTY = 1.05
MAX_NEW_TOKENS = 500
MODEL_NAME = "ericzzz/falcon-rw-1b-chat"
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
torch.set_grad_enabled(False)
@st.cache_resource()
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, device_map="auto", torch_dtype=torch.bfloat16
)
return tokenizer, model
# def chat_func(tokenizer, model, chat_history):
# input_ids = tokenizer.apply_chat_template(
# chat_history, tokenize=True, add_generation_prompt=True, return_tensors="pt"
# ).to(model.device)
# output_tokens = model.generate(
# input_ids,
# do_sample=True,
# temperature=TEMPERATURE,
# repetition_penalty=REPETITION_PENALTY,
# max_new_tokens=MAX_NEW_TOKENS,
# )
# output_text = tokenizer.decode(
# output_tokens[0][len(input_ids[0]) :], skip_special_tokens=True
# )
# return output_text
def chat_func_stream(tokenizer, model, chat_history, streamer):
input_ids = tokenizer.apply_chat_template(
chat_history, tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to(model.device)
# check input length
if len(input_ids[0]) > TOKEN_LIMIT:
st.warning(
f"We have limited computation power. Please keep you input within {TOKEN_LIMIT} tokens."
)
st.session_state.chat_history = st.session_state.chat_history[:-1]
return
model.generate(
input_ids,
do_sample=True,
temperature=TEMPERATURE,
repetition_penalty=REPETITION_PENALTY,
max_new_tokens=MAX_NEW_TOKENS,
streamer=streamer,
)
return
def show_chat_message(contrainer, chat_message):
with contrainer:
with st.chat_message(chat_message["role"]):
st.write(chat_message["content"])
class ResponseStreamer:
def __init__(self, tokenizer, container, chat_history):
self.tokenizer = tokenizer
self.container = container
self.chat_history = chat_history
self.first_call_to_put = True
self.current_response = ""
with self.container:
self.placeholder = st.empty() # placeholder to save streamed message
def put(self, new_token):
# do not write input tokens
if self.first_call_to_put:
self.first_call_to_put = False
return
# decode current token and accumulate current_response
decoded = self.tokenizer.decode(new_token[0], skip_special_tokens=True)
self.current_response += decoded
# display the stramed message
show_chat_message(
self.placeholder.container(),
{"role": "assistant", "content": self.current_response},
)
def end(self):
# save assistant message
self.chat_history.append(
{"role": "assistant", "content": self.current_response}
)
# clean up states (actually not needed as the instance will get recreated)
self.first_call_to_put = True
self.current_response = ""
tokenizer, model = load_model()
chat_messages_container = st.container()
for msg in st.session_state.chat_history:
show_chat_message(chat_messages_container, msg)
user_input = st.chat_input()
if user_input:
new_user_message = {"role": "user", "content": user_input}
st.session_state.chat_history.append(new_user_message)
show_chat_message(chat_messages_container, new_user_message)
# assistant_message = chat_func(tokenizer, model, st.session_state.chat_history)
# assistant_message = {"role": "assistant", "content": assistant_message}
# st.session_state.chat_history.append(assistant_message)
# show_chat_message(chat_messages_container, assistant_message)
streamer = ResponseStreamer(
tokenizer=tokenizer,
container=chat_messages_container,
chat_history=st.session_state.chat_history,
)
chat_func_stream(tokenizer, model, st.session_state.chat_history, streamer)
|