File size: 1,779 Bytes
99ef4f7
 
 
fe666f0
 
99ef4f7
 
 
 
 
 
 
 
a7ebef3
edb6021
99ef4f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from langserve import add_routes
from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFaceHub
import os
from langchain.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain.schema import StrOutputParser
from langchain.embeddings.huggingface import HuggingFaceEmbeddings

app = FastAPI()

# os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/'

hf_llm = HuggingFaceHub(
    repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
    huggingfacehub_api_token=os.environ["HF_TOKEN"],
    task="text-generation",
    model_kwargs={"temperature":0.01, "max_new_tokens" : 250}
)

embedding_model_id = 'WhereIsAI/UAE-Large-V1'
embeddings_model = HuggingFaceEmbeddings(model_name=embedding_model_id)

faiss_index = FAISS.load_local("../langserve_index", embeddings_model)
retriever = faiss_index.as_retriever()
# retriever = faiss_index.as_retriever(search_kwargs={"k": 2})

prompt_template = """\
Use the provided context to answer the user's question. If you don't know the answer, say you don't know.

Context:
{context}

Question:
{question}"""

rag_prompt = ChatPromptTemplate.from_template(prompt_template)

entry_point_chain = RunnableParallel(
    {"context": retriever, "question": RunnablePassthrough()}
)
rag_chain = entry_point_chain | rag_prompt | hf_llm | StrOutputParser()

@app.get("/")
async def redirect_root_to_docs():
    return RedirectResponse("/docs")


# Edit this to add the chain you want to add
add_routes(app, rag_chain, path="/rag")

if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)
    # uvicorn.run(app, host="0.0.0.0", port=7860)