Hyeonseo's picture
Update app.py
27fd964 verified
raw
history blame
2.75 kB
import os
import gradio as gr
from langchain_community.llms import HuggingFaceTextGenInference
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
# Assuming you have the necessary setup for userdata
HF_TOKEN = os.environ['MY_HF_TOKEN']
ENDPOINT_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-2-70b-chat-hf"
# Setup for the document loader and retriever
loader = PyPDFLoader("2023_법정감염병진단_신고기준.pdf")
pages = loader.load_and_split()
disease_pages = pages[54:72]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, add_start_index=True)
splits = text_splitter.split_documents(disease_pages)
modelPath = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
embeddings = HuggingFaceEmbeddings(model_name=modelPath, model_kwargs={'device':'cpu'}, encode_kwargs={'normalize_embeddings': False})
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
# Setup for the language model
llm = HuggingFaceTextGenInference(
inference_server_url=ENDPOINT_URL,
max_new_tokens=1024,
top_k=50,
temperature=0.1,
repetition_penalty=1.03,
server_kwargs={
"headers": {
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json",
}
},
)
# Template for the question-answering
template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible.
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)
def predict(message, history):
question = message
context = "" # Add context if
# Create a RetrievalQA instance
chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
)
# Execute the query
result = chain({"query": question})
# Stream the response
partial_message = ""
for chunk in result['result']:
partial_message += chunk
yield partial_message
gr.Interface(
fn=predict,
inputs=[gr.inputs.Textbox(label="Your Message"), gr.inputs.State(label="History")],
outputs='text',
live=True,
allow_flagging="never"
).launch()