File size: 2,068 Bytes
109014c
a20dfac
ecd63b4
c124df1
5894c9b
c124df1
5894c9b
ecd63b4
a818c02
0583c4b
e87746b
5894c9b
ce65c0f
e87746b
ecd63b4
0583c4b
ecd63b4
c124df1
 
 
 
 
 
ecd63b4
 
5894c9b
 
 
 
 
 
eaab710
5894c9b
 
eaab710
5894c9b
 
ecd63b4
 
 
 
 
 
 
954e857
ecd63b4
 
 
5894c9b
 
ecd63b4
 
 
e87746b
 
c124df1
 
5894c9b
 
 
ecd63b4
c124df1
5894c9b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import huggingface_hub
import streamlit as st
from config import config
from utils import get_assistant_message
from functioncall import ModelInference
from prompter import PromptManager


@st.cache_resource(show_spinner="Loading model..")
def init_llm():
    huggingface_hub.login(token=config.hf_token, new_session=False)
    llm = ModelInference(chat_template=config.chat_template)
    return llm

def get_response(prompt):
    try:
        return llm.generate_function_call(
            prompt, 
            config.chat_template, 
            config.num_fewshot, 
            config.max_depth
        )
    except Exception as e:
        return f"An error occurred: {str(e)}"
    
def get_output(context, user_input):
    try:
        prompt_schema = llm.prompter.read_yaml_file("prompt_assets/output_sys_prompt.yml")
        sys_prompt = llm.prompter.format_yaml_prompt(prompt_schema, dict()) + \
            f"Information:\n{context}"
        convo = [
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": user_input},
        ]
        response = llm.run_inference(convo)
        return get_assistant_message(response, config.chat_template, llm.tokenizer.eos_token)
    except Exception as e:
        return f"An error occurred: {str(e)}"

def main():
    st.title("LLM-ADE 9B Demo")
    
    input_text = st.text_area("Enter your text here:", value="", height=200)
    
    if st.button("Generate"):
        if input_text:
            with st.spinner('Generating response...'):
                agent_resp = get_response(input_text)
                st.write(get_output(agent_resp, input_text))
        else:
            st.warning("Please enter some text to generate a response.")

llm = init_llm()

def main_headless():
    while True:
       input_text = input("Enter your text here: ")
       agent_resp = get_response(input_text)
       print('\033[94m' + get_output(agent_resp, input_text) + '\033[0m')

if __name__ == "__main__":
    if config.headless:
        main_headless()
    else:
        main()