devai-demo / conversation.py
aakash0017's picture
Upload folder using huggingface_hub
bb8eccf
raw
history blame
3.49 kB
from langchain.document_loaders import TextLoader
import pinecone
from langchain.vectorstores import Pinecone
import os
from transformers import AutoTokenizer, AutoModel
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain.chat_models import ChatOpenAI
import torch
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (AgentTokenBufferMemory)
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.schema.messages import SystemMessage
from langchain.prompts import MessagesPlaceholder
import gradio as gr
import time
from db_func import insert_one
def get_bert_embeddings(sentence):
embeddings = []
input_ids = tokenizer.encode(sentence, return_tensors="pt")
with torch.no_grad():
output = model(input_ids)
embedding = output.last_hidden_state[:,0,:].numpy().tolist()
return embedding
model_name = "BAAI/bge-base-en-v1.5"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt_file = open("prompts/version_2.txt", "r").read()
pinecone.init(
api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io
environment=os.getenv("PINECONE_ENV"), # next to api key in console
)
index_name = "ophtal-knowledge-base"
index = pinecone.Index(index_name)
vectorstore = Pinecone(index, get_bert_embeddings, "text")
retriever = vectorstore.as_retriever()
tool = create_retriever_tool(
retriever,
"search_ophtal-knowledge-base",
"Searches and returns documents regarding the ophtal-knowledge-base.",
)
tools = [tool]
system_message = SystemMessage(content=prompt_file)
memory_key='history'
llm = ChatOpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4", temperature=0.2)
prompt = OpenAIFunctionsAgent.create_prompt(
system_message=system_message,
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)],
)
agent_executor = create_conversational_retrieval_agent(llm, tools, verbose=False, prompt=prompt)
user_name = None
def run(input_):
output = agent_executor({"input": input_})
output_text = output["output"]
source_text = ""
doc_text = ""
if len(output["intermediate_steps"])>0:
documents = output["intermediate_steps"][0][1]
sources = []
docs = []
for doc in documents:
if doc.metadata["source"] not in sources:
sources.append(doc.metadata["source"])
docs.append(doc.page_content)
for i in range(len(sources)):
temp = sources[i].replace('.pdf', '').replace('.txt', '').replace("AAO", "").replace("2022-2023", "").replace("data/book", "").replace("text", "").replace(" ", " ")
source_text += f"{i+1}. {temp}\n"
doc_text += f"{i+1}. {docs[i]}\n"
output_text = f"{output_text} \n\nSources: \n{source_text}\n\nDocuments: \n{doc_text}"
doc_to_insert = {
"user": user_name,
"input": input_,
"output": output_text,
"source": source_text,
"documents": doc_text
}
insert_one(doc_to_insert)
return output_text
def make_conversation(message, history):
text_ = run(message)
for i in range(len(text_)):
time.sleep(0.001)
yield text_[: i+1]
def auth_function(username, password):
user_name = username
return username == password