TheBobBob commited on
Commit
cc4a478
·
verified ·
1 Parent(s): 0d36764

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +350 -293
app.py CHANGED
@@ -10,327 +10,384 @@ import libsbml
10
  import networkx as nx
11
  from pyvis.network import Network
12
 
13
- # Constants
14
- GITHUB_OWNER = "TheBobBob"
15
- GITHUB_REPO_CACHE = "BiomodelsCache"
16
- BIOMODELS_JSON_DB_PATH = "src/cached_biomodels.json"
17
- LOCAL_DOWNLOAD_DIR = tempfile.mkdtemp()
18
-
19
- def fetch_github_json():
20
- url = f"https://api.github.com/repos/{GITHUB_OWNER}/{GITHUB_REPO_CACHE}/contents/{BIOMODELS_JSON_DB_PATH}"
21
- headers = {"Accept": "application/vnd.github+json"}
22
- response = requests.get(url, headers=headers)
23
-
24
- if response.status_code == 200:
25
- data = response.json()
26
- if "download_url" in data:
27
- file_url = data["download_url"]
28
- json_response = requests.get(file_url)
29
- return json_response.json()
30
- else:
31
- raise ValueError(f"Unable to fetch model DB from GitHub repository: {GITHUB_OWNER} - {GITHUB_REPO_CACHE}")
32
- else:
33
- raise ValueError(f"Unable to fetch model DB from GitHub repository: {GITHUB_OWNER} - {GITHUB_REPO_CACHE}")
34
-
35
- def search_models(search_str, cached_data):
36
- query_text = search_str.strip().lower()
37
- models = {}
38
-
39
- for model_id, model_data in cached_data.items():
40
- if 'name' in model_data:
41
- name = model_data['name'].lower()
42
- url = model_data['url']
43
- id = model_data['model_id']
44
- title = model_data['title']
45
- authors = model_data['authors']
46
-
47
- if query_text:
48
- if ' ' in query_text:
49
- query_words = query_text.split(" ")
50
- if all(word in ' '.join([str(v).lower() for v in model_data.values()]) for word in query_words):
51
- models[model_id] = {
52
- 'ID': model_id,
53
- 'name': name,
54
- 'url': url,
55
- 'id': id,
56
- 'title': title,
57
- 'authors': authors,
58
- }
59
- else:
60
- if query_text in ' '.join([str(v).lower() for v in model_data.values()]):
61
- models[model_id] = {
62
- 'ID': model_id,
63
- 'name': name,
64
- 'url': url,
65
- 'id': id,
66
- 'title': title,
67
- 'authors': authors,
68
- }
69
-
70
- return models
71
-
72
- def download_model_file(model_url, model_id):
73
- model_url = f"https://raw.githubusercontent.com/sys-bio/BiomodelsStore/main/biomodels/{model_id}/{model_id}_url.xml"
74
- response = requests.get(model_url)
75
-
76
- if response.status_code == 200:
77
- os.makedirs(LOCAL_DOWNLOAD_DIR, exist_ok=True)
78
- file_path = os.path.join(LOCAL_DOWNLOAD_DIR, f"{model_id}.xml")
79
-
80
- with open(file_path, 'wb') as file:
81
- file.write(response.content)
82
 
83
- print(f"Model {model_id} downloaded successfully: {file_path}")
84
- return file_path
85
- else:
86
- raise ValueError(f"Failed to download the model from {model_url}")
87
-
88
- def convert_sbml_to_antimony(sbml_file_path, antimony_file_path):
89
- try:
90
- r = te.loadSBMLModel(sbml_file_path)
91
- antimony_str = r.getCurrentAntimony()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- with open(antimony_file_path, 'w') as file:
94
- file.write(antimony_str)
 
 
 
 
 
 
95
 
96
- print(f"Successfully converted SBML to Antimony: {antimony_file_path}")
97
-
98
- except Exception as e:
99
- print(f"Error converting SBML to Antimony: {e}")
100
-
101
- def split_biomodels(antimony_file_path, GROQ_API_KEY, models):
102
- text_splitter = CharacterTextSplitter(
103
- separator="\n\n",
104
- chunk_size=1000,
105
- chunk_overlap=200,
106
- length_function=len,
107
- is_separator_regex=False,
108
- )
109
-
110
- directory_path = os.path.dirname(os.path.abspath(antimony_file_path))
111
- if not os.path.isdir(directory_path):
112
- print(f"Directory not found: {directory_path}")
113
- return final_items
114
-
115
- files = os.listdir(directory_path)
116
- for file in files:
117
- final_items = []
118
- file_path = os.path.join(directory_path, file)
119
  try:
