File size: 2,938 Bytes
41d24d2
 
bd53d8b
 
 
 
 
5e7cd8f
bd53d8b
 
5e7cd8f
 
 
 
 
bd53d8b
 
41d24d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from PIL import Image
import streamlit as st

try:
    from llmlib.runtime import filled_model_registry
except ImportError:
    import os
    import subprocess

    os.system("pip install -e ./llmlib")
    subprocess.run(
        "pip install flash-attn --no-build-isolation",
        env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
        shell=True,
    )
    from llmlib.runtime import filled_model_registry

from llmlib.model_registry import ModelEntry, ModelRegistry
from llmlib.base_llm import Message
from llmlib.bundler import Bundler
from llmlib.bundler_request import BundlerRequest
from login_mask_simple import check_password

if not check_password():
    st.stop()

st.set_page_config(page_title="LLM App", layout="wide")

st.title("LLM App")


model_registry: ModelRegistry = filled_model_registry()


@st.cache_resource()
def create_model_bundler() -> Bundler:
    return Bundler(registry=model_registry)


def display_warnings(r: ModelRegistry, model_id: str) -> None:
    e1: ModelEntry = r.get_entry(model_id)
    if len(e1.warnings) > 0:
        st.warning("  \n".join(e1.warnings))


cs = st.columns(2)
with cs[0]:
    model1_id: str = st.selectbox("Select model", model_registry.all_model_ids())
    display_warnings(model_registry, model1_id)
with cs[1]:
    if "img-key" not in st.session_state:
        st.session_state["img-key"] = 0
    image = st.file_uploader("Include an image", key=st.session_state["img-key"])

if "messages1" not in st.session_state:
    st.session_state.messages1 = []  # list[Message]
    st.session_state.messages2 = []  # list[Message]

if st.button("Restart chat"):
    st.session_state.messages1 = []  # list[Message]
    st.session_state.messages2 = []  # list[Message]


def render_messages(msgs: list[Message]) -> None:
    for msg in msgs:
        render_message(msg)


def render_message(msg: Message):
    with st.chat_message(msg.role):
        if msg.img_name is not None:
            render_img(msg)
        st.markdown(msg.msg)


def render_img(msg: Message):
    st.image(msg.img, caption=msg.img_name, width=400)


n_cols = 1
cs = st.columns(n_cols)
render_messages(st.session_state.messages1)

prompt = st.chat_input("Type here")
if prompt is None:
    st.stop()

msg = Message(
    role="user",
    msg=prompt,
    img_name=image.name if image is not None else None,
    img=Image.open(image) if image is not None else None,
)

if image is not None:
    st.session_state["img-key"] += 1

st.session_state.messages1.append(msg)
render_message(msg)

model_bundler: Bundler = create_model_bundler()

with st.spinner("Initializing model..."):
    model_bundler.set_model_on_gpu(model_id=model1_id)

with st.spinner("Generating response..."):
    req = BundlerRequest(model_id=model1_id, msgs=st.session_state.messages1)
    response = model_bundler.get_response(req)
msg = Message(role="assistant", msg=response)
st.session_state.messages1.append(msg)
render_message(msg)