|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
TOKEN_LIMIT = 2048 |
|
TEMPERATURE = 0.7 |
|
REPETITION_PENALTY = 1.05 |
|
MAX_NEW_TOKENS = 500 |
|
MODEL_NAME = "ericzzz/falcon-rw-1b-chat" |
|
|
|
|
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history = [] |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
@st.cache_resource() |
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, device_map="auto", torch_dtype=torch.bfloat16 |
|
) |
|
return tokenizer, model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chat_func_stream(tokenizer, model, chat_history, streamer): |
|
input_ids = tokenizer.apply_chat_template( |
|
chat_history, tokenize=True, add_generation_prompt=True, return_tensors="pt" |
|
).to(model.device) |
|
|
|
if len(input_ids[0]) > TOKEN_LIMIT: |
|
st.warning( |
|
f"We have limited computation power. Please keep you input within {TOKEN_LIMIT} tokens." |
|
) |
|
st.session_state.chat_history = st.session_state.chat_history[:-1] |
|
return |
|
model.generate( |
|
input_ids, |
|
do_sample=True, |
|
temperature=TEMPERATURE, |
|
repetition_penalty=REPETITION_PENALTY, |
|
max_new_tokens=MAX_NEW_TOKENS, |
|
streamer=streamer, |
|
) |
|
return |
|
|
|
|
|
def show_chat_message(contrainer, chat_message): |
|
with contrainer: |
|
with st.chat_message(chat_message["role"]): |
|
st.write(chat_message["content"]) |
|
|
|
|
|
class ResponseStreamer: |
|
def __init__(self, tokenizer, container, chat_history): |
|
self.tokenizer = tokenizer |
|
self.container = container |
|
self.chat_history = chat_history |
|
|
|
self.first_call_to_put = True |
|
self.current_response = "" |
|
with self.container: |
|
self.placeholder = st.empty() |
|
|
|
def put(self, new_token): |
|
|
|
if self.first_call_to_put: |
|
self.first_call_to_put = False |
|
return |
|
|
|
decoded = self.tokenizer.decode(new_token[0], skip_special_tokens=True) |
|
self.current_response += decoded |
|
|
|
show_chat_message( |
|
self.placeholder.container(), |
|
{"role": "assistant", "content": self.current_response}, |
|
) |
|
|
|
def end(self): |
|
|
|
self.chat_history.append( |
|
{"role": "assistant", "content": self.current_response} |
|
) |
|
|
|
self.first_call_to_put = True |
|
self.current_response = "" |
|
|
|
|
|
tokenizer, model = load_model() |
|
chat_messages_container = st.container() |
|
|
|
for msg in st.session_state.chat_history: |
|
show_chat_message(chat_messages_container, msg) |
|
|
|
user_input = st.chat_input() |
|
if user_input: |
|
new_user_message = {"role": "user", "content": user_input} |
|
st.session_state.chat_history.append(new_user_message) |
|
show_chat_message(chat_messages_container, new_user_message) |
|
|
|
|
|
|
|
|
|
|
|
|
|
streamer = ResponseStreamer( |
|
tokenizer=tokenizer, |
|
container=chat_messages_container, |
|
chat_history=st.session_state.chat_history, |
|
) |
|
chat_func_stream(tokenizer, model, st.session_state.chat_history, streamer) |
|
|