LLM-ADE-dev / app.py
WilliamGazeley
Merge branch 'simple-rag'
ce65c0f
raw
history blame
2.07 kB
import os
import huggingface_hub
import streamlit as st
from config import config
from utils import get_assistant_message
from functioncall import ModelInference
from prompter import PromptManager
@st.cache_resource(show_spinner="Loading model..")
def init_llm():
huggingface_hub.login(token=config.hf_token, new_session=False)
llm = ModelInference(chat_template=config.chat_template)
return llm
def get_response(prompt):
try:
return llm.generate_function_call(
prompt,
config.chat_template,
config.num_fewshot,
config.max_depth
)
except Exception as e:
return f"An error occurred: {str(e)}"
def get_output(context, user_input):
try:
prompt_schema = llm.prompter.read_yaml_file("prompt_assets/output_sys_prompt.yml")
sys_prompt = llm.prompter.format_yaml_prompt(prompt_schema, dict()) + \
f"Information:\n{context}"
convo = [
{"role": "system", "content": sys_prompt},
{"role": "user", "content": user_input},
]
response = llm.run_inference(convo)
return get_assistant_message(response, config.chat_template, llm.tokenizer.eos_token)
except Exception as e:
return f"An error occurred: {str(e)}"
def main():
st.title("LLM-ADE 9B Demo")
input_text = st.text_area("Enter your text here:", value="", height=200)
if st.button("Generate"):
if input_text:
with st.spinner('Generating response...'):
agent_resp = get_response(input_text)
st.write(get_output(agent_resp, input_text))
else:
st.warning("Please enter some text to generate a response.")
llm = init_llm()
def main_headless():
while True:
input_text = input("Enter your text here: ")
agent_resp = get_response(input_text)
print('\033[94m' + get_output(agent_resp, input_text) + '\033[0m')
if __name__ == "__main__":
if config.headless:
main_headless()
else:
main()