zolicsaki commited on
Commit
0adfdb6
·
verified ·
1 Parent(s): 4b55fdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -393
app.py CHANGED
@@ -1,418 +1,98 @@
1
- import logging
2
  import os
3
- import sys
4
- from contextlib import contextmanager, redirect_stdout
5
- from io import StringIO
6
- from typing import Callable, Generator, Optional, List, Dict
7
- import requests
8
- import json
9
- from consts import AUTO_SEARCH_KEYWORD, SEARCH_TOOL_INSTRUCTION, RELATED_QUESTIONS_TEMPLATE_SEARCH, SEARCH_TOOL_INSTRUCTION, RAG_TEMPLATE, GOOGLE_SEARCH_ENDPOINT, DEFAULT_SEARCH_ENGINE_TIMEOUT, RELATED_QUESTIONS_TEMPLATE_NO_SEARCH
10
- import re
11
- import asyncio
12
- import random
13
 
14
- import streamlit as st
15
- import yaml
16
 
17
- current_dir = os.path.dirname(os.path.abspath(__file__))
18
- kit_dir = os.path.abspath(os.path.join(current_dir, '..'))
19
- repo_dir = os.path.abspath(os.path.join(kit_dir, '..'))
20
-
21
- sys.path.append(kit_dir)
22
- sys.path.append(repo_dir)
23
-
24
-
25
- from visual_env_utils import are_credentials_set, env_input_fields, initialize_env_variables, save_credentials
26
-
27
- logging.basicConfig(level=logging.INFO)
28
- GOOGLE_API_KEY = st.secrets["google_api_key"]
29
- GOOGLE_CX = st.secrets["google_cx"]
30
- BACKUP_KEYS = [st.secrets["backup_key_1"], st.secrets["backup_key_2"], st.secrets["backup_key_3"], st.secrets["backup_key_4"], st.secrets["backup_key_5"]]
31
-
32
- CONFIG_PATH = os.path.join(current_dir, "config.yaml")
33
-
34
- USER_AGENTS = [
35
- "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
36
- "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0.3 Safari/605.1.15",
37
- ]
38
-
39
- def load_config():
40
- with open(CONFIG_PATH, 'r') as yaml_file:
41
- return yaml.safe_load(yaml_file)
42
-
43
-
44
- config = load_config()
45
- prod_mode = config.get('prod_mode', False)
46
- additional_env_vars = config.get('additional_env_vars', None)
47
-
48
- @contextmanager
49
- def st_capture(output_func: Callable[[str], None]) -> Generator:
50
- """
51
- context manager to catch stdout and send it to an output streamlit element
52
- Args:
53
- output_func (function to write terminal output in
54
- Yields:
55
- Generator:
56
- """
57
- with StringIO() as stdout, redirect_stdout(stdout):
58
- old_write = stdout.write
59
-
60
- def new_write(string: str) -> int:
61
- ret = old_write(string)
62
- output_func(stdout.getvalue())
63
- return ret
64
-
65
- stdout.write = new_write # type: ignore
66
- yield
67
-
68
- async def run_samba_api_inference(query, system_prompt = None, ignore_context=False, max_tokens_to_generate=None, num_seconds_to_sleep=1, over_ride_key=None):
69
- # First construct messages
70
- messages = []
71
- if system_prompt is not None:
72
- messages.append({"role": "system", "content": system_prompt})
73
-
74
- if not ignore_context:
75
- for ques, ans in zip(
76
- st.session_state.chat_history[::3],
77
- st.session_state.chat_history[1::3],
78
- ):
79
- messages.append({"role": "user", "content": ques})
80
- messages.append({"role": "assistant", "content": ans})
81
- messages.append({"role": "user", "content": query})
82
-
83
- # Create payloads
84
- payload = {
85
- "messages": messages,
86
- "model": config.get("model")
87
- }
88
- if max_tokens_to_generate is not None:
89
- payload["max_tokens"] = max_tokens_to_generate
90
-
91
- if over_ride_key is None:
92
- api_key = st.session_state.SAMBANOVA_API_KEY
93
- else:
94
- api_key = over_ride_key
95
- headers = {
96
- "Authorization": f"Bearer {api_key}",
97
- "Content-Type": "application/json"
98
- }
99
-
100
- try:
101
- post_response = await asyncio.get_event_loop().run_in_executor(None, lambda: requests.post(config.get("url"), json=payload, headers=headers, stream=True))
102
- post_response.raise_for_status()
103
- except requests.exceptions.HTTPError as e:
104
- if post_response.status_code in {401, 503}:
105
- st.info(f"Invalid Key! Please make sure you have a valid SambaCloud key from https://cloud.sambanova.ai/.")
106
- return "Invalid Key! Please make sure you have a valid SambaCloud key from https://cloud.sambanova.ai/."
107
- if post_response.status_code in {429, 504}:
108
- await asyncio.sleep(num_seconds_to_sleep)
109
- return await run_samba_api_inference(query, over_ride_key=random.choice(BACKUP_KEYS)) # Retry the request
110
- else:
111
- print(f"Request failed with status code: {post_response.status_code}. Error: {e}")
112
- return "Invalid Key! Please make sure you have a valid SambaCloud key from https://cloud.sambanova.ai/."
113
-
114
- response_data = json.loads(post_response.text)
115
-
116
- return response_data["choices"][0]["message"]["content"]
117
-
118
- def extract_query(text):
119
- # Regular expression to capture the query within the quotes
120
- match = re.search(r'query="(.*?)"', text)
121
 
