LLM-ADE-dev / src /app.py
WilliamGazeley
Migrate to loguru
691fc98
raw
history blame
2.29 kB
import os
from time import time
import huggingface_hub
import streamlit as st
from config import config
from functioncall import ModelInference
@st.cache_resource(show_spinner="Loading model..")
def init_llm():
huggingface_hub.login(token=config.hf_token, new_session=False)
llm = ModelInference(chat_template=config.chat_template)
return llm
def get_response(prompt):
try:
return llm.generate_function_call(
prompt, config.chat_template, config.num_fewshot, config.max_depth
)
except Exception as e:
return f"An error occurred: {str(e)}"
def get_output(context, user_input):
try:
config.status.update(label=":bulb: Preparing answer..")
script_dir = os.path.dirname(os.path.abspath(__file__))
prompt_path = os.path.join(script_dir, 'prompt_assets', 'output_sys_prompt.yml')
prompt_schema = llm.prompter.read_yaml_file(prompt_path)
sys_prompt = (
llm.prompter.format_yaml_prompt(prompt_schema, dict())
+ f"Information:\n{context}"
)
convo = [
{"role": "system", "content": sys_prompt},
{"role": "user", "content": user_input},
]
response = llm.run_inference(convo)
return response
except Exception as e:
return f"An error occurred: {str(e)}"
def main():
st.title("LLM-ADE 9B Demo")
input_text = st.text_area("Enter your text here:", value="", height=200)
if st.button("Generate"):
if input_text:
with st.status("Generating response...") as status:
config.status = status
agent_resp = get_response(input_text)
st.write(get_output(agent_resp, input_text))
config.status.update(label="Finished!", state="complete", expanded=True)
else:
st.warning("Please enter some text to generate a response.")
llm = init_llm()
def main_headless(prompt: str):
start = time()
agent_resp = get_response(prompt)
print("\033[94m" + get_output(agent_resp, prompt) + "\033[0m")
print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20)
if __name__ == "__main__":
if config.headless:
import fire
fire.Fire(main_headless)
else:
main()