File size: 2,223 Bytes
de1c7b8
fd872b2
8b384d6
 
3a830ca
 
5005937
3a830ca
5005937
 
 
 
 
 
 
 
 
 
3a830ca
 
 
 
de1c7b8
fd872b2
 
 
 
 
 
 
 
3a830ca
 
8b384d6
fd872b2
 
 
8b384d6
fd872b2
3a830ca
 
 
8b384d6
de1c7b8
77428fd
de1c7b8
3a830ca
 
 
 
 
 
 
 
fd872b2
de1c7b8
8b384d6
 
fd872b2
8b384d6
de1c7b8
 
77428fd
 
 
 
 
 
 
 
de1c7b8
fd872b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import streamlit as st
from transformers import pipeline
from concurrent.futures import ThreadPoolExecutor


# Function to load models only once using Streamlit's cache mechanism
@st.cache_resource(show_spinner="Loading Models...")
def load_models():
    base_pipe = pipeline(
        "text-generation",
        model="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
        max_length=512,
    )
    irai_pipe = pipeline(
        "text-generation",
        model="InvestmentResearchAI/LLM-ADE_tiny-v0.001",
        max_length=512,
    )
    return base_pipe, irai_pipe


base_pipe, irai_pipe = load_models()

prompt_template = (
    "<|system|>\n"
    "You are a friendly chatbot who always gives helpful, detailed, and polite answers.</s>\n"
    "<|user|>\n"
    "{input_text}</s>\n"
    "<|assistant|>\n"
)

executor = ThreadPoolExecutor(max_workers=2)


def generate_base_response(input_text):
    return base_pipe(input_text)[0]["generated_text"]


def generate_irai_response(input_text):
    formatted_input = prompt_template.format(input_text=input_text)
    result = irai_pipe(formatted_input)[0]["generated_text"]
    return result.split("<|assistant|>")[1].strip()


@st.cache_data(show_spinner"Generating responses...")
def generate_response(input_text):
    try:
        future_base = executor.submit(generate_base_response, input_text)
        future_irai = executor.submit(generate_irai_response, input_text)
        base_resp = future_base.result()
        irai_resp = future_irai.result()
    except Exception as e:
        st.error(f"An error occurred: {e}")
        return None, None
    return base_resp, irai_resp


st.title("IRAI LLM-ADE vs Base Model")
user_input = st.text_area("Enter a financial question:", "")

if st.button("Generate"):
    if user_input:
        base_response, irai_response = generate_response(user_input)
        col1, col2 = st.columns(2)
        with col1:
            st.header("Base Model")
            st.text_area(label="", value=base_response, height=300)
        with col2:
            st.header("LLM-ADE Enhanced")
            st.text_area(label="", value=irai_response, height=300)
    else:
        st.warning("Please enter some text to generate a response.")