import os import logging from dotenv import load_dotenv from fastapi import FastAPI, Request, Form from fastapi.responses import JSONResponse, HTMLResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.templating import Jinja2Templates from langchain.chat_models import ChatOpenAI from langchain.embeddings import OpenAIEmbeddings from langchain.chains import LLMChain from langchain.prompts import PromptTemplate from langchain.vectorstores import FAISS from langchain_openai import ChatOpenAI from langchain.prompts import PromptTemplate from langchain.runnables import RunnableSequence # Load environment variables load_dotenv() # Securely retrieve the OpenAI API key openai_api_key = os.getenv("OPENAI_API_KEY") import os # Securely retrieve the OpenAI API key from the environment variable openai_api_key = os.getenv("OPENAI_API_KEY") if not openai_api_key: raise ValueError("Missing OpenAI API key. Set OPENAI_API_KEY in your environment variables.") openai_api_key = os.getenv("OPENAI_API_KEY") if openai_api_key: print("API Key loaded successfully!") else: print("API Key not found.") # Initialize FastAPI app app = FastAPI() templates = Jinja2Templates(directory="templates") # Configure CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize OpenAI embeddings embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key) # Load FAISS index with error handling # Load FAISS index with error handling faiss_index_path = "/app/faiss_files/index.faiss" # Path on Hugging Face Space try: db = FAISS.load_local(faiss_index_path, embeddings, allow_dangerous_deserialization=True) logger.info(f"FAISS index loaded successfully from {faiss_index_path}.") except Exception as e: logger.error(f"Error loading FAISS index from {faiss_index_path}: {e}") db = None # Avoid crashing # Define the prompt template prompt_template = """ You are an expert in skin cancer research. Answer the question based only on the provided context, which may include text, images, or tables. Context: {context} Question: {question} If the context does not contain sufficient information, say: "Sorry, I don't have much information about it." Answer: """ qa_chain = LLMChain( llm=ChatOpenAI(model="gpt-4", openai_api_key=openai_api_key, max_tokens=1024), prompt=PromptTemplate.from_template(prompt_template), ) @app.get("/", response_class=HTMLResponse) async def index(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/get_answer") async def get_answer(question: str = Form(...)): if db is None: return JSONResponse({"error": "FAISS database is unavailable."}, status_code=500) try: # Retrieve relevant documents from FAISS relevant_docs = db.similarity_search(question) context = "" relevant_images = [] for d in relevant_docs: doc_type = d.metadata.get('type', 'text') original_content = d.metadata.get('original_content', '') if doc_type == 'text': context += f"[text] {original_content}\n" elif doc_type == 'table': context += f"[table] {original_content}\n" elif doc_type == 'image': context += f"[image] {d.page_content}\n" relevant_images.append(original_content) # Run the question-answering chain result = qa_chain.run({'context': context, 'question': question}) # Handle cases where no relevant images are found return JSONResponse({ "relevant_images": relevant_images[0] if relevant_images else None, "result": result, }) except Exception as e: logger.error(f"Error processing request: {e}") return JSONResponse({"error": "Internal server error."}, status_code=500)