k-historybot-2 / app.py
lugiiing's picture
Update app.py
eef9ead verified
import os
import openai
import streamlit as st
from dotenv import load_dotenv
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms import OpenAI
from langchain.chains import RetrievalQA
from langchain.document_loaders import TextLoader
from langchain.document_loaders import DirectoryLoader
# ํƒ€์ดํ‹€ ์ ์šฉ, # ํŠน์ˆ˜ ์ด๋ชจํ‹ฐ์ฝ˜ ์‚ฝ์ž… ์˜ˆ์‹œ
# emoji: https://streamlit-emoji-shortcodes-streamlit-app-gwckff.streamlit.app/
st.title(':robot_face:ํ•œ๊ตญ์‚ฌ๋ด‡? ๋„ˆ ์ •๋ง ๋˜‘๋˜‘ํ•˜๋‹ˆ?')
# ์บก์…˜ ์ ์šฉ
st.caption('ํ•œ๊ตญ์‚ฌ ๊ต๊ณผ์„œ๋ฅผ ์ฝ๊ณ  ์ดํ•ดํ•˜์—ฌ ๋‹ต๋ณ€์„ ์ œ๊ณตํ•˜๋Š” ๋กœ๋ด‡์ž…๋‹ˆ๋‹ค. ๋กœ๋ด‡์ด ๊ต๊ณผ์„œ๋ฅผ ์ œ๋Œ€๋กœ ์ดํ•ดํ–ˆ๋Š”์ง€ ํ™•์ธํ•ด๋ณด์„ธ์š”!')
# ๋งˆํฌ๋‹ค์šด ๋ถ€๊ฐ€์„ค๋ช…
st.markdown('###### ์งˆ๋ฌธ, ์š”์•ฝ ๋“ฑ ๋‹ค์–‘ํ•œ ๋ถ€ํƒ์„ ํ•ด ๋ณด์„ธ์š”! ๊ต๊ณผ์„œ์˜ ์–ด๋–ค ๋ถ€๋ถ„์„ ์ฐธ๊ณ ํ–ˆ๋Š”์ง€ ๋น„๊ตํ•˜๋ฉฐ ํ•œ๊ตญ์‚ฌ๋ด‡:robot_face:์ด ๊ต๊ณผ์„œ๋ฅผ ์ œ๋Œ€๋กœ ์ดํ•ดํ–ˆ๋Š”์ง€ ํ™•์ธํ•ด๋ณด์„ธ์š”!:sparkles:')
api_key = st.text_input(label='OpenAI API ํ‚ค๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”', type='password')
if api_key:
# OpenAI API๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•œ ์ฒ˜๋ฆฌ ๊ณผ์ •์„ ํ•จ์ˆ˜๋กœ ์ •์˜
def initialize_openai_processing(api_key):
#client = OpenAI()
#OpenAI.api_key = api_key
loader = DirectoryLoader('./khistory_data', glob="*.txt", loader_cls=TextLoader)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=100)
texts = text_splitter.split_documents(documents)
persist_directory = 'db'
#embedding = OpenAIEmbeddings()
embedding = OpenAIEmbeddings(api_key=api_key) # API ํ‚ค๋ฅผ ์ƒ์„ฑ์ž์— ์ „๋‹ฌ
vectordb = Chroma.from_documents(
documents=texts,
embedding=embedding,
persist_directory=persist_directory)
vectordb.persist()
vectordb = None
vectordb = Chroma(
persist_directory=persist_directory,
embedding_function=embedding)
retriever = vectordb.as_retriever(search_kwargs={"k": 3})
qa_chain = RetrievalQA.from_chain_type(
#llm=OpenAI(),
llm=OpenAI(api_key=api_key),
chain_type="stuff",
retriever=retriever,
return_source_documents=True)
return embedding, vectordb, qa_chain
# ํ•จ์ˆ˜ ํ˜ธ์ถœ๋กœ ์ดˆ๊ธฐํ™” ๊ณผ์ • ์ˆ˜ํ–‰
embedding, vectordb, qa_chain = initialize_openai_processing(api_key)
# ํ…์ŠคํŠธ ์ž…๋ ฅ
query = st.text_input(
label='ํ•œ๊ตญ์‚ฌ๋ด‡์—๊ฒŒ ์งˆ๋ฌธํ•ด๋ณด์„ธ์š”!',
placeholder='์˜ˆ์‹œ: ๋™ํ•™ ๋†๋ฏผ ์šด๋™์€ ์™œ ์ผ์–ด๋‚ฌ๋‚˜์š”?'
)
# ๋ฒ„ํŠผ ํด๋ฆญ
button = st.button(':robot_face:ํ•œ๊ตญ์‚ฌ๋ด‡์—๊ฒŒ ๋ฌผ์–ด๋ณด๊ธฐ')
if button:
llm_response = qa_chain(query)
#process_llm_response(llm_response)
result = llm_response.get('result')
source_documents1 = llm_response.get('source_documents')[0]
source_documents2 = llm_response.get('source_documents')[1]
source_documents3 = llm_response.get('source_documents')[2]
st.write('๊ฒฐ๊ณผ: ', f'{result}')
st.write('๊ต๊ณผ์„œ ๋‚ด์šฉ 1: 'f'{source_documents1}')
st.write('๊ต๊ณผ์„œ ๋‚ด์šฉ 2: 'f'{source_documents2}')
st.write('๊ต๊ณผ์„œ ๋‚ด์šฉ 3: 'f'{source_documents3}')