Spaces:
Sleeping
Sleeping
Eric Guan
commited on
Commit
·
248b0d6
1
Parent(s):
de080b7
Remove binary files from repository
Browse files- .gitignore +2 -1
- __pycache__/rag_model.cpython-311.pyc +0 -0
- __pycache__/yolo_model.cpython-311.pyc +0 -0
- app.py +185 -37
- images/bandaid.jpg +0 -0
- images/ctp_app_example.png +0 -0
- images/example2.png +0 -0
- images/example3.png +0 -0
- models/best.pt +3 -0
- rag_model.py +40 -46
- requirements.txt +29 -0
- yolo_model.py +35 -0
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
venv
|
2 |
-
.env
|
|
|
|
1 |
venv
|
2 |
+
.env
|
3 |
+
test_model.py
|
__pycache__/rag_model.cpython-311.pyc
CHANGED
Binary files a/__pycache__/rag_model.cpython-311.pyc and b/__pycache__/rag_model.cpython-311.pyc differ
|
|
__pycache__/yolo_model.cpython-311.pyc
ADDED
Binary file (1.86 kB). View file
|
|
app.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1 |
import streamlit as st
|
2 |
from transformers import pipeline
|
3 |
from PIL import Image
|
|
|
|
|
4 |
from rag_model import *
|
|
|
5 |
|
6 |
|
7 |
@st.cache_resource
|
@@ -9,64 +12,209 @@ def load_image_model():
|
|
9 |
return pipeline("image-classification", model="Heem2/wound-image-classification")
|
10 |
|
11 |
pipeline = load_image_model()
|
|
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
# Initialize chat history
|
16 |
if "messages" not in st.session_state:
|
17 |
st.session_state.messages = []
|
18 |
|
19 |
-
#
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# Reset chat history if no file is uploaded
|
23 |
-
if
|
24 |
st.session_state.messages = []
|
25 |
|
26 |
-
if
|
27 |
-
# Display the image and predictions
|
28 |
col1, col2 = st.columns(2)
|
29 |
-
image = Image.open(
|
30 |
col1.image(image, use_container_width=True)
|
31 |
|
32 |
-
# Classify the image
|
33 |
predictions = pipeline(image)
|
34 |
detected_wound = predictions[0]['label']
|
35 |
col2.header("Detected Wound")
|
36 |
for p in predictions:
|
37 |
col2.subheader(f"{p['label']}: {round(p['score'] * 100, 1)}%")
|
38 |
|
39 |
-
# Initial advice for wound
|
40 |
if not st.session_state.messages:
|
41 |
initial_query = f"Provide treatment advice for a {detected_wound} wound"
|
42 |
initial_response = rag_chain.invoke(initial_query)
|
43 |
st.session_state.messages.append({"role": "assistant", "content": initial_response})
|
44 |
|
45 |
-
# Display chat messages from history
|
46 |
-
for message in st.session_state.messages:
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
# Accept user input if an image is uploaded
|
51 |
-
if
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import pipeline
|
3 |
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
from rag_model import *
|
7 |
+
from yolo_model import *
|
8 |
|
9 |
|
10 |
@st.cache_resource
|
|
|
12 |
return pipeline("image-classification", model="Heem2/wound-image-classification")
|
13 |
|
14 |
pipeline = load_image_model()
|
15 |
+
yolo_model = load_yolo_model()
|
16 |
|
17 |
+
# Add custom CSS
|
18 |
+
css = """
|
19 |
+
<style>
|
20 |
+
body {
|
21 |
+
font-family: 'Arial', sans-serif;
|
22 |
+
background-color: #f5f5f5;
|
23 |
+
}
|
24 |
+
.main {
|
25 |
+
background-color: #ffffff;
|
26 |
+
padding: 20px;
|
27 |
+
border-radius: 10px;
|
28 |
+
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
|
29 |
+
}
|
30 |
+
.stButton button {
|
31 |
+
background-color: #4CAF50;
|
32 |
+
color: white;
|
33 |
+
border: none;
|
34 |
+
padding: 10px 20px;
|
35 |
+
text-align: center;
|
36 |
+
text-decoration: none;
|
37 |
+
display: inline-block;
|
38 |
+
font-size: 16px;
|
39 |
+
margin: 4px 2px;
|
40 |
+
cursor: pointer;
|
41 |
+
border-radius: 5px;
|
42 |
+
}
|
43 |
+
.stButton button:hover {
|
44 |
+
background-color: #45a049;
|
45 |
+
}
|
46 |
+
.stApp > header {
|
47 |
+
background-color: transparent;
|
48 |
+
}
|
49 |
+
.stApp {
|
50 |
+
margin: auto;
|
51 |
+
background-color: #D9AFD9;
|
52 |
+
background-image: linear-gradient(0deg, #D9AFD9 0%, #97D9E1 100%);
|
53 |
+
}
|
54 |
+
[data-testid='stFileUploader'] {
|
55 |
+
width: max-content;
|
56 |
+
}
|
57 |
+
[data-testid='stFileUploader'] section {
|
58 |
+
padding: 0;
|
59 |
+
float: left;
|
60 |
+
}
|
61 |
+
[data-testid='stFileUploader'] section > input + div {
|
62 |
+
display: none;
|
63 |
+
}
|
64 |
+
[data-testid='stFileUploader'] section + div {
|
65 |
+
float: right;
|
66 |
+
padding-top: 0;
|
67 |
+
}
|
68 |
+
</style>
|
69 |
+
"""
|
70 |
+
|
71 |
+
st.markdown(css, unsafe_allow_html=True)
|
72 |
+
|
73 |
+
st.title("**FirstAid-AI**")
|
74 |
+
|
75 |
+
# Add a description at the top
|
76 |
+
st.markdown("""
|
77 |
+
### Welcome to FirstAid-AI
|
78 |
+
This application provides medical advice based on images of wounds and medical equipment.
|
79 |
+
Upload an image of your wound or medical equipment, and the AI will classify the image and provide relevant advice.
|
80 |
+
""")
|
81 |
+
|
82 |
+
st.markdown("## How to Use FirstAid-AI")
|
83 |
+
st.markdown("### 1. Upload an image of a wound and a piece of equipment (if applicable)")
|
84 |
+
st.image("images/example3.png", use_container_width=True)
|
85 |
+
st.caption("The AI model will detect the wound or equipment in the image and provide confidence levels. The AI assistant will then provide treatment or usage advice.")
|
86 |
+
st.markdown("### 2. Ask follow-up questions and continue the conversation with the AI assistant!")
|
87 |
|
88 |
# Initialize chat history
|
89 |
if "messages" not in st.session_state:
|
90 |
st.session_state.messages = []
|
91 |
|
92 |
+
# Dropdown to select the type of images to provide
|
93 |
+
option = st.selectbox(
|
94 |
+
"Select the type of images you want to provide:",
|
95 |
+
("Provide just wound image", "Provide both wound and equipment")
|
96 |
+
)
|
97 |
+
|
98 |
+
# Upload images based on the selected option
|
99 |
+
file_wound = None
|
100 |
+
file_equipment = None
|
101 |
+
|
102 |
+
if option == "Provide just wound image":
|
103 |
+
file_wound = st.file_uploader("Upload an image of your wound")
|
104 |
+
elif option == "Provide both wound and equipment":
|
105 |
+
file_wound = st.file_uploader("Upload an image of your wound")
|
106 |
+
file_equipment = st.file_uploader("Upload an image of your equipment")
|
107 |
|
108 |
# Reset chat history if no file is uploaded
|
109 |
+
if file_wound is None and file_equipment is None:
|
110 |
st.session_state.messages = []
|
111 |
|
112 |
+
if file_wound is not None and option == "Provide just wound image":
|
113 |
+
# Display the wound image and predictions
|
114 |
col1, col2 = st.columns(2)
|
115 |
+
image = Image.open(file_wound)
|
116 |
col1.image(image, use_container_width=True)
|
117 |
|
118 |
+
# Classify the wound image
|
119 |
predictions = pipeline(image)
|
120 |
detected_wound = predictions[0]['label']
|
121 |
col2.header("Detected Wound")
|
122 |
for p in predictions:
|
123 |
col2.subheader(f"{p['label']}: {round(p['score'] * 100, 1)}%")
|
124 |
|
125 |
+
# Initial advice for wound
|
126 |
if not st.session_state.messages:
|
127 |
initial_query = f"Provide treatment advice for a {detected_wound} wound"
|
128 |
initial_response = rag_chain.invoke(initial_query)
|
129 |
st.session_state.messages.append({"role": "assistant", "content": initial_response})
|
130 |
|
131 |
+
# Display chat messages from history
|
132 |
+
for message in st.session_state.messages:
|
133 |
+
with st.chat_message(message["role"]):
|
134 |
+
st.markdown(message["content"])
|
135 |
+
|
136 |
+
# Accept user input if an image is uploaded
|
137 |
+
if (file_wound is not None or file_equipment is not None) and (prompt := st.chat_input("Ask a follow-up question or continue the conversation:")):
|
138 |
+
# Display user message in chat
|
139 |
+
with st.chat_message("user"):
|
140 |
+
st.markdown(prompt)
|
141 |
+
# Add user message to chat history
|
142 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
143 |
+
|
144 |
+
# Prepare the conversation history for rag_chain
|
145 |
+
conversation_history = "\n".join(
|
146 |
+
f"{message['role']}: {message['content']}" for message in st.session_state.messages
|
147 |
+
)
|
148 |
+
|
149 |
+
# Generate response from rag_chain
|
150 |
+
query = f"Context:\n{conversation_history}\n\nAssistant, respond to the user's latest query: {prompt}"
|
151 |
+
response = rag_chain.invoke(query)
|
152 |
+
|
153 |
+
# Display assistant response in chat message container
|
154 |
+
with st.chat_message("assistant"):
|
155 |
+
st.markdown(response)
|
156 |
+
|
157 |
+
# Add assistant response to chat history
|
158 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
159 |
+
|
160 |
+
if file_wound is not None and file_equipment is not None and option == "Provide both wound and equipment":
|
161 |
+
# Display the wound image and predictions
|
162 |
+
col1, col2 = st.columns(2)
|
163 |
+
image = Image.open(file_wound)
|
164 |
+
col1.image(image, use_container_width=True)
|
165 |
+
|
166 |
+
# Classify the wound image
|
167 |
+
predictions = pipeline(image)
|
168 |
+
detected_wound = predictions[0]['label']
|
169 |
+
col2.header("Detected Wound")
|
170 |
+
for p in predictions:
|
171 |
+
col2.subheader(f"{p['label']}: {round(p['score'] * 100, 1)}%")
|
172 |
+
|
173 |
+
# Display the equipment image and predictions
|
174 |
+
col3, col4 = st.columns(2)
|
175 |
+
image = Image.open(file_equipment)
|
176 |
+
col3.image(image, use_container_width=True)
|
177 |
+
|
178 |
+
# Convert the image to a format supported by YOLO
|
179 |
+
image_np = np.array(image)
|
180 |
+
image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
181 |
+
|
182 |
+
# Classify the equipment image using YOLO model
|
183 |
+
detected_equipment = get_detected_objects(yolo_model, image_cv)
|
184 |
+
col4.header("Detected Equipment")
|
185 |
+
col4.subheader(detected_equipment)
|
186 |
+
|
187 |
+
# Initial advice for equipment
|
188 |
+
if not st.session_state.messages:
|
189 |
+
initial_query = f"Provide usage advice for {detected_equipment} when treating a {detected_wound} wound"
|
190 |
+
initial_response = rag_chain.invoke(initial_query)
|
191 |
+
st.session_state.messages.append({"role": "assistant", "content": initial_response})
|
192 |
+
|
193 |
+
# Display chat messages from history
|
194 |
+
for message in st.session_state.messages:
|
195 |
+
with st.chat_message(message["role"]):
|
196 |
+
st.markdown(message["content"])
|
197 |
+
|
198 |
+
# Accept user input if an image is uploaded
|
199 |
+
if (file_wound is not None or file_equipment is not None) and (prompt := st.chat_input("Ask a follow-up question or continue the conversation:")):
|
200 |
+
# Display user message in chat
|
201 |
+
with st.chat_message("user"):
|
202 |
+
st.markdown(prompt)
|
203 |
+
# Add user message to chat history
|
204 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
205 |
+
|
206 |
+
# Prepare the conversation history for rag_chain
|
207 |
+
conversation_history = "\n".join(
|
208 |
+
f"{message['role']}: {message['content']}" for message in st.session_state.messages
|
209 |
+
)
|
210 |
+
|
211 |
+
# Generate response from rag_chain
|
212 |
+
query = f"Context:\n{conversation_history}\n\nAssistant, respond to the user's latest query: {prompt}"
|
213 |
+
response = rag_chain.invoke(query)
|
214 |
+
|
215 |
+
# Display assistant response in chat message container
|
216 |
+
with st.chat_message("assistant"):
|
217 |
+
st.markdown(response)
|
218 |
+
|
219 |
+
# Add assistant response to chat history
|
220 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
images/bandaid.jpg
ADDED
![]() |
images/ctp_app_example.png
ADDED
![]() |
images/example2.png
ADDED
![]() |
images/example3.png
ADDED
![]() |
models/best.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:54c49363a8d6f9a503178ab1f0594b4ece9ea35f30cc8c2721d4564c63bbb48e
|
3 |
+
size 22528035
|
rag_model.py
CHANGED
@@ -1,13 +1,11 @@
|
|
1 |
-
#
|
2 |
-
|
3 |
-
# File loading and environment variables.
|
4 |
import os
|
5 |
from dotenv import load_dotenv
|
6 |
|
7 |
-
#
|
8 |
-
|
9 |
|
10 |
-
# Langchain
|
11 |
from langchain.document_loaders import TextLoader
|
12 |
from langchain.prompts import PromptTemplate
|
13 |
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
|
@@ -15,34 +13,39 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
15 |
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
|
16 |
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
17 |
|
18 |
-
# MongoDB
|
19 |
from pymongo import MongoClient
|
20 |
|
21 |
-
# Function type hints
|
22 |
from typing import Dict, Any
|
23 |
|
24 |
# Streamlit
|
25 |
import streamlit as st
|
26 |
|
27 |
-
# Load
|
28 |
load_dotenv()
|
29 |
|
30 |
-
#
|
31 |
MONGO_URI = os.getenv("MONGO_URI")
|
32 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
# Setup Vector Store and MongoDB Atlas connection
|
35 |
|
36 |
-
# Connect to MongoDB Atlas cluster using the connection string
|
37 |
DB_NAME = "ericguan04"
|
38 |
COLLECTION_NAME = "first_aid_intents"
|
39 |
vector_search_index = "vector_index"
|
40 |
|
41 |
@st.cache_resource
|
42 |
def get_mongodb_collection():
|
43 |
-
# Connect to MongoDB Atlas cluster using the connection string
|
44 |
cluster = MongoClient(MONGO_URI)
|
45 |
-
# Connect to the specific collection in the database
|
46 |
return cluster[DB_NAME][COLLECTION_NAME]
|
47 |
|
48 |
MONGODB_COLLECTION = get_mongodb_collection()
|
@@ -62,25 +65,23 @@ vector_search = MongoDBAtlasVectorSearch.from_connection_string(
|
|
62 |
index_name=vector_search_index,
|
63 |
)
|
64 |
|
65 |
-
# k to search for only the X most relevant documents
|
66 |
k = 10
|
67 |
|
68 |
-
# score_threshold to use only documents with a relevance score above 0.80
|
69 |
score_threshold = 0.80
|
70 |
|
71 |
# Build your retriever
|
72 |
retriever_1 = vector_search.as_retriever(
|
73 |
-
search_type
|
74 |
-
search_kwargs
|
75 |
)
|
76 |
|
77 |
-
|
78 |
-
# Initialize Hugging Face client
|
79 |
-
hf_client = InferenceClient(api_key=HF_TOKEN)
|
80 |
-
|
81 |
# Define the prompt template
|
82 |
prompt = PromptTemplate.from_template(
|
83 |
-
"""
|
|
|
|
|
84 |
|
85 |
START OF CONTEXT:
|
86 |
{context}
|
@@ -92,41 +93,34 @@ prompt = PromptTemplate.from_template(
|
|
92 |
|
93 |
If you do not know the answer, just say that you do not know.
|
94 |
NEVER assume things.
|
|
|
95 |
"""
|
96 |
)
|
97 |
|
98 |
-
|
99 |
-
# Formatting the retrieved documents before inserting them in the system prompt template.
|
100 |
def format_docs(docs):
|
101 |
return "\n\n".join(doc.page_content for doc in docs)
|
102 |
|
103 |
@st.cache_resource
|
104 |
def generate_response(input_dict: Dict[str, Any]) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
formatted_prompt = prompt.format(**input_dict)
|
106 |
-
|
107 |
-
|
108 |
-
## THIS IS YOUR LLM
|
109 |
-
response = hf_client.chat.completions.create(
|
110 |
-
model="Qwen/Qwen2.5-1.5B-Instruct",
|
111 |
-
messages=[{
|
112 |
-
"role": "system",
|
113 |
-
"content": formatted_prompt
|
114 |
-
},{
|
115 |
-
"role": "user",
|
116 |
-
"content": input_dict["question"]
|
117 |
-
}],
|
118 |
-
max_tokens=1000,
|
119 |
-
temperature=0.2,
|
120 |
-
)
|
121 |
-
|
122 |
-
return response.choices[0].message.content
|
123 |
|
124 |
-
|
125 |
-
# Build the chain with retriever_1.
|
126 |
rag_chain = (
|
127 |
{
|
128 |
"context": retriever_1 | RunnableLambda(format_docs),
|
129 |
-
"question": RunnablePassthrough()
|
130 |
}
|
131 |
| RunnableLambda(generate_response)
|
132 |
-
)
|
|
|
1 |
+
# File loading and environment variables
|
|
|
|
|
2 |
import os
|
3 |
from dotenv import load_dotenv
|
4 |
|
5 |
+
# Gemini Library
|
6 |
+
import google.generativeai as genai
|
7 |
|
8 |
+
# Langchain
|
9 |
from langchain.document_loaders import TextLoader
|
10 |
from langchain.prompts import PromptTemplate
|
11 |
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
|
|
|
13 |
from langchain_community.vectorstores import MongoDBAtlasVectorSearch
|
14 |
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
|
15 |
|
16 |
+
# MongoDB
|
17 |
from pymongo import MongoClient
|
18 |
|
19 |
+
# Function type hints
|
20 |
from typing import Dict, Any
|
21 |
|
22 |
# Streamlit
|
23 |
import streamlit as st
|
24 |
|
25 |
+
# Load environment variables
|
26 |
load_dotenv()
|
27 |
|
28 |
+
# Retrieve environment variables
|
29 |
MONGO_URI = os.getenv("MONGO_URI")
|
30 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
31 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
32 |
+
|
33 |
+
# Configure Gemini
|
34 |
+
genai.configure(api_key=GEMINI_API_KEY)
|
35 |
+
model = genai.GenerativeModel("gemini-1.5-flash")
|
36 |
|
37 |
+
# Setup Vector Store and MongoDB Atlas connection
|
38 |
|
39 |
+
# Connect to MongoDB Atlas cluster using the connection string
|
40 |
DB_NAME = "ericguan04"
|
41 |
COLLECTION_NAME = "first_aid_intents"
|
42 |
vector_search_index = "vector_index"
|
43 |
|
44 |
@st.cache_resource
|
45 |
def get_mongodb_collection():
|
46 |
+
# Connect to MongoDB Atlas cluster using the connection string
|
47 |
cluster = MongoClient(MONGO_URI)
|
48 |
+
# Connect to the specific collection in the database
|
49 |
return cluster[DB_NAME][COLLECTION_NAME]
|
50 |
|
51 |
MONGODB_COLLECTION = get_mongodb_collection()
|
|
|
65 |
index_name=vector_search_index,
|
66 |
)
|
67 |
|
68 |
+
# k to search for only the X most relevant documents
|
69 |
k = 10
|
70 |
|
71 |
+
# score_threshold to use only documents with a relevance score above 0.80
|
72 |
score_threshold = 0.80
|
73 |
|
74 |
# Build your retriever
|
75 |
retriever_1 = vector_search.as_retriever(
|
76 |
+
search_type="similarity", # similarity, mmr, similarity_score_threshold. https://api.python.langchain.com/en/latest/vectorstores/langchain_core.vectorstores.VectorStore.html#langchain_core.vectorstores.VectorStore.as_retriever
|
77 |
+
search_kwargs={"k": k, "score_threshold": score_threshold},
|
78 |
)
|
79 |
|
|
|
|
|
|
|
|
|
80 |
# Define the prompt template
|
81 |
prompt = PromptTemplate.from_template(
|
82 |
+
"""You are playing the role of a medical assistant. A patient has come to you with a minor medical issue.
|
83 |
+
Use the following pieces of context to answer the question at the end.
|
84 |
+
To be more natural, do not mention you are referring to the context.
|
85 |
|
86 |
START OF CONTEXT:
|
87 |
{context}
|
|
|
93 |
|
94 |
If you do not know the answer, just say that you do not know.
|
95 |
NEVER assume things.
|
96 |
+
If the question is not relevant to the context, just say that it is not relevant.
|
97 |
"""
|
98 |
)
|
99 |
|
100 |
+
# Formatting the retrieved documents before inserting them in the system prompt template
|
|
|
101 |
def format_docs(docs):
|
102 |
return "\n\n".join(doc.page_content for doc in docs)
|
103 |
|
104 |
@st.cache_resource
|
105 |
def generate_response(input_dict: Dict[str, Any]) -> str:
|
106 |
+
"""
|
107 |
+
Generate a response using the Gemini model.
|
108 |
+
|
109 |
+
Parameters:
|
110 |
+
input_dict (Dict[str, Any]): Dictionary with formatted context and question.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
str: Generated response from the Gemini model.
|
114 |
+
"""
|
115 |
formatted_prompt = prompt.format(**input_dict)
|
116 |
+
response = model.generate_content(formatted_prompt)
|
117 |
+
return response.text # Adjust based on actual response structure
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
+
# Build the chain with retriever_1
|
|
|
120 |
rag_chain = (
|
121 |
{
|
122 |
"context": retriever_1 | RunnableLambda(format_docs),
|
123 |
+
"question": RunnablePassthrough(),
|
124 |
}
|
125 |
| RunnableLambda(generate_response)
|
126 |
+
)
|
requirements.txt
CHANGED
@@ -10,15 +10,28 @@ cachetools==5.5.0
|
|
10 |
certifi==2024.8.30
|
11 |
charset-normalizer==3.4.0
|
12 |
click==8.1.7
|
|
|
|
|
13 |
dataclasses-json==0.6.7
|
14 |
dnspython==2.7.0
|
15 |
filelock==3.16.1
|
|
|
16 |
frozenlist==1.5.0
|
17 |
fsspec==2024.10.0
|
18 |
gitdb==4.0.11
|
19 |
GitPython==3.1.43
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
h11==0.14.0
|
21 |
httpcore==1.0.7
|
|
|
22 |
httpx==0.27.2
|
23 |
httpx-sse==0.4.0
|
24 |
huggingface-hub==0.26.2
|
@@ -28,6 +41,7 @@ jsonpatch==1.33
|
|
28 |
jsonpointer==3.0.0
|
29 |
jsonschema==4.23.0
|
30 |
jsonschema-specifications==2024.10.1
|
|
|
31 |
langchain==0.3.7
|
32 |
langchain-community==0.3.7
|
33 |
langchain-core==0.3.19
|
@@ -36,6 +50,7 @@ langsmith==0.1.143
|
|
36 |
markdown-it-py==3.0.0
|
37 |
MarkupSafe==3.0.2
|
38 |
marshmallow==3.23.1
|
|
|
39 |
mdurl==0.1.2
|
40 |
mpmath==1.3.0
|
41 |
multidict==6.1.0
|
@@ -43,19 +58,26 @@ mypy-extensions==1.0.0
|
|
43 |
narwhals==1.14.1
|
44 |
networkx==3.4.2
|
45 |
numpy==1.26.4
|
|
|
46 |
orjson==3.10.11
|
47 |
packaging==24.2
|
48 |
pandas==2.2.3
|
49 |
pillow==11.0.0
|
50 |
propcache==0.2.0
|
|
|
51 |
protobuf==5.28.3
|
|
|
|
|
52 |
pyarrow==18.0.0
|
|
|
|
|
53 |
pydantic==2.9.2
|
54 |
pydantic-settings==2.6.1
|
55 |
pydantic_core==2.23.4
|
56 |
pydeck==0.9.1
|
57 |
Pygments==2.18.0
|
58 |
pymongo==4.10.1
|
|
|
59 |
python-dateutil==2.9.0.post0
|
60 |
python-dotenv==1.0.1
|
61 |
pytz==2024.2
|
@@ -66,7 +88,10 @@ requests==2.32.3
|
|
66 |
requests-toolbelt==1.0.0
|
67 |
rich==13.9.4
|
68 |
rpds-py==0.21.0
|
|
|
69 |
safetensors==0.4.5
|
|
|
|
|
70 |
six==1.16.0
|
71 |
smmap==5.0.1
|
72 |
sniffio==1.3.1
|
@@ -77,11 +102,15 @@ tenacity==9.0.0
|
|
77 |
tokenizers==0.20.3
|
78 |
toml==0.10.2
|
79 |
torch==2.5.1
|
|
|
80 |
tornado==6.4.1
|
81 |
tqdm==4.67.0
|
82 |
transformers==4.46.3
|
83 |
typing-inspect==0.9.0
|
84 |
typing_extensions==4.12.2
|
85 |
tzdata==2024.2
|
|
|
|
|
|
|
86 |
urllib3==2.2.3
|
87 |
yarl==1.17.2
|
|
|
10 |
certifi==2024.8.30
|
11 |
charset-normalizer==3.4.0
|
12 |
click==8.1.7
|
13 |
+
contourpy==1.3.1
|
14 |
+
cycler==0.12.1
|
15 |
dataclasses-json==0.6.7
|
16 |
dnspython==2.7.0
|
17 |
filelock==3.16.1
|
18 |
+
fonttools==4.55.2
|
19 |
frozenlist==1.5.0
|
20 |
fsspec==2024.10.0
|
21 |
gitdb==4.0.11
|
22 |
GitPython==3.1.43
|
23 |
+
google-ai-generativelanguage==0.6.10
|
24 |
+
google-api-core==2.23.0
|
25 |
+
google-api-python-client==2.154.0
|
26 |
+
google-auth==2.36.0
|
27 |
+
google-auth-httplib2==0.2.0
|
28 |
+
google-generativeai==0.8.3
|
29 |
+
googleapis-common-protos==1.66.0
|
30 |
+
grpcio==1.68.1
|
31 |
+
grpcio-status==1.68.1
|
32 |
h11==0.14.0
|
33 |
httpcore==1.0.7
|
34 |
+
httplib2==0.22.0
|
35 |
httpx==0.27.2
|
36 |
httpx-sse==0.4.0
|
37 |
huggingface-hub==0.26.2
|
|
|
41 |
jsonpointer==3.0.0
|
42 |
jsonschema==4.23.0
|
43 |
jsonschema-specifications==2024.10.1
|
44 |
+
kiwisolver==1.4.7
|
45 |
langchain==0.3.7
|
46 |
langchain-community==0.3.7
|
47 |
langchain-core==0.3.19
|
|
|
50 |
markdown-it-py==3.0.0
|
51 |
MarkupSafe==3.0.2
|
52 |
marshmallow==3.23.1
|
53 |
+
matplotlib==3.9.3
|
54 |
mdurl==0.1.2
|
55 |
mpmath==1.3.0
|
56 |
multidict==6.1.0
|
|
|
58 |
narwhals==1.14.1
|
59 |
networkx==3.4.2
|
60 |
numpy==1.26.4
|
61 |
+
opencv-python==4.10.0.84
|
62 |
orjson==3.10.11
|
63 |
packaging==24.2
|
64 |
pandas==2.2.3
|
65 |
pillow==11.0.0
|
66 |
propcache==0.2.0
|
67 |
+
proto-plus==1.25.0
|
68 |
protobuf==5.28.3
|
69 |
+
psutil==6.1.0
|
70 |
+
py-cpuinfo==9.0.0
|
71 |
pyarrow==18.0.0
|
72 |
+
pyasn1==0.6.1
|
73 |
+
pyasn1_modules==0.4.1
|
74 |
pydantic==2.9.2
|
75 |
pydantic-settings==2.6.1
|
76 |
pydantic_core==2.23.4
|
77 |
pydeck==0.9.1
|
78 |
Pygments==2.18.0
|
79 |
pymongo==4.10.1
|
80 |
+
pyparsing==3.2.0
|
81 |
python-dateutil==2.9.0.post0
|
82 |
python-dotenv==1.0.1
|
83 |
pytz==2024.2
|
|
|
88 |
requests-toolbelt==1.0.0
|
89 |
rich==13.9.4
|
90 |
rpds-py==0.21.0
|
91 |
+
rsa==4.9
|
92 |
safetensors==0.4.5
|
93 |
+
scipy==1.14.1
|
94 |
+
seaborn==0.13.2
|
95 |
six==1.16.0
|
96 |
smmap==5.0.1
|
97 |
sniffio==1.3.1
|
|
|
102 |
tokenizers==0.20.3
|
103 |
toml==0.10.2
|
104 |
torch==2.5.1
|
105 |
+
torchvision==0.20.1
|
106 |
tornado==6.4.1
|
107 |
tqdm==4.67.0
|
108 |
transformers==4.46.3
|
109 |
typing-inspect==0.9.0
|
110 |
typing_extensions==4.12.2
|
111 |
tzdata==2024.2
|
112 |
+
ultralytics==8.3.47
|
113 |
+
ultralytics-thop==2.0.12
|
114 |
+
uritemplate==4.1.1
|
115 |
urllib3==2.2.3
|
116 |
yarl==1.17.2
|
yolo_model.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# YOLO model
|
2 |
+
from ultralytics import YOLO
|
3 |
+
# Streamlit
|
4 |
+
import streamlit as st
|
5 |
+
|
6 |
+
@st.cache_resource
|
7 |
+
def load_yolo_model():
|
8 |
+
return YOLO("models/best.pt")
|
9 |
+
|
10 |
+
def get_detected_objects(yolo_model, image_path, conf_threshold=0.5):
|
11 |
+
"""
|
12 |
+
Run YOLO prediction on an image and return detected objects as a string.
|
13 |
+
|
14 |
+
Parameters:
|
15 |
+
model_path (str): Path to the YOLO model file.
|
16 |
+
image_path (str): Path to the input image.
|
17 |
+
conf_threshold (float): Confidence threshold for detections.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
str: A comma-separated string of detected object names.
|
21 |
+
"""
|
22 |
+
# Load the YOLO model
|
23 |
+
model = yolo_model
|
24 |
+
|
25 |
+
# Run prediction
|
26 |
+
results = model.predict(source=image_path, conf=conf_threshold)
|
27 |
+
|
28 |
+
# Extract detected objects as a list
|
29 |
+
detected_objects = [box.cls for box in results[0].boxes] # Access the first image's detections
|
30 |
+
|
31 |
+
# Convert class indices to class names
|
32 |
+
detected_class_names = [model.names[int(cls)] for cls in detected_objects]
|
33 |
+
|
34 |
+
# Join detected class names into a single string
|
35 |
+
return ", ".join(detected_class_names)
|