File size: 2,553 Bytes
109014c
a20dfac
ecd63b4
 
 
a818c02
 
 
 
 
954e857
a818c02
954e857
 
 
 
 
 
 
a818c02
 
e87746b
 
558d9e8
e87746b
eaab710
e87746b
ecd63b4
a818c02
 
ecd63b4
eaab710
a818c02
eaab710
 
 
a818c02
ecd63b4
 
 
 
 
 
 
 
 
a818c02
 
 
 
 
 
ecd63b4
954e857
ecd63b4
 
 
a818c02
ecd63b4
 
 
 
e87746b
 
ecd63b4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import huggingface_hub
import streamlit as st
from vllm import LLM, SamplingParams


@st.cache_data(show_spinner=False)
def get_system_message():
    return """#Context:
You are an AI-based automated expert financial advisor named IRAI. You have a comprehensive understanding of finance and investing because you have trained on a  extensive dataset based on of financial news, analyst reports, books, company filings, earnings call transcripts, and finance websites.
#Objective:
Answer questions accurately and truthfully given the data you have trained on.  You do not have access to up-to-date current market data; this will be available in the future. 
Style and tone:
Please answer in a friendly and engaging manner representing a top female investment professional working at a leading investment bank.
#Audience:
The questions will be asked by top technology executives and CFO of large fintech companies and successful startups.
#Response:
Answer, concise yet insightful."""


@st.cache_resource(show_spinner=False)
def init_llm():
    huggingface_hub.login(token=os.getenv("HF_TOKEN"))
    llm = LLM(model="InvestmentResearchAI/LLM-ADE-dev")
    tok = llm.get_tokenizer()
    tok.eos_token = '<|im_end|>' # Override to use turns
    return llm


def get_response(prompt, custom_sys_msg):
    try:
        convo = [
            {"role": "system", "content": custom_sys_msg},
            {"role": "user", "content": prompt},
        ]
        prompts = [llm.get_tokenizer().apply_chat_template(convo, tokenize=False)]
        sampling_params = SamplingParams(temperature=0.3, top_p=0.95, max_tokens=2000, stop_token_ids=[128009])
        outputs = llm.generate(prompts, sampling_params)
        for output in outputs:
            return output.outputs[0].text
    except Exception as e:
        return f"An error occurred: {str(e)}"

def main():
    st.title("LLM-ADE 9B Demo")
    
    # Retrieve the default system message
    sys_msg = get_system_message()
    
    # UI for editable preprompt
    user_modified_sys_msg = st.text_area("Preprompt: ", value=sys_msg, height=200)

    input_text = st.text_area("Enter your text here:", value="", height=200)
    
    if st.button("Generate"):
        if input_text:
            with st.spinner('Generating response...'):
                response_text = get_response(input_text, user_modified_sys_msg)
                st.write(response_text)
        else:
            st.warning("Please enter some text to generate a response.")

llm = init_llm()

if __name__ == "__main__":
    main()