122
- # If a match is found, return the query, otherwise return None
123
- if match:
124
- return match.group(1)
125
- return None
126
-
127
- def extract_text_between_brackets(text):
128
- # Using regular expressions to find all text between brackets
129
- matches = re.findall(r'\[(.*?)\]', text)
130
- return matches
131
 
132
- def search_with_google(query: str):
133
  """
134
- Search with google and return the contexts.
135
- """
136
- params = {
137
- "key": GOOGLE_API_KEY,
138
- "cx": GOOGLE_CX,
139
- "q": query,
140
- "num": 5,
141
- }
142
- response = requests.get(
143
- GOOGLE_SEARCH_ENDPOINT, params=params, timeout=DEFAULT_SEARCH_ENGINE_TIMEOUT
144
- )
145
-
146
- if not response.ok:
147
- raise Exception(response.status_code, "Search engine error.")
148
- json_content = response.json()
149
-
150
- contexts = json_content["items"][:5]
151
-
152
- return contexts
153
-
154
- async def get_related_questions(query, contexts = None):
155
- if contexts:
156
- related_question_system_prompt = RELATED_QUESTIONS_TEMPLATE_SEARCH.format(
157
- context="\n\n".join([c["snippet"] for c in contexts])
158
- )
159
- else:
160
- # When no search is performed, use a generic prompt
161
- related_question_system_prompt = RELATED_QUESTIONS_TEMPLATE_SEARCH
162
-
163
- related_questions_raw = await run_samba_api_inference(query, related_question_system_prompt)
164
-
165
- try:
166
- return json.loads(related_questions_raw)
167
- except:
168
- try:
169
- extracted_related_questions = extract_text_between_brackets(related_questions_raw)
170
- return json.loads(extracted_related_questions)
171
- except:
172
- return []
173
-
174
- def process_citations(response: str, search_result_contexts: List[Dict]) -> str:
175
- """
176
- Process citations in the response and replace them with numbered icons.
177
 
178
  Args:
179
- response (str): The original response with citations.
180
- search_result_contexts (List[Dict]): The search results with context information.
181
-
182
  Returns:
183
- str: The processed response with numbered icons for citations.
184
  """
185
- citations = re.findall(r'\[citation:(\d+)\]', response)
 
 
 
186
 
187
- for i, citation in enumerate(citations, 1):
188
- response = response.replace(f'[citation:{citation}]', f'<sup>[{i}]</sup>')
 
189
 
190
- return response
 
 
 
191
 
192
- def generate_citation_links(search_result_contexts: List[Dict]) -> str:
193
  """
194
- Generate HTML for citation links.
195
 
196
  Args:
197
- search_result_contexts (List[Dict]): The search results with context information.
198
-
199
  Returns:
200
- str: HTML string with numbered citation links.
201
  """
202
- citation_links = []
203
- for i, context in enumerate(search_result_contexts, 1):
204
- title = context.get('title', 'No title')
205
- link = context.get('link', '#')
206
- citation_links.append(f'<p>[{i}] <a href="{link}" target="_blank">{title}</a></p>')
207
 