120
- with open(file_path, 'r') as f:
121
- file_content = f.read()
122
- items = text_splitter.create_documents([file_content])
123
- final_items.extend(items)
124
- db, client = create_vector_db(final_items, GROQ_API_KEY, models)
125
- break
126
  except Exception as e:
127
- print(f"Error reading file {file_path}: {e}")
128
 
129
- return db, client
130
 
131
- def create_vector_db(final_items, GROQ_API_KEY, models):
132
- client = chromadb.Client()
133
- collection_name = "BioModelsRAG"
134
 
135
- db = client.get_or_create_collection(name=collection_name)
 
 
 
 
 
 
 
136
 
137
- client = Groq(
138
- api_key=GROQ_API_KEY,
139
- )
140
- for model_id, _ in models.items():
141
-
142
- results = db.get(where = {"document" : model_id})
143
-
144
- if not results['results']:
145
- counter = 0
146
- for item in final_items:
147
- counter += 1
148
- counter += " " + model_id
149
-
150
- prompt = f"""
151
- Summarize the following segment of Antimony in a clear and concise manner:
152
- 1. Provide a detailed summary using a reasonable number of words.
153
- 2. Maintain all original values and include any mathematical expressions or values in full.
154
- 3. Ensure that all variable names and their values are clearly presented.
155
- 4. Write the summary in paragraph format, putting an emphasis on clarity and completeness.
 
 
 
 
 
 
156
 
157
- Segment of Antimony: {item}
158
- """
159
-
160
- chat_completion = client.chat.completions.create(
161
- messages=[
162
- {
 
 
 
 
 
 
 
 
 
 
 
 
163
  "role": "user",
164
  "content": prompt,
165
- }
166
- ],
167
- model="llama3-8b-8192",
168
- )
169
-
170
- if chat_completion.choices[0].message.content:
171
- db.upsert(
172
- ids = [counter],
173
- metadatas = [{"document" : model_id}],
174
- documents = [chat_completion.choices[0].message.content],
175
  )
176
-
177
- return db, client
178
 
179
- def generate_response(db, query_text, client, models):
180
- query_results_final = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- for model_id in models:
183
- query_results = db.query(
184
- query_texts=query_text,
185
- n_results=5,
186
- where={"document": models[model_id]},
187
- )
188
- best_recommendation = query_results['documents']
189
- query_results_final += best_recommendation + "\n\n"
190
-
191
- prompt_template = f"""
192
-
193
- Using the context provided below, answer the following question. If the information is insufficient to answer the question, please state that clearly:
194
- Context:
195
- {query_results_final}
196
- Instructions:
197
- 1. Cross-Reference: Use all provided context to define variables and identify any unknown entities.
198
- 2. Mathematical Calculations: Perform any necessary calculations based on the context and available data.
199
- 3. Consistency: Remember and incorporate previous responses if the question is related to earlier information.
200
-
201
- Question:
202
- {query_text}
203
-
204
- """
205
- chat_completion = client.chat.completions.create(
206
- messages=[
207
- {
208
- "role": "user",
209
- "content": prompt_template,
210
- }
211
- ],
212
- model="llama-3.1-8b-instant",
213
- )
214
- return chat_completion.choices[0].message.content
215
-
216
- def sbml_to_network(file_path):
217
- """
218
- Parse the SBML model, create a network of species and reactions, and return the pyvis.Network object.
219
-
220
- Args:
221
- file_path (str): Path to the SBML model file.
222
-
223
- Returns:
224
- pyvis.Network: Network object that can be visualized later.
225
- """
226
- reader = libsbml.SBMLReader()
227
- document = reader.readSBML(file_path)
228
- model = document.getModel()
229
-
230
- G = nx.Graph()
231
-
232
- for species in model.getListOfSpecies():
233
- species_id = species.getId()
234
- G.add_node(species_id, label=species_id, shape="dot", color="blue")
235
-
236
- for reaction in model.getListOfReactions():
237
- reaction_id = reaction.getId()
238
- substrates = [s.getSpecies() for s in reaction.getListOfReactants()]
239
- products = [p.getSpecies() for p in reaction.getListOfProducts()]
240
-
241
- for substrate in substrates:
242
- for product in products:
243
- G.add_edge(substrate, product, label=reaction_id, color="gray")
244
-
245
- net = Network(notebook=True)
246
- net.from_nx(G)
247
-
248
- net.set_options("""
249
- var options = {
250
- "physics": {
251
- "enabled": true,
252
- "barnesHut": {
253
- "gravitationalConstant": -50000,
254
- "centralGravity": 0.3,
255
- "springLength": 95
256
  },
257
- "maxVelocity": 50,
258
- "minVelocity": 0.1
259
- },
260
- "nodes": {
261
- "size": 20,
262
- "font": {
263
- "size": 18
264
- }
265
- },
266
- "edges": {
267
- "arrows": {
268
- "to": {
269
- "enabled": true
 
 
 
 
270
  }
271
  }
272
  }
273
- }
274
- """)
275
-
276
- return net
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- def streamlit_app():
279
- st.title("BioModelsRAG")
280
-
281
- if "db" not in st.session_state:
282
- st.session_state.db = None
283
 
