Spaces:
Sleeping
Sleeping
import os | |
import logging | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, Request, Form | |
from fastapi.responses import JSONResponse, HTMLResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.templating import Jinja2Templates | |
from langchain_community.chat_models import ChatOpenAI # Correct import | |
from langchain_community.embeddings import OpenAIEmbeddings # Correct import | |
from langchain.chains import LLMChain | |
from langchain.prompts import PromptTemplate | |
from langchain_community.vectorstores import FAISS # Correct import | |
# Load environment variables | |
load_dotenv() | |
# Securely retrieve the OpenAI API key | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
import os | |
# Securely retrieve the OpenAI API key from the environment variable | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
if not openai_api_key: | |
raise ValueError("Missing OpenAI API key. Set OPENAI_API_KEY in your environment variables.") | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
if openai_api_key: | |
print("API Key loaded successfully!") | |
else: | |
print("API Key not found.") | |
# Initialize FastAPI app | |
app = FastAPI() | |
templates = Jinja2Templates(directory="templates") | |
# Configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize OpenAI embeddings | |
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key) | |
# Load FAISS index with error handling | |
# Load FAISS index with error handling | |
faiss_index_path = "/Users/sasi/Downloads/Multimodal/faiss_index/index.faiss" # Path on Hugging Face Space | |
try: | |
db = FAISS.load_local(faiss_index_path, embeddings, allow_dangerous_deserialization=True) | |
logger.info(f"FAISS index loaded successfully from {faiss_index_path}.") | |
except Exception as e: | |
logger.error(f"Error loading FAISS index from {faiss_index_path}: {e}") | |
db = None # Avoid crashing | |
# Define the prompt template | |
prompt_template = """ | |
You are an expert in skin cancer research. | |
Answer the question based only on the provided context, which may include text, images, or tables. | |
Context: | |
{context} | |
Question: {question} | |
If the context does not contain sufficient information, say: "Sorry, I don't have much information about it." | |
Answer: | |
""" | |
qa_chain = LLMChain( | |
llm=ChatOpenAI(model="gpt-4", openai_api_key=openai_api_key, max_tokens=1024), | |
prompt=PromptTemplate.from_template(prompt_template), | |
) | |
async def index(request: Request): | |
return templates.TemplateResponse("index.html", {"request": request}) | |
async def get_answer(question: str = Form(...)): | |
if db is None: | |
return JSONResponse({"error": "FAISS database is unavailable."}, status_code=500) | |
try: | |
# Retrieve relevant documents from FAISS | |
relevant_docs = db.similarity_search(question) | |
context = "" | |
relevant_images = [] | |
for d in relevant_docs: | |
doc_type = d.metadata.get('type', 'text') | |
original_content = d.metadata.get('original_content', '') | |
if doc_type == 'text': | |
context += f"[text] {original_content}\n" | |
elif doc_type == 'table': | |
context += f"[table] {original_content}\n" | |
elif doc_type == 'image': | |
context += f"[image] {d.page_content}\n" | |
relevant_images.append(original_content) | |
# Run the question-answering chain | |
result = qa_chain.run({'context': context, 'question': question}) | |
# Handle cases where no relevant images are found | |
return JSONResponse({ | |
"relevant_images": relevant_images[0] if relevant_images else None, | |
"result": result, | |
}) | |
except Exception as e: | |
logger.error(f"Error processing request: {e}") | |
return JSONResponse({"error": "Internal server error."}, status_code=500) | |