File size: 1,383 Bytes
109014c
ecd63b4
 
 
e88bc69
ecd63b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import streamlit as st
from vllm import LLM, SamplingParams

llm = LLM(model="InvestmentResearchAI/LLM-ADE-small-v0.1.0", token=os.getenv("HF_TOKEN"))
tok = llm.get_tokenizer()
tok.eos_token = '<|eot_id|>' # Override to use turns


template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful financial assistant that answers the user as accurately, truthfully, and concisely as possible.<|eot_id|><|start_header_id|>user<|end_header_id|>

{user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""


def get_response(prompt):
    try:
        prompts = [template.format(user_message=prompt)]
        sampling_params = SamplingParams(temperature=0.3, top_p=0.95)
        outputs = llm.generate(prompts, sampling_params)
        for output in outputs:
            return output.outputs[0].text
    except Exception as e:
        return f"An error occurred: {str(e)}"

def main():
    st.title("LLM-ADE 9B Demo")
    
    input_text = st.text_area("Enter your text here:", value="", height=200)
    if st.button("Generate"):
        if input_text:
            with st.spinner('Generating response...'):
                response_text = get_response(input_text)
                st.write(response_text)
        else:
            st.warning("Please enter some text to generate a response.")

if __name__ == "__main__":
    main()