Spaces:
Running
Running
File size: 2,240 Bytes
99ef4f7 fe666f0 99ef4f7 a7ebef3 edb6021 99ef4f7 d291d49 99ef4f7 e8037fe 99ef4f7 f443e49 99ef4f7 4e84bce 99ef4f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from langserve import add_routes
from langchain_community.vectorstores import FAISS
from langchain_community.llms import HuggingFaceHub
import os
from langchain.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain.schema import StrOutputParser
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
app = FastAPI()
# os.environ['TRANSFORMERS_CACHE'] = '/blabla/cache/'
hf_llm = HuggingFaceHub(
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
huggingfacehub_api_token=os.environ["HF_TOKEN"],
task="text-generation",
model_kwargs={"temperature":0.01, "max_new_tokens" : 250}
)
embedding_model_id = 'WhereIsAI/UAE-Large-V1'
embeddings_model = HuggingFaceEmbeddings(model_name=embedding_model_id)
faiss_index = FAISS.load_local("langserve_index", embeddings_model, allow_dangerous_deserialization=True)
retriever = faiss_index.as_retriever()
# retriever = faiss_index.as_retriever(search_kwargs={"k": 2})
prompt_template = """\
Given the context, your task is to answer the user's question. If the answer is unknown, admit it. If the answer is uncertain, try to provide related information that could help the user better understand the topic. Your response should be concise yet comprehensive, limited to 3 sentences. Your primary goal is to assist the user in learning and understanding how to perform a PCR test for COVID-19 more effectively. If you're unsure about the information, suggest that the user consult their instructor for further clarification.
Context:
{context}
Question:
{question}"""
rag_prompt = ChatPromptTemplate.from_template(prompt_template)
entry_point_chain = RunnableParallel(
{"context": retriever, "question": RunnablePassthrough()}
)
rag_chain = entry_point_chain | rag_prompt | hf_llm | StrOutputParser()
@app.get("/")
async def redirect_root_to_docs():
return RedirectResponse("/docs")
# Edit this to add the chain you want to add
add_routes(app, rag_chain, path="/rag")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
# uvicorn.run(app, host="0.0.0.0", port=7860)
|