Daoneeee commited on
Commit
4ca555a
·
1 Parent(s): 8bc6aeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -74
app.py CHANGED
@@ -1,21 +1,16 @@
1
  import streamlit as st
2
  from dotenv import load_dotenv
3
- from PyPDF2 import PdfReader
4
- from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
5
  from langchain.embeddings import OpenAIEmbeddings, HuggingFaceInstructEmbeddings
6
- from langchain.vectorstores import FAISS, Chroma
7
- from langchain.embeddings import HuggingFaceEmbeddings # General embeddings from HuggingFace models.
8
  from langchain.chat_models import ChatOpenAI
9
  from langchain.memory import ConversationBufferMemory
10
  from langchain.chains import ConversationalRetrievalChain
11
- from htmlTemplates import css, bot_template, user_template
12
- from langchain.llms import HuggingFaceHub, LlamaCpp, CTransformers # For loading transformer models.
13
  from langchain.document_loaders import PyPDFLoader, TextLoader, JSONLoader, CSVLoader
14
- import tempfile # 임시 파일을 생성하기 위한 라이브러리입니다.
15
  import os
16
 
17
 
18
- # PDF 문서로부터 텍스트를 추출하는 함수입니다.
19
  def get_pdf_text(pdf_docs):
20
  temp_dir = tempfile.TemporaryDirectory()
21
  temp_filepath = os.path.join(temp_dir.name, pdf_docs.name)
@@ -23,72 +18,50 @@ def get_pdf_text(pdf_docs):
23
  f.write(pdf_docs.getvalue())
24
  pdf_loader = PyPDFLoader(temp_filepath)
25
  pdf_doc = pdf_loader.load()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- text_list = [] # 각 페이지의 텍스트를 저장할 리스트
28
-
29
- for page_num in range(len(pdf_doc)):
30
- text_list.append(pdf_doc.get_page_text(page_num))
31
-
32
- return text_list
33
-
34
- # 과제
35
- # 아래 텍스트 추출 함수를 작성
36
-
37
- def get_text_file(text_docs):
38
- text_content = text_docs.getvalue().decode("utf-8")
39
- return [text_content]
40
-
41
- def get_csv_file(csv_docs):
42
- csv_content = csv_docs.getvalue().decode("utf-8")
43
- csv_data = pd.read_csv(pd.compat.StringIO(csv_content))
44
- text_list = []
45
- for column in csv_data.columns:
46
- text_list.extend(csv_data[column].astype(str).tolist())
47
- return text_list
48
-
49
- def get_json_file(json_docs):
50
- json_content = json_docs.getvalue().decode("utf-8")
51
- json_data = json.loads(json_content)
52
- text_list = []
53
- for key, value in json_data.items():
54
- if isinstance(value, str):
55
- text_list.append(value)
56
- elif isinstance(value, list):
57
- text_list.extend(value)
58
- elif isinstance(value, dict):
59
- text_list.extend(value.values())
60
- return text_list
61
-
62
- # 문서들을 처리하여 텍스트 청크로 나누는 함수입니다.
63
  def get_text_chunks(documents):
64
  text_splitter = RecursiveCharacterTextSplitter(
65
- chunk_size=1000, # 청크의 크기를 지정합니다.
66
- chunk_overlap=200, # 청크 사이의 중복을 지정합니다.
67
- length_function=len # 텍스트의 길이를 측정하는 함수를 지정합니다.
68
  )
69
 
70
- documents = text_splitter.split_documents(documents) # 문서들을 청크로 나눕니다
71
- return documents # 나눈 청크를 반환합니다.
72
 
73
 
74
- # 텍스트 청크들로부터 벡터 스토어를 생성하는 함수입니다.
75
  def get_vectorstore(text_chunks):
76
- # OpenAI 임베딩 모델을 로드합니다. (Embedding models - Ada v2)
77
-
78
  embeddings = OpenAIEmbeddings()
79
- vectorstore = FAISS.from_documents(text_chunks, embeddings) # FAISS 벡터 스토어를 생성합니다.
80
-
81
- return vectorstore # 생성된 벡터 스토어를 반환합니다.
82
 
83
 
84
  def get_conversation_chain(vectorstore):
85
  gpt_model_name = 'gpt-3.5-turbo'
86
- llm = ChatOpenAI(model_name = gpt_model_name) #gpt-3.5 모델 로드
87
-
88
- # 대화 기록을 저장하기 위한 메모리를 생성합니다.
89
  memory = ConversationBufferMemory(
90
  memory_key='chat_history', return_messages=True)
91
- # 대화 검색 체인을 생성합니다.
92
  conversation_chain = ConversationalRetrievalChain.from_llm(
93
  llm=llm,
94
  retriever=vectorstore.as_retriever(),
@@ -96,11 +69,9 @@ def get_conversation_chain(vectorstore):
96
  )
97
  return conversation_chain
98
 
99
- # 사용자 입력을 처리하는 함수입니다.
100
  def handle_userinput(user_question):
101
- # 대화 체인을 사용하여 사용자 질문에 대한 응답을 생성합니다.
102
  response = st.session_state.conversation({'question': user_question})
103
- # 대화 기록을 저장합니다.
104
  st.session_state.chat_history = response['chat_history']
105
 
106
  for i, message in enumerate(st.session_state.chat_history):
@@ -135,34 +106,23 @@ def main():
135
 
136
  st.subheader("Your documents")
137
  docs = st.file_uploader(
138
- "Upload your PDFs here and click on 'Process'", accept_multiple_files=True)
139
  if st.button("Process"):
140
  with st.spinner("Processing"):
141
- # get pdf text
142
  doc_list = []
143
 
144
  for file in docs:
145
- print('file - type : ', file.type)
146
  if file.type == 'text/plain':
147
- # file is .txt
148
  doc_list.extend(get_text_file(file))
149
  elif file.type in ['application/octet-stream', 'application/pdf']:
150
- # file is .pdf
151
  doc_list.extend(get_pdf_text(file))
152
  elif file.type == 'text/csv':
153
- # file is .csv
154
  doc_list.extend(get_csv_file(file))
155
  elif file.type == 'application/json':
156
- # file is .json
157
  doc_list.extend(get_json_file(file))
158
 
159
- # get the text chunks
160
  text_chunks = get_text_chunks(doc_list)
161
-
162
- # create vector store
163
  vectorstore = get_vectorstore(text_chunks)
164
-
165
- # create conversation chain
166
  st.session_state.conversation = get_conversation_chain(
167
  vectorstore)
168
 
 
1
  import streamlit as st
2
  from dotenv import load_dotenv
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
 
4
  from langchain.embeddings import OpenAIEmbeddings, HuggingFaceInstructEmbeddings
5
+ from langchain.vectorstores import FAISS
 
6
  from langchain.chat_models import ChatOpenAI
7
  from langchain.memory import ConversationBufferMemory
8
  from langchain.chains import ConversationalRetrievalChain
 
 
9
  from langchain.document_loaders import PyPDFLoader, TextLoader, JSONLoader, CSVLoader
10
+ import tempfile
11
  import os
12
 
13
 
 
14
  def get_pdf_text(pdf_docs):
15
  temp_dir = tempfile.TemporaryDirectory()
16
  temp_filepath = os.path.join(temp_dir.name, pdf_docs.name)
 
18
  f.write(pdf_docs.getvalue())
19
  pdf_loader = PyPDFLoader(temp_filepath)
20
  pdf_doc = pdf_loader.load()
21
+ return pdf_doc
22
+
23
+
24
+ def get_text_file(docs):
25
+ text_loader = TextLoader(docs.name)
26
+ text = text_loader.load()
27
+ return [text]
28
+
29
+
30
+ def get_csv_file(docs):
31
+ csv_loader = CSVLoader(docs.name)
32
+ csv_text = csv_loader.load()
33
+ return csv_text.values.tolist()
34
+
35
+
36
+ def get_json_file(docs):
37
+ json_loader = JSONLoader(docs.name)
38
+ json_text = json_loader.load()
39
+ return [json_text]
40
+
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def get_text_chunks(documents):
43
  text_splitter = RecursiveCharacterTextSplitter(
44
+ chunk_size=1000,
45
+ chunk_overlap=200,
46
+ length_function=len
47
  )
48
 
49
+ documents = text_splitter.split_documents(documents)
50
+ return documents
51
 
52
 
 
53
  def get_vectorstore(text_chunks):
 
 
54
  embeddings = OpenAIEmbeddings()
55
+ vectorstore = FAISS.from_documents(text_chunks, embeddings)
56
+ return vectorstore
 
57
 
58
 
59
  def get_conversation_chain(vectorstore):
60
  gpt_model_name = 'gpt-3.5-turbo'
61
+ llm = ChatOpenAI(model_name=gpt_model_name)
62
+
 
63
  memory = ConversationBufferMemory(
64
  memory_key='chat_history', return_messages=True)
 
65
  conversation_chain = ConversationalRetrievalChain.from_llm(
66
  llm=llm,
67
  retriever=vectorstore.as_retriever(),
 
69
  )
70
  return conversation_chain
71
 
72
+
73
  def handle_userinput(user_question):
 
74
  response = st.session_state.conversation({'question': user_question})
 
75
  st.session_state.chat_history = response['chat_history']
76
 
77
  for i, message in enumerate(st.session_state.chat_history):
 
106
 
107
  st.subheader("Your documents")
108
  docs = st.file_uploader(
109
+ "Upload your files here and click on 'Process'", accept_multiple_files=True)
110
  if st.button("Process"):
111
  with st.spinner("Processing"):
 
112
  doc_list = []
113
 
114
  for file in docs:
 
115
  if file.type == 'text/plain':
 
116
  doc_list.extend(get_text_file(file))
117
  elif file.type in ['application/octet-stream', 'application/pdf']:
 
118
  doc_list.extend(get_pdf_text(file))
119
  elif file.type == 'text/csv':
 
120
  doc_list.extend(get_csv_file(file))
121
  elif file.type == 'application/json':
 
122
  doc_list.extend(get_json_file(file))
123
 
 
124
  text_chunks = get_text_chunks(doc_list)
 
 
125
  vectorstore = get_vectorstore(text_chunks)
 
 
126
  st.session_state.conversation = get_conversation_chain(
127
  vectorstore)
128