File size: 1,788 Bytes
0691d6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a768299
0691d6d
 
5ae48fc
f262293
 
 
0691d6d
f262293
 
 
 
0691d6d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.optim
import model
import numpy as np
from PIL import Image
import streamlit as st
from torchvision import transforms

scale_factor = 1

@st.cache
def load_model() -> torch.nn.Module:
    DCE_net = model.enhance_net_nopool(scale_factor)
    DCE_net.load_state_dict(torch.load("lowlight-dce-snapshot.pth", map_location=torch.device('cpu')))
    
    return DCE_net

def fix_lowlight(image: Image.Image) -> Image.Image:
    DCE_net = load_model()
    data_lowlight = np.asarray(image) / 255.0

    data_lowlight = torch.from_numpy(data_lowlight).float()

    h = (data_lowlight.shape[0] // scale_factor) * scale_factor
    w = (data_lowlight.shape[1] // scale_factor) * scale_factor
    data_lowlight = data_lowlight[0:h, 0:w, :]
    data_lowlight = data_lowlight.permute(2, 0, 1)
    data_lowlight = data_lowlight.unsqueeze(0)
    
    enhanced_image, _ = DCE_net(data_lowlight)
    im = transforms.ToPILImage()(enhanced_image[0]).convert("RGB")

    return im

def main():
    st.title("Lowlight Enhancement")
    st.write("This is a simple lowlight enhancement app with great performance and does not require paired images to train.")
    st.write("The model runs at 1000/11 FPS on single GPU/CPU on images with a size of 1200*900*3")
    uploaded_file = st.file_uploader("Lowlight Image")
    if uploaded_file:
        data_lowlight = Image.open(uploaded_file).convert('RGB')
        col1, col2 = st.columns(2)
        col1.write("Original (Lowlight)")
        col1.image(data_lowlight, caption="Lowlight Image", use_column_width=True)
        
        col2.write("Enhanced")
        with st.spinner('🧠 Enhancing...'):
            fixed_img = fix_lowlight(data_lowlight)
        col2.image(fixed_img, caption="Enhanced Image", use_column_width=True)

main()