284
- search_str = st.text_input("Enter search query:")
285
 
286
- GROQ_API_KEY = st.text_input("Enter GROQ API Key (which is free to make!):")
287
 
288
- if search_str:
289
- cached_data = fetch_github_json()
290
- models = search_models(search_str, cached_data)
291
-
292
- if models:
293
- model_ids = list(models.keys())
294
- selected_models = st.multiselect(
295
- "Select biomodels to analyze",
296
- options=model_ids,
297
- default=[model_ids[0]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  )
 
 
 
299
 
300
- if st.button("Visualize selected models"):
301
- for model_id in selected_models:
302
- model_data = models[model_id]
303
- model_url = model_data['url']
304
 
305
- model_file_path = download_model_file(model_url, model_id)
 
306
 
307
- net = sbml_to_network(model_file_path)
 
308
 
309
- st.subheader(f"Model {model_data['title']}")
310
- net.show(f"sbml_network_{model_id}.html")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
- HtmlFile = open(f"sbml_network_{model_id}.html", "r", encoding="utf-8")
313
- st.components.v1.html(HtmlFile.read(), height=600)
314
 
315
- if st.button("Analyze Selected Models"):
316
-
317
- for model_id in selected_models:
318
- model_data = models[model_id]
319
-
320
- st.write(f"Selected model: {model_data['name']}")
321
-
322
- model_url = model_data['url']
323
- model_file_path = download_model_file(model_url, model_id)
324
- antimony_file_path = model_file_path.replace(".xml", ".antimony")
325
-
326
- convert_sbml_to_antimony(model_file_path, antimony_file_path)
327
- db, client = split_biomodels(antimony_file_path, GROQ_API_KEY, selected_models)
328
- print(f"Model {model_id} {model_data['name']} has sucessfully been added to the database! :) ")
329
-
330
- else:
331
- st.error("No items found in the models. Check if the Antimony files were generated correctly.")
332
-
333
- #generate response and remembering previous chat here
334
-
335
  if __name__ == "__main__":
336
- streamlit_app()
 
 
10
  import networkx as nx
11
  from pyvis.network import Network
12
 
13
+
14
+ CHROMA_DATA_PATH = tempfile.mkdtemp()
15
+ EMBED_MODEL = "all-MiniLM-L6-v2"
16
+ client = chromadb.PersistentClient(path = CHROMA_DATA_PATH)
17
+ collection_name = "BioModelsRAG"
18
+
19
+ global db
20
+ db = client.get_or_create_collection(name=collection_name)
21
+
22
+ #Todolists
23
+ #1. if MODEL (cannot download) don't even include (TICK)
24
+ #2. switch the choosing and groq api key so if they just want to visualize thats fine (TICK)
25
+
26
+
27
+ class BioModelFetcher:
28
+ def __init__(self, github_owner="TheBobBob", github_repo_cache="BiomodelsCache", biomodels_json_db_path="src/cached_biomodels.json"):
29
+ self.github_owner = github_owner
30
+ self.github_repo_cache = github_repo_cache
31
+ self.biomodels_json_db_path = biomodels_json_db_path
32
+ self.local_download_dir = tempfile.mkdtemp()
33
+
34
+ def fetch_github_json(self):
35
+ url = f"https://api.github.com/repos/{self.github_owner}/{self.github_repo_cache}/contents/{self.biomodels_json_db_path}"
36
+ headers = {"Accept": "application/vnd.github+json"}
37
+ response = requests.get(url, headers=headers)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ if response.status_code == 200:
40
+ data = response.json()
41
+
42
+ if "download_url" in data:
43
+ file_url = data["download_url"]
44
+ json_response = requests.get(file_url)
45
+ json_data = json_response.json()
46
+
47
+ return json_data
48
+ else:
49
+ raise ValueError(f"Unable to fetch model DB from GitHub repository: {self.github_owner} - {self.github_repo_cache}")
50
+ else:
51
+ raise ValueError(f"Unable to fetch model DB from GitHub repository: {self.github_owner} - {self.github_repo_cache}")
52
+
53
+
54
+ class BioModelSearch:
55
+ @staticmethod
56
+ def search_models(search_str, cached_data):
57
+ query_text = search_str.strip().lower()
58
+ models = {}
59
+
60
+ for model_id, model_data in cached_data.items():
61
+ if 'name' in model_data:
62
+ name = model_data['name'].lower()
63
+ url = model_data['url']
64
+ title = model_data['title']
65
+ authors = model_data['authors']
66
+
67
+ if query_text:
68
+ if ' ' in query_text:
69
+ query_words = query_text.split(" ")
70
+ if all(word in ' '.join([str(v).lower() for v in model_data.values()]) for word in query_words):
71
+ models[model_id] = {
72
+ 'ID': model_id,
73
+ 'name': name,
74
+ 'url': url,
75
+ 'title': title,
76
+ 'authors': authors,
77
+ }
78
+ else:
79
+ if query_text in ' '.join([str(v).lower() for v in model_data.values()]):
80
+ models[model_id] = {
81
+ 'ID': model_id,
82
+ 'name': name,
83
+ 'url': url,
84
+ 'title': title,
85
+ 'authors': authors,
86
+ }
87
 
88
+ return models
89
+
90
+
91
+ class ModelDownloader:
92
+ @staticmethod
93
+ def download_model_file(model_url, model_id, local_download_dir):
94
+ model_url = f"https://raw.githubusercontent.com/sys-bio/BiomodelsStore/main/biomodels/{model_id}/{model_id}_url.xml"
95
+ response = requests.get(model_url)
96
 
97
+ if response.status_code == 200:
98
+ os.makedirs(local_download_dir, exist_ok=True)
99
+ file_path = os.path.join(local_download_dir, f"{model_id}.xml")
100
+
101
+ with open(file_path, 'wb') as file:
102
+ file.write(response.content)
103
+
104
+ return file_path
105
+ else:
106
+ raise ValueError(f"Failed to download the model from {model_url}")
107
+
108
+
109
+ class AntimonyConverter:
110
+ @staticmethod
111
+ def convert_sbml_to_antimony(sbml_file_path, antimony_file_path):
 
 
 
 
 
 
 
 
112
  try:
113
+ r = te.loadSBMLModel(sbml_file_path)
114
+ antimony_str = r.getCurrentAntimony()
115
+
116
+ with open(antimony_file_path, 'w') as file:
117
+ file.write(antimony_str)
 
118
  except Exception as e:
119
+ print(f"Error converting SBML to Antimony: {e}")
120
 
 
121
 
122
+ class BioModelSplitter:
123
+ def __init__(self, groq_api_key):
124
+ self.groq_client = Groq(api_key=groq_api_key)
125
 
126
+ def split_biomodels(self, antimony_file_path, models):
127
+ text_splitter = CharacterTextSplitter(
128
+ separator=" // ",
129
+ chunk_size=1000,
130
+ chunk_overlap=200,
131
+ length_function=len,
132
+ is_separator_regex=False,
133
+ )
134
 
135
+ directory_path = os.path.dirname(os.path.abspath(antimony_file_path))
136
+
137
+ files = os.listdir(directory_path)
138
+ for file in files:
139
+ file_path = os.path.join(directory_path, file)
140
+ try:
141
+ with open(file_path, 'r') as f:
142
+ file_content = f.read()
143
+ items = text_splitter.create_documents([file_content])
144
+ self.create_vector_db(items, models)
145
+ break
146
+ except Exception as e:
147
+ print(f"Error reading file {file_path}: {e}")
148
+
149
+ return db
150
+
151
+ def create_vector_db(self, final_items, models):
152
+ counter = 0
153
+ for model_id in models:
154
+ try:
155
+ results = db.get(where={"document": {"$eq": model_id}})
156
+
157
+ #might be a problem here?
158
+ if results['documents']:
159
+ continue
160
 
161
+ #could also be a problem in how the IDs are created
162
+ for item in final_items:
163
+ counter += 1 # Increment counter for each item
164
+ item_id = f"{counter}_{model_id}"
165
+
166
+ # Construct the prompt
167
+ prompt = f"""
168
+ Summarize the following segment of Antimony in a clear and concise manner:
169
+ 1. Provide a detailed summary using a reasonable number of words.
170
+ 2. Maintain all original values and include any mathematical expressions or values in full.
171
+ 3. Ensure that all variable names and their values are clearly presented.
172
+ 4. Write the summary in paragraph format, putting an emphasis on clarity and completeness.
173
+
174
+ Segment of Antimony: {item}
175
+ """
176
+
177
+ chat_completion = self.groq_client.chat.completions.create(
178
+ messages=[{
179
  "role": "user",
180
  "content": prompt,
181
+ }],
182
+ model="llama-3.1-8b-instant",
 
 
 
 
 
 
 
 
183
  )
 
 
184
 
185
+ if chat_completion.choices[0].message.content:
186
+ db.upsert(
187
+ ids=[item_id],
188
+ metadatas=[{"document": model_id}],
189
+ documents=[chat_completion.choices[0].message.content],
190
+ )
191
+ else:
192
+ print(f"Error: No content returned from Groq for model {model_id}.")
193
+ except Exception as e:
194
+ print(f"Error processing model {model_id}: {e}")
195
+
196
+
197
+ class SBMLNetworkVisualizer:
198
+ @staticmethod
199
+ def sbml_to_network(file_path):
200
+ reader = libsbml.SBMLReader()
201
+ document = reader.readSBML(file_path)
202
+ model = document.getModel()
203
+
204
+ G = nx.Graph()
205
+
206
+ # Add species as nodes
207
+ for species in model.getListOfSpecies():
208
+ species_id = species.getId()
209
+ G.add_node(species_id, label=species_id, shape="dot", color="blue")
210
+
211
+ # Add reactions as edges with reaction details as labels
212
+ for reaction in model.getListOfReactions():
213
+ reaction_id = reaction.getId()
214
+
215
+ substrates = [s.getSpecies() for s in reaction.getListOfReactants()]
216
+ products = [p.getSpecies() for p in reaction.getListOfProducts()]
217
+
218
+ substrate_str = ' + '.join(substrates)
219
+ product_str = ' + '.join(products)
220
+ reaction_equation = f"{substrate_str} -> {product_str}"
221
+
222
+ for substrate in substrates:
223
+ for product in products:
224
+ G.add_edge(
225
+ substrate,
226
+ product,
227
+ label=reaction_equation,
228
+ color="gray"
229
+ )
230
 
231
+ net = Network(notebook=True)
232
+ net.from_nx(G)
233
+ net.set_options("""
234
+ var options = {
235
+ "physics": {
236
+ "enabled": true,
237
+ "barnesHut": {
238
+ "gravitationalConstant": -50000,
239
+ "centralGravity": 0.3,
240
+ "springLength": 95
241
+ },
242
+ "maxVelocity": 50,
243
+ "minVelocity": 0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  },
245
+ "nodes": {
246
+ "size": 20,
247
+ "font": {
248
+ "size": 18
249
+ }
250
+ },
251
+ "edges": {
252
+ "arrows": {
253
+ "to": {
254
+ "enabled": true
255
+ }
256
+ },
257
+ "label": {
258
+ "enabled": true,
259
+ "font": {
260
+ "size": 10
261
+ }
262
  }
263
  }
264
  }
265
+ """)
266
+ return net
267
+
268
+
269
+ class StreamlitApp:
270
+ def __init__(self):
271
+ self.fetcher = BioModelFetcher()
272
+ self.searcher = BioModelSearch()
273
+ self.downloader = ModelDownloader()
274
+ self.splitter = None
275
+ self.visualizer = SBMLNetworkVisualizer()
276
+
277
+ def run(self):
278
+ st.title("BioModelsRAG")
279
+
280
+ if "messages" not in st.session_state:
281
+ st.session_state.messages = []
282
+
283
+ search_str = st.text_input("Enter search query:", key = "search_str")
284
+
285
+ if search_str:
286
+ cached_data = self.fetcher.fetch_github_json()
287
+ models = self.searcher.search_models(search_str, cached_data)
288
+
289
+ if models:
290
+ model_ids = list(models.keys())
291
+ model_ids = [model_id for model_id in model_ids if not str(model_id).startswith("MODEL")]
292
+ if models:
293
+ selected_models = st.multiselect(
294
+ "Select biomodels to analyze",
295
+ options=model_ids,
296
+ default=[model_ids[0]]
297
+ )
298
 
299
+ if models:
300
+ if st.button("Visualize selected models"):
301
+ for model_id in selected_models:
302
+ model_data = models[model_id]
303
+ model_url = model_data['url']
304
 
305
+ model_file_path = self.downloader.download_model_file(model_url, model_id, self.fetcher.local_download_dir)
306
 
307
+ net = self.visualizer.sbml_to_network(model_file_path)
308
 
309
+ st.subheader(f"Model: {model_data['title']}")
310
+ net.show(f"sbml_network_{model_id}.html")
311
+
312
+ HtmlFile = open(f"sbml_network_{model_id}.html", "r", encoding="utf-8")
313
+ st.components.v1.html(HtmlFile.read(), height=600)
314
+
315
+ GROQ_API_KEY = st.text_input("Enter a GROQ API Key (which is free to make!):", key = "api_keys")
316
+ self.splitter = BioModelSplitter(GROQ_API_KEY)
317
+
318
+ if GROQ_API_KEY:
319
+ if st.button("Analyze Selected Models"):
320
+ for model_id in selected_models:
321
+ model_data = models[model_id]
322
+
323
+ st.write(f"Selected model: {model_data['name']}")
324
+
325
+ model_url = model_data['url']
326
+ model_file_path = self.downloader.download_model_file(model_url, model_id, self.fetcher.local_download_dir)
327
+ antimony_file_path = model_file_path.replace(".xml", ".txt")
328
+
329
+ AntimonyConverter.convert_sbml_to_antimony(model_file_path, antimony_file_path)
330
+ self.splitter.split_biomodels(antimony_file_path, selected_models)
331
+
332
+ st.info(f"Model {model_id} {model_data['name']} has successfully been added to the database! :) ")
333
+
334
+ prompt_fin = st.chat_input("Enter Q when you would like to quit! ", key = "input_1")
335
+
336
+ if prompt_fin:
337
+ prompt = str(prompt_fin)
338
+ st.session_state.messages.append({"role": "user", "content": prompt})
339
+
340
+ history = st.session_state.messages[-6:]
341
+ response = self.generate_response(prompt, history, models)
342
+
343
+ st.session_state.messages.append({"role": "assistant", "content": response})
344
+
345
+ for message in st.session_state.messages:
346
+ with st.chat_message(message["role"]):
347
+ st.markdown(message["content"])
348
+
349
+ def generate_response(self, prompt, history, models):
350
+ query_results_final = ""
351
+
352
+ for model_id in models:
353
+ query_results = db.query(
354
+ query_texts = prompt,
355
+ n_results=5,
356
+ where={"document": {"$eq": model_id}},
357
  )
358
+ best_recommendation = query_results['documents']
359
+ flat_recommendation = [item for sublist in best_recommendation for item in (sublist if isinstance(sublist, list) else [sublist])]
360
+ query_results_final += "\n\n".join(flat_recommendation) + "\n\n"
361
 
 
 
 
 
362
 
363
+ prompt_template = f"""
364
+ Using the context and previous conversation provided below, answer the following question. If the information is insufficient to answer the question, please state that clearly:
365
 
366
+ Context:
367
+ {query_results_final}
368
 
369
+ Previous Conversation:
370
+ {history}
371
+
372
+ Instructions:
373
+ 1. Cross-Reference: Use all provided context to define variables and identify any unknown entities.
374
+ 2. Mathematical Calculations: Perform any necessary calculations based on the context and available data.
375
+ 3. Consistency: Remember and incorporate previous responses if the question is related to earlier information.
376
+
377
+ Question:
378
+ {prompt}
379
+ """
380
+ chat_completion = self.splitter.groq_client.chat.completions.create(
381
+ messages=[{
382
+ "role": "user",
383
+ "content": prompt_template,
384
+ }],
385
+ model="llama-3.1-8b-instant",
386
+ )
387
+
388
+ return chat_completion.choices[0].message.content
389
 
 
 
390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  if __name__ == "__main__":
392
+ app = StreamlitApp()
393
+ app.run()