208
- return ''.join(citation_links)
209
-
210
-
211
- async def run_auto_search_pipe(query):
212
- full_context_answer = asyncio.create_task(run_samba_api_inference(query))
213
- related_questions_no_search = asyncio.create_task(get_related_questions(query))
214
-
215
- # First call Llama3.1 8B with special system prompt for auto search
216
- with st.spinner('Checking if web search is needed...'):
217
- auto_search_result = await run_samba_api_inference(query, SEARCH_TOOL_INSTRUCTION, True, max_tokens_to_generate=100)
218
-
219
- # If Llama3.1 8B returns a search query then run search pipeline
220
- if AUTO_SEARCH_KEYWORD in auto_search_result:
221
- st.session_state.search_performed = True
222
- # search
223
- with st.spinner('Searching the internet...'):
224
- search_result_contexts = search_with_google(extract_query(auto_search_result))
225
-
226
- # RAG response
227
- with st.spinner('Generating response based on web search...'):
228
- rag_system_prompt = RAG_TEMPLATE.format(
229
- context="\n\n".join(
230
- [f"[[citation:{i+1}]] {c['snippet']}" for i, c in enumerate(search_result_contexts)]
231
- )
232
- )
233
-
234
- model_response = asyncio.create_task(run_samba_api_inference(query, rag_system_prompt))
235
- related_questions = asyncio.create_task(get_related_questions(query, search_result_contexts))
236
- # Process citations and generate links
237
- citation_links = generate_citation_links(search_result_contexts)
238
-
239
- model_response_complete = await model_response
240
- processed_response = process_citations(model_response_complete, search_result_contexts)
241
- related_questions_complete = await related_questions
242
-
243
-
244
- return processed_response, citation_links, related_questions_complete
245
 
246
- # If Llama3.1 8B returns an answer directly, then please query Llama 405B to get the best possible answer
247
- else:
248
- st.session_state.search_performed = False
249
- result = await full_context_answer
250
- related_questions = await related_questions_no_search
251
- return result, "", related_questions
252
-
253
-
254
- def handle_userinput(user_question: Optional[str]) -> None:
255
- """
256
- Handle user input and generate a response, also update chat UI in streamlit app
257
- Args:
258
- user_question (str): The user's question or input.
259
- """
260
- if user_question:
261
- # Clear any existing related question buttons
262
- if 'related_questions' in st.session_state:
263
- st.session_state.related_questions = []
264
-
265
- async def run_search():
266
- return await run_auto_search_pipe(user_question)
267
 
268
- response, citation_links, related_questions = asyncio.run(run_search())
269
- if st.session_state.search_performed:
270
- search_or_not_text = "🔍 Web search was performed for this query."
271
- else:
272
- search_or_not_text = "📚 This response was generated from the model's knowledge."
273
-
274
- st.session_state.chat_history.append(user_question)
275
- st.session_state.chat_history.append((response, citation_links))
276
- st.session_state.chat_history.append(search_or_not_text)
277
-
278
- # Store related questions in session state
279
- st.session_state.related_questions = related_questions
280
-
281
- for ques, ans, search_or_not_text in zip(
282
- st.session_state.chat_history[::3],
283
- st.session_state.chat_history[1::3],
284
- st.session_state.chat_history[2::3],
285
- ):
286
- with st.chat_message('user'):
287
- st.write(f'{ques}')
288
 
289
- with st.chat_message(
290
- 'ai',
291
- avatar='https://sambanova.ai/hubfs/logotype_sambanova_orange.png',
292
- ):
293
- st.markdown(f'{ans[0]}', unsafe_allow_html=True)
294
- if ans[1]:
295
- st.markdown("### Sources", unsafe_allow_html=True)
296
- st.markdown(ans[1], unsafe_allow_html=True)
297
- st.info(search_or_not_text)
298
- if len(st.session_state.related_questions) > 0:
299
- st.markdown("### Related Questions")
300
- for question in st.session_state.related_questions:
301
- if st.button(question):
302
- setChatInputValue(question)
303
-
304
- def setChatInputValue(chat_input_value: str) -> None:
305
- js = f"""
306
- <script>
307
- function insertText(dummy_var_to_force_repeat_execution) {{
308
- var chatInput = parent.document.querySelector('textarea[data-testid="stChatInputTextArea"]');
309
- var nativeInputValueSetter = Object.getOwnPropertyDescriptor(window.HTMLTextAreaElement.prototype, "value").set;
310
- nativeInputValueSetter.call(chatInput, "{chat_input_value}");
311
- var event = new Event('input', {{ bubbles: true}});
312
- chatInput.dispatchEvent(event);
313
- }}
314
- insertText(3);
315
- </script>
316
- """
317
- st.components.v1.html(js)
318
-
319
- def main() -> None:
320
- st.set_page_config(
321
- page_title='Auto Web Search Demo',
322
- page_icon='https://sambanova.ai/hubfs/logotype_sambanova_orange.png',
323
  )
