Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
from langchain_community.llms import HuggingFaceTextGenInference | |
from langchain.prompts import PromptTemplate | |
from langchain.chains import RetrievalQA | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import Chroma | |
from langchain.embeddings import HuggingFaceEmbeddings | |
# Assuming you have the necessary setup for userdata | |
HF_TOKEN = os.environ['MY_HF_TOKEN'] | |
ENDPOINT_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-2-70b-chat-hf" | |
# Setup for the document loader and retriever | |
loader = PyPDFLoader("2023_법정감염병진단_신고기준.pdf") | |
pages = loader.load_and_split() | |
disease_pages = pages[54:72] | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, add_start_index=True) | |
splits = text_splitter.split_documents(disease_pages) | |
modelPath = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | |
embeddings = HuggingFaceEmbeddings(model_name=modelPath, model_kwargs={'device':'cpu'}, encode_kwargs={'normalize_embeddings': False}) | |
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings) | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) | |
# Setup for the language model | |
llm = HuggingFaceTextGenInference( | |
inference_server_url=ENDPOINT_URL, | |
max_new_tokens=1024, | |
top_k=50, | |
temperature=0.1, | |
repetition_penalty=1.03, | |
server_kwargs={ | |
"headers": { | |
"Authorization": f"Bearer {HF_TOKEN}", | |
"Content-Type": "application/json", | |
} | |
}, | |
) | |
# Template for the question-answering | |
template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. | |
{context} | |
Question: {question} | |
Helpful Answer:""" | |
QA_CHAIN_PROMPT = PromptTemplate.from_template(template) | |
def predict(message, history): | |
question = message | |
context = "" # Add context if | |
# Create a RetrievalQA instance | |
chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True, | |
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT} | |
) | |
# Execute the query | |
result = chain({"query": question}) | |
# Stream the response | |
partial_message = "" | |
for chunk in result['result']: | |
partial_message += chunk | |
yield partial_message | |
gr.Interface( | |
fn=predict, | |
inputs=[gr.inputs.Textbox(label="Your Message"), gr.inputs.State(label="History")], | |
outputs='text', | |
live=True, | |
allow_flagging="never" | |
).launch() |