Whatson / query.py
gerasdf
enabled-disable play last while generating
c038d5b
raw
history blame
11.7 kB
import gradio as gr
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain_astradb import AstraDBChatMessageHistory, AstraDBStore, AstraDBVectorStore
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from elevenlabs import VoiceSettings
from elevenlabs.client import ElevenLabs
from openai import OpenAI
from json import loads as json_loads
import time
import os
prompt_template = os.environ.get("PROMPT_TEMPLATE")
prompt = ChatPromptTemplate.from_messages([('system', prompt_template)])
AI = True
def ai_setup():
global llm, prompt_chain, oai_client
if AI:
oai_client = OpenAI()
llm = ChatOpenAI(model = "gpt-4o", temperature=0.8)
embedding = OpenAIEmbeddings()
vstore = AstraDBVectorStore(
embedding=embedding,
collection_name=os.environ.get("ASTRA_DB_COLLECTION"),
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
)
retriever = vstore.as_retriever(search_kwargs={'k': 10})
else:
retriever = RunnableLambda(just_read)
prompt_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| RunnableLambda(format_context)
| prompt
# | llm
# | StrOutputParser()
)
def group_and_sort(documents):
grouped = {}
for document in documents:
title = document.metadata["Title"]
docs = grouped.get(title, [])
grouped[title] = docs
docs.append((document.page_content, document.metadata["range"]))
for title, values in grouped.items():
values.sort(key=lambda doc:doc[1][0])
for title in grouped:
text = ''
prev_last = 0
for fragment, (start, last) in grouped[title]:
if start < prev_last:
text += fragment[prev_last-start:]
elif start == prev_last:
text += fragment
else:
text += ' [...] '
text += fragment
prev_last = last
grouped[title] = text
return grouped
def format_context(pipeline_state):
"""Print the state passed between Runnables in a langchain and pass it on"""
context = ''
documents = group_and_sort(pipeline_state["context"])
for title, text in documents.items():
context += f"\nTitle: {title}\n"
context += text
context += '\n\n---\n'
pipeline_state["context"] = context
return pipeline_state
def just_read(pipeline_state):
fname = "docs.pickle"
import pickle
return pickle.load(open(fname, "rb"))
def new_state():
return gr.State({
"user" : None,
"system" : None,
"history" : None,
})
def session_id(state: dict, request: gr.Request) -> str:
return f'{state["user"]}_{request.session_hash}'
store = None
def auth(token, state, request: gr.Request):
global store
tokens=os.environ.get("APP_TOKENS")
if not tokens:
state["user"] = "anonymous"
else:
tokens=json_loads(tokens)
state["user"] = tokens.get(token, None)
if state["user"]:
if store is None:
store = AstraDBStore(
collection_name=f'{os.environ.get("ASTRA_DB_COLLECTION")}_sessions',
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
)
user_session = session_id(state, request)
session_data = {
'user' : state["user"],
'session' : request.session_hash,
'timestamp' : time.asctime(time.gmtime())
}
store.mset([(user_session, session_data)])
return "", state
AUTH_JS = """function auth_js(token, state) {
if (!!document.location.hash) {
token = document.location.hash
document.location.hash=""
}
return [token, state]
}
"""
def not_authenticated(state):
answer = (state is None) or (not state['user'])
if answer:
gr.Warning("You need to authenticate first")
return answer
def add_history(state, request, type, message):
if not state["history"]:
session = session_id(state, request)
state["history"] = AstraDBChatMessageHistory(
session_id=session,
collection_name=f'{os.environ.get("ASTRA_DB_COLLECTION")}_chat_history',
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
)
history = state["history"]
if type == "system":
history.add_message(message)
elif type == "user":
history.add_user_message(message)
elif type == "ai":
history.add_ai_message(message)
def chat(message, history, state, request:gr.Request):
if not_authenticated(state):
yield "You need to authenticate first"
else:
if AI:
if not history:
system_prompt = prompt_chain.invoke(message)
system_prompt = system_prompt.messages[0]
state["system"] = system_prompt
# add_history(state, request, "system", system_prompt)
else:
system_prompt = state["system"]
add_history(state, request, "user", message)
messages = [system_prompt]
for human, ai in history:
messages.append(HumanMessage(human))
messages.append(AIMessage(ai))
messages.append(HumanMessage(message))
answer = ''
for response in llm.stream(messages):
answer += response.content
yield answer+'…'
else:
add_history(state, request, "user", message)
msg = f"{time.ctime()}: You said: {message}"
answer = ' '
for word in msg.split():
answer += f' {word}'
yield answer+'…'
time.sleep(0.05)
yield answer
add_history(state, request, "ai", answer)
def on_audio(path, state):
if not_authenticated(state):
return (gr.update(), None)
else:
if not path:
return [gr.update(), None]
if AI:
text = oai_client.audio.transcriptions.create(
model="whisper-1",
file=open(path, "rb"),
response_format="text"
)
else:
text = f"{time.ctime()}: You said something"
return (text, None)
def play_last(history, state):
if not_authenticated(state):
pass
else:
if len(history):
voice_id = "IINmogebEQykLiDoSkd0"
text = history[-1][1]
lab11 = ElevenLabs()
whatson=lab11.voices.get(voice_id)
response = lab11.generate(text=text, voice=whatson, stream=True)
yield from response
def chat_chage(history):
if history:
if not history[-1][1]:
return gr.update(interactive=False)
elif history[-1][1][-1] != '…':
return gr.update(interactive=True)
return gr.update()
TEXT_TALK = "🎤 Talk"
TEXT_STOP = "⏹ Stop"
def gr_main():
theme = gr.Theme.from_hub("freddyaboulton/[email protected]")
theme.set(
color_accent_soft="#818eb6", # ChatBot.svelte / .message-row.panel.user-row
background_fill_secondary="#6272a4", # ChatBot.svelte / .message-row.panel.bot-row
button_primary_text_color="*button_secondary_text_color",
button_primary_background_fill="*button_secondary_background_fill")
with gr.Blocks(
title="Sherlock Holmes stories",
fill_height=True,
theme=theme
) as app:
state = new_state()
# auto_play = gr.Checkbox(False, label="Autoplay", render=False)
chatbot = gr.Chatbot(show_label=False, render=False, scale=1)
iface = gr.ChatInterface(
chat,
chatbot=chatbot,
title="Sherlock Holmes stories",
submit_btn=gr.Button(
"Send",
variant="primary",
scale=1,
min_width=150,
elem_id="submit_btn",
render=False
),
undo_btn=None,
clear_btn=None,
retry_btn=None,
# examples=[
# ["I arrived late last night and found a dead goose in my bed"],
# ["Help please sir. I'm about to get married, to the most lovely lady,"
# "and I just received a letter threatening me to make public some things"
# "of my past I'd rather keep quiet, unless I don't marry"],
# ],
additional_inputs=[state])
with gr.Row():
player = gr.Audio(
visible=False,
show_label=False,
show_download_button=False,
show_share_button=False,
autoplay=True,
streaming=True,
interactive=False)
mic = gr.Audio(
sources=["microphone"],
type="filepath",
show_label=False,
format="mp3",
elem_id="microphone",
visible=False,
waveform_options=gr.WaveformOptions(sample_rate=16000, show_recording_waveform=False))
start_stop_rec = gr.Button(TEXT_TALK, size = "lg")
play_last_btn = gr.Button("🔊 Play last", size = "lg", interactive=False)
play_last_btn.click(
play_last,
[chatbot, state], player)
chatbot.change(chat_chage, inputs=chatbot, outputs=play_last_btn)
start_stop_rec.click(
lambda x:x,
inputs=start_stop_rec,
outputs=start_stop_rec,
js=f'''function (text) {{
if (text == "{TEXT_TALK}") {{
document.getElementById("microphone").querySelector(".record-button").click()
return ["{TEXT_STOP}"]
}} else {{
document.getElementById("microphone").querySelector(".stop-button").click()
return ["{TEXT_TALK}"]
}}
}}'''
)
mic.change(
on_audio, [mic, state], [iface.textbox, mic]
).then(
lambda x:None,
inputs=iface.textbox,
js='function (text){if (text) document.getElementById("submit_btn").click(); return [text]}'
)
token = gr.Textbox(visible=False)
app.load(auth,
[token,state],
[token,state],
js=AUTH_JS)
app.queue(default_concurrency_limit=None, api_open=False)
app.launch(show_api=False)
if __name__ == "__main__":
ai_setup()
gr_main()