Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import base64
|
3 |
+
import zipfile
|
4 |
+
from pathlib import Path
|
5 |
+
import streamlit as st
|
6 |
+
from byaldi import RAGMultiModalModel
|
7 |
+
from openai import OpenAI
|
8 |
+
|
9 |
+
# Function to unzip a folder if it does not exist
|
10 |
+
def unzip_folder_if_not_exist(zip_path, extract_to):
|
11 |
+
if not os.path.exists(extract_to):
|
12 |
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
13 |
+
zip_ref.extractall(extract_to)
|
14 |
+
|
15 |
+
# Example usage
|
16 |
+
zip_path = 'medical_index.zip'
|
17 |
+
extract_to = 'medical_index'
|
18 |
+
unzip_folder_if_not_exist(zip_path, extract_to)
|
19 |
+
|
20 |
+
# Preload the RAGMultiModalModel
|
21 |
+
@st.cache_resource
|
22 |
+
def load_model():
|
23 |
+
return RAGMultiModalModel.from_index("medical_index")
|
24 |
+
|
25 |
+
RAG = load_model()
|
26 |
+
|
27 |
+
# OpenAI API key from environment
|
28 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
29 |
+
client = OpenAI(api_key=api_key)
|
30 |
+
|
31 |
+
# Streamlit UI
|
32 |
+
st.title("Medical Diagnostic Assistant")
|
33 |
+
st.write("Enter a medical query and get diagnostic recommendations along with visual references.")
|
34 |
+
|
35 |
+
# User input
|
36 |
+
query = st.text_input("Query", "What should be the appropriate diagnostic test for peptic ulcer?")
|
37 |
+
|
38 |
+
if st.button("Submit"):
|
39 |
+
if query:
|
40 |
+
# Search using RAG model
|
41 |
+
with st.spinner('Retrieving information...'):
|
42 |
+
try:
|
43 |
+
returned_page = RAG.search(query, k=1)[0].base64
|
44 |
+
|
45 |
+
# Decode and display the retrieved image
|
46 |
+
image_bytes = base64.b64decode(returned_page)
|
47 |
+
filename = 'retrieved_image.jpg'
|
48 |
+
with open(filename, 'wb') as f:
|
49 |
+
f.write(image_bytes)
|
50 |
+
|
51 |
+
# Display image in Streamlit
|
52 |
+
st.image(filename, caption="Reference Image", use_column_width=True)
|
53 |
+
|
54 |
+
# Get model response
|
55 |
+
response = client.chat.completions.create(
|
56 |
+
model="gpt-4o-mini-2024-07-18",
|
57 |
+
messages=[
|
58 |
+
{"role": "system", "content": "You are a helpful assistant. You only answer the question based on the provided image"},
|
59 |
+
{
|
60 |
+
"role": "user",
|
61 |
+
"content": [
|
62 |
+
{"type": "text", "text": query},
|
63 |
+
{
|
64 |
+
"type": "image_url",
|
65 |
+
"image_url": {"url": f"data:image/jpeg;base64,{returned_page}"},
|
66 |
+
},
|
67 |
+
],
|
68 |
+
},
|
69 |
+
],
|
70 |
+
max_tokens=300,
|
71 |
+
)
|
72 |
+
|
73 |
+
# Display the response
|
74 |
+
st.success("Model Response:")
|
75 |
+
st.write(response.choices[0].message.content)
|
76 |
+
except Exception as e:
|
77 |
+
st.error(f"An error occurred: {e}")
|
78 |
+
else:
|
79 |
+
st.warning("Please enter a query.")
|