File size: 3,196 Bytes
109014c
9e2a95f
a20dfac
ecd63b4
c124df1
 
ecd63b4
a818c02
0583c4b
e87746b
5894c9b
ce65c0f
e87746b
ecd63b4
9e2a95f
e40d8d8
ecd63b4
c124df1
9e2a95f
c124df1
ecd63b4
 
9e2a95f
 
e40d8d8
 
5894c9b
47c54d0
9e2a95f
e40d8d8
9e2a95f
 
 
 
 
eaab710
5894c9b
 
eaab710
5894c9b
9e2a95f
ecd63b4
 
 
e40d8d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e2a95f
ecd63b4
 
9e2a95f
ecd63b4
9e2a95f
ecd63b4
 
9e2a95f
7aab4a8
e40d8d8
7aab4a8
ecd63b4
 
 
9e2a95f
 
 
e40d8d8
9e2a95f
 
ecd63b4
e40d8d8
 
 
c124df1
5894c9b
9e2a95f
e40d8d8
9e2a95f
5894c9b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
from time import time
import huggingface_hub
import streamlit as st
from config import config
from functioncall import ModelInference


@st.cache_resource(show_spinner="Loading model..")
def init_llm():
    huggingface_hub.login(token=config.hf_token, new_session=False)
    llm = ModelInference(chat_template=config.chat_template)
    return llm


def function_agent(prompt):
    try:
        return llm.generate_function_call(
            prompt, config.chat_template, config.num_fewshot, config.max_depth
        )
    except Exception as e:
        return f"An error occurred: {str(e)}"


def output_agent(context, user_input):
    """Takes the output of the RAG and generates a final response."""
    try:
        config.status.update(label=":bulb: Preparing answer..")
        script_dir = os.path.dirname(os.path.abspath(__file__))
        prompt_path = os.path.join(script_dir, "prompt_assets", "output_sys_prompt.yml")
        prompt_schema = llm.prompter.read_yaml_file(prompt_path)
        sys_prompt = (
            llm.prompter.format_yaml_prompt(prompt_schema, dict())
            + f"Information:\n{context}"
        )
        convo = [
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": user_input},
        ]
        response = llm.run_inference(convo)
        return response
    except Exception as e:
        return f"An error occurred: {str(e)}"

def query_agent(prompt):
    """Modifies the prompt and runs inference on it."""
    try:
        config.status.update(label=":brain: Starting inference..")
        script_dir = os.path.dirname(os.path.abspath(__file__))
        prompt_path = os.path.join(script_dir, "prompt_assets", "output_sys_prompt.yml")
        prompt_schema = llm.prompter.read_yaml_file(prompt_path)
        sys_prompt = llm.prompter.format_yaml_prompt(prompt_schema, dict())
        convo = [
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": prompt},
        ]
        response = llm.run_inference(convo)
        return response
    except Exception as e:
        return f"An error occurred: {str(e)}"


def get_response(input_text: str):
    """This is the main function that generates the final response."""
    agent_resp = function_agent(input_text)
    output = output_agent(agent_resp, input_text)
    return output


def main():
    st.title("LLM-ADE 9B Demo")

    input_text = st.text_area("Enter your text here:", value="", height=200)

    if st.button("Generate"):
        if input_text:
            with st.status("Generating response...") as status:
                config.status = status
                st.write(get_response(input_text))
                config.status.update(label="Finished!", state="complete", expanded=True)
        else:
            st.warning("Please enter some text to generate a response.")


def main_headless(prompt: str):
    start = time()
    print("\033[94m" + get_response(prompt) + "\033[0m")
    print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20)


llm = init_llm()


if __name__ == "__main__":
    if config.headless:
        import fire

        fire.Fire(main_headless)
    else:
        main()