Whatson / query.py
gerasdf
first v
cf5e123
raw
history blame
4.57 kB
import gradio as gr
from langchain_astradb import AstraDBVectorStore
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
import os
prompt_template = os.environ.get("PROMPT_TEMPLATE")
prompt = ChatPromptTemplate.from_messages([('system', prompt_template)])
AI = False
def ai_setup():
global llm, prompt_chain
llm = ChatOpenAI(model = "gpt-4o", temperature=0.8)
if AI:
embedding = OpenAIEmbeddings()
vstore = AstraDBVectorStore(
embedding=embedding,
collection_name=os.environ.get("ASTRA_DB_COLLECTION"),
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
)
retriever = vstore.as_retriever(search_kwargs={'k': 10})
else:
retriever = RunnableLambda(just_read)
prompt_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| RunnableLambda(format_context)
| prompt
# | llm
# | StrOutputParser()
)
def group_and_sort(documents):
grouped = {}
for document in documents:
title = document.metadata["Title"]
docs = grouped.get(title, [])
grouped[title] = docs
docs.append((document.page_content, document.metadata["range"]))
for title, values in grouped.items():
values.sort(key=lambda doc:doc[1][0])
for title in grouped:
text = ''
prev_last = 0
for fragment, (start, last) in grouped[title]:
if start < prev_last:
text += fragment[prev_last-start:]
elif start == prev_last:
text += fragment
else:
text += ' [...] '
text += fragment
prev_last = last
grouped[title] = text
return grouped
def format_context(pipeline_state):
"""Print the state passed between Runnables in a langchain and pass it on"""
context = ''
documents = group_and_sort(pipeline_state["context"])
for title, text in documents.items():
context += f"\nTitle: {title}\n"
context += text
context += '\n\n---\n'
pipeline_state["context"] = context
return pipeline_state
def just_read(pipeline_state):
fname = "docs.pickle"
import pickle
return pickle.load(open(fname, "rb"))
def new_state():
return gr.State({
"system": None,
})
def chat(message, history, state):
if not history:
system_prompt = prompt_chain.invoke(message)
system_prompt = system_prompt.messages[0]
state["system"] = system_prompt
else:
system_prompt = state["system"]
messages = [system_prompt]
for human, ai in history:
messages.append(HumanMessage(human))
messages.append(AIMessage(ai))
messages.append(HumanMessage(message))
all = ''
for response in llm.stream(messages):
all += response.content
yield all
def gr_main():
theme = gr.Theme.from_hub("freddyaboulton/[email protected]")
theme.set(
color_accent_soft="#818eb6", # ChatBot.svelte / .message-row.panel.user-row
background_fill_secondary="#6272a4", # ChatBot.svelte / .message-row.panel.bot-row
button_primary_text_color="*button_secondary_text_color",
button_primary_background_fill="*button_secondary_background_fill")
with gr.Blocks(
title="Sherlock Holmes stories",
fill_height=True,
theme=theme
) as app:
state = new_state()
gr.ChatInterface(
chat,
chatbot=gr.Chatbot(show_label=False, render=False, scale=1),
title="Sherlock Holmes stories",
examples=[
["I arrived late last night and found a dead goose in my bed"],
["Help please sir. I'm about to get married, to the most lovely lady,"
"and I just received a letter threatening me to make public some things"
"of my past I'd rather keep quiet, unless I don't marry"],
],
additional_inputs=[state])
app.launch(show_api=False)
if __name__ == "__main__":
ai_setup()
gr_main()