|
import os |
|
import huggingface_hub |
|
import streamlit as st |
|
from config import config |
|
from utils import get_assistant_message |
|
from functioncall import ModelInference |
|
from prompter import PromptManager |
|
|
|
print("Why, hello there!", flush=True) |
|
|
|
@st.cache_resource(show_spinner="Loading model..") |
|
def init_llm(): |
|
huggingface_hub.login(token=config.hf_token, new_session=False) |
|
llm = ModelInference(chat_template=config.chat_template) |
|
return llm |
|
|
|
def get_response(prompt): |
|
try: |
|
return llm.generate_function_call( |
|
prompt, |
|
config.chat_template, |
|
config.num_fewshot, |
|
config.max_depth |
|
) |
|
except Exception as e: |
|
return f"An error occurred: {str(e)}" |
|
|
|
def get_output(context, user_input): |
|
try: |
|
prompt_schema = llm.prompter.read_yaml_file("prompt_assets/output_sys_prompt.yml") |
|
sys_prompt = llm.prompter.format_yaml_prompt(prompt_schema, dict()) + \ |
|
f"Information:\n{context}" |
|
convo = [ |
|
{"role": "system", "content": sys_prompt}, |
|
{"role": "user", "content": user_input}, |
|
] |
|
response = llm.run_inference(convo) |
|
return get_assistant_message(response, config.chat_template, llm.tokenizer.eos_token) |
|
except Exception as e: |
|
return f"An error occurred: {str(e)}" |
|
|
|
def main(): |
|
st.title("LLM-ADE 9B Demo") |
|
|
|
input_text = st.text_area("Enter your text here:", value="", height=200) |
|
|
|
if st.button("Generate"): |
|
if input_text: |
|
with st.spinner('Generating response...'): |
|
agent_resp = get_response(input_text) |
|
st.write(get_output(agent_resp, input_text)) |
|
else: |
|
st.warning("Please enter some text to generate a response.") |
|
|
|
llm = init_llm() |
|
|
|
def main_headless(): |
|
while True: |
|
input_text = input("Enter your text here: ") |
|
agent_resp = get_response(input_text) |
|
print('\033[94m' + get_output(agent_resp, input_text) + '\033[0m') |
|
|
|
if __name__ == "__main__": |
|
print(f"Test env vars: {os.getenv('TEST_SECRET')}") |
|
if config.headless: |
|
main_headless() |
|
else: |
|
main() |
|
|