Spaces:
Runtime error
Runtime error
File size: 1,788 Bytes
0691d6d a768299 0691d6d 5ae48fc f262293 0691d6d f262293 0691d6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
import torch.optim
import model
import numpy as np
from PIL import Image
import streamlit as st
from torchvision import transforms
scale_factor = 1
@st.cache
def load_model() -> torch.nn.Module:
DCE_net = model.enhance_net_nopool(scale_factor)
DCE_net.load_state_dict(torch.load("lowlight-dce-snapshot.pth", map_location=torch.device('cpu')))
return DCE_net
def fix_lowlight(image: Image.Image) -> Image.Image:
DCE_net = load_model()
data_lowlight = np.asarray(image) / 255.0
data_lowlight = torch.from_numpy(data_lowlight).float()
h = (data_lowlight.shape[0] // scale_factor) * scale_factor
w = (data_lowlight.shape[1] // scale_factor) * scale_factor
data_lowlight = data_lowlight[0:h, 0:w, :]
data_lowlight = data_lowlight.permute(2, 0, 1)
data_lowlight = data_lowlight.unsqueeze(0)
enhanced_image, _ = DCE_net(data_lowlight)
im = transforms.ToPILImage()(enhanced_image[0]).convert("RGB")
return im
def main():
st.title("Lowlight Enhancement")
st.write("This is a simple lowlight enhancement app with great performance and does not require paired images to train.")
st.write("The model runs at 1000/11 FPS on single GPU/CPU on images with a size of 1200*900*3")
uploaded_file = st.file_uploader("Lowlight Image")
if uploaded_file:
data_lowlight = Image.open(uploaded_file).convert('RGB')
col1, col2 = st.columns(2)
col1.write("Original (Lowlight)")
col1.image(data_lowlight, caption="Lowlight Image", use_column_width=True)
col2.write("Enhanced")
with st.spinner('🧠 Enhancing...'):
fixed_img = fix_lowlight(data_lowlight)
col2.image(fixed_img, caption="Enhanced Image", use_column_width=True)
main() |