pcr_rag_v2 / app /server.py
leofan's picture
Upload server.py
a7ebef3 verified
raw
history blame
1.78 kB
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from langserve import add_routes
from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFaceHub
import os
from langchain.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain.schema import StrOutputParser
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
app = FastAPI()
# os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/'
hf_llm = HuggingFaceHub(
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
huggingfacehub_api_token=os.environ["HF_TOKEN"],
task="text-generation",
model_kwargs={"temperature":0.01, "max_new_tokens" : 250}
)
embedding_model_id = 'WhereIsAI/UAE-Large-V1'
embeddings_model = HuggingFaceEmbeddings(model_name=embedding_model_id)
faiss_index = FAISS.load_local("../langserve_index", embeddings_model)
retriever = faiss_index.as_retriever()
# retriever = faiss_index.as_retriever(search_kwargs={"k": 2})
prompt_template = """\
Use the provided context to answer the user's question. If you don't know the answer, say you don't know.
Context:
{context}
Question:
{question}"""
rag_prompt = ChatPromptTemplate.from_template(prompt_template)
entry_point_chain = RunnableParallel(
{"context": retriever, "question": RunnablePassthrough()}
)
rag_chain = entry_point_chain | rag_prompt | hf_llm | StrOutputParser()
@app.get("/")
async def redirect_root_to_docs():
return RedirectResponse("/docs")
# Edit this to add the chain you want to add
add_routes(app, rag_chain, path="/rag")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
# uvicorn.run(app, host="0.0.0.0", port=7860)