Spaces:
Sleeping
Sleeping
import streamlit as st | |
from ultralytics import YOLO | |
from PIL import Image, ImageDraw, ImageFont | |
import numpy as np | |
import io | |
import json | |
# Set page config | |
st.set_page_config(page_title="Skin Condition Classifier", layout="wide") | |
# Initialize the model | |
def load_model(): | |
return YOLO('best.pt') | |
model = load_model() | |
class_names = ['Acne', 'Dark circles', 'blackheads', 'eczema', 'rosacea', 'whiteheads', 'wrinkles'] | |
def process_image(image): | |
# Convert PIL Image to numpy array | |
img_array = np.array(image) | |
# Get predictions | |
results = model(img_array)[0] | |
predictions = [] | |
if results.boxes is not None: | |
boxes = results.boxes.xyxy.cpu().numpy() # Ensure we have numpy array | |
confidences = results.boxes.conf.cpu().numpy() | |
classes = results.boxes.cls.cpu().numpy() | |
for i in range(len(boxes)): | |
box = boxes[i] | |
confidence = float(confidences[i]) | |
class_id = int(classes[i]) | |
# Convert boxes to YOLO format (ensure they're within image boundaries) | |
x1 = max(0, box[0]) | |
y1 = max(0, box[1]) | |
x2 = min(image.width, box[2]) | |
y2 = min(image.height, box[3]) | |
prediction = { | |
"x1": float(x1), | |
"y1": float(y1), | |
"x2": float(x2), | |
"y2": float(y2), | |
"confidence": confidence, | |
"class": class_names[class_id], | |
} | |
predictions.append(prediction) | |
return predictions | |
def draw_boxes(image, predictions): | |
# Create a copy of the image | |
image_draw = image.copy() | |
draw = ImageDraw.Draw(image_draw) | |
# Try to load a font, fall back to default if not available | |
try: | |
font = ImageFont.truetype("arial.ttf", 16) | |
except: | |
font = ImageFont.load_default() | |
# Color map for different classes | |
colors = { | |
'Acne': (255, 0, 0), # Red | |
'Dark circles': (0, 255, 0), # Green | |
'blackheads': (0, 0, 255), # Blue | |
'eczema': (255, 255, 0), # Yellow | |
'rosacea': (255, 0, 255), # Magenta | |
'whiteheads': (0, 255, 255), # Cyan | |
'wrinkles': (128, 0, 128) # Purple | |
} | |
for pred in predictions: | |
# Extract coordinates | |
x1, y1, x2, y2 = map(float, [pred["x1"], pred["y1"], pred["x2"], pred["y2"]]) | |
# Get color for class | |
color = colors.get(pred['class'], (255, 0, 0)) # Default to red if class not found | |
# Draw the box | |
draw.rectangle([x1, y1, x2, y2], outline=color, width=2) | |
# Prepare label text | |
label = f"{pred['class']} ({pred['confidence']:.2f})" | |
# Calculate text size and position | |
text_width = draw.textlength(label, font=font) | |
text_height = 16 # Approximate height for the font size | |
# Draw label background | |
draw.rectangle([x1, y1 - text_height, x1 + text_width, y1], | |
fill=color) | |
# Draw label text | |
draw.text((x1, y1 - text_height), label, | |
fill=(255, 255, 255), # White text | |
font=font) | |
return image_draw | |
def main(): | |
st.title("Skin Condition Classifier") | |
# File uploader | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
try: | |
# Create columns for layout | |
col1, col2 = st.columns(2) | |
# Read and display original image | |
image = Image.open(uploaded_file) | |
# Convert to RGB if in RGBA mode | |
if image.mode == 'RGBA': | |
image = image.convert('RGB') | |
col1.subheader("Original Image") | |
col1.image(image, use_container_width=True) # Updated parameter | |
# Process image and get predictions | |
with st.spinner('Processing image...'): | |
predictions = process_image(image) | |
if predictions: | |
# Draw boxes on image | |
image_with_boxes = draw_boxes(image, predictions) | |
# Display annotated image | |
col2.subheader("Detected Conditions") | |
col2.image(image_with_boxes, use_container_width=True) # Updated parameter | |
# Display predictions in a nice format | |
st.subheader("Detailed Results") | |
for pred in predictions: | |
st.write(f"- Detected {pred['class']} with {pred['confidence']:.2f} confidence") | |
# Add download button for JSON results | |
json_results = json.dumps({"predictions": predictions}, indent=2) | |
st.download_button( | |
label="Download Results as JSON", | |
data=json_results, | |
file_name="predictions.json", | |
mime="application/json" | |
) | |
else: | |
st.warning("No conditions detected in the image.") | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
if __name__ == "__main__": | |
main() |