from fastapi import FastAPI from fastapi.responses import RedirectResponse from langserve import add_routes from langchain_community.vectorstores import FAISS from langchain_community.llms import HuggingFaceHub import os from langchain.prompts import ChatPromptTemplate from langchain_core.runnables import RunnablePassthrough, RunnableParallel from langchain.schema import StrOutputParser from langchain.embeddings.huggingface import HuggingFaceEmbeddings app = FastAPI() # os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/' hf_llm = HuggingFaceHub( repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", huggingfacehub_api_token=os.environ["HF_TOKEN"], task="text-generation", model_kwargs={"temperature":0.01, "max_new_tokens" : 250} ) embedding_model_id = 'WhereIsAI/UAE-Large-V1' embeddings_model = HuggingFaceEmbeddings(model_name=embedding_model_id) faiss_index = FAISS.load_local("langserve_index", embeddings_model, allow_dangerous_deserialization=True) retriever = faiss_index.as_retriever() # retriever = faiss_index.as_retriever(search_kwargs={"k": 2}) prompt_template = """\ Given the context, your task is to answer the user's question. If the answer is unknown, admit it. If the answer is uncertain, try to provide related information that could help the user better understand the topic. Your response should be concise yet comprehensive, limited to 3 sentences. Your primary goal is to assist the user in learning and understanding how to perform a PCR test for COVID-19 more effectively. If you're unsure about the information, suggest that the user consult their instructor for further clarification. Context: {context} Question: {question}""" rag_prompt = ChatPromptTemplate.from_template(prompt_template) entry_point_chain = RunnableParallel( {"context": retriever, "question": RunnablePassthrough()} ) rag_chain = entry_point_chain | rag_prompt | hf_llm | StrOutputParser() @app.get("/") async def redirect_root_to_docs(): return RedirectResponse("/docs") # Edit this to add the chain you want to add add_routes(app, rag_chain, path="/rag") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000) # uvicorn.run(app, host="0.0.0.0", port=7860)