Spaces:
Sleeping
Sleeping
initial
Browse files
app.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from omegaconf import OmegaConf
|
| 3 |
+
from query import VectaraQuery
|
| 4 |
+
import streamlit as st
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
def isTrue(x) -> bool:
|
| 8 |
+
if isinstance(x, bool):
|
| 9 |
+
return x
|
| 10 |
+
return x.strip().lower() == 'true'
|
| 11 |
+
|
| 12 |
+
def launch_bot():
|
| 13 |
+
def generate_response(question, role, topic):
|
| 14 |
+
response = vq.submit_query(question, role, topic)
|
| 15 |
+
return response
|
| 16 |
+
|
| 17 |
+
if 'cfg' not in st.session_state:
|
| 18 |
+
cfg = OmegaConf.create({
|
| 19 |
+
'customer_id': str(os.environ['VECTARA_CUSTOMER_ID']),
|
| 20 |
+
'corpus_id': str(os.environ['VECTARA_CORPUS_ID']),
|
| 21 |
+
'api_key': str(os.environ['VECTARA_API_KEY']),
|
| 22 |
+
'prompt_name': 'vectara-experimental-summary-ext-2023-12-11-large',
|
| 23 |
+
'topic': 'standardized testing in education',
|
| 24 |
+
'human_role': 'in opposition to',
|
| 25 |
+
'bot_role': 'in support of'
|
| 26 |
+
})
|
| 27 |
+
st.session_state.cfg = cfg
|
| 28 |
+
st.session_state.vq = VectaraQuery(cfg.api_key, cfg.customer_id, cfg.corpus_id, cfg.prompt_name)
|
| 29 |
+
|
| 30 |
+
cfg = st.session_state.cfg
|
| 31 |
+
vq = st.session_state.vq
|
| 32 |
+
st.set_page_config(page_title="Debate Bot", layout="wide")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# left side content
|
| 36 |
+
with st.sidebar:
|
| 37 |
+
st.markdown(f"## Welcome to Debate Bot.\n\n\n"
|
| 38 |
+
f"You are {cfg.human_role} '{cfg.topic}'.\n\n")
|
| 39 |
+
|
| 40 |
+
st.markdown("---")
|
| 41 |
+
st.markdown(
|
| 42 |
+
"## How this works?\n"
|
| 43 |
+
"This app was built with [Vectara](https://vectara.com).\n"
|
| 44 |
+
)
|
| 45 |
+
st.markdown("---")
|
| 46 |
+
|
| 47 |
+
if "messages" not in st.session_state.keys():
|
| 48 |
+
st.session_state.messages = [{"role": "assistant", "content": f"Please make your first statment {cfg.human_role} '{cfg.topic}'"}]
|
| 49 |
+
|
| 50 |
+
# Display chat messages
|
| 51 |
+
for message in st.session_state.messages:
|
| 52 |
+
with st.chat_message(message["role"]):
|
| 53 |
+
st.write(message["content"])
|
| 54 |
+
|
| 55 |
+
# User-provided prompt
|
| 56 |
+
if prompt := st.chat_input():
|
| 57 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
| 58 |
+
with st.chat_message("user"):
|
| 59 |
+
st.write(prompt)
|
| 60 |
+
|
| 61 |
+
# Generate a new response if last message is not from assistant
|
| 62 |
+
if st.session_state.messages[-1]["role"] != "assistant":
|
| 63 |
+
with st.chat_message("assistant"):
|
| 64 |
+
stream = generate_response(prompt, cfg.bot_role, cfg.topic)
|
| 65 |
+
response = st.write_stream(stream)
|
| 66 |
+
message = {"role": "assistant", "content": response}
|
| 67 |
+
st.session_state.messages.append(message)
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
launch_bot()
|
| 71 |
+
|
query.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
from urllib.parse import quote
|
| 5 |
+
|
| 6 |
+
def extract_between_tags(text, start_tag, end_tag):
|
| 7 |
+
start_index = text.find(start_tag)
|
| 8 |
+
end_index = text.find(end_tag, start_index)
|
| 9 |
+
return text[start_index+len(start_tag):end_index-len(end_tag)]
|
| 10 |
+
|
| 11 |
+
class VectaraQuery():
|
| 12 |
+
def __init__(self, api_key: str, customer_id: str, corpus_id: str, prompt_name: str = None):
|
| 13 |
+
self.customer_id = customer_id
|
| 14 |
+
self.corpus_id = corpus_id
|
| 15 |
+
self.api_key = api_key
|
| 16 |
+
self.prompt_name = prompt_name if prompt_name else "vectara-experimental-summary-ext-2023-12-11-large"
|
| 17 |
+
self.conv_id = None
|
| 18 |
+
|
| 19 |
+
def get_body(self, user_response: str, role: str, topic: str):
|
| 20 |
+
corpora_key_list = [{
|
| 21 |
+
'customer_id': self.customer_id, 'corpus_id': self.corpus_id, 'lexical_interpolation_config': {'lambda': 0.025}
|
| 22 |
+
}]
|
| 23 |
+
|
| 24 |
+
prompt = f'''
|
| 25 |
+
[
|
| 26 |
+
{{
|
| 27 |
+
"role": "system",
|
| 28 |
+
"content": "You are a professional debate bot. You are provided with search results related to {topic}
|
| 29 |
+
and respond to the previous arugments made so far. Be sure to provide a thoughtful and convincing reply.
|
| 30 |
+
Never mention search results explicitly in your response.
|
| 31 |
+
Do not base your response on information or knowledge that is not in the search results.
|
| 32 |
+
Respond while demonstrating respect to the other party and the topic. Limit your responses to not more than 3 paragraphs."
|
| 33 |
+
}},
|
| 34 |
+
{{
|
| 35 |
+
"role": "user",
|
| 36 |
+
"content": "
|
| 37 |
+
#foreach ($qResult in $vectaraQueryResults)
|
| 38 |
+
Search result $esc.java(${{foreach.index}}+1): $esc.java(${{qResult.getText()}})
|
| 39 |
+
#end
|
| 40 |
+
"
|
| 41 |
+
}},
|
| 42 |
+
{{
|
| 43 |
+
"role": "user",
|
| 44 |
+
"content": "provide a convincing reply {role} {topic}.
|
| 45 |
+
Consider the search results as relevant information with which to form your response.
|
| 46 |
+
Do not repeat earlier arguments and make sure your new response is coherent with the previous arguments, and responsive to the last argument: {user_response}."
|
| 47 |
+
}}
|
| 48 |
+
]
|
| 49 |
+
'''
|
| 50 |
+
|
| 51 |
+
return {
|
| 52 |
+
'query': [
|
| 53 |
+
{
|
| 54 |
+
'query': "how would you respond?",
|
| 55 |
+
'start': 0,
|
| 56 |
+
'numResults': 50,
|
| 57 |
+
'corpusKey': corpora_key_list,
|
| 58 |
+
'context_config': {
|
| 59 |
+
'sentences_before': 2,
|
| 60 |
+
'sentences_after': 2,
|
| 61 |
+
'start_tag': "%START_SNIPPET%",
|
| 62 |
+
'end_tag': "%END_SNIPPET%",
|
| 63 |
+
},
|
| 64 |
+
'rerankingConfig':
|
| 65 |
+
{
|
| 66 |
+
'rerankerId': 272725718,
|
| 67 |
+
'mmrConfig': {
|
| 68 |
+
'diversityBias': 0.3
|
| 69 |
+
}
|
| 70 |
+
},
|
| 71 |
+
'summary': [
|
| 72 |
+
{
|
| 73 |
+
'responseLang': 'eng',
|
| 74 |
+
'maxSummarizedResults': 7,
|
| 75 |
+
'summarizerPromptName': self.prompt_name,
|
| 76 |
+
'promptText': prompt,
|
| 77 |
+
'chat': {
|
| 78 |
+
'store': True,
|
| 79 |
+
'conversationId': self.conv_id
|
| 80 |
+
},
|
| 81 |
+
}
|
| 82 |
+
]
|
| 83 |
+
}
|
| 84 |
+
]
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
def get_headers(self):
|
| 88 |
+
return {
|
| 89 |
+
"Content-Type": "application/json",
|
| 90 |
+
"Accept": "application/json",
|
| 91 |
+
"customer-id": self.customer_id,
|
| 92 |
+
"x-api-key": self.api_key,
|
| 93 |
+
"grpc-timeout": "60S"
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
def submit_query(self, query_str: str, role: str, topic: str):
|
| 97 |
+
|
| 98 |
+
endpoint = f"https://api.vectara.io/v1/stream-query"
|
| 99 |
+
body = self.get_body(query_str, role, topic)
|
| 100 |
+
|
| 101 |
+
response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_headers(), stream=True)
|
| 102 |
+
if response.status_code != 200:
|
| 103 |
+
print(f"Query failed with code {response.status_code}, reason {response.reason}, text {response.text}")
|
| 104 |
+
return "Sorry, something went wrong in my brain. Please try again later."
|
| 105 |
+
|
| 106 |
+
chunks = []
|
| 107 |
+
accumulated_text = "" # Initialize text accumulation
|
| 108 |
+
pattern_max_length = 50 # Example heuristic
|
| 109 |
+
for line in response.iter_lines():
|
| 110 |
+
if line: # filter out keep-alive new lines
|
| 111 |
+
data = json.loads(line.decode('utf-8'))
|
| 112 |
+
res = data['result']
|
| 113 |
+
response_set = res['responseSet']
|
| 114 |
+
if response_set is None:
|
| 115 |
+
# grab next chunk and yield it as output
|
| 116 |
+
summary = res.get('summary', None)
|
| 117 |
+
if summary is None or len(summary)==0:
|
| 118 |
+
continue
|
| 119 |
+
else:
|
| 120 |
+
chat = summary.get('chat', None)
|
| 121 |
+
if chat and chat.get('status', None):
|
| 122 |
+
st_code = chat['status']
|
| 123 |
+
print(f"Chat query failed with code {st_code}")
|
| 124 |
+
if st_code == 'RESOURCE_EXHAUSTED':
|
| 125 |
+
self.conv_id = None
|
| 126 |
+
return 'Sorry, Vectara chat turns exceeds plan limit.'
|
| 127 |
+
return 'Sorry, something went wrong in my brain. Please try again later.'
|
| 128 |
+
conv_id = chat.get('conversationId', None) if chat else None
|
| 129 |
+
if conv_id:
|
| 130 |
+
self.conv_id = conv_id
|
| 131 |
+
|
| 132 |
+
chunk = summary['text']
|
| 133 |
+
accumulated_text += chunk # Append current chunk to accumulation
|
| 134 |
+
if len(accumulated_text) > pattern_max_length:
|
| 135 |
+
accumulated_text = re.sub(r"\[\d+\]", "", accumulated_text)
|
| 136 |
+
accumulated_text = re.sub(r"\s+\.", ".", accumulated_text)
|
| 137 |
+
out_chunk = accumulated_text[:-pattern_max_length]
|
| 138 |
+
chunks.append(out_chunk)
|
| 139 |
+
yield out_chunk
|
| 140 |
+
accumulated_text = accumulated_text[-pattern_max_length:]
|
| 141 |
+
|
| 142 |
+
if summary['done']:
|
| 143 |
+
break
|
| 144 |
+
|
| 145 |
+
# yield the last piece
|
| 146 |
+
if len(accumulated_text) > 0:
|
| 147 |
+
accumulated_text = re.sub(r" \[\d+\]\.", ".", accumulated_text)
|
| 148 |
+
chunks.append(accumulated_text)
|
| 149 |
+
yield accumulated_text
|
| 150 |
+
|
| 151 |
+
return ''.join(chunks)
|