tomasruiz's picture
Install flash_attn
5e7cd8f
raw
history blame
2.94 kB
from PIL import Image
import streamlit as st
try:
from llmlib.runtime import filled_model_registry
except ImportError:
import os
import subprocess
os.system("pip install -e ./llmlib")
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
from llmlib.runtime import filled_model_registry
from llmlib.model_registry import ModelEntry, ModelRegistry
from llmlib.base_llm import Message
from llmlib.bundler import Bundler
from llmlib.bundler_request import BundlerRequest
from login_mask_simple import check_password
if not check_password():
st.stop()
st.set_page_config(page_title="LLM App", layout="wide")
st.title("LLM App")
model_registry: ModelRegistry = filled_model_registry()
@st.cache_resource()
def create_model_bundler() -> Bundler:
return Bundler(registry=model_registry)
def display_warnings(r: ModelRegistry, model_id: str) -> None:
e1: ModelEntry = r.get_entry(model_id)
if len(e1.warnings) > 0:
st.warning(" \n".join(e1.warnings))
cs = st.columns(2)
with cs[0]:
model1_id: str = st.selectbox("Select model", model_registry.all_model_ids())
display_warnings(model_registry, model1_id)
with cs[1]:
if "img-key" not in st.session_state:
st.session_state["img-key"] = 0
image = st.file_uploader("Include an image", key=st.session_state["img-key"])
if "messages1" not in st.session_state:
st.session_state.messages1 = [] # list[Message]
st.session_state.messages2 = [] # list[Message]
if st.button("Restart chat"):
st.session_state.messages1 = [] # list[Message]
st.session_state.messages2 = [] # list[Message]
def render_messages(msgs: list[Message]) -> None:
for msg in msgs:
render_message(msg)
def render_message(msg: Message):
with st.chat_message(msg.role):
if msg.img_name is not None:
render_img(msg)
st.markdown(msg.msg)
def render_img(msg: Message):
st.image(msg.img, caption=msg.img_name, width=400)
n_cols = 1
cs = st.columns(n_cols)
render_messages(st.session_state.messages1)
prompt = st.chat_input("Type here")
if prompt is None:
st.stop()
msg = Message(
role="user",
msg=prompt,
img_name=image.name if image is not None else None,
img=Image.open(image) if image is not None else None,
)
if image is not None:
st.session_state["img-key"] += 1
st.session_state.messages1.append(msg)
render_message(msg)
model_bundler: Bundler = create_model_bundler()
with st.spinner("Initializing model..."):
model_bundler.set_model_on_gpu(model_id=model1_id)
with st.spinner("Generating response..."):
req = BundlerRequest(model_id=model1_id, msgs=st.session_state.messages1)
response = model_bundler.get_response(req)
msg = Message(role="assistant", msg=response)
st.session_state.messages1.append(msg)
render_message(msg)