gerasdf
History with doc_ids. This version doesn't work. to do get the _ids for documents I'd have to impement either a new retriever that answers with documents and their _ids or manually implement the vector search in the pipeline. I'm dropping it for now, I'll just re-do the vector search when loading a history
f5924e7
import gradio as gr | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda | |
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage | |
from langchain_astradb import AstraDBChatMessageHistory, AstraDBStore, AstraDBVectorStore | |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
from elevenlabs import VoiceSettings | |
from elevenlabs.client import ElevenLabs | |
from openai import OpenAI | |
from json import loads as json_loads, dumps as json_dumps | |
import itertools | |
import time | |
import os | |
AI = True | |
if not hasattr(itertools, "batched"): | |
def batched(iterable, n): | |
"Batch data into lists of length n. The last batch may be shorter." | |
# batched('ABCDEFG', 3) --> ABC DEF G | |
it = iter(iterable) | |
while True: | |
batch = list(itertools.islice(it, n)) | |
if not batch: | |
return | |
yield batch | |
itertools.batched = batched | |
def ai_setup(): | |
global llm, prompt_chain, oai_client | |
if AI: | |
oai_client = OpenAI() | |
llm = ChatOpenAI(model = "gpt-4o", temperature=0.8) | |
embedding = OpenAIEmbeddings() | |
vstore = AstraDBVectorStore( | |
embedding=embedding, | |
collection_name=os.environ.get("ASTRA_DB_COLLECTION"), | |
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"), | |
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"), | |
) | |
retriever = vstore.as_retriever(search_kwargs={'k': 10}) | |
prompt_template = os.environ.get("PROMPT_TEMPLATE") | |
prompt = ChatPromptTemplate.from_messages([ | |
('system', "{doc_ids}"), | |
('system', prompt_template)]) | |
prompt_chain = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| RunnableLambda(format_context) | |
| prompt | |
# | llm | |
# | StrOutputParser() | |
) | |
else: | |
retriever = RunnableLambda(just_read) | |
def group_and_sort(documents): | |
grouped = {} | |
for document in documents: | |
title = document.metadata["Title"] | |
docs = grouped.get(title, []) | |
grouped[title] = docs | |
docs.append((document.page_content, document.metadata["range"])) | |
for title, values in grouped.items(): | |
values.sort(key=lambda doc:doc[1][0]) | |
for title in grouped: | |
text = '' | |
prev_last = 0 | |
for fragment, (start, last) in grouped[title]: | |
if start < prev_last: | |
text += fragment[prev_last-start:] | |
elif start == prev_last: | |
text += fragment | |
else: | |
text += ' [...] ' | |
text += fragment | |
prev_last = last | |
grouped[title] = text | |
return grouped | |
def format_context(pipeline_state): | |
"""Print the state passed between Runnables in a langchain and pass it on""" | |
context = '' | |
documents = group_and_sort(pipeline_state["context"]) | |
for title, text in documents.items(): | |
context += f"\nTitle: {title}\n" | |
context += text | |
context += '\n\n---\n' | |
doc_ids = [1,2,3,4,5] | |
pipeline_state["context"] = context | |
pipeline_state["doc_ids"] = json_dumps(doc_ids) | |
return pipeline_state | |
def just_read(pipeline_state): | |
fname = "docs.pickle" | |
import pickle | |
return pickle.load(open(fname, "rb")) | |
def new_state(): | |
return gr.State({ | |
"user" : None, | |
"system" : None, | |
"history" : None, | |
}) | |
def session_id(state: dict, request: gr.Request) -> str: | |
return f'{state["user"]}_{request.session_hash}' | |
class History: | |
store = None | |
def __init__(self, name:str, user:str, session_id:str, id:str = None): | |
self.session_id = session_id | |
self.name = name | |
self.user = user | |
self.astra_history = None | |
if id: | |
self.id = id | |
else: | |
self.id = f"{user}_{session_id}" | |
self.create() | |
def get_store(self): | |
if self.store is None: | |
self.store = AstraDBStore( | |
collection_name=f'{os.environ.get("ASTRA_DB_COLLECTION")}_sessions', | |
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"), | |
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"), | |
) | |
return self.store | |
def from_dict(cls, id:str, data:dict): | |
name = f":{id}" | |
name = data.get("name", name) | |
answer = cls(name, user=data["user"], id = id, session_id=data["session"]) | |
return answer | |
def get_histories(cls, user:str): | |
store = cls.get_store() | |
histories = [] | |
keys = [k for k in store.yield_keys(prefix=f"{user}_")] | |
for id, history in zip(keys, store.mget(keys)): | |
history = cls.from_dict(id = id, data = history) | |
histories.append(history) | |
return histories | |
def load(cls, id:str): | |
data = cls.get_store().mget([id]) | |
return cls.from_dict(id, data[0]) | |
def __str__(self): | |
return f"{self.id}:{self.name}" | |
def create(self): | |
history = { | |
'session' : self.session_id, | |
'user' : self.user, | |
'timestamp' : time.asctime(time.gmtime()), | |
'name' : self.name | |
} | |
self.get_store().mset([(self.id, history)]) | |
def get_history_collection_name(): | |
return f'{os.environ.get("ASTRA_DB_COLLECTION")}_chat_history' | |
def get_astra_history(self): | |
if self.astra_history is None: | |
self.astra_history = AstraDBChatMessageHistory( | |
session_id=self.id, | |
collection_name=self.get_history_collection_name(), | |
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"), | |
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"), | |
) | |
return self.astra_history | |
def add(self, type:str, message): | |
if type == "system": | |
self.get_astra_history().add_message(message) | |
elif type == "user": | |
self.get_astra_history().add_user_message(message) | |
elif type == "ai": | |
self.get_astra_history().add_ai_message(message) | |
def messages(self): | |
return self.get_astra_history().messages | |
def clear(self): | |
self.get_astra_history().clear() | |
def delete(self): | |
self.clear() | |
self.get_store().mdelete([self.id]) | |
def auth(token, state, request: gr.Request): | |
tokens=os.environ.get("APP_TOKENS") | |
if not tokens: | |
state["user"] = "anonymous" | |
else: | |
tokens=json_loads(tokens) | |
state["user"] = tokens.get(token, None) | |
return "", state | |
AUTH_JS = """function auth_js(token, state) { | |
if (!!document.location.hash) { | |
token = document.location.hash | |
document.location.hash="" | |
} | |
return [token, state] | |
} | |
""" | |
def not_authenticated(state): | |
answer = (state is None) or (not state['user']) | |
if answer: | |
gr.Warning("You need to authenticate first") | |
return answer | |
def list_histories(state): | |
if not_authenticated(state): | |
return gr.update() | |
histories = History.get_histories(state["user"]) | |
answer = [(h.name, h.id) for h in histories] | |
return gr.update(choices=answer, value=None) | |
def add_history(state, request, type, message, name:str = None): | |
if not state["history"]: | |
name = name or message[:60] | |
state["history"] = History( | |
name = name, | |
user = state["user"], | |
session_id = request.session_hash | |
) | |
state["history"].add(type, message) | |
def load_history(state, history_id): | |
state["history"] = History.load(history_id) | |
history = [m.content for m in state["history"].messages()] | |
history = itertools.batched(history, 2) | |
history = [m for m in history] | |
if len(history) and len(history[-1]) == 1: | |
user_input = history[-1][0] | |
history = history[:-1] | |
else: | |
user_input = '' | |
return state, history, history, user_input # state, Chatbot, ChatInterface.state, ChatInterface.textbox | |
def chat(message, history, state, request:gr.Request): | |
if not_authenticated(state): | |
yield "You need to authenticate first" | |
else: | |
if AI: | |
if not history: | |
system_prompts = prompt_chain.invoke(message) | |
system_prompt = system_prompts.messages[1] | |
state["system"] = system_prompt | |
# Next is commented out because astra has a limit on document size | |
doc_ids = system_prompts.messages[0].content | |
add_history(state, request, "system", doc_ids, name=message) | |
else: | |
system_prompt = state["system"] | |
add_history(state, request, "user", message) | |
messages = [system_prompt] | |
for human, ai in history: | |
messages.append(HumanMessage(human)) | |
messages.append(AIMessage(ai)) | |
messages.append(HumanMessage(message)) | |
answer = '' | |
for response in llm.stream(messages): | |
answer += response.content | |
yield answer+'…' | |
else: | |
add_history(state, request, "user", message) | |
msg = f"{time.ctime()}: You said: {message}" | |
answer = ' ' | |
for word in msg.split(): | |
answer += f' {word}' | |
yield answer+'…' | |
time.sleep(0.05) | |
yield answer | |
add_history(state, request, "ai", answer) | |
def on_audio(path, state): | |
if not_authenticated(state): | |
return (gr.update(), None) | |
else: | |
if not path: | |
return [gr.update(), None] | |
if AI: | |
text = oai_client.audio.transcriptions.create( | |
model="whisper-1", | |
file=open(path, "rb"), | |
response_format="text" | |
) | |
else: | |
text = f"{time.ctime()}: You said something" | |
return (text, None) | |
def play_last(history, state): | |
if not_authenticated(state): | |
pass | |
else: | |
if len(history): | |
voice_id = "IINmogebEQykLiDoSkd0" | |
text = history[-1][1] | |
lab11 = ElevenLabs() | |
whatson=lab11.voices.get(voice_id) | |
response = lab11.generate(text=text, voice=whatson, stream=True) | |
yield from response | |
def chat_change(history): | |
if history: | |
if not history[-1][1]: | |
return gr.update(interactive=False) | |
elif history[-1][1][-1] != '…': | |
return gr.update(interactive=True) | |
return gr.update() # play_last_btn | |
TEXT_TALK = "🎤 Talk" | |
TEXT_STOP = "⏹ Stop" | |
def gr_setup(): | |
theme = gr.Theme.from_hub("freddyaboulton/[email protected]") | |
theme.set( | |
color_accent_soft="#818eb6", # ChatBot.svelte / .user / .message-row.panel.user-row . neutral_500 -> neutral_200 | |
background_fill_secondary="#6272a4", # ChatBot.svelte / .bot / .message-row.panel.bot-row . neutral_500 -> neutral_400 | |
background_fill_primary="#818eb6", # DropdownOptions.svelte / item | |
button_primary_text_color="*button_secondary_text_color", | |
button_primary_background_fill="*button_secondary_background_fill") | |
with gr.Blocks( | |
title="Sherlock Holmes stories", | |
fill_height=True, | |
theme=theme, | |
css="footer {visibility: hidden}" | |
) as app: | |
state = new_state() | |
chatbot = gr.Chatbot(show_label=False, render=False, scale=1) | |
gr.HTML('<h1 style="text-align: center">Sherlock Holmes stories</h1>') | |
history_choice = gr.Dropdown( | |
choices=[("History", "History")], | |
value="History", | |
show_label=False, | |
container=False, | |
interactive=True, | |
filterable=True) | |
iface = gr.ChatInterface( | |
chat, | |
chatbot=chatbot, | |
title=None, | |
submit_btn=gr.Button( | |
"Send", | |
variant="primary", | |
scale=1, | |
min_width=150, | |
elem_id="submit_btn", | |
render=False | |
), | |
undo_btn=None, | |
clear_btn=None, | |
retry_btn=None, | |
# examples=[ | |
# ["I arrived late last night and found a dead goose in my bed"], | |
# ["Help please sir. I'm about to get married, to the most lovely lady," | |
# "and I just received a letter threatening me to make public some things" | |
# "of my past I'd rather keep quiet, unless I don't marry"], | |
# ], | |
additional_inputs=[state]) | |
with gr.Row(): | |
player = gr.Audio( | |
visible=False, | |
show_label=False, | |
show_download_button=False, | |
show_share_button=False, | |
autoplay=True, | |
streaming=True, | |
interactive=False) | |
mic = gr.Audio( | |
sources=["microphone"], | |
type="filepath", | |
show_label=False, | |
format="mp3", | |
elem_id="microphone", | |
visible=False, | |
waveform_options=gr.WaveformOptions(sample_rate=16000, show_recording_waveform=False)) | |
start_stop_rec = gr.Button(TEXT_TALK, size = "lg") | |
play_last_btn = gr.Button("🔊 Play last", size = "lg", interactive=False) | |
play_last_btn.click( | |
play_last, | |
[chatbot, state], player) | |
chatbot.change(chat_change, inputs=chatbot, outputs=play_last_btn) | |
start_stop_rec.click( | |
lambda x:x, | |
inputs=start_stop_rec, | |
outputs=start_stop_rec, | |
js=f'''function (text) {{ | |
if (text == "{TEXT_TALK}") {{ | |
document.getElementById("microphone").querySelector(".record-button").click() | |
return ["{TEXT_STOP}"] | |
}} else {{ | |
document.getElementById("microphone").querySelector(".stop-button").click() | |
return ["{TEXT_TALK}"] | |
}} | |
}}''' | |
) | |
mic.change( | |
on_audio, [mic, state], [iface.textbox, mic] | |
).then( | |
lambda x:None, | |
inputs=iface.textbox, | |
js='function (text){if (text) document.getElementById("submit_btn").click(); return [text]}' | |
) | |
history_choice.focus( | |
list_histories, | |
inputs=state, | |
outputs=history_choice | |
) | |
history_choice.input( | |
load_history, | |
inputs=[state, history_choice], | |
outputs=[state, chatbot, iface.chatbot_state, iface.textbox]) | |
token = gr.Textbox(visible=False) | |
app.load(auth, | |
[token,state], | |
[token,state], | |
js=AUTH_JS) | |
app.queue(default_concurrency_limit=None, api_open=False) | |
return app | |
if __name__ == "__main__": | |
ai_setup() | |
app = gr_setup() | |
app.launch(show_api=False) | |