Ani14 commited on
Commit
d451e55
·
verified ·
1 Parent(s): a31f615

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +138 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,140 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ # wound_agent_streamlit.py
2
+
 
3
  import streamlit as st
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+ import tempfile
8
+ from PIL import Image
9
+ from tensorflow.keras.models import load_model
10
+ from transformers import pipeline
11
+
12
+ st.set_page_config(page_title="SmartHeal Wound Agent", layout="wide")
13
+
14
+ # ------------------------------
15
+ # Load All Models Once
16
+ # ------------------------------
17
+ @st.cache_resource
18
+ def load_all_models():
19
+ # YOLOv5 detection
20
+ detection_model = torch.hub.load("ultralytics/yolov5", "custom", path="best.pt", force_reload=False)
21
+
22
+ # Segmentation
23
+ segmentation_model = load_model("segmentation model.h5", compile=False)
24
+
25
+ # Classification
26
+ classification_pipe = pipeline("image-classification", model="Hemg/Wound-classification")
27
+
28
+ # Med-Gemma
29
+ medgemma_pipe = pipeline(
30
+ "image-text-to-text",
31
+ model="google/medgemma-4b-it",
32
+ torch_dtype=torch.bfloat16,
33
+ device="cuda"
34
+ )
35
+
36
+ return detection_model, segmentation_model, classification_pipe, medgemma_pipe
37
+
38
+ yolo_model, seg_model, classify_pipe, medgemma = load_all_models()
39
+
40
+ # ------------------------------
41
+ # Area Estimation
42
+ # ------------------------------
43
+ def estimate_area(mask, px_per_cm=20):
44
+ pixel_area = np.sum(mask > 0)
45
+ area_cm2 = pixel_area / (px_per_cm ** 2)
46
+ return round(area_cm2, 2)
47
+
48
+ # ------------------------------
49
+ # Main UI
50
+ # ------------------------------
51
+ st.title("🩹 SmartHeal: Real-Time Wound Care Agent")
52
+
53
+ uploaded_file = st.file_uploader("📤 Upload a wound image", type=["jpg", "jpeg", "png"])
54
+ with st.form("patient_form"):
55
+ age = st.number_input("Patient Age", min_value=1, max_value=120)
56
+ diabetic = st.radio("Is the patient diabetic?", ["Yes", "No"])
57
+ infection = st.radio("Signs of infection present?", ["Yes", "No"])
58
+ submitted = st.form_submit_button("🔍 Analyze")
59
+
60
+ if uploaded_file and submitted:
61
+ image = Image.open(uploaded_file).convert("RGB")
62
+ st.image(image, caption="Uploaded Image", use_column_width=True)
63
+
64
+ # Convert to OpenCV format
65
+ image_cv = np.array(image)
66
+
67
+ # ------------------ DETECTION ------------------
68
+ st.subheader("🧠 Detection")
69
+ results = yolo_model(image_cv)
70
+ boxes = results.xyxy[0].cpu().numpy()
71
+
72
+ if len(boxes) == 0:
73
+ st.error("No wound detected.")
74
+ st.stop()
75
+
76
+ x1, y1, x2, y2 = map(int, boxes[0][:4])
77
+ detected_region = image_cv[y1:y2, x1:x2]
78
+
79
+ # ------------------ SEGMENTATION ------------------
80
+ st.subheader("🧠 Segmentation")
81
+ resized = cv2.resize(detected_region, (256, 256)) / 255.0
82
+ input_tensor = np.expand_dims(resized, axis=0)
83
+ pred_mask = seg_model.predict(input_tensor)[0]
84
+ binary_mask = (pred_mask[:, :, 0] > 0.5).astype(np.uint8)
85
+ mask_resized = cv2.resize(binary_mask, (x2 - x1, y2 - y1))
86
+ full_mask = np.zeros(image_cv.shape[:2], dtype=np.uint8)
87
+ full_mask[y1:y2, x1:x2] = mask_resized
88
+
89
+ area_cm2 = estimate_area(full_mask)
90
+ st.markdown(f"📏 **Estimated Wound Area:** `{area_cm2} cm²`")
91
+
92
+ overlay = image_cv.copy()
93
+ overlay[full_mask > 0] = [255, 0, 0]
94
+ st.image(overlay, caption="Wound Segmentation Overlay", use_column_width=True)
95
+
96
+ # ------------------ CLASSIFICATION ------------------
97
+ st.subheader("🧠 Wound Classification")
98
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
99
+ Image.fromarray(detected_region).save(tmp.name)
100
+ wound_type = classify_pipe(tmp.name)[0]["label"]
101
+ st.success(f"✅ Wound Type Classified: **{wound_type}**")
102
+
103
+ # ------------------ MED-GEMMA ------------------
104
+ st.subheader("🧠 Med-Gemma Diagnosis + Treatment Plan")
105
+
106
+ messages = [
107
+ {
108
+ "role": "system",
109
+ "content": [{"type": "text", "text": "You are a wound care expert."}]
110
+ },
111
+ {
112
+ "role": "user",
113
+ "content": [
114
+ {"type": "text", "text": f"""Patient Info:
115
+ - Age: {age}
116
+ - Diabetic: {diabetic}
117
+ - Wound Type: {wound_type}
118
+ - Area: {area_cm2} cm²
119
+ - Signs of infection: {infection}
120
+
121
+ Please provide:
122
+ 1. Wound assessment
123
+ 2. Recommended treatment
124
+ 3. Cleaning & dressing method
125
+ 4. Red flags to monitor
126
+ 5. Follow-up schedule"""},
127
+ {"type": "image", "image": image}
128
+ ]
129
+ }
130
+ ]
131
+
132
+ with st.spinner("Generating treatment plan..."):
133
+ output = medgemma(text=messages, max_new_tokens=300)
134
+ response = output[0]["generated_text"][-1]["content"]
135
+
136
+ st.markdown("### 📝 Recommendation")
137
+ st.info(response)
138
+
139
+ st.download_button("📄 Download Report", response, file_name="treatment_plan.txt")
140