task_Chatbot / app.py
Daoneeee's picture
Update app.py
2b10cea
raw
history blame
4.67 kB
import streamlit as st
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain.document_loaders import PyPDFLoader
import tempfile
import os
# PDF 문서로부터 텍스트를 추출하는 함수입니다.
def get_pdf_text(pdf_docs):
temp_dir = tempfile.TemporaryDirectory()
temp_filepath = os.path.join(temp_dir.name, pdf_docs.name)
with open(temp_filepath, "wb") as f:
f.write(pdf_docs.getvalue())
pdf_loader = PyPDFLoader(temp_filepath)
pdf_doc = pdf_loader.load()
return pdf_doc
# 텍스트 파일을 처리하는 함수입니다.
def get_text_file(docs):
text = docs.getvalue().decode("utf-8")
return [text]
# CSV 파일을 처리하는 함수입니다.
def get_csv_file(docs):
import pandas as pd
csv_text = docs.getvalue().decode("utf-8")
csv_data = pd.read_csv(pd.compat.StringIO(csv_text))
csv_columns = csv_data.columns.tolist()
csv_rows = csv_data.to_dict(orient='records')
csv_texts = [', '.join([f"{col}: {row[col]}" for col in csv_columns]) for row in csv_rows]
return csv_texts
# JSON 파일을 처리하는 함수입니다.
def get_json_file(docs):
import json
json_text = docs.getvalue().decode("utf-8")
json_data = json.loads(json_text)
json_texts = [item.get('text', '') for item in json_data]
return json_texts
# 문서들을 처리하여 텍스트 청크로 나누는 함수입니다.
def get_text_chunks(documents):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len
)
return text_splitter.split_documents(documents)
# 텍스트 청크들로부터 벡터 스토어를 생성하는 함수입니다.
def get_vectorstore(text_chunks):
embeddings = OpenAIEmbeddings()
vectorstore = FAISS.from_documents(text_chunks, embeddings)
return vectorstore
# 대화 체인을 생성하는 함수입니다.
def get_conversation_chain(vectorstore):
gpt_model_name = 'gpt-3.5-turbo'
llm = ChatOpenAI(model_name=gpt_model_name)
memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vectorstore.as_retriever(),
memory=memory
)
return conversation_chain
# 사용자 입력을 처리하는 함수입니다.
def handle_userinput(user_question):
response = st.session_state.conversation({'question': user_question})
st.session_state.chat_history = response['chat_history']
for i, message in enumerate(st.session_state.chat_history):
if i % 2 == 0:
st.write(f"<div>{message.content}</div>", unsafe_allow_html=True)
else:
st.write(f"<div>{message.content}</div>", unsafe_allow_html=True)
def main():
st.set_page_config(page_title="Chat with multiple Files", page_icon=":books:")
if "conversation" not in st.session_state:
st.session_state.conversation = None
if "chat_history" not in st.session_state:
st.session_state.chat_history = None
st.header("Chat with multiple Files :")
user_question = st.text_input("Ask a question about your documents:")
with st.sidebar:
st.subheader("Your documents")
docs = st.file_uploader(
"Upload your files here and click on 'Process'",
accept_multiple_files=True
)
if st.button("Process"):
with st.spinner("Processing"):
doc_list = []
for file in docs:
if file.type == 'text/plain':
doc_list.extend(get_text_file(file))
elif file.type == 'application/pdf':
doc_list.extend(get_pdf_text(file))
elif file.type == 'text/csv':
doc_list.extend(get_csv_file(file))
elif file.type == 'application/json':
doc_list.extend(get_json_file(file))
text_chunks = get_text_chunks(doc_list)
vectorstore = get_vectorstore(text_chunks)
st.session_state.conversation = get_conversation_chain(vectorstore)
if user_question and st.session_state.conversation: # 대화 체인이 있을 때만 사용자 입력 처리
handle_userinput(user_question)
if __name__ == '__main__':
main()