324
-
325
-
326
- initialize_env_variables(prod_mode, additional_env_vars)
327
-
328
- if 'input_disabled' not in st.session_state:
329
- if 'SAMBANOVA_API_KEY' in st.session_state:
330
- st.session_state.input_disabled = False
331
- else:
332
- st.session_state.input_disabled = True
333
- if 'chat_history' not in st.session_state:
334
- st.session_state.chat_history = []
335
- if 'search_performed' not in st.session_state:
336
- st.session_state.search_performed = False
337
- if 'related_questions' not in st.session_state:
338
- st.session_state.related_questions = []
339
-
340
- st.title(' Auto Web Search')
341
- st.subheader('Powered by :orange[SambaNova Cloud] and Llama405B')
342
-
343
- with st.sidebar:
344
- st.title('Get your :orange[SambaNova Cloud] API key [here](https://cloud.sambanova.ai/apis)')
345
-
346
- if not are_credentials_set(additional_env_vars):
347
- api_key, additional_vars = env_input_fields(additional_env_vars)
348
- if st.button('Save Credentials'):
349
- message = save_credentials(api_key, additional_vars, prod_mode)
350
- st.session_state.input_disabled = False
351
- st.success(message)
352
- st.rerun()
353
-
354
- else:
355
- st.success('Credentials are set')
356
- if st.button('Clear Credentials'):
357
- save_credentials('', {var: '' for var in (additional_env_vars or [])}, prod_mode)
358
- st.session_state.input_disabled = True
359
- st.rerun()
360
-
361
-
362
- if are_credentials_set(additional_env_vars):
363
- with st.expander('**Example Queries With Search**', expanded=True):
364
- if st.button('What is the population of Virginia?'):
365
- setChatInputValue(
366
- 'What is the population of Virginia?'
367
- )
368
- if st.button('SNP 500 stock market moves'):
369
- setChatInputValue('SNP 500 stock market moves')
370
- if st.button('What is the weather in Palo Alto?'):
371
- setChatInputValue(
372
- 'What is the weather in Palo Alto?'
373
- )
374
- with st.expander('**Example Queries No Search**', expanded=True):
375
- if st.button('write a short poem following a specific pattern: the first letter of every word should spell out the name of a country.'):
376
- setChatInputValue(
377
- 'write a short poem following a specific pattern: the first letter of every word should spell out the name of a country.'
378
- )
379
- if st.button('Write a python program to find the longest root to leaf path in a tree, and some test cases for it.'):
380
- setChatInputValue('Write a python program to find the longest root to leaf path in a tree, and some test cases for it.')
381
-
382
- st.markdown('**Reset chat**')
383
- st.markdown('**Note:** Resetting the chat will clear all interactions history')
384
- if st.button('Reset conversation'):
385
- st.session_state.chat_history = []
386
- st.session_state.sources_history = []
387
- if 'related_questions' in st.session_state:
388
- st.session_state.related_questions = []
389
- st.toast('Interactions reset. The next response will clear the history on the screen')
390
-
391
- # Add a footer with the GitHub citation
392
- footer_html = """
393
- <style>
394
- .footer {
395
- position: fixed;
396
- right: 10px;
397
- bottom: 10px;
398
- width: auto;
399
- background-color: transparent;
400
- color: grey;
401
- text-align: right;
402
- padding: 10px;
403
- font-size: 16px;
404
- }
405
- </style>
406
- <div class="footer">
407
- Inspired by: <a href="https://github.com/leptonai/search_with_lepton" target="_blank">search_with_lepton</a>
408
- </div>
409
- """
410
- st.markdown(footer_html, unsafe_allow_html=True)
411
-
412
- user_question = st.chat_input('Ask something', disabled=st.session_state.input_disabled, key='TheChatInput')
413
- handle_userinput(user_question)
414
-
415
 
 
 
 
 
 
 
 
 
 
416
 
