Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
from PIL import Image | |
import numpy as np | |
import cv2 | |
from rag_model import * | |
from yolo_model import * | |
def load_image_model(): | |
return pipeline("image-classification", model="Heem2/wound-image-classification") | |
pipeline = load_image_model() | |
yolo_model = load_yolo_model() | |
# Add custom CSS | |
css = """ | |
<style> | |
body { | |
font-family: 'Arial', sans-serif; | |
background-color: #f5f5f5; | |
} | |
.main { | |
background-color: #ffffff; | |
padding: 20px; | |
border-radius: 10px; | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
} | |
.stButton button { | |
background-color: #4CAF50; | |
color: white; | |
border: none; | |
padding: 10px 20px; | |
text-align: center; | |
text-decoration: none; | |
display: inline-block; | |
font-size: 16px; | |
margin: 4px 2px; | |
cursor: pointer; | |
border-radius: 5px; | |
} | |
.stButton button:hover { | |
background-color: #45a049; | |
} | |
.stApp > header { | |
background-color: transparent; | |
} | |
.stApp { | |
margin: auto; | |
background-color: #D9AFD9; | |
background-image: linear-gradient(0deg, #D9AFD9 0%, #97D9E1 100%); | |
} | |
[data-testid='stFileUploader'] { | |
width: max-content; | |
} | |
[data-testid='stFileUploader'] section { | |
padding: 0; | |
float: left; | |
} | |
[data-testid='stFileUploader'] section > input + div { | |
display: none; | |
} | |
[data-testid='stFileUploader'] section + div { | |
float: right; | |
padding-top: 0; | |
} | |
</style> | |
""" | |
st.markdown(css, unsafe_allow_html=True) | |
st.title("**FirstAid-AI**") | |
# Add a description at the top | |
st.markdown(""" | |
### Welcome to FirstAid-AI | |
This application provides medical advice based on images of wounds and medical equipment. | |
Upload an image of your wound or medical equipment, and the AI will classify the image and provide relevant advice. | |
""") | |
st.markdown("## How to Use FirstAid-AI") | |
st.markdown("### 1. Upload an image of a wound and a piece of equipment (if applicable)") | |
st.image("images/example3.png", use_container_width=True) | |
st.caption("The AI model will detect the wound or equipment in the image and provide confidence levels. The AI assistant will then provide treatment or usage advice.") | |
st.markdown("### 2. Ask follow-up questions and continue the conversation with the AI assistant!") | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Dropdown to select the type of images to provide | |
option = st.selectbox( | |
"Select the type of images you want to provide:", | |
("Provide just wound image", "Provide both wound and equipment") | |
) | |
# Upload images based on the selected option | |
file_wound = None | |
file_equipment = None | |
if option == "Provide just wound image": | |
file_wound = st.file_uploader("Upload an image of your wound") | |
elif option == "Provide both wound and equipment": | |
file_wound = st.file_uploader("Upload an image of your wound") | |
file_equipment = st.file_uploader("Upload an image of your equipment") | |
# Reset chat history if no file is uploaded | |
if file_wound is None and file_equipment is None: | |
st.session_state.messages = [] | |
if file_wound is not None and option == "Provide just wound image": | |
# Display the wound image and predictions | |
col1, col2 = st.columns(2) | |
image = Image.open(file_wound) | |
col1.image(image, use_container_width=True) | |
# Classify the wound image | |
predictions = pipeline(image) | |
detected_wound = predictions[0]['label'] | |
col2.header("Detected Wound") | |
for p in predictions: | |
col2.subheader(f"{p['label']}: {round(p['score'] * 100, 1)}%") | |
# Initial advice for wound | |
if not st.session_state.messages: | |
initial_query = f"Provide treatment advice for a {detected_wound} wound" | |
initial_response = rag_chain.invoke(initial_query) | |
st.session_state.messages.append({"role": "assistant", "content": initial_response}) | |
# Display chat messages from history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Accept user input if an image is uploaded | |
if (file_wound is not None or file_equipment is not None) and (prompt := st.chat_input("Ask a follow-up question or continue the conversation:")): | |
# Display user message in chat | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Prepare the conversation history for rag_chain | |
conversation_history = "\n".join( | |
f"{message['role']}: {message['content']}" for message in st.session_state.messages | |
) | |
# Generate response from rag_chain | |
query = f"Context:\n{conversation_history}\n\nAssistant, respond to the user's latest query: {prompt}" | |
response = rag_chain.invoke(query) | |
# Display assistant response in chat message container | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
# Add assistant response to chat history | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
if file_wound is not None and file_equipment is not None and option == "Provide both wound and equipment": | |
# Display the wound image and predictions | |
col1, col2 = st.columns(2) | |
image = Image.open(file_wound) | |
col1.image(image, use_container_width=True) | |
# Classify the wound image | |
predictions = pipeline(image) | |
detected_wound = predictions[0]['label'] | |
col2.header("Detected Wound") | |
for p in predictions: | |
col2.subheader(f"{p['label']}: {round(p['score'] * 100, 1)}%") | |
# Display the equipment image and predictions | |
col3, col4 = st.columns(2) | |
image = Image.open(file_equipment) | |
col3.image(image, use_container_width=True) | |
# Convert the image to a format supported by YOLO | |
image_np = np.array(image) | |
image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
# Classify the equipment image using YOLO model | |
detected_equipment = get_detected_objects(yolo_model, image_cv) | |
col4.header("Detected Equipment") | |
col4.subheader(detected_equipment) | |
# Initial advice for equipment | |
if not st.session_state.messages: | |
initial_query = f"Provide usage advice for {detected_equipment} when treating a {detected_wound} wound" | |
initial_response = rag_chain.invoke(initial_query) | |
st.session_state.messages.append({"role": "assistant", "content": initial_response}) | |
# Display chat messages from history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Accept user input if an image is uploaded | |
if (file_wound is not None or file_equipment is not None) and (prompt := st.chat_input("Ask a follow-up question or continue the conversation:")): | |
# Display user message in chat | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Prepare the conversation history for rag_chain | |
conversation_history = "\n".join( | |
f"{message['role']}: {message['content']}" for message in st.session_state.messages | |
) | |
# Generate response from rag_chain | |
query = f"Context:\n{conversation_history}\n\nAssistant, respond to the user's latest query: {prompt}" | |
response = rag_chain.invoke(query) | |
# Display assistant response in chat message container | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
# Add assistant response to chat history | |
st.session_state.messages.append({"role": "assistant", "content": response}) |