Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
from transformers import pipeline | |
import numpy as np | |
import cv2 | |
import matplotlib.cm as cm | |
import time | |
import base64 | |
from io import BytesIO | |
st.set_page_config(layout="wide") | |
with open("styles.css") as f: | |
st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True) | |
st.markdown("<h1 class='title'>Segformer Semantic Segmentation</h1>", unsafe_allow_html=True) | |
st.markdown(""" | |
<div class='text-center'> | |
This app uses the Segformer deep learning model to perform semantic segmentation on <b style='color: red; font-weight: 40px;'>road images</b>. The Transformer-based model is | |
trained on the CityScapes dataset which contains images of urban road scenes. Upload a | |
road scene and the app will return the image with semantic segmentation applied. | |
</div> | |
""", unsafe_allow_html=True) | |
group_members = ["Ang Ngo Ching, Josh Darren W.", "Bautista, Ryan Matthew M.", "Lacuesta, Angelo Giuseppe M.", "Reyes, Kenwin Hans", "Ting, Sidney Mitchell O."] | |
# model_versions = ["b1", "b2", "b3", "b4", "b5"] | |
# selected_model_version = st.selectbox("Select a model version:", model_versions) | |
st.markdown(""" | |
<h3 class='text-center' style='margin-top: 0.5rem;'> | |
ℹ️ You can get sample images of road scenes in this <a href='https://drive.google.com/drive/folders/1202EMeXAHnN18NuhJKWWme34vg0V-svY?fbclid=IwAR3kyjGS895nOBKi9aGT_P4gLX9jvSNrV5b5y3GH49t2Pvg2sZSRA58LLxs' target='_blank'>link</a>. | |
</h3>""", unsafe_allow_html=True) | |
semantic_segmentation = pipeline("image-segmentation", f"nvidia/segformer-b1-finetuned-cityscapes-1024-1024") | |
new_file_uploaded = False | |
uploaded_file = st.file_uploader("", type=["jpg", "png"]) | |
label_colors = {} | |
def draw_masks_fromDict(image, results): | |
masked_image = image.copy() | |
colormap = cm.get_cmap('nipy_spectral') | |
for i, result in enumerate(results): | |
mask = np.array(result['mask']) | |
mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2) | |
color = colormap(i / len(results))[:3] | |
color = tuple(int(c * 255) for c in color) | |
masked_image = np.where(mask, color, masked_image) | |
label_colors[color] = result['label'] | |
masked_image = masked_image.astype(np.uint8) | |
return cv2.addWeighted(image, 0.3, masked_image, 0.7, 0) | |
col1, col2 = st.columns(2) | |
if "uploaded_file" not in st.session_state: | |
st.session_state.uploaded_file = None | |
if uploaded_file is not None: | |
st.session_state.uploaded_file = uploaded_file | |
if st.session_state.uploaded_file is not None: | |
image = Image.open(st.session_state.uploaded_file) | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image(image, caption='Uploaded Image.', use_column_width=True) | |
while True: | |
with st.spinner('Processing...'): | |
segmentation_results = semantic_segmentation(image) | |
image_with_masks = draw_masks_fromDict(np.array(image)[:, :, :3], segmentation_results) | |
image_with_masks_pil = Image.fromarray(image_with_masks, 'RGB') | |
with col2: | |
st.image(image_with_masks_pil, caption='Segmented Image.', use_column_width=True) | |
st.markdown("**Labels:**") | |
for color, label in label_colors.items(): | |
st.markdown(f"<div style='display: flex; align-items: center; margin-bottom: 0.5rem;'><span style='display: inline-block; width: 20px; height: 20px; background-color: rgb{color}; margin-right: 1rem; border-radius: 10px;'></span><p style='margin: 0;'>{label}</p></div>", unsafe_allow_html=True) | |
buffered = BytesIO() | |
image_with_masks_pil.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
href = f'<a href="data:file/png;base64,{img_str}" download="segmented_{st.session_state.uploaded_file.name}">Download Segmented Image</a>' | |
st.markdown(href, unsafe_allow_html=True) | |
new_file_uploaded = False | |
while not new_file_uploaded: | |
time.sleep(1) | |
pdf_url = "https://arxiv.org/pdf/2105.15203.pdf" | |
st.markdown(""" | |
<h3 style='text-align: center; margin-top: 2rem;'> | |
Read more about the paper below👇 | |
</h5> | |
""", unsafe_allow_html=True) | |
st.markdown(f'<iframe class="pdf" src={pdf_url}></iframe>', unsafe_allow_html=True) | |
st.markdown("Group Members:") | |
for member in group_members: | |
st.markdown("- " + member) |