Spaces:
Sleeping
Sleeping
import os | |
import json | |
from langchain_groq import ChatGroq | |
from langchain_core.prompts import PromptTemplate | |
from qdrant_client import QdrantClient | |
from langchain.chains import RetrievalQA | |
from langchain_community.vectorstores import Qdrant | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.encoders import jsonable_encoder | |
from fastapi.templating import Jinja2Templates | |
from fastapi import FastAPI, Request, Form, Response | |
from langchain_community.embeddings import SentenceTransformerEmbeddings | |
os.environ["TRANSFORMERS_FORCE_CPU"] = "true" | |
app = FastAPI() | |
templates = Jinja2Templates(directory="templates") | |
config = { | |
'max_new_tokens': 1024, | |
'context_length': 2048, | |
'repetition_penalty': 1.1, | |
'temperature': 0.1, | |
'top_k': 50, | |
'top_p': 0.9, | |
'stream': True, | |
'threads': int(os.cpu_count() / 2) | |
} | |
api_key = os.environ.get("API_KEY") | |
llm = ChatGroq( | |
model="mixtral-8x7b-32768", | |
api_key=api_key, | |
) | |
print("LLM Initialized....") | |
prompt_template = """Use the following pieces of information to answer the user's question. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
Context: {context} | |
Question: {question} | |
Only return the helpful answer below and nothing else. | |
Helpful answer: | |
""" | |
embeddings = SentenceTransformerEmbeddings(model_name="BAAI/bge-large-en") | |
url = os.environ.get("INSTANCE_URL") | |
client = QdrantClient( | |
url=url, prefer_grpc=False | |
) | |
db = Qdrant(client=client, embeddings=embeddings, collection_name="patent_database") | |
prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question']) | |
retriever = db.as_retriever(search_kwargs={"k": 3}) | |
async def read_root(request: Request): | |
return templates.TemplateResponse("index.html", {"request": request}) | |
async def get_response(query: str = Form(...)): | |
chain_type_kwargs = {"prompt": prompt} | |
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, chain_type_kwargs=chain_type_kwargs, verbose=True) | |
response = qa(query) | |
print(response) | |
answer = response['result'] | |
source_document = response['source_documents'][0].page_content | |
doc = response['source_documents'][0].metadata['source'] | |
response_data = jsonable_encoder(json.dumps({"answer": answer, "source_document": source_document, "doc": doc})) | |
res = Response(response_data) | |
return res | |