Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -29,54 +29,38 @@ if openai_api_key:
|
|
29 |
else:
|
30 |
print("API Key not found.")
|
31 |
|
32 |
-
|
33 |
app = FastAPI()
|
34 |
templates = Jinja2Templates(directory="templates")
|
35 |
|
36 |
# Configure CORS
|
37 |
app.add_middleware(
|
38 |
CORSMiddleware,
|
39 |
-
allow_origins=["*"],
|
40 |
allow_credentials=True,
|
41 |
-
allow_methods=["*"],
|
42 |
allow_headers=["*"],
|
43 |
)
|
44 |
|
45 |
-
# Configure logging
|
46 |
-
logging.basicConfig(level=logging.INFO)
|
47 |
-
logger = logging.getLogger(__name__)
|
48 |
-
|
49 |
-
# Initialize OpenAI embeddings
|
50 |
-
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
51 |
-
|
52 |
-
# Load FAISS index with error handling
|
53 |
-
app = FastAPI()
|
54 |
-
templates = Jinja2Templates(directory="templates")
|
55 |
|
56 |
-
# Load FAISS index with error handling
|
57 |
db = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
|
58 |
|
59 |
-
|
60 |
-
|
61 |
# Define the prompt template
|
62 |
prompt_template = """
|
63 |
-
You are an expert in skin cancer
|
64 |
-
Answer the question based only on the
|
65 |
-
|
66 |
-
Context:
|
67 |
{context}
|
68 |
|
69 |
Question: {question}
|
70 |
|
71 |
-
|
|
|
72 |
|
73 |
Answer:
|
74 |
"""
|
75 |
|
76 |
-
qa_chain = LLMChain(
|
77 |
-
|
78 |
-
prompt=PromptTemplate.from_template(prompt_template),
|
79 |
-
)
|
80 |
|
81 |
@app.get("/", response_class=HTMLResponse)
|
82 |
async def index(request: Request):
|
@@ -84,35 +68,16 @@ async def index(request: Request):
|
|
84 |
|
85 |
@app.post("/get_answer")
|
86 |
async def get_answer(question: str = Form(...)):
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
if doc_type == 'text':
|
101 |
-
context += f"[text] {original_content}\n"
|
102 |
-
elif doc_type == 'table':
|
103 |
-
context += f"[table] {original_content}\n"
|
104 |
-
elif doc_type == 'image':
|
105 |
-
context += f"[image] {d.page_content}\n"
|
106 |
-
relevant_images.append(original_content)
|
107 |
-
|
108 |
-
# Run the question-answering chain
|
109 |
-
result = qa_chain.run({'context': context, 'question': question})
|
110 |
-
|
111 |
-
# Handle cases where no relevant images are found
|
112 |
-
return JSONResponse({
|
113 |
-
"relevant_images": relevant_images[0] if relevant_images else None,
|
114 |
-
"result": result,
|
115 |
-
})
|
116 |
-
except Exception as e:
|
117 |
-
logger.error(f"Error processing request: {e}")
|
118 |
-
return JSONResponse({"error": "Internal server error."}, status_code=500)
|
|
|
29 |
else:
|
30 |
print("API Key not found.")
|
31 |
|
32 |
+
|
33 |
app = FastAPI()
|
34 |
templates = Jinja2Templates(directory="templates")
|
35 |
|
36 |
# Configure CORS
|
37 |
app.add_middleware(
|
38 |
CORSMiddleware,
|
39 |
+
allow_origins=["*"],
|
40 |
allow_credentials=True,
|
41 |
+
allow_methods=["*"],
|
42 |
allow_headers=["*"],
|
43 |
)
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
|
|
46 |
db = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
|
47 |
|
|
|
|
|
48 |
# Define the prompt template
|
49 |
prompt_template = """
|
50 |
+
You are an expert in skin cancer, etc.
|
51 |
+
Answer the question based only on the following context, which can include text, images, and tables:
|
|
|
|
|
52 |
{context}
|
53 |
|
54 |
Question: {question}
|
55 |
|
56 |
+
Don't answer if you are not sure and decline to answer and say "Sorry, I don't have much information about it."
|
57 |
+
Just return the helpful answer in as much detail as possible.
|
58 |
|
59 |
Answer:
|
60 |
"""
|
61 |
|
62 |
+
qa_chain = LLMChain(llm=ChatOpenAI(model="gpt-4", openai_api_key = openai_api_key, max_tokens=1024),
|
63 |
+
prompt=PromptTemplate.from_template(prompt_template))
|
|
|
|
|
64 |
|
65 |
@app.get("/", response_class=HTMLResponse)
|
66 |
async def index(request: Request):
|
|
|
68 |
|
69 |
@app.post("/get_answer")
|
70 |
async def get_answer(question: str = Form(...)):
|
71 |
+
relevant_docs = db.similarity_search(question)
|
72 |
+
context = ""
|
73 |
+
relevant_images = []
|
74 |
+
for d in relevant_docs:
|
75 |
+
if d.metadata['type'] == 'text':
|
76 |
+
context += '[text]' + d.metadata['original_content']
|
77 |
+
elif d.metadata['type'] == 'table':
|
78 |
+
context += '[table]' + d.metadata['original_content']
|
79 |
+
elif d.metadata['type'] == 'image':
|
80 |
+
context += '[image]' + d.page_content
|
81 |
+
relevant_images.append(d.metadata['original_content'])
|
82 |
+
result = qa_chain.run({'context': context, 'question': question})
|
83 |
+
return JSONResponse({"relevant_images": relevant_images[0], "result": result})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|