File size: 3,468 Bytes
8b04a03
10cb35a
2aa7829
7899b33
 
 
78e99f5
 
 
8b04a03
78e99f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10cb35a
7899b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78e99f5
 
 
 
7899b33
10cb35a
78e99f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2aa7829
78e99f5
1c6b417
78e99f5
 
 
 
 
 
1c6b417
78e99f5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
from PIL import Image
from transformers import pipeline
import numpy as np
import cv2
import matplotlib.cm as cm
import time
import base64
from io import BytesIO

st.set_page_config(layout="wide")

with open("styles.css") as f:
    st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True)


st.markdown("<h1 class='title'>Segformer Semantic Segmentation</h1>", unsafe_allow_html=True)
st.markdown("""
<div class='text-center'>
This app uses the Segformer deep learning model to perform semantic segmentation on road images. The Transformer-based model is 
trained on the CityScapes dataset which contains images of urban road scenes. Upload a 
road scene and the app will return the image with semantic segmentation applied.
</div>
""", unsafe_allow_html=True)

group_members = ["Ang Ngo Ching, Josh Darren W.", "Bautista, Ryan Matthew M.", "Lacuesta, Angelo Giuseppe M.", "Reyes, Kenwin Hans", "Ting, Sidney Mitchell O."]


# model_versions = ["b1", "b2", "b3", "b4", "b5"]
# selected_model_version = st.selectbox("Select a model version:", model_versions)


semantic_segmentation = pipeline("image-segmentation", f"nvidia/segformer-b1-finetuned-cityscapes-1024-1024")

new_file_uploaded = False
uploaded_file = st.file_uploader("", type=["jpg", "png"])


def draw_masks_fromDict(image, results):
    masked_image = image.copy()
    
    colormap = cm.get_cmap('nipy_spectral')
    
    for i, result in enumerate(results):
        mask = np.array(result['mask'])  
        mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2)  
        
        color = colormap(i / len(results))[:3] 
        color = tuple(int(c * 255) for c in color)  

        masked_image = np.where(mask, color, masked_image)

    masked_image = masked_image.astype(np.uint8)
    return cv2.addWeighted(image, 0.3, masked_image, 0.7, 0)

col1, col2 = st.columns(2)

if "uploaded_file" not in st.session_state:
    st.session_state.uploaded_file = None

if uploaded_file is not None:
    st.session_state.uploaded_file = uploaded_file

if st.session_state.uploaded_file is not None:
    image = Image.open(st.session_state.uploaded_file)
    col1, col2 = st.columns(2)

    with col1:
        st.image(image, caption='Uploaded Image.', use_column_width=True)


    while True:
        with st.spinner('Processing...'):  
            segmentation_results = semantic_segmentation(image)
            image_with_masks = draw_masks_fromDict(np.array(image), segmentation_results)
            image_with_masks_pil = Image.fromarray(image_with_masks, 'RGB')

        with col2:
            st.image(image_with_masks_pil, caption='Segmented Image.', use_column_width=True)

        buffered = BytesIO()
        image_with_masks_pil.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode()
        href = f'<a href="data:file/png;base64,{img_str}" download="segmented_{st.session_state.uploaded_file.name}">Download Segmented Image</a>'
        st.markdown(href, unsafe_allow_html=True)

        new_file_uploaded = False 

        while not new_file_uploaded: 
            time.sleep(1) 



pdf_url = "https://arxiv.org/pdf/2105.15203.pdf"

st.markdown("""
<h3 class='text-center'>
Read more about the paper below👇
</h5>
""", unsafe_allow_html=True)
st.markdown(f'<iframe class="pdf" src={pdf_url}></iframe>', unsafe_allow_html=True)

st.markdown("Group Members:")
for member in group_members:
    st.markdown("- " + member)