File size: 2,328 Bytes
109014c
9e2a95f
a20dfac
ecd63b4
c124df1
5894c9b
c124df1
ecd63b4
a818c02
0583c4b
e87746b
5894c9b
ce65c0f
e87746b
ecd63b4
9e2a95f
0583c4b
ecd63b4
c124df1
9e2a95f
c124df1
ecd63b4
 
9e2a95f
 
5894c9b
 
47c54d0
9e2a95f
 
 
 
 
 
 
eaab710
5894c9b
 
eaab710
5894c9b
9e2a95f
ecd63b4
 
 
9e2a95f
ecd63b4
 
9e2a95f
ecd63b4
9e2a95f
ecd63b4
 
9e2a95f
7aab4a8
5894c9b
 
7aab4a8
ecd63b4
 
 
9e2a95f
e87746b
 
9e2a95f
 
 
 
 
 
 
ecd63b4
c124df1
5894c9b
9e2a95f
 
5894c9b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
from time import time
import huggingface_hub
import streamlit as st
from config import config
from utils import get_assistant_message
from functioncall import ModelInference


@st.cache_resource(show_spinner="Loading model..")
def init_llm():
    huggingface_hub.login(token=config.hf_token, new_session=False)
    llm = ModelInference(chat_template=config.chat_template)
    return llm


def get_response(prompt):
    try:
        return llm.generate_function_call(
            prompt, config.chat_template, config.num_fewshot, config.max_depth
        )
    except Exception as e:
        return f"An error occurred: {str(e)}"


def get_output(context, user_input):
    try:
        config.status.update(label=":bulb: Preparing answer..")
        script_dir = os.path.dirname(os.path.abspath(__file__))
        prompt_path = os.path.join(script_dir, 'prompt_assets', 'output_sys_prompt.yml')
        prompt_schema = llm.prompter.read_yaml_file(prompt_path)
        sys_prompt = (
            llm.prompter.format_yaml_prompt(prompt_schema, dict())
            + f"Information:\n{context}"
        )
        convo = [
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": user_input},
        ]
        response = llm.run_inference(convo)
        return response
    except Exception as e:
        return f"An error occurred: {str(e)}"


def main():
    st.title("LLM-ADE 9B Demo")

    input_text = st.text_area("Enter your text here:", value="", height=200)

    if st.button("Generate"):
        if input_text:
            with st.status("Generating response...") as status:
                config.status = status
                agent_resp = get_response(input_text)
                st.write(get_output(agent_resp, input_text))
                config.status.update(label="Finished!", state="complete", expanded=True)
        else:
            st.warning("Please enter some text to generate a response.")


llm = init_llm()


def main_headless(prompt: str):
    start = time()
    agent_resp = get_response(prompt)
    print("\033[94m" + get_output(agent_resp, prompt) + "\033[0m")
    print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20)


if __name__ == "__main__":
    if config.headless:
        import fire
        fire.Fire(main_headless)
    else:
        main()