Spaces:
Running
Running
from fastapi import FastAPI | |
from fastapi.responses import RedirectResponse | |
from langserve import add_routes | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.llms import HuggingFaceHub | |
import os | |
from langchain.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnablePassthrough, RunnableParallel | |
from langchain.schema import StrOutputParser | |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
app = FastAPI() | |
# os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/' | |
hf_llm = HuggingFaceHub( | |
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
huggingfacehub_api_token=os.environ["HF_TOKEN"], | |
task="text-generation", | |
model_kwargs={"temperature":0.01, "max_new_tokens" : 250} | |
) | |
embedding_model_id = 'WhereIsAI/UAE-Large-V1' | |
embeddings_model = HuggingFaceEmbeddings(model_name=embedding_model_id) | |
faiss_index = FAISS.load_local("langserve_index", embeddings_model, allow_dangerous_deserialization=True) | |
retriever = faiss_index.as_retriever() | |
# retriever = faiss_index.as_retriever(search_kwargs={"k": 2}) | |
prompt_template = """\ | |
Use the provided context to answer the user's question. If the context is not relevant or the question is general, respond as you would in a normal conversation. If you don't know the answer, try to provide related information that might help the user understand the topic better. The answer you provided should be concise but comprehensive. Your goal is to help the user learn and understand how to perform the PCR-test for COVID-19 better. | |
Context: | |
{context} | |
Question: | |
{question}""" | |
rag_prompt = ChatPromptTemplate.from_template(prompt_template) | |
entry_point_chain = RunnableParallel( | |
{"context": retriever, "question": RunnablePassthrough()} | |
) | |
rag_chain = entry_point_chain | rag_prompt | hf_llm | StrOutputParser() | |
async def redirect_root_to_docs(): | |
return RedirectResponse("/docs") | |
# Edit this to add the chain you want to add | |
add_routes(app, rag_chain, path="/rag") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |
# uvicorn.run(app, host="0.0.0.0", port=7860) | |