File size: 3,839 Bytes
ebab1a2 c0dfbd7 ebab1a2 ff6e79e a602459 ebab1a2 a602459 ebab1a2 a602459 c5b9f7d ebab1a2 c5b9f7d ebab1a2 a602459 464c719 a602459 ebab1a2 a602459 41cbe66 ebab1a2 e061b37 ebab1a2 a602459 c34db09 a602459 c34db09 c7aebdb a602459 ebab1a2 a602459 ebab1a2 a602459 ebab1a2 a602459 ebab1a2 a602459 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import streamlit as st
import os
# os.environ['HF_HOME'] = '/scratch/sroydip1/cache/hf/'
os.environ["HUGGINGFACEHUB_API_TOKEN"] = st.secrets["HF_TOKEN"]
# import torch
import pickle
import torch
from transformers import Conversation, pipeline, AutoTokenizer, AutoModelForCausalLM
from upload import get_file, upload_file
from utils import clear_uploader, undo, restart
TOKEN = st.secrets["HF_TOKEN"]
share_keys = ["messages", "model_name"]
MODELS = [
"meta-llama/Llama-2-7b-chat-hf",
"mistralai/Mistral-7B-Instruct-v0.2",
# "google/flan-t5-small",
# "google/flan-t5-base",
# "google/flan-t5-large",
# "google/flan-t5-xl",
# "google/flan-t5-xxl",
]
default_model = MODELS[0]
# default_model = "meta-llama/Llama-2-7b-chat-hf"
st.set_page_config(
page_title="LLM",
page_icon="π",
)
if "model_name" not in st.session_state:
st.session_state.model_name = default_model
@st.cache_resource
def get_pipeline(model_name):
model_name = "gpt2-medium"
device = 0 if torch.cuda.is_available() else -1
# if True or model_name == "meta-llama/Llama-2-7b-chat-hf" or model_name == "mistralai/Mistral-7B-Instruct-v0.2":
# chatbot = pipeline(model=model_name, task="conversational", device=device)#, model_kwargs=model_kwargs)
# else:
# chatbot = pipeline(model=model_name, task="text-generation", device=device)
tokenizer = AutoTokenizer.from_pretrained(model_name, token=TOKEN)
model = AutoModelForCausalLM.from_pretrained(model_name, token=TOKEN)
# chatbot = pipeline("conversational", model=model, tokenizer=tokenizer, device=device)
chatbot = pipeline("conversational", model=model, tokenizer=tokenizer)
return chatbot
chatbot = get_pipeline(st.session_state.model_name)
if "messages" not in st.session_state:
st.session_state.messages = []
if len(st.session_state.messages) == 0 and "id" in st.query_params:
with st.spinner("Loading chat..."):
id = st.query_params["id"]
data = get_file(id)
obj = pickle.loads(data)
for k, v in obj.items():
st.session_state[k] = v
def share():
obj = {}
for k in share_keys:
if k in st.session_state:
obj[k] = st.session_state[k]
data = pickle.dumps(obj)
id = upload_file(data)
url = f"https://umbc-nlp-chat-llm.hf.space/?id={id}"
st.markdown(f"[share](/?id={id})")
st.success(f"Share URL: {url}")
with st.sidebar:
st.title(":blue[LLM Only]")
st.subheader("Model")
model_name = st.selectbox("Model", MODELS, key="model_name")
if st.button("Share", use_container_width=True):
share()
cols = st.columns(2)
with cols[0]:
if st.button("Restart", type="primary", use_container_width=True):
restart()
with cols[1]:
if st.button("Undo", use_container_width=True):
undo()
append = st.checkbox("Append to previous message", value=False)
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
def push_message(role, content):
message = {"role": role, "content": content}
st.session_state.messages.append(message)
return message
if prompt := st.chat_input("Type a message", key="chat_input"):
push_message("user", prompt)
with st.chat_message("user"):
st.markdown(prompt)
if not append:
with st.chat_message("assistant"):
chat = Conversation()
for m in st.session_state.messages:
chat.add_message(m)
print(chat)
with st.spinner("Generating response..."):
response = chatbot(chat)
response = response[-1]["content"]
st.write(response)
push_message("assistant", response)
clear_uploader() |