File size: 5,927 Bytes
0b48895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b43b8b4
 
 
 
 
 
 
 
 
 
 
9231abf
b43b8b4
 
 
 
 
 
 
 
 
 
 
0b48895
 
 
 
 
b43b8b4
 
 
 
 
 
 
0b48895
 
 
 
 
 
 
 
 
b43b8b4
 
 
0b48895
b43b8b4
2bdeb9c
b43b8b4
 
 
 
0b48895
b43b8b4
 
 
 
 
0b48895
b43b8b4
 
 
0b48895
 
bef58e1
9231abf
 
 
 
b43b8b4
 
 
2bdeb9c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# import os
# import base64
# import zipfile
# from pathlib import Path
# import streamlit as st
# from byaldi import RAGMultiModalModel
# from openai import OpenAI

# import os
# st.write("Current Working Directory:", os.getcwd())

# # Function to unzip a folder if it does not exist
# # def unzip_folder_if_not_exist(zip_path, extract_to):
# #     if not os.path.exists(extract_to):
# #         with zipfile.ZipFile(zip_path, 'r') as zip_ref:
# #             zip_ref.extractall(extract_to)

# # # Example usage
# # zip_path = 'medical_index.zip'
# # extract_to = 'medical_index'
# # unzip_folder_if_not_exist(zip_path, extract_to)

# # Preload the RAGMultiModalModel
# @st.cache_resource
# def load_model():
#     return RAGMultiModalModel.from_index("./medical_index")

# RAG = load_model()

# # OpenAI API key from environment
# api_key = os.getenv("OPENAI_API_KEY")
# client = OpenAI(api_key=api_key)

# # Streamlit UI
# st.title("Medical Diagnostic Assistant")
# st.write("Enter a medical query and get diagnostic recommendations along with visual references.")

# # User input
# query = st.text_input("Query", "What should be the appropriate diagnostic test for peptic ulcer?")

# if st.button("Submit"):
#     if query:
#         # Search using RAG model
#         with st.spinner('Retrieving information...'):
#             try:
#                 returned_page = RAG.search(query, k=1)[0].base64

#                 # Decode and display the retrieved image
#                 image_bytes = base64.b64decode(returned_page)
#                 filename = 'retrieved_image.jpg'
#                 with open(filename, 'wb') as f:
#                     f.write(image_bytes)

#                 # Display image in Streamlit
#                 st.image(filename, caption="Reference Image", use_column_width=True)

#                 # Get model response
#                 response = client.chat.completions.create(
#                     model="gpt-4o-mini-2024-07-18",
#                     messages=[
#                         {"role": "system", "content": "You are a helpful assistant. You only answer the question based on the provided image"},
#                         {
#                             "role": "user",
#                             "content": [
#                                 {"type": "text", "text": query},
#                                 {
#                                     "type": "image_url",
#                                     "image_url": {"url": f"data:image/jpeg;base64,{returned_page}"},
#                                 },
#                             ],
#                         },
#                     ],
#                     max_tokens=300,
#                 )
                
#                 # Display the response
#                 st.success("Model Response:")
#                 st.write(response.choices[0].message.content)
#             except Exception as e:
#                 st.error(f"An error occurred: {e}")
#     else:
#         st.warning("Please enter a query.")


import os
import base64
import zipfile
from pathlib import Path
import streamlit as st
from byaldi import RAGMultiModalModel
from openai import OpenAI

# Preload the RAGMultiModalModel
@st.cache_resource
def load_model():
    return RAGMultiModalModel.from_index("/home/user/app/medical_index1")

RAG = load_model()

# OpenAI API key from environment
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=api_key)

# Streamlit UI
st.title("Medical Diagnostic Assistant")
st.write("Enter a medical query and get diagnostic recommendations along with visual references.")

# User input for selecting the model
model_options = ["gpt-4o", "gpt-4o-mini", "o1-preview", "o1-mini"]
selected_model = st.selectbox("Choose a GPT model", model_options)

# User input for query
query = st.text_input("Query", "What should be the appropriate diagnostic test for peptic ulcer?")

if st.button("Submit"):
    if query:
        # Search using RAG model
        with st.spinner('Retrieving information...'):
            try:
                # Get top 10 images
                returned_pages = RAG.search(query, k=10)
                image_urls = []
                for i, page in enumerate(returned_pages):
                    image_bytes = base64.b64decode(page.base64)
                    filename = f'retrieved_image_{i}.jpg'
                    with open(filename, 'wb') as f:
                        f.write(image_bytes)
                    image_urls.append(f"data:image/jpeg;base64,{page.base64}")

                # Get model response
                response = client.chat.completions.create(
                    model=selected_model,
                    messages=[
                        {"role": "system", "content": "You are a helpful assistant. You only answer the question based on the provided images."},
                        {
                            "role": "user",
                            "content": [
                                {"type": "text", "text": query},
                                *[{"type": "image_url", "image_url": {"url": url}} for url in image_urls],
                            ],
                        },
                    ],
                    max_tokens=300,
                )

                # Display the response
                st.success("Model Response:")
                st.write(response.choices[0].message.content)

                # Option to see all references
                # # Option to see all references
                # if st.button("Show References"):
                #     st.subheader("References")
                #     for i, page in enumerate(returned_pages):
                #         st.image(f'retrieved_image_{i}.jpg', caption=f"Reference Image {i+1}", use_column_width=True)
            except Exception as e:
                st.error(f"An error occurred: {e}")
    else:
        st.warning("Please enter a query.")