talktosayno / app.py
Ridealist's picture
refactor: Add Comments for explainig codes and adjust linting
6a11f08
import gradio as gr
import openai
import os
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from dotenv import load_dotenv
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.vectorstores.elastic_vector_search import ElasticVectorSearch
from PyPDF2 import PdfWriter
load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv('my_secret')
openai.api_key = os.getenv('my_secret')
## Load PDF file
loader = PyPDFLoader("docs.pdf")
documents = loader.load()
## Split Document
text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
## token -> Vector Embedding
embeddings = OpenAIEmbeddings()
vector_store = Chroma.from_documents(texts, embeddings)
retriever = vector_store.as_retriever(search_kwargs={"k": 2})
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
## Build LLM Chain
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0) # Modify model_name if you have access to GPT-4
chain = RetrievalQAWithSourcesChain.from_chain_type(
llm=llm,
chain_type="stuff",
retriever = retriever,
return_source_documents=True)
system_template="""Use the following pieces of context to answer the users question shortly.
Given the following summaries of a long document and a question, create a final answer with references ("SOURCES"), use "SOURCES" in capital letters regardless of the number of sources.
If you don't know the answer, just say that "I don't know", don't try to make up an answer.
----------------
{summaries}
You MUST answer in Korean and in Markdown format:"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}")
]
prompt = ChatPromptTemplate.from_messages(messages)
############################
## Local μ—μ„œ 잘 λ˜λŠ”μ§€ 확인
chain_type_kwargs = {"prompt": prompt}
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0) # Modify model_name if you have access to GPT-4
chain = RetrievalQAWithSourcesChain.from_chain_type(
llm=llm,
chain_type="stuff",
retriever = retriever,
return_source_documents=True,
chain_type_kwargs=chain_type_kwargs
)
query = "ν–‰λ³΅ν•œ μΈμƒμ΄λž€?"
result = chain(query)
for doc in result['source_documents']:
print('λ‚΄μš© : ' + doc.page_content[0:100].replace('\n', ' '))
print('파일 : ' + doc.metadata['source'])
print('νŽ˜μ΄μ§€ : ' + str(doc.metadata['page']))
##############################
## Define response method
def respond(message, chat_history): # μ±„νŒ…λ΄‡μ˜ 응닡을 μ²˜λ¦¬ν•˜λŠ” ν•¨μˆ˜λ₯Ό μ •μ˜ν•©λ‹ˆλ‹€.
result = chain(message)
bot_message = result['answer']
for i, doc in enumerate(result['source_documents']):
bot_message += '[' + str(i+1) + '] ' + doc.metadata['source'] + '(' + str(doc.metadata['page']) + ') '
chat_history.append((message, bot_message)) # μ±„νŒ… 기둝에 μ‚¬μš©μžμ˜ λ©”μ‹œμ§€μ™€ λ΄‡μ˜ 응닡을 μΆ”κ°€ν•©λ‹ˆλ‹€.
return "", chat_history # μˆ˜μ •λœ μ±„νŒ… 기둝을 λ°˜ν™˜ν•©λ‹ˆλ‹€.
## Build Gradio App
with gr.Blocks(theme='gstaff/sketch') as demo: # gr.Blocks()λ₯Ό μ‚¬μš©ν•˜μ—¬ μΈν„°νŽ˜μ΄μŠ€λ₯Ό μƒμ„±ν•©λ‹ˆλ‹€.
gr.Markdown("# μ•ˆλ…•ν•˜μ„Έμš”. 세이노와 λŒ€ν™”ν•΄λ³΄μ„Έμš”.")
chatbot = gr.Chatbot(label="μ±„νŒ…μ°½") # 'μ±„νŒ…μ°½'μ΄λΌλŠ” λ ˆμ΄λΈ”μ„ 가진 μ±„νŒ…λ΄‡ μ»΄ν¬λ„ŒνŠΈλ₯Ό μƒμ„±ν•©λ‹ˆλ‹€.
msg = gr.Textbox(label="μž…λ ₯") # 'μž…λ ₯'μ΄λΌλŠ” λ ˆμ΄λΈ”μ„ 가진 ν…μŠ€νŠΈλ°•μŠ€λ₯Ό μƒμ„±ν•©λ‹ˆλ‹€.
clear = gr.Button("μ΄ˆκΈ°ν™”") # 'μ΄ˆκΈ°ν™”'λΌλŠ” λ ˆμ΄λΈ”μ„ 가진 λ²„νŠΌμ„ μƒμ„±ν•©λ‹ˆλ‹€.
msg.submit(respond, [msg, chatbot], [msg, chatbot]) # ν…μŠ€νŠΈλ°•μŠ€μ— λ©”μ‹œμ§€λ₯Ό μž…λ ₯ν•˜κ³  μ œμΆœν•˜λ©΄ respond ν•¨μˆ˜κ°€ ν˜ΈμΆœλ˜λ„λ‘ ν•©λ‹ˆλ‹€.
clear.click(lambda: None, None, chatbot, queue=False) # 'μ΄ˆκΈ°ν™”' λ²„νŠΌμ„ ν΄λ¦­ν•˜λ©΄ μ±„νŒ… 기둝을 μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€.
demo.launch(debug=True) # μΈν„°νŽ˜μ΄μŠ€λ₯Ό μ‹€ν–‰ν•©λ‹ˆλ‹€. μ‹€ν–‰ν•˜λ©΄ μ‚¬μš©μžλŠ” 'μž…λ ₯' ν…μŠ€νŠΈλ°•μŠ€μ— λ©”μ‹œμ§€λ₯Ό μž‘μ„±ν•˜κ³  μ œμΆœν•  수 있으며, 'μ΄ˆκΈ°ν™”' λ²„νŠΌμ„ 톡해 μ±„νŒ… 기둝을 μ΄ˆκΈ°ν™” ν•  수 μžˆμŠ΅λ‹ˆλ‹€.