zolicsaki commited on
Commit
4b4a88a
·
verified ·
1 Parent(s): 282ed7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +393 -73
app.py CHANGED
@@ -1,98 +1,418 @@
 
1
  import os
2
- import gradio as gr
3
- import tempfile
4
- from pathlib import Path
5
- from utils import get_file_name, run_pdf2text, run_paper2slides
 
 
 
 
 
 
6
 
7
- # Create a temporary directory for outputs
8
- OUTPUT_FOLDER = tempfile.mkdtemp(prefix="pdf_slides_")
9
 
10
- def paper2slides(pdf_path: str, logo_path: str='logo.npg'):
11
- # Ensure the output directory exists
12
- os.makedirs(OUTPUT_FOLDER, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- output_path = os.path.join(OUTPUT_FOLDER, "converted_slides.pptx")
15
- file_name = get_file_name(full_path=pdf_path)
16
- run_pdf2text(paper_pdf_path=pdf_path, save_json_name=file_name + '.json') ## pdf to json file
17
- run_paper2slides(paper_json_name=file_name + '.json', model='llama3_70b', logo_path=logo_path, save_file_name=output_path) ## json file to slides
18
- return output_path
 
 
 
 
19
 
20
- def convert_pdf_to_slides(pdf_path):
21
  """
22
- Convert a PDF file to a slide deck.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  Args:
25
- pdf_path (str): Path to the input PDF file
26
-
 
27
  Returns:
28
- str: Path to the created slide deck file
29
  """
30
- print(f"Processing PDF: {pdf_path}")
31
-
32
- # Ensure the output directory exists
33
- os.makedirs(OUTPUT_FOLDER, exist_ok=True)
34
 
35
- # Create output filename in the temporary directory
36
- output_filename = f"{Path(pdf_path).stem}_slides.pptx"
37
- output_path = os.path.join(OUTPUT_FOLDER, output_filename)
38
 
39
- output_path = paper2slides(pdf_path, 'logo.png')
40
-
41
- print(f"Conversion complete. Slides saved to: {output_path}")
42
- return output_path
43
 
44
- def process_upload(pdf_file):
45
  """
46
- Process the uploaded PDF file and return the path to the generated slide deck.
47
 
48
  Args:
49
- pdf_file (tempfile): The uploaded PDF file
50
-
51
  Returns:
52
- str: Path to the slide deck for download
53
  """
54
- if pdf_file is None:
55
- return None
56
-
57
- try:
58
- # Convert the PDF to slides
59
- slide_path = convert_pdf_to_slides(pdf_file.name)
60
- return slide_path
61
- except Exception as e:
62
- print(f"Error during conversion: {str(e)}")
63
- return None
64
-
65
- # Create the Gradio interface
66
- with gr.Blocks(title="PDF to Slides Converter") as app:
67
- gr.Markdown("# PDF to Slides Converter")
68
- gr.Markdown("Upload a PDF file to convert it into a slide deck presentation.")
69
 
70
- with gr.Row():
71
- with gr.Column():
72
- pdf_input = gr.File(label="Upload PDF File", file_types=[".pdf"])
73
- convert_btn = gr.Button("Convert to Slides", variant="primary")
74
 
75
- with gr.Column():
76
- output_file = gr.File(label="Download Slides")
77
- status = gr.Markdown("Upload a PDF and click 'Convert to Slides'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- convert_btn.click(
80
- fn=process_upload,
81
- inputs=[pdf_input],
82
- outputs=[output_file],
83
- api_name="convert"
84
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- pdf_input.change(
87
- fn=lambda x: "PDF uploaded. Click 'Convert to Slides' to process." if x else "Upload a PDF and click 'Convert to Slides'",
88
- inputs=[pdf_input],
89
- outputs=[status]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  )
91
 
92
- # Launch the app
93
- if __name__ == "__main__":
94
- app.launch()
95
 
96
- # Optional: Clean up the temporary directory when the app exits
97
- import shutil
98
- shutil.rmtree(OUTPUT_FOLDER, ignore_errors=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
  import os
3
+ import sys
4
+ from contextlib import contextmanager, redirect_stdout
5
+ from io import StringIO
6
+ from typing import Callable, Generator, Optional, List, Dict
7
+ import requests
8
+ import json
9
+ from consts import AUTO_SEARCH_KEYWORD, SEARCH_TOOL_INSTRUCTION, RELATED_QUESTIONS_TEMPLATE_SEARCH, SEARCH_TOOL_INSTRUCTION, RAG_TEMPLATE, GOOGLE_SEARCH_ENDPOINT, DEFAULT_SEARCH_ENGINE_TIMEOUT, RELATED_QUESTIONS_TEMPLATE_NO_SEARCH
10
+ import re
11
+ import asyncio
12
+ import random
13
 
14
+ import streamlit as st
15
+ import yaml
16
 
17
+ current_dir = os.path.dirname(os.path.abspath(__file__))
18
+ kit_dir = os.path.abspath(os.path.join(current_dir, '..'))
19
+ repo_dir = os.path.abspath(os.path.join(kit_dir, '..'))
20
+
21
+ sys.path.append(kit_dir)
22
+ sys.path.append(repo_dir)
23
+
24
+
25
+ from visual_env_utils import are_credentials_set, env_input_fields, initialize_env_variables, save_credentials
26
+
27
+ logging.basicConfig(level=logging.INFO)
28
+ GOOGLE_API_KEY = st.secrets["google_api_key"]
29
+ GOOGLE_CX = st.secrets["google_cx"]
30
+ BACKUP_KEYS = [st.secrets["backup_key_1"], st.secrets["backup_key_2"], st.secrets["backup_key_3"], st.secrets["backup_key_4"], st.secrets["backup_key_5"]]
31
+
32
+ CONFIG_PATH = os.path.join(current_dir, "config.yaml")
33
+
34
+ USER_AGENTS = [
35
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
36
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0.3 Safari/605.1.15",
37
+ ]
38
+
39
+ def load_config():
40
+ with open(CONFIG_PATH, 'r') as yaml_file:
41
+ return yaml.safe_load(yaml_file)
42
+
43
+
44
+ config = load_config()
45
+ prod_mode = config.get('prod_mode', False)
46
+ additional_env_vars = config.get('additional_env_vars', None)
47
+
48
+ @contextmanager
49
+ def st_capture(output_func: Callable[[str], None]) -> Generator:
50
+ """
51
+ context manager to catch stdout and send it to an output streamlit element
52
+ Args:
53
+ output_func (function to write terminal output in
54
+ Yields:
55
+ Generator:
56
+ """
57
+ with StringIO() as stdout, redirect_stdout(stdout):
58
+ old_write = stdout.write
59
+
60
+ def new_write(string: str) -> int:
61
+ ret = old_write(string)
62
+ output_func(stdout.getvalue())
63
+ return ret
64
+
65
+ stdout.write = new_write # type: ignore
66
+ yield
67
+
68
+ async def run_samba_api_inference(query, system_prompt = None, ignore_context=False, max_tokens_to_generate=None, num_seconds_to_sleep=1, over_ride_key=None):
69
+ # First construct messages
70
+ messages = []
71
+ if system_prompt is not None:
72
+ messages.append({"role": "system", "content": system_prompt})
73
+
74
+ if not ignore_context:
75
+ for ques, ans in zip(
76
+ st.session_state.chat_history[::3],
77
+ st.session_state.chat_history[1::3],
78
+ ):
79
+ messages.append({"role": "user", "content": ques})
80
+ messages.append({"role": "assistant", "content": ans})
81
+ messages.append({"role": "user", "content": query})
82
+
83
+ # Create payloads
84
+ payload = {
85
+ "messages": messages,
86
+ "model": config.get("model")
87
+ }
88
+ if max_tokens_to_generate is not None:
89
+ payload["max_tokens"] = max_tokens_to_generate
90
+
91
+ if over_ride_key is None:
92
+ api_key = st.session_state.SAMBANOVA_API_KEY
93
+ else:
94
+ api_key = over_ride_key
95
+ headers = {
96
+ "Authorization": f"Bearer {api_key}",
97
+ "Content-Type": "application/json"
98
+ }
99
+
100
+ try:
101
+ post_response = await asyncio.get_event_loop().run_in_executor(None, lambda: requests.post(config.get("url"), json=payload, headers=headers, stream=True))
102
+ post_response.raise_for_status()
103
+ except requests.exceptions.HTTPError as e:
104
+ if post_response.status_code in {401, 503}:
105
+ st.info(f"Invalid Key! Please make sure you have a valid SambaCloud key from https://cloud.sambanova.ai/.")
106
+ return "Invalid Key! Please make sure you have a valid SambaCloud key from https://cloud.sambanova.ai/."
107
+ if post_response.status_code in {429, 504}:
108
+ await asyncio.sleep(num_seconds_to_sleep)
109
+ return await run_samba_api_inference(query, over_ride_key=random.choice(BACKUP_KEYS)) # Retry the request
110
+ else:
111
+ print(f"Request failed with status code: {post_response.status_code}. Error: {e}")
112
+ return "Invalid Key! Please make sure you have a valid SambaCloud key from https://cloud.sambanova.ai/."
113
+
114
+ response_data = json.loads(post_response.text)
115
+
116
+ return response_data["choices"][0]["message"]["content"]
117
+
118
+ def extract_query(text):
119
+ # Regular expression to capture the query within the quotes
120
+ match = re.search(r'query="(.*?)"', text)
121
 
122
+ # If a match is found, return the query, otherwise return None
123
+ if match:
124
+ return match.group(1)
125
+ return None
126
+
127
+ def extract_text_between_brackets(text):
128
+ # Using regular expressions to find all text between brackets
129
+ matches = re.findall(r'\[(.*?)\]', text)
130
+ return matches
131
 
132
+ def search_with_google(query: str):
133
  """
134
+ Search with google and return the contexts.
135
+ """
136
+ params = {
137
+ "key": GOOGLE_API_KEY,
138
+ "cx": GOOGLE_CX,
139
+ "q": query,
140
+ "num": 5,
141
+ }
142
+ response = requests.get(
143
+ GOOGLE_SEARCH_ENDPOINT, params=params, timeout=DEFAULT_SEARCH_ENGINE_TIMEOUT
144
+ )
145
+
146
+ if not response.ok:
147
+ raise Exception(response.status_code, "Search engine error.")
148
+ json_content = response.json()
149
+
150
+ contexts = json_content["items"][:5]
151
+
152
+ return contexts
153
+
154
+ async def get_related_questions(query, contexts = None):
155
+ if contexts:
156
+ related_question_system_prompt = RELATED_QUESTIONS_TEMPLATE_SEARCH.format(
157
+ context="\n\n".join([c["snippet"] for c in contexts])
158
+ )
159
+ else:
160
+ # When no search is performed, use a generic prompt
161
+ related_question_system_prompt = RELATED_QUESTIONS_TEMPLATE_SEARCH
162
+
163
+ related_questions_raw = await run_samba_api_inference(query, related_question_system_prompt)
164
+
165
+ try:
166
+ return json.loads(related_questions_raw)
167
+ except:
168
+ try:
169
+ extracted_related_questions = extract_text_between_brackets(related_questions_raw)
170
+ return json.loads(extracted_related_questions)
171
+ except:
172
+ return []
173
+
174
+ def process_citations(response: str, search_result_contexts: List[Dict]) -> str:
175
+ """
176
+ Process citations in the response and replace them with numbered icons.
177
 
178
  Args:
179
+ response (str): The original response with citations.
180
+ search_result_contexts (List[Dict]): The search results with context information.
181
+
182
  Returns:
183
+ str: The processed response with numbered icons for citations.
184
  """
185
+ citations = re.findall(r'\[citation:(\d+)\]', response)
 
 
 
186
 
187
+ for i, citation in enumerate(citations, 1):
188
+ response = response.replace(f'[citation:{citation}]', f'<sup>[{i}]</sup>')
 
189
 
190
+ return response
 
 
 
191
 
192
+ def generate_citation_links(search_result_contexts: List[Dict]) -> str:
193
  """
194
+ Generate HTML for citation links.
195
 
196
  Args:
197
+ search_result_contexts (List[Dict]): The search results with context information.
198
+
199
  Returns:
200
+ str: HTML string with numbered citation links.
201
  """
202
+ citation_links = []
203
+ for i, context in enumerate(search_result_contexts, 1):
204
+ title = context.get('title', 'No title')
205
+ link = context.get('link', '#')
206
+ citation_links.append(f'<p>[{i}] <a href="{link}" target="_blank">{title}</a></p>')
 
 
 
 
 
 
 
 
 
 
207
 
208
+ return ''.join(citation_links)
209
+
 
 
210
 
211
+ async def run_auto_search_pipe(query):
212
+ full_context_answer = asyncio.create_task(run_samba_api_inference(query))
213
+ related_questions_no_search = asyncio.create_task(get_related_questions(query))
214
+
215
+ # First call Llama3.1 8B with special system prompt for auto search
216
+ with st.spinner('Checking if web search is needed...'):
217
+ auto_search_result = await run_samba_api_inference(query, SEARCH_TOOL_INSTRUCTION, True, max_tokens_to_generate=100)
218
+
219
+ # If Llama3.1 8B returns a search query then run search pipeline
220
+ if AUTO_SEARCH_KEYWORD in auto_search_result:
221
+ st.session_state.search_performed = True
222
+ # search
223
+ with st.spinner('Searching the internet...'):
224
+ search_result_contexts = search_with_google(extract_query(auto_search_result))
225
+
226
+ # RAG response
227
+ with st.spinner('Generating response based on web search...'):
228
+ rag_system_prompt = RAG_TEMPLATE.format(
229
+ context="\n\n".join(
230
+ [f"[[citation:{i+1}]] {c['snippet']}" for i, c in enumerate(search_result_contexts)]
231
+ )
232
+ )
233
+
234
+ model_response = asyncio.create_task(run_samba_api_inference(query, rag_system_prompt))
235
+ related_questions = asyncio.create_task(get_related_questions(query, search_result_contexts))
236
+ # Process citations and generate links
237
+ citation_links = generate_citation_links(search_result_contexts)
238
+
239
+ model_response_complete = await model_response
240
+ processed_response = process_citations(model_response_complete, search_result_contexts)
241
+ related_questions_complete = await related_questions
242
+
243
+
244
+ return processed_response, citation_links, related_questions_complete
245
 
246
+ # If Llama3.1 8B returns an answer directly, then please query Llama 405B to get the best possible answer
247
+ else:
248
+ st.session_state.search_performed = False
249
+ result = await full_context_answer
250
+ related_questions = await related_questions_no_search
251
+ return result, "", related_questions
252
+
253
+
254
+ def handle_userinput(user_question: Optional[str]) -> None:
255
+ """
256
+ Handle user input and generate a response, also update chat UI in streamlit app
257
+ Args:
258
+ user_question (str): The user's question or input.
259
+ """
260
+ if user_question:
261
+ # Clear any existing related question buttons
262
+ if 'related_questions' in st.session_state:
263
+ st.session_state.related_questions = []
264
+
265
+ async def run_search():
266
+ return await run_auto_search_pipe(user_question)
267
+
268
+ response, citation_links, related_questions = asyncio.run(run_search())
269
+ if st.session_state.search_performed:
270
+ search_or_not_text = "🔍 Web search was performed for this query."
271
+ else:
272
+ search_or_not_text = "📚 This response was generated from the model's knowledge."
273
+
274
+ st.session_state.chat_history.append(user_question)
275
+ st.session_state.chat_history.append((response, citation_links))
276
+ st.session_state.chat_history.append(search_or_not_text)
277
+
278
+ # Store related questions in session state
279
+ st.session_state.related_questions = related_questions
280
+
281
+ for ques, ans, search_or_not_text in zip(
282
+ st.session_state.chat_history[::3],
283
+ st.session_state.chat_history[1::3],
284
+ st.session_state.chat_history[2::3],
285
+ ):
286
+ with st.chat_message('user'):
287
+ st.write(f'{ques}')
288
 
289
+ with st.chat_message(
290
+ 'ai',
291
+ avatar='https://sambanova.ai/hubfs/logotype_sambanova_orange.png',
292
+ ):
293
+ st.markdown(f'{ans[0]}', unsafe_allow_html=True)
294
+ if ans[1]:
295
+ st.markdown("### Sources", unsafe_allow_html=True)
296
+ st.markdown(ans[1], unsafe_allow_html=True)
297
+ st.info(search_or_not_text)
298
+ if len(st.session_state.related_questions) > 0:
299
+ st.markdown("### Related Questions")
300
+ for question in st.session_state.related_questions:
301
+ if st.button(question):
302
+ setChatInputValue(question)
303
+
304
+ def setChatInputValue(chat_input_value: str) -> None:
305
+ js = f"""
306
+ <script>
307
+ function insertText(dummy_var_to_force_repeat_execution) {{
308
+ var chatInput = parent.document.querySelector('textarea[data-testid="stChatInputTextArea"]');
309
+ var nativeInputValueSetter = Object.getOwnPropertyDescriptor(window.HTMLTextAreaElement.prototype, "value").set;
310
+ nativeInputValueSetter.call(chatInput, "{chat_input_value}");
311
+ var event = new Event('input', {{ bubbles: true}});
312
+ chatInput.dispatchEvent(event);
313
+ }}
314
+ insertText(3);
315
+ </script>
316
+ """
317
+ st.components.v1.html(js)
318
+
319
+ def main() -> None:
320
+ st.set_page_config(
321
+ page_title='Auto Web Search Demo',
322
+ page_icon='https://sambanova.ai/hubfs/logotype_sambanova_orange.png',
323
  )
324
 
 
 
 
325
 
326
+ initialize_env_variables(prod_mode, additional_env_vars)
327
+
328
+ if 'input_disabled' not in st.session_state:
329
+ if 'SAMBANOVA_API_KEY' in st.session_state:
330
+ st.session_state.input_disabled = False
331
+ else:
332
+ st.session_state.input_disabled = True
333
+ if 'chat_history' not in st.session_state:
334
+ st.session_state.chat_history = []
335
+ if 'search_performed' not in st.session_state:
336
+ st.session_state.search_performed = False
337
+ if 'related_questions' not in st.session_state:
338
+ st.session_state.related_questions = []
339
+
340
+ st.title(' Auto Web Search')
341
+ st.subheader('Powered by :orange[SambaNova Cloud] and Llama405B')
342
+
343
+ with st.sidebar:
344
+ st.title('Get your :orange[SambaNova Cloud] API key [here](https://cloud.sambanova.ai/apis)')
345
+
346
+ if not are_credentials_set(additional_env_vars):
347
+ api_key, additional_vars = env_input_fields(additional_env_vars)
348
+ if st.button('Save Credentials'):
349
+ message = save_credentials(api_key, additional_vars, prod_mode)
350
+ st.session_state.input_disabled = False
351
+ st.success(message)
352
+ st.rerun()
353
+
354
+ else:
355
+ st.success('Credentials are set')
356
+ if st.button('Clear Credentials'):
357
+ save_credentials('', {var: '' for var in (additional_env_vars or [])}, prod_mode)
358
+ st.session_state.input_disabled = True
359
+ st.rerun()
360
+
361
+
362
+ if are_credentials_set(additional_env_vars):
363
+ with st.expander('**Example Queries With Search**', expanded=True):
364
+ if st.button('What is the population of Virginia?'):
365
+ setChatInputValue(
366
+ 'What is the population of Virginia?'
367
+ )
368
+ if st.button('SNP 500 stock market moves'):
369
+ setChatInputValue('SNP 500 stock market moves')
370
+ if st.button('What is the weather in Palo Alto?'):
371
+ setChatInputValue(
372
+ 'What is the weather in Palo Alto?'
373
+ )
374
+ with st.expander('**Example Queries No Search**', expanded=True):
375
+ if st.button('write a short poem following a specific pattern: the first letter of every word should spell out the name of a country.'):
376
+ setChatInputValue(
377
+ 'write a short poem following a specific pattern: the first letter of every word should spell out the name of a country.'
378
+ )
379
+ if st.button('Write a python program to find the longest root to leaf path in a tree, and some test cases for it.'):
380
+ setChatInputValue('Write a python program to find the longest root to leaf path in a tree, and some test cases for it.')
381
+
382
+ st.markdown('**Reset chat**')
383
+ st.markdown('**Note:** Resetting the chat will clear all interactions history')
384
+ if st.button('Reset conversation'):
385
+ st.session_state.chat_history = []
386
+ st.session_state.sources_history = []
387
+ if 'related_questions' in st.session_state:
388
+ st.session_state.related_questions = []
389
+ st.toast('Interactions reset. The next response will clear the history on the screen')
390
+
391
+ # Add a footer with the GitHub citation
392
+ footer_html = """
393
+ <style>
394
+ .footer {
395
+ position: fixed;
396
+ right: 10px;
397
+ bottom: 10px;
398
+ width: auto;
399
+ background-color: transparent;
400
+ color: grey;
401
+ text-align: right;
402
+ padding: 10px;
403
+ font-size: 16px;
404
+ }
405
+ </style>
406
+ <div class="footer">
407
+ Inspired by: <a href="https://github.com/leptonai/search_with_lepton" target="_blank">search_with_lepton</a>
408
+ </div>
409
+ """
410
+ st.markdown(footer_html, unsafe_allow_html=True)
411
+
412
+ user_question = st.chat_input('Ask something', disabled=st.session_state.input_disabled, key='TheChatInput')
413
+ handle_userinput(user_question)
414
+
415
+
416
+
417
+ if __name__ == '__main__':
418
+ main()