Whatson / query.py
gerasdf
enabled AI
5f6fa02
raw
history blame
8.2 kB
import gradio as gr
from langchain_astradb import AstraDBVectorStore
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from elevenlabs import VoiceSettings
from elevenlabs.client import ElevenLabs
from openai import OpenAI
from json import loads as json_loads
import time
import os
prompt_template = os.environ.get("PROMPT_TEMPLATE")
prompt = ChatPromptTemplate.from_messages([('system', prompt_template)])
AI = True
def ai_setup():
global llm, prompt_chain, oai_client
if AI:
oai_client = OpenAI()
llm = ChatOpenAI(model = "gpt-4o", temperature=0.8)
embedding = OpenAIEmbeddings()
vstore = AstraDBVectorStore(
embedding=embedding,
collection_name=os.environ.get("ASTRA_DB_COLLECTION"),
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
)
retriever = vstore.as_retriever(search_kwargs={'k': 10})
else:
retriever = RunnableLambda(just_read)
prompt_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| RunnableLambda(format_context)
| prompt
# | llm
# | StrOutputParser()
)
def group_and_sort(documents):
grouped = {}
for document in documents:
title = document.metadata["Title"]
docs = grouped.get(title, [])
grouped[title] = docs
docs.append((document.page_content, document.metadata["range"]))
for title, values in grouped.items():
values.sort(key=lambda doc:doc[1][0])
for title in grouped:
text = ''
prev_last = 0
for fragment, (start, last) in grouped[title]:
if start < prev_last:
text += fragment[prev_last-start:]
elif start == prev_last:
text += fragment
else:
text += ' [...] '
text += fragment
prev_last = last
grouped[title] = text
return grouped
def format_context(pipeline_state):
"""Print the state passed between Runnables in a langchain and pass it on"""
context = ''
documents = group_and_sort(pipeline_state["context"])
for title, text in documents.items():
context += f"\nTitle: {title}\n"
context += text
context += '\n\n---\n'
pipeline_state["context"] = context
return pipeline_state
def just_read(pipeline_state):
fname = "docs.pickle"
import pickle
return pickle.load(open(fname, "rb"))
def new_state():
return gr.State({
"user": None,
"system": None,
})
def auth(token, state):
tokens=os.environ.get("APP_TOKENS")
if not tokens:
state["user"] = "anonymous"
else:
tokens=json_loads(tokens)
state["user"] = tokens.get(token, None)
return "", state
AUTH_JS = """function auth_js(token, state) {
if (!!document.location.hash) {
token = document.location.hash
document.location.hash=""
}
return [token, state]
}
"""
def not_authenticated(state):
answer = (state is None) or (not state['user'])
if answer:
gr.Warning("You need to authenticate first")
return answer
def chat(message, history, state):
if not_authenticated(state):
yield "You need to authenticate first"
elif AI:
if not history:
system_prompt = prompt_chain.invoke(message)
system_prompt = system_prompt.messages[0]
state["system"] = system_prompt
else:
system_prompt = state["system"]
messages = [system_prompt]
for human, ai in history:
messages.append(HumanMessage(human))
messages.append(AIMessage(ai))
messages.append(HumanMessage(message))
all = ''
for response in llm.stream(messages):
all += response.content
yield all
else:
yield f"{time.ctime()}: You said: {message}"
def on_audio(path, state):
if not_authenticated(state):
return (gr.update(), None)
else:
if not path:
return [gr.update(), None]
if AI:
text = oai_client.audio.transcriptions.create(
model="whisper-1",
file=open(path, "rb"),
response_format="text"
)
else:
text = f"{time.ctime()}: You said something"
return (text, None)
def play_last(history, state):
if not_authenticated(state):
pass
else:
if len(history):
voice_id = "IINmogebEQykLiDoSkd0"
text = history[-1][1]
lab11 = ElevenLabs()
whatson=lab11.voices.get(voice_id)
response = lab11.generate(text=text, voice=whatson, stream=True)
yield from response
def gr_main():
theme = gr.Theme.from_hub("freddyaboulton/[email protected]")
theme.set(
color_accent_soft="#818eb6", # ChatBot.svelte / .message-row.panel.user-row
background_fill_secondary="#6272a4", # ChatBot.svelte / .message-row.panel.bot-row
button_primary_text_color="*button_secondary_text_color",
button_primary_background_fill="*button_secondary_background_fill")
with gr.Blocks(
title="Sherlock Holmes stories",
fill_height=True,
theme=theme
) as app:
state = new_state()
chatbot = gr.Chatbot(show_label=False, render=False, scale=1)
iface = gr.ChatInterface(
chat,
chatbot=chatbot,
title="Sherlock Holmes stories",
submit_btn=gr.Button(
"Submit",
variant="primary",
scale=1,
min_width=150,
elem_id="submit_btn",
render=False
),
undo_btn=None,
clear_btn=None,
retry_btn=None,
# examples=[
# ["I arrived late last night and found a dead goose in my bed"],
# ["Help please sir. I'm about to get married, to the most lovely lady,"
# "and I just received a letter threatening me to make public some things"
# "of my past I'd rather keep quiet, unless I don't marry"],
# ],
additional_inputs=[state])
with gr.Row():
mic = gr.Audio(
sources=["microphone"],
type="filepath",
show_label=False,
format="mp3",
waveform_options=gr.WaveformOptions(sample_rate=16000))
mic.change(
on_audio, [mic, state], [iface.textbox, mic]
).then(
lambda x:None,
js='function (text){if (text) document.getElementById("submit_btn").click(); return [text]}',
inputs=iface.textbox
)
player = gr.Audio(
show_label=False,
show_download_button=True,
visible=True,
autoplay=True,
streaming=True)
play_btn = gr.Button("Play last ")
play_btn.click(play_last, [chatbot, state], player)
token = gr.Textbox(visible=False)
app.load(auth,
[token,state],
[token,state],
js=AUTH_JS)
app.launch(show_api=False)
if __name__ == "__main__":
ai_setup()
gr_main()