Spaces:
Sleeping
Sleeping
File size: 3,196 Bytes
109014c 9e2a95f a20dfac ecd63b4 c124df1 ecd63b4 a818c02 0583c4b e87746b 5894c9b ce65c0f e87746b ecd63b4 9e2a95f e40d8d8 ecd63b4 c124df1 9e2a95f c124df1 ecd63b4 9e2a95f e40d8d8 5894c9b 47c54d0 9e2a95f e40d8d8 9e2a95f eaab710 5894c9b eaab710 5894c9b 9e2a95f ecd63b4 e40d8d8 9e2a95f ecd63b4 9e2a95f ecd63b4 9e2a95f ecd63b4 9e2a95f 7aab4a8 e40d8d8 7aab4a8 ecd63b4 9e2a95f e40d8d8 9e2a95f ecd63b4 e40d8d8 c124df1 5894c9b 9e2a95f e40d8d8 9e2a95f 5894c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import os
from time import time
import huggingface_hub
import streamlit as st
from config import config
from functioncall import ModelInference
@st.cache_resource(show_spinner="Loading model..")
def init_llm():
huggingface_hub.login(token=config.hf_token, new_session=False)
llm = ModelInference(chat_template=config.chat_template)
return llm
def function_agent(prompt):
try:
return llm.generate_function_call(
prompt, config.chat_template, config.num_fewshot, config.max_depth
)
except Exception as e:
return f"An error occurred: {str(e)}"
def output_agent(context, user_input):
"""Takes the output of the RAG and generates a final response."""
try:
config.status.update(label=":bulb: Preparing answer..")
script_dir = os.path.dirname(os.path.abspath(__file__))
prompt_path = os.path.join(script_dir, "prompt_assets", "output_sys_prompt.yml")
prompt_schema = llm.prompter.read_yaml_file(prompt_path)
sys_prompt = (
llm.prompter.format_yaml_prompt(prompt_schema, dict())
+ f"Information:\n{context}"
)
convo = [
{"role": "system", "content": sys_prompt},
{"role": "user", "content": user_input},
]
response = llm.run_inference(convo)
return response
except Exception as e:
return f"An error occurred: {str(e)}"
def query_agent(prompt):
"""Modifies the prompt and runs inference on it."""
try:
config.status.update(label=":brain: Starting inference..")
script_dir = os.path.dirname(os.path.abspath(__file__))
prompt_path = os.path.join(script_dir, "prompt_assets", "output_sys_prompt.yml")
prompt_schema = llm.prompter.read_yaml_file(prompt_path)
sys_prompt = llm.prompter.format_yaml_prompt(prompt_schema, dict())
convo = [
{"role": "system", "content": sys_prompt},
{"role": "user", "content": prompt},
]
response = llm.run_inference(convo)
return response
except Exception as e:
return f"An error occurred: {str(e)}"
def get_response(input_text: str):
"""This is the main function that generates the final response."""
agent_resp = function_agent(input_text)
output = output_agent(agent_resp, input_text)
return output
def main():
st.title("LLM-ADE 9B Demo")
input_text = st.text_area("Enter your text here:", value="", height=200)
if st.button("Generate"):
if input_text:
with st.status("Generating response...") as status:
config.status = status
st.write(get_response(input_text))
config.status.update(label="Finished!", state="complete", expanded=True)
else:
st.warning("Please enter some text to generate a response.")
def main_headless(prompt: str):
start = time()
print("\033[94m" + get_response(prompt) + "\033[0m")
print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20)
llm = init_llm()
if __name__ == "__main__":
if config.headless:
import fire
fire.Fire(main_headless)
else:
main()
|