File size: 2,240 Bytes
99ef4f7
 
 
fe666f0
 
99ef4f7
 
 
 
 
 
 
 
a7ebef3
edb6021
99ef4f7
d291d49
99ef4f7
 
e8037fe
99ef4f7
 
 
 
 
f443e49
99ef4f7
 
 
 
4e84bce
99ef4f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from langserve import add_routes
from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFaceHub
import os
from langchain.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain.schema import StrOutputParser
from langchain.embeddings.huggingface import HuggingFaceEmbeddings

app = FastAPI()

# os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/'

hf_llm = HuggingFaceHub(
    repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
    huggingfacehub_api_token=os.environ["HF_TOKEN"],
    task="text-generation",
    model_kwargs={"temperature":0.01, "max_new_tokens" : 250}
)

embedding_model_id = 'WhereIsAI/UAE-Large-V1'
embeddings_model = HuggingFaceEmbeddings(model_name=embedding_model_id)

faiss_index = FAISS.load_local("langserve_index", embeddings_model, allow_dangerous_deserialization=True)
retriever = faiss_index.as_retriever()
# retriever = faiss_index.as_retriever(search_kwargs={"k": 2})

prompt_template = """\
Given the context, your task is to answer the user's question. If the answer is unknown, admit it. If the answer is uncertain, try to provide related information that could help the user better understand the topic. Your response should be concise yet comprehensive, limited to 3 sentences. Your primary goal is to assist the user in learning and understanding how to perform a PCR test for COVID-19 more effectively. If you're unsure about the information, suggest that the user consult their instructor for further clarification.

Context:
{context}

Question:
{question}"""

rag_prompt = ChatPromptTemplate.from_template(prompt_template)

entry_point_chain = RunnableParallel(
    {"context": retriever, "question": RunnablePassthrough()}
)
rag_chain = entry_point_chain | rag_prompt | hf_llm | StrOutputParser()

@app.get("/")
async def redirect_root_to_docs():
    return RedirectResponse("/docs")


# Edit this to add the chain you want to add
add_routes(app, rag_chain, path="/rag")

if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)
    # uvicorn.run(app, host="0.0.0.0", port=7860)