File size: 3,839 Bytes
ebab1a2
 
 
c0dfbd7
ebab1a2
ff6e79e
a602459
ebab1a2
 
a602459
 
 
 
ebab1a2
 
a602459
c5b9f7d
ebab1a2
 
 
 
 
 
 
c5b9f7d
ebab1a2
 
a602459
 
 
 
 
 
 
464c719
a602459
 
ebab1a2
a602459
41cbe66
ebab1a2
 
 
 
 
 
 
e061b37
ebab1a2
 
a602459
 
 
 
 
 
 
 
 
 
c34db09
a602459
 
 
 
 
 
 
 
 
 
 
c34db09
c7aebdb
a602459
 
 
 
 
 
 
ebab1a2
a602459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebab1a2
a602459
ebab1a2
 
a602459
ebab1a2
a602459
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

import streamlit as st
import os
# os.environ['HF_HOME'] = '/scratch/sroydip1/cache/hf/'
os.environ["HUGGINGFACEHUB_API_TOKEN"] = st.secrets["HF_TOKEN"]
# import torch
import pickle
import torch
from transformers import Conversation, pipeline, AutoTokenizer, AutoModelForCausalLM
from upload import get_file, upload_file
from utils import clear_uploader, undo, restart


TOKEN = st.secrets["HF_TOKEN"]

share_keys = ["messages", "model_name"]
MODELS = [
    "meta-llama/Llama-2-7b-chat-hf",
    "mistralai/Mistral-7B-Instruct-v0.2",
    # "google/flan-t5-small",
    # "google/flan-t5-base",
    # "google/flan-t5-large",
    # "google/flan-t5-xl",
    # "google/flan-t5-xxl",
]
default_model = MODELS[0]
# default_model = "meta-llama/Llama-2-7b-chat-hf"

st.set_page_config(
    page_title="LLM",
    page_icon="πŸ“š",
)

if "model_name" not in st.session_state:
    st.session_state.model_name = default_model
    

@st.cache_resource
def get_pipeline(model_name):
    model_name = "gpt2-medium"
    device = 0 if torch.cuda.is_available() else -1
    # if True or model_name == "meta-llama/Llama-2-7b-chat-hf" or model_name == "mistralai/Mistral-7B-Instruct-v0.2":
    #     chatbot = pipeline(model=model_name, task="conversational", device=device)#, model_kwargs=model_kwargs)
    # else:
    #     chatbot = pipeline(model=model_name, task="text-generation", device=device)

    tokenizer = AutoTokenizer.from_pretrained(model_name, token=TOKEN)
    model = AutoModelForCausalLM.from_pretrained(model_name, token=TOKEN)
    # chatbot = pipeline("conversational", model=model, tokenizer=tokenizer, device=device)
    chatbot = pipeline("conversational", model=model, tokenizer=tokenizer)
    return chatbot

chatbot = get_pipeline(st.session_state.model_name)

if "messages" not in st.session_state:
    st.session_state.messages = []

if len(st.session_state.messages) == 0 and "id" in st.query_params:
    with st.spinner("Loading chat..."):
        id = st.query_params["id"]
        data = get_file(id)
        obj = pickle.loads(data)
        for k, v in obj.items():
            st.session_state[k] = v


def share():
    obj = {}
    for k in share_keys:
        if k in st.session_state:
            obj[k] = st.session_state[k]
    data = pickle.dumps(obj)
    id = upload_file(data)
    url = f"https://umbc-nlp-chat-llm.hf.space/?id={id}"
    st.markdown(f"[share](/?id={id})")
    st.success(f"Share URL: {url}")

with st.sidebar:
    st.title(":blue[LLM Only]")

    st.subheader("Model")
    model_name = st.selectbox("Model", MODELS, key="model_name")

    if st.button("Share", use_container_width=True):
        share()

    cols = st.columns(2)
    with cols[0]:
        if st.button("Restart", type="primary", use_container_width=True):
            restart()
    
    with cols[1]:
        if st.button("Undo", use_container_width=True):
            undo()

    append = st.checkbox("Append to previous message", value=False)


for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])


def push_message(role, content):
    message = {"role": role, "content": content}
    st.session_state.messages.append(message)
    return message

if prompt := st.chat_input("Type a message", key="chat_input"):
    push_message("user", prompt)
    with st.chat_message("user"):
        st.markdown(prompt)

    if not append:
        with st.chat_message("assistant"):
            chat = Conversation()
            for m in st.session_state.messages:
                chat.add_message(m)
            print(chat)
            with st.spinner("Generating response..."):
                response = chatbot(chat)
                response = response[-1]["content"]
                st.write(response)

        push_message("assistant", response)
    clear_uploader()