Last commit not found
import gradio as gr | |
from langchain_astradb import AstraDBVectorStore | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda | |
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage | |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
from json import loads as json_loads | |
import os | |
prompt_template = os.environ.get("PROMPT_TEMPLATE") | |
prompt = ChatPromptTemplate.from_messages([('system', prompt_template)]) | |
AI = True | |
def ai_setup(): | |
global llm, prompt_chain | |
llm = ChatOpenAI(model = "gpt-4o", temperature=0.8) | |
if AI: | |
embedding = OpenAIEmbeddings() | |
vstore = AstraDBVectorStore( | |
embedding=embedding, | |
collection_name=os.environ.get("ASTRA_DB_COLLECTION"), | |
token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"), | |
api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"), | |
) | |
retriever = vstore.as_retriever(search_kwargs={'k': 10}) | |
else: | |
retriever = RunnableLambda(just_read) | |
prompt_chain = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| RunnableLambda(format_context) | |
| prompt | |
# | llm | |
# | StrOutputParser() | |
) | |
def group_and_sort(documents): | |
grouped = {} | |
for document in documents: | |
title = document.metadata["Title"] | |
docs = grouped.get(title, []) | |
grouped[title] = docs | |
docs.append((document.page_content, document.metadata["range"])) | |
for title, values in grouped.items(): | |
values.sort(key=lambda doc:doc[1][0]) | |
for title in grouped: | |
text = '' | |
prev_last = 0 | |
for fragment, (start, last) in grouped[title]: | |
if start < prev_last: | |
text += fragment[prev_last-start:] | |
elif start == prev_last: | |
text += fragment | |
else: | |
text += ' [...] ' | |
text += fragment | |
prev_last = last | |
grouped[title] = text | |
return grouped | |
def format_context(pipeline_state): | |
"""Print the state passed between Runnables in a langchain and pass it on""" | |
context = '' | |
documents = group_and_sort(pipeline_state["context"]) | |
for title, text in documents.items(): | |
context += f"\nTitle: {title}\n" | |
context += text | |
context += '\n\n---\n' | |
pipeline_state["context"] = context | |
return pipeline_state | |
def just_read(pipeline_state): | |
fname = "docs.pickle" | |
import pickle | |
return pickle.load(open(fname, "rb")) | |
def new_state(): | |
return gr.State({ | |
"user": None, | |
"system": None, | |
}) | |
def auth(token, state): | |
tokens=os.environ.get("APP_TOKENS", None) | |
if tokens is None: | |
state["user"] = "anonymous" | |
else: | |
tokens=json_loads(tokens) | |
state["user"] = tokens.get(token, None) | |
return "", state | |
AUTH_JS = """function auth_js(token, state) { | |
if (!!document.location.hash) { | |
token = document.location.hash | |
document.location.hash="" | |
} | |
return [token, state] | |
} | |
""" | |
def chat(message, history, state): | |
if (state is None) or (not state['user']): | |
gr.Warning("You need to authenticate first") | |
yield "You need to authenticate first" | |
else: | |
if not history: | |
system_prompt = prompt_chain.invoke(message) | |
system_prompt = system_prompt.messages[0] | |
state["system"] = system_prompt | |
else: | |
system_prompt = state["system"] | |
messages = [system_prompt] | |
for human, ai in history: | |
messages.append(HumanMessage(human)) | |
messages.append(AIMessage(ai)) | |
messages.append(HumanMessage(message)) | |
all = '' | |
for response in llm.stream(messages): | |
all += response.content | |
yield all | |
def gr_main(): | |
theme = gr.Theme.from_hub("freddyaboulton/[email protected]") | |
theme.set( | |
color_accent_soft="#818eb6", # ChatBot.svelte / .message-row.panel.user-row | |
background_fill_secondary="#6272a4", # ChatBot.svelte / .message-row.panel.bot-row | |
button_primary_text_color="*button_secondary_text_color", | |
button_primary_background_fill="*button_secondary_background_fill") | |
with gr.Blocks( | |
title="Sherlock Holmes stories", | |
fill_height=True, | |
theme=theme | |
) as app: | |
state = new_state() | |
gr.ChatInterface( | |
chat, | |
chatbot=gr.Chatbot(show_label=False, render=False, scale=1), | |
title="Sherlock Holmes stories", | |
examples=[ | |
["I arrived late last night and found a dead goose in my bed"], | |
["Help please sir. I'm about to get married, to the most lovely lady," | |
"and I just received a letter threatening me to make public some things" | |
"of my past I'd rather keep quiet, unless I don't marry"], | |
], | |
additional_inputs=[state]) | |
token = gr.Textbox(visible=False) | |
app.load(auth, | |
[token,state], | |
[token,state], | |
js=AUTH_JS) | |
app.launch(show_api=False) | |
if __name__ == "__main__": | |
ai_setup() | |
gr_main() |