|
import logging |
|
import os |
|
import sys |
|
from contextlib import contextmanager, redirect_stdout |
|
from io import StringIO |
|
from typing import Callable, Generator, Optional, List, Dict |
|
import requests |
|
import json |
|
from consts import AUTO_SEARCH_KEYWORD, SEARCH_TOOL_INSTRUCTION, RELATED_QUESTIONS_TEMPLATE_SEARCH, SEARCH_TOOL_INSTRUCTION, RAG_TEMPLATE, GOOGLE_SEARCH_ENDPOINT, DEFAULT_SEARCH_ENGINE_TIMEOUT, RELATED_QUESTIONS_TEMPLATE_NO_SEARCH |
|
import re |
|
import asyncio |
|
import random |
|
|
|
import streamlit as st |
|
import yaml |
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
kit_dir = os.path.abspath(os.path.join(current_dir, '..')) |
|
repo_dir = os.path.abspath(os.path.join(kit_dir, '..')) |
|
|
|
sys.path.append(kit_dir) |
|
sys.path.append(repo_dir) |
|
|
|
|
|
from visual_env_utils import are_credentials_set, env_input_fields, initialize_env_variables, save_credentials |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
GOOGLE_API_KEY = st.secrets["google_api_key"] |
|
GOOGLE_CX = st.secrets["google_cx"] |
|
BACKUP_KEYS = [st.secrets["backup_key_1"], st.secrets["backup_key_2"], st.secrets["backup_key_3"], st.secrets["backup_key_4"], st.secrets["backup_key_5"]] |
|
|
|
CONFIG_PATH = os.path.join(current_dir, "config.yaml") |
|
|
|
USER_AGENTS = [ |
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", |
|
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0.3 Safari/605.1.15", |
|
] |
|
|
|
def load_config(): |
|
with open(CONFIG_PATH, 'r') as yaml_file: |
|
return yaml.safe_load(yaml_file) |
|
|
|
|
|
config = load_config() |
|
prod_mode = config.get('prod_mode', False) |
|
additional_env_vars = config.get('additional_env_vars', None) |
|
|
|
@contextmanager |
|
def st_capture(output_func: Callable[[str], None]) -> Generator: |
|
""" |
|
context manager to catch stdout and send it to an output streamlit element |
|
Args: |
|
output_func (function to write terminal output in |
|
Yields: |
|
Generator: |
|
""" |
|
with StringIO() as stdout, redirect_stdout(stdout): |
|
old_write = stdout.write |
|
|
|
def new_write(string: str) -> int: |
|
ret = old_write(string) |
|
output_func(stdout.getvalue()) |
|
return ret |
|
|
|
stdout.write = new_write |
|
yield |
|
|
|
async def run_samba_api_inference(query, system_prompt = None, ignore_context=False, max_tokens_to_generate=None, num_seconds_to_sleep=1, over_ride_key=None): |
|
|
|
messages = [] |
|
if system_prompt is not None: |
|
messages.append({"role": "system", "content": system_prompt}) |
|
|
|
if not ignore_context: |
|
for ques, ans in zip( |
|
st.session_state.chat_history[::3], |
|
st.session_state.chat_history[1::3], |
|
): |
|
messages.append({"role": "user", "content": ques}) |
|
messages.append({"role": "assistant", "content": ans}) |
|
messages.append({"role": "user", "content": query}) |
|
|
|
|
|
payload = { |
|
"messages": messages, |
|
"model": config.get("model") |
|
} |
|
if max_tokens_to_generate is not None: |
|
payload["max_tokens"] = max_tokens_to_generate |
|
|
|
if over_ride_key is None: |
|
api_key = st.session_state.SAMBANOVA_API_KEY |
|
else: |
|
api_key = over_ride_key |
|
headers = { |
|
"Authorization": f"Bearer {api_key}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
try: |
|
post_response = await asyncio.get_event_loop().run_in_executor(None, lambda: requests.post(config.get("url"), json=payload, headers=headers, stream=True)) |
|
post_response.raise_for_status() |
|
except requests.exceptions.HTTPError as e: |
|
if post_response.status_code in {401, 503}: |
|
st.info(f"Invalid Key! Please make sure you have a valid SambaCloud key from https://cloud.sambanova.ai/.") |
|
return "Invalid Key! Please make sure you have a valid SambaCloud key from https://cloud.sambanova.ai/." |
|
if post_response.status_code in {429, 504}: |
|
await asyncio.sleep(num_seconds_to_sleep) |
|
return await run_samba_api_inference(query, over_ride_key=random.choice(BACKUP_KEYS)) |
|
else: |
|
print(f"Request failed with status code: {post_response.status_code}. Error: {e}") |
|
return "Invalid Key! Please make sure you have a valid SambaCloud key from https://cloud.sambanova.ai/." |
|
|
|
response_data = json.loads(post_response.text) |
|
|
|
return response_data["choices"][0]["message"]["content"] |
|
|
|
def extract_query(text): |
|
|
|
match = re.search(r'query="(.*?)"', text) |
|
|
|
|
|
if match: |
|
return match.group(1) |
|
return None |
|
|
|
def extract_text_between_brackets(text): |
|
|
|
matches = re.findall(r'\[(.*?)\]', text) |
|
return matches |
|
|
|
def search_with_google(query: str): |
|
""" |
|
Search with google and return the contexts. |
|
""" |
|
params = { |
|
"key": GOOGLE_API_KEY, |
|
"cx": GOOGLE_CX, |
|
"q": query, |
|
"num": 5, |
|
} |
|
response = requests.get( |
|
GOOGLE_SEARCH_ENDPOINT, params=params, timeout=DEFAULT_SEARCH_ENGINE_TIMEOUT |
|
) |
|
|
|
if not response.ok: |
|
raise Exception(response.status_code, "Search engine error.") |
|
json_content = response.json() |
|
|
|
contexts = json_content["items"][:5] |
|
|
|
return contexts |
|
|
|
async def get_related_questions(query, contexts = None): |
|
if contexts: |
|
related_question_system_prompt = RELATED_QUESTIONS_TEMPLATE_SEARCH.format( |
|
context="\n\n".join([c["snippet"] for c in contexts]) |
|
) |
|
else: |
|
|
|
related_question_system_prompt = RELATED_QUESTIONS_TEMPLATE_SEARCH |
|
|
|
related_questions_raw = await run_samba_api_inference(query, related_question_system_prompt) |
|
|
|
try: |
|
return json.loads(related_questions_raw) |
|
except: |
|
try: |
|
extracted_related_questions = extract_text_between_brackets(related_questions_raw) |
|
return json.loads(extracted_related_questions) |
|
except: |
|
return [] |
|
|
|
def process_citations(response: str, search_result_contexts: List[Dict]) -> str: |
|
""" |
|
Process citations in the response and replace them with numbered icons. |
|
|
|
Args: |
|
response (str): The original response with citations. |
|
search_result_contexts (List[Dict]): The search results with context information. |
|
|
|
Returns: |
|
str: The processed response with numbered icons for citations. |
|
""" |
|
citations = re.findall(r'\[citation:(\d+)\]', response) |
|
|
|
for i, citation in enumerate(citations, 1): |
|
response = response.replace(f'[citation:{citation}]', f'<sup>[{i}]</sup>') |
|
|
|
return response |
|
|
|
def generate_citation_links(search_result_contexts: List[Dict]) -> str: |
|
""" |
|
Generate HTML for citation links. |
|
|
|
Args: |
|
search_result_contexts (List[Dict]): The search results with context information. |
|
|
|
Returns: |
|
str: HTML string with numbered citation links. |
|
""" |
|
citation_links = [] |
|
for i, context in enumerate(search_result_contexts, 1): |
|
title = context.get('title', 'No title') |
|
link = context.get('link', '#') |
|
citation_links.append(f'<p>[{i}] <a href="{link}" target="_blank">{title}</a></p>') |
|
|
|
return ''.join(citation_links) |
|
|
|
|
|
async def run_auto_search_pipe(query): |
|
full_context_answer = asyncio.create_task(run_samba_api_inference(query)) |
|
related_questions_no_search = asyncio.create_task(get_related_questions(query)) |
|
|
|
|
|
with st.spinner('Checking if web search is needed...'): |
|
auto_search_result = await run_samba_api_inference(query, SEARCH_TOOL_INSTRUCTION, True, max_tokens_to_generate=100) |
|
|
|
|
|
if AUTO_SEARCH_KEYWORD in auto_search_result: |
|
st.session_state.search_performed = True |
|
|
|
with st.spinner('Searching the internet...'): |
|
search_result_contexts = search_with_google(extract_query(auto_search_result)) |
|
|
|
|
|
with st.spinner('Generating response based on web search...'): |
|
rag_system_prompt = RAG_TEMPLATE.format( |
|
context="\n\n".join( |
|
[f"[[citation:{i+1}]] {c['snippet']}" for i, c in enumerate(search_result_contexts)] |
|
) |
|
) |
|
|
|
model_response = asyncio.create_task(run_samba_api_inference(query, rag_system_prompt)) |
|
related_questions = asyncio.create_task(get_related_questions(query, search_result_contexts)) |
|
|
|
citation_links = generate_citation_links(search_result_contexts) |
|
|
|
model_response_complete = await model_response |
|
processed_response = process_citations(model_response_complete, search_result_contexts) |
|
related_questions_complete = await related_questions |
|
|
|
|
|
return processed_response, citation_links, related_questions_complete |
|
|
|
|
|
else: |
|
st.session_state.search_performed = False |
|
result = await full_context_answer |
|
related_questions = await related_questions_no_search |
|
return result, "", related_questions |
|
|
|
|
|
def handle_userinput(user_question: Optional[str]) -> None: |
|
""" |
|
Handle user input and generate a response, also update chat UI in streamlit app |
|
Args: |
|
user_question (str): The user's question or input. |
|
""" |
|
if user_question: |
|
|
|
if 'related_questions' in st.session_state: |
|
st.session_state.related_questions = [] |
|
|
|
async def run_search(): |
|
return await run_auto_search_pipe(user_question) |
|
|
|
response, citation_links, related_questions = asyncio.run(run_search()) |
|
if st.session_state.search_performed: |
|
search_or_not_text = "🔍 Web search was performed for this query." |
|
else: |
|
search_or_not_text = "📚 This response was generated from the model's knowledge." |
|
|
|
st.session_state.chat_history.append(user_question) |
|
st.session_state.chat_history.append((response, citation_links)) |
|
st.session_state.chat_history.append(search_or_not_text) |
|
|
|
|
|
st.session_state.related_questions = related_questions |
|
|
|
for ques, ans, search_or_not_text in zip( |
|
st.session_state.chat_history[::3], |
|
st.session_state.chat_history[1::3], |
|
st.session_state.chat_history[2::3], |
|
): |
|
with st.chat_message('user'): |
|
st.write(f'{ques}') |
|
|
|
with st.chat_message( |
|
'ai', |
|
avatar='https://sambanova.ai/hubfs/logotype_sambanova_orange.png', |
|
): |
|
st.markdown(f'{ans[0]}', unsafe_allow_html=True) |
|
if ans[1]: |
|
st.markdown("### Sources", unsafe_allow_html=True) |
|
st.markdown(ans[1], unsafe_allow_html=True) |
|
st.info(search_or_not_text) |
|
if len(st.session_state.related_questions) > 0: |
|
st.markdown("### Related Questions") |
|
for question in st.session_state.related_questions: |
|
if st.button(question): |
|
setChatInputValue(question) |
|
|
|
def setChatInputValue(chat_input_value: str) -> None: |
|
js = f""" |
|
<script> |
|
function insertText(dummy_var_to_force_repeat_execution) {{ |
|
var chatInput = parent.document.querySelector('textarea[data-testid="stChatInputTextArea"]'); |
|
var nativeInputValueSetter = Object.getOwnPropertyDescriptor(window.HTMLTextAreaElement.prototype, "value").set; |
|
nativeInputValueSetter.call(chatInput, "{chat_input_value}"); |
|
var event = new Event('input', {{ bubbles: true}}); |
|
chatInput.dispatchEvent(event); |
|
}} |
|
insertText(3); |
|
</script> |
|
""" |
|
st.components.v1.html(js) |
|
|
|
def main() -> None: |
|
st.set_page_config( |
|
page_title='Auto Web Search Demo', |
|
page_icon='https://sambanova.ai/hubfs/logotype_sambanova_orange.png', |
|
) |
|
|
|
|
|
initialize_env_variables(prod_mode, additional_env_vars) |
|
|
|
if 'input_disabled' not in st.session_state: |
|
if 'SAMBANOVA_API_KEY' in st.session_state: |
|
st.session_state.input_disabled = False |
|
else: |
|
st.session_state.input_disabled = True |
|
if 'chat_history' not in st.session_state: |
|
st.session_state.chat_history = [] |
|
if 'search_performed' not in st.session_state: |
|
st.session_state.search_performed = False |
|
if 'related_questions' not in st.session_state: |
|
st.session_state.related_questions = [] |
|
|
|
st.title(' Auto Web Search') |
|
st.subheader('Powered by :orange[SambaNova Cloud] and Llama405B') |
|
|
|
with st.sidebar: |
|
st.title('Get your :orange[SambaNova Cloud] API key [here](https://cloud.sambanova.ai/apis)') |
|
|
|
if not are_credentials_set(additional_env_vars): |
|
api_key, additional_vars = env_input_fields(additional_env_vars) |
|
if st.button('Save Credentials'): |
|
message = save_credentials(api_key, additional_vars, prod_mode) |
|
st.session_state.input_disabled = False |
|
st.success(message) |
|
st.rerun() |
|
|
|
else: |
|
st.success('Credentials are set') |
|
if st.button('Clear Credentials'): |
|
save_credentials('', {var: '' for var in (additional_env_vars or [])}, prod_mode) |
|
st.session_state.input_disabled = True |
|
st.rerun() |
|
|
|
|
|
if are_credentials_set(additional_env_vars): |
|
with st.expander('**Example Queries With Search**', expanded=True): |
|
if st.button('What is the population of Virginia?'): |
|
setChatInputValue( |
|
'What is the population of Virginia?' |
|
) |
|
if st.button('SNP 500 stock market moves'): |
|
setChatInputValue('SNP 500 stock market moves') |
|
if st.button('What is the weather in Palo Alto?'): |
|
setChatInputValue( |
|
'What is the weather in Palo Alto?' |
|
) |
|
with st.expander('**Example Queries No Search**', expanded=True): |
|
if st.button('write a short poem following a specific pattern: the first letter of every word should spell out the name of a country.'): |
|
setChatInputValue( |
|
'write a short poem following a specific pattern: the first letter of every word should spell out the name of a country.' |
|
) |
|
if st.button('Write a python program to find the longest root to leaf path in a tree, and some test cases for it.'): |
|
setChatInputValue('Write a python program to find the longest root to leaf path in a tree, and some test cases for it.') |
|
|
|
st.markdown('**Reset chat**') |
|
st.markdown('**Note:** Resetting the chat will clear all interactions history') |
|
if st.button('Reset conversation'): |
|
st.session_state.chat_history = [] |
|
st.session_state.sources_history = [] |
|
if 'related_questions' in st.session_state: |
|
st.session_state.related_questions = [] |
|
st.toast('Interactions reset. The next response will clear the history on the screen') |
|
|
|
|
|
footer_html = """ |
|
<style> |
|
.footer { |
|
position: fixed; |
|
right: 10px; |
|
bottom: 10px; |
|
width: auto; |
|
background-color: transparent; |
|
color: grey; |
|
text-align: right; |
|
padding: 10px; |
|
font-size: 16px; |
|
} |
|
</style> |
|
<div class="footer"> |
|
Inspired by: <a href="https://github.com/leptonai/search_with_lepton" target="_blank">search_with_lepton</a> |
|
</div> |
|
""" |
|
st.markdown(footer_html, unsafe_allow_html=True) |
|
|
|
user_question = st.chat_input('Ask something', disabled=st.session_state.input_disabled, key='TheChatInput') |
|
handle_userinput(user_question) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |