Daoneeee commited on
Commit
01c8a41
·
1 Parent(s): fa456cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -43
app.py CHANGED
@@ -11,68 +11,67 @@ from langchain.chains import ConversationalRetrievalChain
11
  from htmlTemplates import css, bot_template, user_template
12
  from langchain.llms import HuggingFaceHub, LlamaCpp, CTransformers # For loading transformer models.
13
  from langchain.document_loaders import PyPDFLoader, TextLoader, JSONLoader, CSVLoader
14
- import tempfile # 임시 파일을 생성하기 위한 라이브러리입니다.
15
  import os
16
 
 
17
  # PDF 문서로부터 텍스트를 추출하는 함수입니다.
18
  def get_pdf_text(pdf_docs):
19
- temp_dir = tempfile.TemporaryDirectory()
20
- temp_filepath = os.path.join(temp_dir.name, pdf_docs.name)
21
- with open(temp_filepath, "wb") as f:
22
- f.write(pdf_docs.getvalue())
23
- pdf_loader = PyPDFLoader(temp_filepath)
24
- pdf_doc = pdf_loader.load()
25
- return pdf_doc
26
-
27
-
28
- # 텍스트 파일을 처리하는 함수입니다.
29
- def get_text_file(docs):
30
- text = docs.getvalue().decode("utf-8")
 
 
31
  return [text]
32
 
33
 
34
- # CSV 파일을 처리하는 함수입니다.
35
  def get_csv_file(docs):
36
- import pandas as pd
37
- csv_text = docs.getvalue().decode("utf-8")
38
- csv_data = pd.read_csv(pd.compat.StringIO(csv_text))
39
- csv_columns = csv_data.columns.tolist()
40
- csv_rows = csv_data.to_dict(orient='records')
41
- csv_texts = [', '.join([f"{col}: {row[col]}" for col in csv_columns]) for row in csv_rows]
42
- return csv_texts
43
 
44
 
45
- # JSON 파일을 처리하는 함수입니다.
46
  def get_json_file(docs):
47
- import json
48
- json_text = docs.getvalue().decode("utf-8")
49
- json_data = json.loads(json_text)
50
- json_texts = [item.get('text', '') for item in json_data]
51
- return json_texts
52
 
53
 
54
  # 문서들을 처리하여 텍스트 청크로 나누는 함수입니다.
55
  def get_text_chunks(documents):
56
  text_splitter = RecursiveCharacterTextSplitter(
57
- chunk_size=1000,
58
- chunk_overlap=200,
59
- length_function=len
60
  )
61
- return text_splitter.split_documents(documents)
 
 
62
 
63
 
64
  # 텍스트 청크들로부터 벡터 스토어를 생성하는 함수입니다.
65
  def get_vectorstore(text_chunks):
 
 
66
  embeddings = OpenAIEmbeddings()
67
- vectorstore = FAISS.from_documents(text_chunks, embeddings)
68
- return vectorstore
 
69
 
70
 
71
- # 대화 체인을 생성하는 함수입니다.
72
  def get_conversation_chain(vectorstore):
73
  gpt_model_name = 'gpt-3.5-turbo'
74
- llm = ChatOpenAI(model_name=gpt_model_name)
75
- memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)
 
 
 
 
76
  conversation_chain = ConversationalRetrievalChain.from_llm(
77
  llm=llm,
78
  retriever=vectorstore.as_retriever(),
@@ -83,14 +82,18 @@ def get_conversation_chain(vectorstore):
83
 
84
  # 사용자 입력을 처리하는 함수입니다.
85
  def handle_userinput(user_question):
 
86
  response = st.session_state.conversation({'question': user_question})
 
87
  st.session_state.chat_history = response['chat_history']
88
 
89
  for i, message in enumerate(st.session_state.chat_history):
90
  if i % 2 == 0:
91
- st.write(f"<div>{message.content}</div>", unsafe_allow_html=True)
 
92
  else:
93
- st.write(f"<div>{message.content}</div>", unsafe_allow_html=True)
 
94
 
95
 
96
  def main():
@@ -123,22 +126,30 @@ def main():
123
  doc_list = []
124
 
125
  for file in docs:
 
126
  if file.type == 'text/plain':
 
127
  doc_list.extend(get_text_file(file))
128
- elif file.type == 'application/pdf':
 
129
  doc_list.extend(get_pdf_text(file))
130
  elif file.type == 'text/csv':
 
131
  doc_list.extend(get_csv_file(file))
132
  elif file.type == 'application/json':
 
133
  doc_list.extend(get_json_file(file))
134
 
 
135
  text_chunks = get_text_chunks(doc_list)
 
 
136
  vectorstore = get_vectorstore(text_chunks)
137
- st.session_state.conversation = get_conversation_chain(vectorstore)
138
 
139
- if user_question and st.session_state.conversation: # 대화 체인이 있을 때만 사용자 입력 처리
140
- handle_userinput(user_question)
 
141
 
142
 
143
  if __name__ == '__main__':
144
- main()
 
11
  from htmlTemplates import css, bot_template, user_template
12
  from langchain.llms import HuggingFaceHub, LlamaCpp, CTransformers # For loading transformer models.
13
  from langchain.document_loaders import PyPDFLoader, TextLoader, JSONLoader, CSVLoader
14
+ import tempfile # 임시 파일을 생성하기 위한 라이브러리입니다.
15
  import os
16
 
17
+
18
  # PDF 문서로부터 텍스트를 추출하는 함수입니다.
19
  def get_pdf_text(pdf_docs):
20
+ temp_dir = tempfile.TemporaryDirectory() # 임시 디렉토리를 생성합니다.
21
+ temp_filepath = os.path.join(temp_dir.name, pdf_docs.name) # 임시 파일 경로를 생성합니다.
22
+ with open(temp_filepath, "wb") as f: # 임시 파일을 바이너리 쓰기 모드로 엽니다.
23
+ f.write(pdf_docs.getvalue()) # PDF 문서의 내용을 임시 파일에 씁니다.
24
+ pdf_loader = PyPDFLoader(temp_filepath) # PyPDFLoader를 사용해 PDF를 로드합니다.
25
+ pdf_doc = pdf_loader.load() # 텍스트를 추출합니다.
26
+ return pdf_doc # 추출한 텍스트를 반환합니다.
27
+
28
+
29
+ # 과제
30
+ # 아래 텍스트 추출 함수를 작성
31
+
32
+ def get_text_file(doc):
33
+ text = doc.getvalue().decode("utf-8")
34
  return [text]
35
 
36
 
 
37
  def get_csv_file(docs):
38
+ pass
 
 
 
 
 
 
39
 
40
 
 
41
  def get_json_file(docs):
42
+ pass
 
 
 
 
43
 
44
 
45
  # 문서들을 처리하여 텍스트 청크로 나누는 함수입니다.
46
  def get_text_chunks(documents):
47
  text_splitter = RecursiveCharacterTextSplitter(
48
+ chunk_size=1000, # 청크의 크기를 지정합니다.
49
+ chunk_overlap=200, # 청크 사이의 중복을 지정합니다.
50
+ length_function=len # 텍스트의 길이를 측정하는 함수를 지정합니다.
51
  )
52
+
53
+ documents = text_splitter.split_documents(documents) # 문서들을 청크로 나눕니다
54
+ return documents # 나눈 청크를 반환합니다.
55
 
56
 
57
  # 텍스트 청크들로부터 벡터 스토어를 생성하는 함수입니다.
58
  def get_vectorstore(text_chunks):
59
+ # OpenAI 임베딩 모델을 로드합니다. (Embedding models - Ada v2)
60
+
61
  embeddings = OpenAIEmbeddings()
62
+ vectorstore = FAISS.from_documents(text_chunks, embeddings) # FAISS 벡터 스토어를 생성합니다.
63
+
64
+ return vectorstore # 생성된 벡터 스토어를 반환합니다.
65
 
66
 
 
67
  def get_conversation_chain(vectorstore):
68
  gpt_model_name = 'gpt-3.5-turbo'
69
+ llm = ChatOpenAI(model_name=gpt_model_name) # gpt-3.5 모델 로드
70
+
71
+ # 대화 기록을 저장하기 위한 메모리를 생성합니다.
72
+ memory = ConversationBufferMemory(
73
+ memory_key='chat_history', return_messages=True)
74
+ # 대화 검색 체인을 생성합니다.
75
  conversation_chain = ConversationalRetrievalChain.from_llm(
76
  llm=llm,
77
  retriever=vectorstore.as_retriever(),
 
82
 
83
  # 사용자 입력을 처리하는 함수입니다.
84
  def handle_userinput(user_question):
85
+ # 대화 체인을 사용하여 사용자 질문에 대한 응답을 생성합니다.
86
  response = st.session_state.conversation({'question': user_question})
87
+ # 대화 기록을 저장합니다.
88
  st.session_state.chat_history = response['chat_history']
89
 
90
  for i, message in enumerate(st.session_state.chat_history):
91
  if i % 2 == 0:
92
+ st.write(user_template.replace(
93
+ "{{MSG}}", message.content), unsafe_allow_html=True)
94
  else:
95
+ st.write(bot_template.replace(
96
+ "{{MSG}}", message.content), unsafe_allow_html=True)
97
 
98
 
99
  def main():
 
126
  doc_list = []
127
 
128
  for file in docs:
129
+ print('file - type : ', file.type)
130
  if file.type == 'text/plain':
131
+ # file is .txt
132
  doc_list.extend(get_text_file(file))
133
+ elif file.type in ['application/octet-stream', 'application/pdf']:
134
+ # file is .pdf
135
  doc_list.extend(get_pdf_text(file))
136
  elif file.type == 'text/csv':
137
+ # file is .csv
138
  doc_list.extend(get_csv_file(file))
139
  elif file.type == 'application/json':
140
+ # file is .json
141
  doc_list.extend(get_json_file(file))
142
 
143
+ # get the text chunks
144
  text_chunks = get_text_chunks(doc_list)
145
+
146
+ # create vector store
147
  vectorstore = get_vectorstore(text_chunks)
 
148
 
149
+ # create conversation chain
150
+ st.session_state.conversation = get_conversation_chain(
151
+ vectorstore)
152
 
153
 
154
  if __name__ == '__main__':
155
+ main()