File size: 2,235 Bytes
de1c7b8
fd872b2
8b384d6
 
3a830ca
 
5005937
3a830ca
5005937
 
 
 
 
 
 
 
 
 
3a830ca
 
 
 
de1c7b8
fd872b2
 
 
 
 
 
 
 
3a830ca
 
8b384d6
fd872b2
 
 
8b384d6
fd872b2
3a830ca
 
 
8b384d6
de1c7b8
379c3fe
de1c7b8
3a830ca
 
 
a0ab831
3a830ca
 
 
 
fd872b2
de1c7b8
8b384d6
abe7082
fd872b2
8b384d6
de1c7b8
 
77428fd
 
 
 
 
 
 
 
de1c7b8
379c3fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import streamlit as st
from transformers import pipeline
from concurrent.futures import ThreadPoolExecutor


# Function to load models only once using Streamlit's cache mechanism
@st.cache_resource(show_spinner="Loading Models...")
def load_models():
    base_pipe = pipeline(
        "text-generation",
        model="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
        max_length=512,
    )
    irai_pipe = pipeline(
        "text-generation",
        model="InvestmentResearchAI/LLM-ADE_tiny-v0.001",
        max_length=512,
    )
    return base_pipe, irai_pipe


base_pipe, irai_pipe = load_models()

prompt_template = (
    "<|system|>\n"
    "You are a friendly chatbot who always gives helpful, detailed, and polite answers.</s>\n"
    "<|user|>\n"
    "{input_text}</s>\n"
    "<|assistant|>\n"
)

executor = ThreadPoolExecutor(max_workers=2)


def generate_base_response(input_text):
    return base_pipe(input_text)[0]["generated_text"]


def generate_irai_response(input_text):
    formatted_input = prompt_template.format(input_text=input_text)
    result = irai_pipe(formatted_input)[0]["generated_text"]
    return result.split("<|assistant|>")[1].strip()


@st.cache_data(show_spinner="Generating responses...")
def generate_response(input_text):
    try:
        future_base = executor.submit(generate_base_response, input_text)
        future_irai = executor.submit(generate_irai_response, input_text)
        base_resp = future_base.result().replace(input_text, "", 1).strip()
        irai_resp = future_irai.result()
    except Exception as e:
        st.error(f"An error occurred: {e}")
        return None, None
    return base_resp, irai_resp


st.title("Base Model vs IRAI LLM-ADE")
user_input = st.text_area("Enter a financial question:", "")

if st.button("Generate"):
    if user_input:
        base_response, irai_response = generate_response(user_input)
        col1, col2 = st.columns(2)
        with col1:
            st.header("Base Model")
            st.text_area(label="", value=base_response, height=300)
        with col2:
            st.header("LLM-ADE Enhanced")
            st.text_area(label="", value=irai_response, height=300)
    else:
        st.warning("Please enter some text")