import os import gradio as gr from langchain_community.llms import HuggingFaceTextGenInference from langchain.prompts import PromptTemplate from langchain.chains import RetrievalQA from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import Chroma from langchain.embeddings import HuggingFaceEmbeddings # Assuming you have the necessary setup for userdata HF_TOKEN = os.environ['MY_HF_TOKEN'] ENDPOINT_URL = "https://api-inference.huggingface.co/models/meta-llama/Llama-2-70b-chat-hf" # Setup for the document loader and retriever loader = PyPDFLoader("2023_법정감염병진단_신고기준.pdf") pages = loader.load_and_split() disease_pages = pages[54:72] text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, add_start_index=True) splits = text_splitter.split_documents(disease_pages) modelPath = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" embeddings = HuggingFaceEmbeddings(model_name=modelPath, model_kwargs={'device':'cpu'}, encode_kwargs={'normalize_embeddings': False}) vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings) retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) # Setup for the language model llm = HuggingFaceTextGenInference( inference_server_url=ENDPOINT_URL, max_new_tokens=1024, top_k=50, temperature=0.1, repetition_penalty=1.03, server_kwargs={ "headers": { "Authorization": f"Bearer {HF_TOKEN}", "Content-Type": "application/json", } }, ) # Template for the question-answering template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. Keep the answer as concise as possible. {context} Question: {question} Helpful Answer:""" QA_CHAIN_PROMPT = PromptTemplate.from_template(template) def predict(message): question = message context = "" # Add context if # Create a RetrievalQA instance chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, chain_type_kwargs={"prompt": QA_CHAIN_PROMPT} ) # Execute the query result = chain({"query": question}) # Stream the response partial_message = "" for chunk in result['result']: partial_message += chunk yield partial_message iface = gr.Interface( fn=predict, inputs=gr.Textbox(placeholder="Chat with me!", label="Your Message"), outputs=gr.Text(label="Response"), live=False, title="Infectious-Disease-Diagnosis-Chatbot", description="This is the demo for Gradio UI consuming TGI endpoint with LLaMA 7B-Chat model.", examples=[["발열과 구토 증상이 있는데, 어떤 감염병이야?"]], theme="default" # You can choose a theme that fits your UI preference ) iface.launch()