File size: 2,150 Bytes
99ef4f7
 
 
fe666f0
 
99ef4f7
 
 
 
 
 
 
 
a7ebef3
edb6021
99ef4f7
 
 
 
e8037fe
99ef4f7
 
 
 
 
f443e49
99ef4f7
 
 
 
5c5dd07
99ef4f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from langserve import add_routes
from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFaceHub
import os
from langchain.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain.schema import StrOutputParser
from langchain.embeddings.huggingface import HuggingFaceEmbeddings

app = FastAPI()

# os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/'

hf_llm = HuggingFaceHub(
    repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
    huggingfacehub_api_token=os.environ["HF_TOKEN"],
    task="text-generation",
    model_kwargs={"temperature":0.01, "max_new_tokens" : 250}
)

embedding_model_id = 'WhereIsAI/UAE-Large-V1'
embeddings_model = HuggingFaceEmbeddings(model_name=embedding_model_id)

faiss_index = FAISS.load_local("langserve_index", embeddings_model, allow_dangerous_deserialization=True)
retriever = faiss_index.as_retriever()
# retriever = faiss_index.as_retriever(search_kwargs={"k": 2})

prompt_template = """\
Use the provided context to answer the user's question. If the context is not relevant or the question is general, respond as you would in a normal conversation. If you don't know the answer, try to provide related information that might help the user understand the topic better. The answer you provided should be concise but comprehensive. Your goal is to help the user learn and understand how to perform the PCR-test for COVID-19 better.

Context:
{context}

Question:
{question}"""

rag_prompt = ChatPromptTemplate.from_template(prompt_template)

entry_point_chain = RunnableParallel(
    {"context": retriever, "question": RunnablePassthrough()}
)
rag_chain = entry_point_chain | rag_prompt | hf_llm | StrOutputParser()

@app.get("/")
async def redirect_root_to_docs():
    return RedirectResponse("/docs")


# Edit this to add the chain you want to add
add_routes(app, rag_chain, path="/rag")

if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)
    # uvicorn.run(app, host="0.0.0.0", port=7860)