417
- if __name__ == '__main__':
418
- main()
 
 
 
1
  import os
2
+ import gradio as gr
3
+ import tempfile
4
+ from pathlib import Path
5
+ from utils import get_file_name, run_pdf2text, run_paper2slides
 
 
 
 
 
 
6
 
7
+ # Create a temporary directory for outputs
8
+ OUTPUT_FOLDER = tempfile.mkdtemp(prefix="pdf_slides_")
9
 
10
+ def paper2slides(pdf_path: str, logo_path: str='logo.npg'):
11
+ # Ensure the output directory exists
12
+ os.makedirs(OUTPUT_FOLDER, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ output_path = os.path.join(OUTPUT_FOLDER, "converted_slides.pptx")
15
+ file_name = get_file_name(full_path=pdf_path)
16
+ run_pdf2text(paper_pdf_path=pdf_path, save_json_name=file_name + '.json') ## pdf to json file
17
+ run_paper2slides(paper_json_name=file_name + '.json', model='llama3_70b', logo_path=logo_path, save_file_name=output_path) ## json file to slides
18
+ return output_path
 
 
 
 
19
 
20
+ def convert_pdf_to_slides(pdf_path):
21
  """
22
+ Convert a PDF file to a slide deck.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  Args:
25
+ pdf_path (str): Path to the input PDF file
26
+
 
27
  Returns:
28
+ str: Path to the created slide deck file
29
  """
30
+ print(f"Processing PDF: {pdf_path}")
31
+
32
+ # Ensure the output directory exists
33
+ os.makedirs(OUTPUT_FOLDER, exist_ok=True)
34
 
35
+ # Create output filename in the temporary directory
36
+ output_filename = f"{Path(pdf_path).stem}_slides.pptx"
37
+ output_path = os.path.join(OUTPUT_FOLDER, output_filename)
38
 
39
+ output_path = paper2slides(pdf_path, 'logo.png')
40
+
41
+ print(f"Conversion complete. Slides saved to: {output_path}")
42
+ return output_path
43
 
44
+ def process_upload(pdf_file):
45
  """
46
+ Process the uploaded PDF file and return the path to the generated slide deck.
47
 
48
  Args:
49
+ pdf_file (tempfile): The uploaded PDF file
50
+
51
  Returns:
52
+ str: Path to the slide deck for download
53
  """
54
+ if pdf_file is None:
55
+ return None
 
 
 
56
 
57
+ try:
58
+ # Convert the PDF to slides
59
+ slide_path = convert_pdf_to_slides(pdf_file.name)
60
+ return slide_path
61
+ except Exception as e:
62
+ print(f"Error during conversion: {str(e)}")
63
+ return None
64
+
65
+ # Create the Gradio interface
66
+ with gr.Blocks(title="PDF to Slides Converter") as app:
67
+ gr.Markdown("# PDF to Slides Converter")
68
+ gr.Markdown("Upload a PDF file to convert it into a slide deck presentation.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ with gr.Row():
71
+ with gr.Column():
72
+ pdf_input = gr.File(label="Upload PDF File", file_types=[".pdf"])
73
+ convert_btn = gr.Button("Convert to Slides", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ with gr.Column():
76
+ output_file = gr.File(label="Download Slides")
77
+ status = gr.Markdown("Upload a PDF and click 'Convert to Slides'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ convert_btn.click(
80
+ fn=process_upload,
81
+ inputs=[pdf_input],
82
+ outputs=[output_file],
83
+ api_name="convert"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ pdf_input.change(
87
+ fn=lambda x: "PDF uploaded. Click 'Convert to Slides' to process." if x else "Upload a PDF and click 'Convert to Slides'",
88
+ inputs=[pdf_input],
89
+ outputs=[status]
90
+ )
91
+
92
+ # Launch the app
93
+ if __name__ == "__main__":
94
+ app.launch()
95
 
96
+ # Optional: Clean up the temporary directory when the app exits
97
+ import shutil
98
+ shutil.rmtree(OUTPUT_FOLDER, ignore_errors=True)