File size: 4,038 Bytes
2aa4435
 
 
 
 
 
 
d661241
 
2aa4435
 
d661241
2aa4435
 
 
 
 
a836c6a
 
 
 
 
 
2aa4435
 
a836c6a
 
 
 
 
2aa4435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3cfbe7
34cfcd2
c3cfbe7
2aa4435
c3cfbe7
 
2aa4435
c3cfbe7
2aa4435
 
c3cfbe7
2aa4435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import logging
from dotenv import load_dotenv
from fastapi import FastAPI, Request, Form
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
from langchain_community.chat_models import ChatOpenAI  # Correct import
from langchain_community.embeddings import OpenAIEmbeddings  # Correct import
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores import FAISS  # Correct import

# Load environment variables
load_dotenv()

# Securely retrieve the OpenAI API key
openai_api_key = os.getenv("OPENAI_API_KEY")
import os

# Securely retrieve the OpenAI API key from the environment variable
openai_api_key = os.getenv("OPENAI_API_KEY")

if not openai_api_key:
    raise ValueError("Missing OpenAI API key. Set OPENAI_API_KEY in your environment variables.")
openai_api_key = os.getenv("OPENAI_API_KEY")
if openai_api_key:
    print("API Key loaded successfully!")
else:
    print("API Key not found.")

# Initialize FastAPI app
app = FastAPI()
templates = Jinja2Templates(directory="templates")

# Configure CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize OpenAI embeddings
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)

# Load FAISS index with error handling
# Load FAISS index with error handling
faiss_index_path = "/Users/sasi/Downloads/Multimodal/faiss_index/index.faiss"  # Path on Hugging Face Space

try:
    db = FAISS.load_local(faiss_index_path, embeddings, allow_dangerous_deserialization=True)
    logger.info(f"FAISS index loaded successfully from {faiss_index_path}.")
except Exception as e:
    logger.error(f"Error loading FAISS index from {faiss_index_path}: {e}")
    db = None  # Avoid crashing


# Define the prompt template
prompt_template = """
You are an expert in skin cancer research.
Answer the question based only on the provided context, which may include text, images, or tables.

Context:
{context}

Question: {question}

If the context does not contain sufficient information, say: "Sorry, I don't have much information about it."

Answer:
"""

qa_chain = LLMChain(
    llm=ChatOpenAI(model="gpt-4", openai_api_key=openai_api_key, max_tokens=1024),
    prompt=PromptTemplate.from_template(prompt_template),
)

@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})

@app.post("/get_answer")
async def get_answer(question: str = Form(...)):
    if db is None:
        return JSONResponse({"error": "FAISS database is unavailable."}, status_code=500)

    try:
        # Retrieve relevant documents from FAISS
        relevant_docs = db.similarity_search(question)
        context = ""
        relevant_images = []

        for d in relevant_docs:
            doc_type = d.metadata.get('type', 'text')
            original_content = d.metadata.get('original_content', '')

            if doc_type == 'text':
                context += f"[text] {original_content}\n"
            elif doc_type == 'table':
                context += f"[table] {original_content}\n"
            elif doc_type == 'image':
                context += f"[image] {d.page_content}\n"
                relevant_images.append(original_content)

        # Run the question-answering chain
        result = qa_chain.run({'context': context, 'question': question})

        # Handle cases where no relevant images are found
        return JSONResponse({
            "relevant_images": relevant_images[0] if relevant_images else None,
            "result": result,
        })
    except Exception as e:
        logger.error(f"Error processing request: {e}")
        return JSONResponse({"error": "Internal server error."}, status_code=500)