Spaces:
Sleeping
Sleeping
File size: 2,235 Bytes
de1c7b8 fd872b2 8b384d6 3a830ca 5005937 3a830ca 5005937 3a830ca de1c7b8 fd872b2 3a830ca 8b384d6 fd872b2 8b384d6 fd872b2 3a830ca 8b384d6 de1c7b8 379c3fe de1c7b8 3a830ca a0ab831 3a830ca fd872b2 de1c7b8 8b384d6 abe7082 fd872b2 8b384d6 de1c7b8 77428fd de1c7b8 379c3fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import streamlit as st
from transformers import pipeline
from concurrent.futures import ThreadPoolExecutor
# Function to load models only once using Streamlit's cache mechanism
@st.cache_resource(show_spinner="Loading Models...")
def load_models():
base_pipe = pipeline(
"text-generation",
model="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
max_length=512,
)
irai_pipe = pipeline(
"text-generation",
model="InvestmentResearchAI/LLM-ADE_tiny-v0.001",
max_length=512,
)
return base_pipe, irai_pipe
base_pipe, irai_pipe = load_models()
prompt_template = (
"<|system|>\n"
"You are a friendly chatbot who always gives helpful, detailed, and polite answers.</s>\n"
"<|user|>\n"
"{input_text}</s>\n"
"<|assistant|>\n"
)
executor = ThreadPoolExecutor(max_workers=2)
def generate_base_response(input_text):
return base_pipe(input_text)[0]["generated_text"]
def generate_irai_response(input_text):
formatted_input = prompt_template.format(input_text=input_text)
result = irai_pipe(formatted_input)[0]["generated_text"]
return result.split("<|assistant|>")[1].strip()
@st.cache_data(show_spinner="Generating responses...")
def generate_response(input_text):
try:
future_base = executor.submit(generate_base_response, input_text)
future_irai = executor.submit(generate_irai_response, input_text)
base_resp = future_base.result().replace(input_text, "", 1).strip()
irai_resp = future_irai.result()
except Exception as e:
st.error(f"An error occurred: {e}")
return None, None
return base_resp, irai_resp
st.title("Base Model vs IRAI LLM-ADE")
user_input = st.text_area("Enter a financial question:", "")
if st.button("Generate"):
if user_input:
base_response, irai_response = generate_response(user_input)
col1, col2 = st.columns(2)
with col1:
st.header("Base Model")
st.text_area(label="", value=base_response, height=300)
with col2:
st.header("LLM-ADE Enhanced")
st.text_area(label="", value=irai_response, height=300)
else:
st.warning("Please enter some text")
|