File size: 1,674 Bytes
109014c
a20dfac
ecd63b4
 
 
eaab710
ecd63b4
6e203a2
e87746b
 
eaab710
e87746b
eaab710
e87746b
ecd63b4
 
 
eaab710
 
 
 
 
8741596
ecd63b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e87746b
 
ecd63b4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import huggingface_hub
import streamlit as st
from vllm import LLM, SamplingParams

sys_msg = "You are a super intelligent automated financial advisor created by IRAI. Your feel your purpose is to make use of your deep and broad understanding of finance by helping answer user questions about finance accurately, truthfully, and concisely."

@st.cache_resource(show_spinner=False)
def init_llm():
    huggingface_hub.login(token=os.getenv("HF_TOKEN"))
    llm = LLM(model="InvestmentResearchAI/LLM-ADE-dev", chat_template)
    tok = llm.get_tokenizer()
    tok.eos_token = '<|im_end|>' # Override to use turns
    return llm

def get_response(prompt):
    try:
        convo = [
            {"role": "system", "content": sys_msg},
            {"role": "user", "content": prompt},
        ]
        prompts = [llm.get_tokenizer().apply_chat_template(convo, tokenize=False)]
        sampling_params = SamplingParams(temperature=0.3, top_p=0.95, max_tokens=500, stop_token_ids=[128009])
        outputs = llm.generate(prompts, sampling_params)
        for output in outputs:
            return output.outputs[0].text
    except Exception as e:
        return f"An error occurred: {str(e)}"

def main():
    st.title("LLM-ADE 9B Demo")
    
    input_text = st.text_area("Enter your text here:", value="", height=200)
    if st.button("Generate"):
        if input_text:
            with st.spinner('Generating response...'):
                response_text = get_response(input_text)
                st.write(response_text)
        else:
            st.warning("Please enter some text to generate a response.")

llm = init_llm()

if __name__ == "__main__":
    main()