File size: 2,748 Bytes
a54ee03
db44f18
a54ee03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27fd964
 
 
 
 
 
 
 
a54ee03
27fd964
 
a54ee03
27fd964
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import gradio as gr
from langchain_community.llms import HuggingFaceTextGenInference
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings

# Assuming you have the necessary setup for userdata
HF_TOKEN = os.environ['MY_HF_TOKEN']
ENDPOINT_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-2-70b-chat-hf"

# Setup for the document loader and retriever
loader = PyPDFLoader("2023_법정감염병진단_신고기준.pdf")
pages = loader.load_and_split()
disease_pages = pages[54:72]

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, add_start_index=True)
splits = text_splitter.split_documents(disease_pages)

modelPath = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
embeddings = HuggingFaceEmbeddings(model_name=modelPath, model_kwargs={'device':'cpu'}, encode_kwargs={'normalize_embeddings': False})
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
retriever = vectorstore.as_retriever(search_kwargs={"k": 4})

# Setup for the language model
llm = HuggingFaceTextGenInference(
    inference_server_url=ENDPOINT_URL,
    max_new_tokens=1024,
    top_k=50,
    temperature=0.1,
    repetition_penalty=1.03,
    server_kwargs={
        "headers": {
            "Authorization": f"Bearer {HF_TOKEN}",
            "Content-Type": "application/json",
        }
    },
)

# Template for the question-answering
template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible.
{context}
Question: {question}
Helpful Answer:"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)

def predict(message, history):
    question = message
    context = ""  # Add context if
    
    # Create a RetrievalQA instance
    chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=retriever,
        return_source_documents=True,
        chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
    )

    # Execute the query
    result = chain({"query": question})

    # Stream the response
    partial_message = ""
    for chunk in result['result']:
        partial_message += chunk
        yield partial_message
        
gr.Interface(
    fn=predict, 
    inputs=[gr.inputs.Textbox(label="Your Message"), gr.inputs.State(label="History")],
    outputs='text',
    live=True,
    allow_flagging="never"
).launch()