|
import gradio as gr |
|
|
|
from langchain_astradb import AstraDBVectorStore |
|
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough, RunnableLambda |
|
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage |
|
from langchain_openai import OpenAIEmbeddings, ChatOpenAI |
|
|
|
import os |
|
|
|
prompt_template = os.environ.get("PROMPT_TEMPLATE") |
|
|
|
prompt = ChatPromptTemplate.from_messages([('system', prompt_template)]) |
|
|
|
AI = False |
|
|
|
def ai_setup(): |
|
global llm, prompt_chain |
|
llm = ChatOpenAI(model = "gpt-4o", temperature=0.8) |
|
|
|
if AI: |
|
embedding = OpenAIEmbeddings() |
|
vstore = AstraDBVectorStore( |
|
embedding=embedding, |
|
collection_name=os.environ.get("ASTRA_DB_COLLECTION"), |
|
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"), |
|
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"), |
|
) |
|
|
|
retriever = vstore.as_retriever(search_kwargs={'k': 10}) |
|
else: |
|
retriever = RunnableLambda(just_read) |
|
|
|
prompt_chain = ( |
|
{"context": retriever, "question": RunnablePassthrough()} |
|
| RunnableLambda(format_context) |
|
| prompt |
|
|
|
|
|
) |
|
|
|
def group_and_sort(documents): |
|
grouped = {} |
|
for document in documents: |
|
title = document.metadata["Title"] |
|
docs = grouped.get(title, []) |
|
grouped[title] = docs |
|
|
|
docs.append((document.page_content, document.metadata["range"])) |
|
|
|
for title, values in grouped.items(): |
|
values.sort(key=lambda doc:doc[1][0]) |
|
|
|
for title in grouped: |
|
text = '' |
|
prev_last = 0 |
|
for fragment, (start, last) in grouped[title]: |
|
if start < prev_last: |
|
text += fragment[prev_last-start:] |
|
elif start == prev_last: |
|
text += fragment |
|
else: |
|
text += ' [...] ' |
|
text += fragment |
|
prev_last = last |
|
|
|
grouped[title] = text |
|
|
|
return grouped |
|
|
|
def format_context(pipeline_state): |
|
"""Print the state passed between Runnables in a langchain and pass it on""" |
|
|
|
context = '' |
|
documents = group_and_sort(pipeline_state["context"]) |
|
for title, text in documents.items(): |
|
context += f"\nTitle: {title}\n" |
|
context += text |
|
context += '\n\n---\n' |
|
|
|
pipeline_state["context"] = context |
|
return pipeline_state |
|
|
|
def just_read(pipeline_state): |
|
fname = "docs.pickle" |
|
import pickle |
|
|
|
return pickle.load(open(fname, "rb")) |
|
|
|
def new_state(): |
|
return gr.State({ |
|
"system": None, |
|
}) |
|
|
|
def chat(message, history, state): |
|
if not history: |
|
system_prompt = prompt_chain.invoke(message) |
|
system_prompt = system_prompt.messages[0] |
|
state["system"] = system_prompt |
|
else: |
|
system_prompt = state["system"] |
|
|
|
messages = [system_prompt] |
|
for human, ai in history: |
|
messages.append(HumanMessage(human)) |
|
messages.append(AIMessage(ai)) |
|
messages.append(HumanMessage(message)) |
|
|
|
all = '' |
|
for response in llm.stream(messages): |
|
all += response.content |
|
yield all |
|
|
|
def gr_main(): |
|
theme = gr.Theme.from_hub("freddyaboulton/[email protected]") |
|
theme.set( |
|
color_accent_soft="#818eb6", |
|
background_fill_secondary="#6272a4", |
|
button_primary_text_color="*button_secondary_text_color", |
|
button_primary_background_fill="*button_secondary_background_fill") |
|
|
|
with gr.Blocks( |
|
title="Sherlock Holmes stories", |
|
fill_height=True, |
|
theme=theme |
|
) as app: |
|
state = new_state() |
|
gr.ChatInterface( |
|
chat, |
|
chatbot=gr.Chatbot(show_label=False, render=False, scale=1), |
|
title="Sherlock Holmes stories", |
|
examples=[ |
|
["I arrived late last night and found a dead goose in my bed"], |
|
["Help please sir. I'm about to get married, to the most lovely lady," |
|
"and I just received a letter threatening me to make public some things" |
|
"of my past I'd rather keep quiet, unless I don't marry"], |
|
], |
|
additional_inputs=[state]) |
|
app.launch(show_api=False) |
|
if __name__ == "__main__": |
|
ai_setup() |
|
gr_main() |