File size: 4,060 Bytes
2aa4435
 
 
 
 
 
 
82c5dd6
 
2aa4435
 
82c5dd6
c23fbde
 
 
2aa4435
 
 
 
 
a836c6a
 
 
 
 
 
2aa4435
 
a836c6a
 
 
 
 
2aa4435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3cfbe7
c23fbde
c3cfbe7
2aa4435
c3cfbe7
 
2aa4435
c3cfbe7
2aa4435
 
c3cfbe7
2aa4435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import logging
from dotenv import load_dotenv
from fastapi import FastAPI, Request, Form
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.vectorstores import FAISS
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.runnables import RunnableSequence

# Load environment variables
load_dotenv()

# Securely retrieve the OpenAI API key
openai_api_key = os.getenv("OPENAI_API_KEY")
import os

# Securely retrieve the OpenAI API key from the environment variable
openai_api_key = os.getenv("OPENAI_API_KEY")

if not openai_api_key:
    raise ValueError("Missing OpenAI API key. Set OPENAI_API_KEY in your environment variables.")
openai_api_key = os.getenv("OPENAI_API_KEY")
if openai_api_key:
    print("API Key loaded successfully!")
else:
    print("API Key not found.")

# Initialize FastAPI app
app = FastAPI()
templates = Jinja2Templates(directory="templates")

# Configure CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize OpenAI embeddings
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)

# Load FAISS index with error handling
# Load FAISS index with error handling
faiss_index_path = "/app/faiss_files/index.faiss"  # Path on Hugging Face Space

try:
    db = FAISS.load_local(faiss_index_path, embeddings, allow_dangerous_deserialization=True)
    logger.info(f"FAISS index loaded successfully from {faiss_index_path}.")
except Exception as e:
    logger.error(f"Error loading FAISS index from {faiss_index_path}: {e}")
    db = None  # Avoid crashing


# Define the prompt template
prompt_template = """
You are an expert in skin cancer research.
Answer the question based only on the provided context, which may include text, images, or tables.

Context:
{context}

Question: {question}

If the context does not contain sufficient information, say: "Sorry, I don't have much information about it."

Answer:
"""

qa_chain = LLMChain(
    llm=ChatOpenAI(model="gpt-4", openai_api_key=openai_api_key, max_tokens=1024),
    prompt=PromptTemplate.from_template(prompt_template),
)

@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})

@app.post("/get_answer")
async def get_answer(question: str = Form(...)):
    if db is None:
        return JSONResponse({"error": "FAISS database is unavailable."}, status_code=500)

    try:
        # Retrieve relevant documents from FAISS
        relevant_docs = db.similarity_search(question)
        context = ""
        relevant_images = []

        for d in relevant_docs:
            doc_type = d.metadata.get('type', 'text')
            original_content = d.metadata.get('original_content', '')

            if doc_type == 'text':
                context += f"[text] {original_content}\n"
            elif doc_type == 'table':
                context += f"[table] {original_content}\n"
            elif doc_type == 'image':
                context += f"[image] {d.page_content}\n"
                relevant_images.append(original_content)

        # Run the question-answering chain
        result = qa_chain.run({'context': context, 'question': question})

        # Handle cases where no relevant images are found
        return JSONResponse({
            "relevant_images": relevant_images[0] if relevant_images else None,
            "result": result,
        })
    except Exception as e:
        logger.error(f"Error processing request: {e}")
        return JSONResponse({"error": "Internal server error."}, status_code=500)