WilliamGazeley commited on
Commit
8b384d6
·
1 Parent(s): fd872b2

Implement simple threading

Browse files
Files changed (1) hide show
  1. app.py +35 -15
app.py CHANGED
@@ -1,6 +1,19 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
- from concurrent.futures import ProcessPoolExecutor
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  prompt_template = (
6
  "<|system|>\n"
@@ -10,16 +23,21 @@ prompt_template = (
10
  "<|assistant|>\n"
11
  )
12
 
 
13
  def generate_base_response(input_text):
14
- base_pipe = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", max_length=512)
15
  return base_pipe(input_text)[0]["generated_text"]
16
 
 
17
  def generate_irai_response(input_text):
18
- irai_pipe = pipeline("text-generation", model="InvestmentResearchAI/LLM-ADE_tiny-v0.001", max_length=512)
19
- return irai_pipe(prompt_template.format(input_text=input_text))[0]["generated_text"].split("<|assistant|>")[1].strip()
 
 
 
 
20
 
21
  def generate_response(input_text):
22
- with ProcessPoolExecutor() as executor:
23
  try:
24
  future_base = executor.submit(generate_base_response, input_text)
25
  future_irai = executor.submit(generate_irai_response, input_text)
@@ -30,18 +48,20 @@ def generate_response(input_text):
30
  return None, None
31
  return base_resp, irai_resp
32
 
33
- st.title("IRAI LLM-ADE Model vs Base Model")
 
34
  user_input = st.text_area("Enter a financial question:", "")
 
35
  if st.button("Generate"):
36
  if user_input:
37
- base_response, irai_response = generate_response(user_input)
38
- col1, col2 = st.columns(2) # Updated to use `st.columns`
39
- with col1:
40
- st.header("Base Model Response")
41
- st.text_area("", base_response, height=300)
42
- with col2:
43
- st.header("IRAI LLM-ADE Model Response")
44
- st.text_area("", irai_response, height=300)
 
45
  else:
46
  st.warning("Please enter some text to generate a response.")
47
-
 
1
  import streamlit as st
2
  from transformers import pipeline
3
+ from concurrent.futures import ThreadPoolExecutor
4
+
5
+ # Load models at startup
6
+ with st.spinner(text="Loading Models..."):
7
+ base_pipe = pipeline(
8
+ "text-generation",
9
+ model="TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
10
+ max_length=512,
11
+ )
12
+ irai_pipe = pipeline(
13
+ "text-generation",
14
+ model="InvestmentResearchAI/LLM-ADE_tiny-v0.001",
15
+ max_length=512,
16
+ )
17
 
18
  prompt_template = (
19
  "<|system|>\n"
 
23
  "<|assistant|>\n"
24
  )
25
 
26
+
27
  def generate_base_response(input_text):
 
28
  return base_pipe(input_text)[0]["generated_text"]
29
 
30
+
31
  def generate_irai_response(input_text):
32
+ return (
33
+ irai_pipe(prompt_template.format(input_text=input_text))[0]["generated_text"]
34
+ .split("<|assistant|>")[1]
35
+ .strip()
36
+ )
37
+
38
 
39
  def generate_response(input_text):
40
+ with ThreadPoolExecutor() as executor:
41
  try:
42
  future_base = executor.submit(generate_base_response, input_text)
43
  future_irai = executor.submit(generate_irai_response, input_text)
 
48
  return None, None
49
  return base_resp, irai_resp
50
 
51
+
52
+ st.title("IRAI LLM-ADE vs Base Model")
53
  user_input = st.text_area("Enter a financial question:", "")
54
+
55
  if st.button("Generate"):
56
  if user_input:
57
+ with st.spinner(text="Generating text..."):
58
+ base_response, irai_response = generate_response(user_input)
59
+ col1, col2 = st.columns(2)
60
+ with col1:
61
+ st.header("Base Model Response")
62
+ st.text_area("", base_response, height=300)
63
+ with col2:
64
+ st.header("IRAI LLM-ADE Model Response")
65
+ st.text_area("", irai_response, height=300)
66
  else:
67
  st.warning("Please enter some text to generate a response.")