Spaces:
Sleeping
Sleeping
import os | |
import huggingface_hub | |
import streamlit as st | |
from config import config | |
from utils import get_assistant_message | |
from functioncall import ModelInference | |
from prompter import PromptManager | |
print("Why, hello there!", flush=True) | |
def init_llm(): | |
huggingface_hub.login(token=config.hf_token, new_session=False) | |
llm = ModelInference(chat_template=config.chat_template) | |
return llm | |
def get_response(prompt): | |
try: | |
return llm.generate_function_call( | |
prompt, | |
config.chat_template, | |
config.num_fewshot, | |
config.max_depth | |
) | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
def get_output(context, user_input): | |
try: | |
config.status.update(label=":bulb: Preparing answer..") | |
prompt_schema = llm.prompter.read_yaml_file("prompt_assets/output_sys_prompt.yml") | |
sys_prompt = llm.prompter.format_yaml_prompt(prompt_schema, dict()) + \ | |
f"Information:\n{context}" | |
convo = [ | |
{"role": "system", "content": sys_prompt}, | |
{"role": "user", "content": user_input}, | |
] | |
response = llm.run_inference(convo) | |
return get_assistant_message(response, config.chat_template, llm.tokenizer.eos_token) | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
def main(): | |
st.title("LLM-ADE 9B Demo") | |
input_text = st.text_area("Enter your text here:", value="", height=200) | |
if st.button("Generate"): | |
if input_text: | |
with st.status('Generating response...') as status: | |
config.status = status | |
agent_resp = get_response(input_text) | |
st.write(get_output(agent_resp, input_text)) | |
config.status.update(label="Finished!", state="complete", expanded=True) | |
else: | |
st.warning("Please enter some text to generate a response.") | |
llm = init_llm() | |
def main_headless(): | |
while True: | |
input_text = input("Enter your text here: ") | |
agent_resp = get_response(input_text) | |
print('\033[94m' + get_output(agent_resp, input_text) + '\033[0m') | |
if __name__ == "__main__": | |
print(f"Test env vars: {os.getenv('TEST_SECRET')}") | |
if config.headless: | |
main_headless() | |
else: | |
main() | |