Spaces:
Sleeping
Sleeping
File size: 2,150 Bytes
99ef4f7 fe666f0 99ef4f7 a7ebef3 edb6021 99ef4f7 e8037fe 99ef4f7 f443e49 99ef4f7 5c5dd07 99ef4f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from langserve import add_routes
from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFaceHub
import os
from langchain.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain.schema import StrOutputParser
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
app = FastAPI()
# os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/'
hf_llm = HuggingFaceHub(
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
huggingfacehub_api_token=os.environ["HF_TOKEN"],
task="text-generation",
model_kwargs={"temperature":0.01, "max_new_tokens" : 250}
)
embedding_model_id = 'WhereIsAI/UAE-Large-V1'
embeddings_model = HuggingFaceEmbeddings(model_name=embedding_model_id)
faiss_index = FAISS.load_local("langserve_index", embeddings_model, allow_dangerous_deserialization=True)
retriever = faiss_index.as_retriever()
# retriever = faiss_index.as_retriever(search_kwargs={"k": 2})
prompt_template = """\
Use the provided context to answer the user's question. If the context is not relevant or the question is general, respond as you would in a normal conversation. If you don't know the answer, try to provide related information that might help the user understand the topic better. The answer you provided should be concise but comprehensive. Your goal is to help the user learn and understand how to perform the PCR-test for COVID-19 better.
Context:
{context}
Question:
{question}"""
rag_prompt = ChatPromptTemplate.from_template(prompt_template)
entry_point_chain = RunnableParallel(
{"context": retriever, "question": RunnablePassthrough()}
)
rag_chain = entry_point_chain | rag_prompt | hf_llm | StrOutputParser()
@app.get("/")
async def redirect_root_to_docs():
return RedirectResponse("/docs")
# Edit this to add the chain you want to add
add_routes(app, rag_chain, path="/rag")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
# uvicorn.run(app, host="0.0.0.0", port=7860)
|