Spaces:
Sleeping
Sleeping
File size: 4,038 Bytes
2aa4435 d661241 2aa4435 d661241 2aa4435 a836c6a 2aa4435 a836c6a 2aa4435 c3cfbe7 34cfcd2 c3cfbe7 2aa4435 c3cfbe7 2aa4435 c3cfbe7 2aa4435 c3cfbe7 2aa4435 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
import logging
from dotenv import load_dotenv
from fastapi import FastAPI, Request, Form
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
from langchain_community.chat_models import ChatOpenAI # Correct import
from langchain_community.embeddings import OpenAIEmbeddings # Correct import
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores import FAISS # Correct import
# Load environment variables
load_dotenv()
# Securely retrieve the OpenAI API key
openai_api_key = os.getenv("OPENAI_API_KEY")
import os
# Securely retrieve the OpenAI API key from the environment variable
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
raise ValueError("Missing OpenAI API key. Set OPENAI_API_KEY in your environment variables.")
openai_api_key = os.getenv("OPENAI_API_KEY")
if openai_api_key:
print("API Key loaded successfully!")
else:
print("API Key not found.")
# Initialize FastAPI app
app = FastAPI()
templates = Jinja2Templates(directory="templates")
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize OpenAI embeddings
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
# Load FAISS index with error handling
# Load FAISS index with error handling
faiss_index_path = "/Users/sasi/Downloads/Multimodal/faiss_index/index.faiss" # Path on Hugging Face Space
try:
db = FAISS.load_local(faiss_index_path, embeddings, allow_dangerous_deserialization=True)
logger.info(f"FAISS index loaded successfully from {faiss_index_path}.")
except Exception as e:
logger.error(f"Error loading FAISS index from {faiss_index_path}: {e}")
db = None # Avoid crashing
# Define the prompt template
prompt_template = """
You are an expert in skin cancer research.
Answer the question based only on the provided context, which may include text, images, or tables.
Context:
{context}
Question: {question}
If the context does not contain sufficient information, say: "Sorry, I don't have much information about it."
Answer:
"""
qa_chain = LLMChain(
llm=ChatOpenAI(model="gpt-4", openai_api_key=openai_api_key, max_tokens=1024),
prompt=PromptTemplate.from_template(prompt_template),
)
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/get_answer")
async def get_answer(question: str = Form(...)):
if db is None:
return JSONResponse({"error": "FAISS database is unavailable."}, status_code=500)
try:
# Retrieve relevant documents from FAISS
relevant_docs = db.similarity_search(question)
context = ""
relevant_images = []
for d in relevant_docs:
doc_type = d.metadata.get('type', 'text')
original_content = d.metadata.get('original_content', '')
if doc_type == 'text':
context += f"[text] {original_content}\n"
elif doc_type == 'table':
context += f"[table] {original_content}\n"
elif doc_type == 'image':
context += f"[image] {d.page_content}\n"
relevant_images.append(original_content)
# Run the question-answering chain
result = qa_chain.run({'context': context, 'question': question})
# Handle cases where no relevant images are found
return JSONResponse({
"relevant_images": relevant_images[0] if relevant_images else None,
"result": result,
})
except Exception as e:
logger.error(f"Error processing request: {e}")
return JSONResponse({"error": "Internal server error."}, status_code=500)
|