Spaces:
Runtime error
Runtime error
import logging | |
import os | |
from fastapi import FastAPI, UploadFile, File, HTTPException | |
from fastapi.responses import HTMLResponse, JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from langchain.vectorstores import Chroma | |
from langchain.llms import OpenAI | |
from langchain.vectorstores.cassandra import Cassandra | |
from langchain.indexes.vectorstore import VectorStoreIndexWrapper | |
from langchain.chains import RetrievalQA | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.vectorstores.base import VectorStoreRetriever | |
from langchain.text_splitter import CharacterTextSplitter | |
from azure.core.credentials import AzureKeyCredential | |
from azure.ai.inference import EmbeddingsClient | |
import cassio | |
from pydantic import BaseModel | |
import shutil | |
from config import settings | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
logging.basicConfig(level=logging.ERROR) | |
logger = logging.getLogger(__name__) | |
HUGGINGFACE_API_KEY = settings.huggingface_key | |
ASTRA_DB_APPLICATION_TOKEN = settings.astra_db_application_token | |
ASTRA_DB_ID = settings.astra_db_id | |
OPENAI_API_KEY = settings.openai_api_key | |
GITHUB_TOKEN = settings.github_token | |
AZURE_OPENAI_ENDPOINT = settings.azure_openai_endpoint | |
AZURE_OPENAI_MODELNAME = settings.azure_openai_modelname | |
AZURE_OPENAI_EMBEDMODELNAME = settings.azure_openai_embedmodelname | |
UPLOAD_FOLDER = '/uploads' | |
conversation_retrieval_chain = None | |
chat_history = [] | |
llm = None | |
embedding = None | |
cassio.init(token=ASTRA_DB_APPLICATION_TOKEN, database_id=ASTRA_DB_ID) | |
class MessageRequest(BaseModel): | |
userMessage: str | |
class AzureOpenAIEmbeddings: | |
def __init__(self, client): | |
self.client = client | |
self.model_name = AZURE_OPENAI_EMBEDMODELNAME # Store model name | |
def embed_query(self, text: str): | |
"""Embed a query.""" | |
response = self.client.embed( | |
input=[text], | |
model=self.model_name | |
) | |
return response.data[0].embedding | |
def embed_documents(self, texts: list): | |
"""Embed a list of documents.""" | |
response = self.client.embed( | |
input=texts, | |
model=self.model_name | |
) | |
return [item.embedding for item in response.data] | |
def init_llm(): | |
global llm, embedding | |
llm = OpenAI( | |
base_url=AZURE_OPENAI_ENDPOINT, | |
api_key=GITHUB_TOKEN, | |
model=AZURE_OPENAI_MODELNAME | |
) | |
embedding = EmbeddingsClient( | |
endpoint=AZURE_OPENAI_ENDPOINT, | |
credential=AzureKeyCredential(GITHUB_TOKEN), | |
model=AZURE_OPENAI_EMBEDMODELNAME | |
) | |
def process_document(document_path): | |
init_llm() | |
global conversation_retrieval_chain | |
loader = PyPDFLoader(document_path) | |
documents = loader.load() | |
text_splitter = CharacterTextSplitter( | |
chunk_size=800, | |
chunk_overlap=200, | |
) | |
raw_text = "".join([doc.page_content for doc in documents]) | |
texts = text_splitter.split_text(raw_text) | |
custom_embedding = AzureOpenAIEmbeddings(embedding) | |
astra_vector_store = Cassandra( | |
embedding=custom_embedding, | |
table_name="qa_mini_demo", | |
session=None, | |
keyspace=None, | |
) | |
astra_vector_store.add_texts(texts[:500]) | |
retriever = VectorStoreRetriever( | |
vectorstore=astra_vector_store, search_type="mmr", search_kwargs={'k': 1, 'lambda_mult': 0.25} | |
) | |
conversation_retrieval_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=False, | |
input_key="question" | |
) | |
def process_prompt(prompt): | |
init_llm() | |
global chat_history | |
global conversation_retrieval_chain | |
output = conversation_retrieval_chain({"question": prompt+"you should only give answer to the question ,do not give any other information", "chat_history": chat_history}) | |
answer = output["result"] | |
chat_history.append((prompt, answer)) | |
return answer | |
# Define the route for the index page | |
async def index(): | |
return """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>File Upload</title> | |
</head> | |
<body> | |
<h2>Upload a PDF Document</h2> | |
<form action="/process-document" method="post" enctype="multipart/form-data"> | |
<input type="file" name="file" required> | |
<button type="submit">Upload</button> | |
</form> | |
<h2>Chat with the Bot</h2> | |
<form id="chat-form"> | |
<input type="text" id="userMessage" placeholder="Type your message here..." required> | |
<button type="submit">Send | |
</button> | |
</form> | |
<div id="chat-response"></div> | |
<script> | |
document.getElementById("chat-form").onsubmit = async (e) => { | |
e.preventDefault(); | |
const userMessage = document.getElementById("userMessage").value; | |
const response = await fetch("/process-message", { | |
method: "POST", | |
headers: { | |
"Content-Type": "application/json", | |
}, | |
body: JSON.stringify({ userMessage }), | |
}); | |
const data = await response.json(); | |
document.getElementById("chat-response").innerText = data.botResponse || data.error; | |
document.getElementById("userMessage").value = ""; // Clear input | |
}; | |
</script> | |
</body> | |
</html> | |
""" | |
# Define the route for processing messages | |
async def process_message_route(message: MessageRequest): | |
try: | |
user_message = message.userMessage # Extract the user's message from the request | |
if not user_message: | |
raise HTTPException(status_code=400, detail="User message is required.") | |
bot_response = process_prompt(user_message) # Process the user's message | |
bot_response = bot_response.split("<|fim_suffix|>")[0].strip() | |
# Remove everything after <|fim_suffix|> and trim | |
bot_response = bot_response.split("\n")[0].strip() | |
# Return the bot's response as JSON | |
return JSONResponse(content={"botResponse": bot_response}) | |
except Exception as e: | |
logger.error(f"Error processing message: {e}") | |
raise HTTPException(status_code=500, detail="An error occurred while processing the message.") | |
# Define the route for processing documents | |
async def process_document_route(file: UploadFile = File(...)): | |
try: | |
# Check if a file was uploaded | |
if not file: | |
raise HTTPException(status_code=400, detail="File not uploaded.") | |
file_path = f"uploads/{file.filename}" # Define the path where the file will be saved | |
os.makedirs("uploads", exist_ok=True) # Create the uploads directory if it doesn't exist | |
with open(file_path, "wb") as buffer: | |
shutil.copyfileobj(file.file, buffer) # Save the file | |
process_document(file_path) # Process the document | |
# Return a success message as JSON | |
return JSONResponse(content={ | |
"botResponse": "Thank you for providing your PDF document. I have analyzed it, so now you can ask me any questions regarding it!" | |
}) | |
except Exception as e: | |
logger.error(f"Error processing document: {e}") | |
raise HTTPException(status_code=500, detail="An error occurred while processing the document.") | |