WilliamGazeley commited on
Commit
5005937
·
1 Parent(s): 3a830ca

Added output caching

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -4,19 +4,18 @@ from concurrent.futures import ThreadPoolExecutor
4
 
5
 
6
  # Function to load models only once using Streamlit's cache mechanism
7
- @st.cache(allow_output_mutation=True)
8
  def load_models():
9
- with st.spinner(text="Loading Models..."):
10
- base_pipe = pipeline(
11
- "text-generation",
12
- model="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
13
- max_length=512,
14
- )
15
- irai_pipe = pipeline(
16
- "text-generation",
17
- model="InvestmentResearchAI/LLM-ADE_tiny-v0.001",
18
- max_length=512,
19
- )
20
  return base_pipe, irai_pipe
21
 
22
 
@@ -43,6 +42,7 @@ def generate_irai_response(input_text):
43
  return result.split("<|assistant|>")[1].strip()
44
 
45
 
 
46
  def generate_response(input_text):
47
  try:
48
  future_base = executor.submit(generate_base_response, input_text)
@@ -64,10 +64,10 @@ if st.button("Generate"):
64
  base_response, irai_response = generate_response(user_input)
65
  col1, col2 = st.columns(2)
66
  with col1:
67
- st.header("Base Model Response")
68
  st.text_area(label="", value=base_response, height=300)
69
  with col2:
70
- st.header("IRAI LLM-ADE Model Response")
71
  st.text_area(label="", value=irai_response, height=300)
72
  else:
73
  st.warning("Please enter some text to generate a response.")
 
4
 
5
 
6
  # Function to load models only once using Streamlit's cache mechanism
7
+ @st.cache_resource(show_spinner="Loading Models...")
8
  def load_models():
9
+ base_pipe = pipeline(
10
+ "text-generation",
11
+ model="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
12
+ max_length=512,
13
+ )
14
+ irai_pipe = pipeline(
15
+ "text-generation",
16
+ model="InvestmentResearchAI/LLM-ADE_tiny-v0.001",
17
+ max_length=512,
18
+ )
 
19
  return base_pipe, irai_pipe
20
 
21
 
 
42
  return result.split("<|assistant|>")[1].strip()
43
 
44
 
45
+ @st.cache_data
46
  def generate_response(input_text):
47
  try:
48
  future_base = executor.submit(generate_base_response, input_text)
 
64
  base_response, irai_response = generate_response(user_input)
65
  col1, col2 = st.columns(2)
66
  with col1:
67
+ st.header("Base Model")
68
  st.text_area(label="", value=base_response, height=300)
69
  with col2:
70
+ st.header("LLM-ADE Enhanced")
71
  st.text_area(label="", value=irai_response, height=300)
72
  else:
73
  st.warning("Please enter some text to generate a response.")