abhirajeshbhai's picture
fixed title
031bff8 verified
raw
history blame
649 Bytes
import streamlit as st
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from model import model, image_transforms
def col_select(value):
print(value)
st.title("Banana Image Colorizer")
upload_file = st.file_uploader("Upload Image")
if upload_file:
image = upload_file
image = Image.open(image)
image_gs = image_transforms(image)
image_gs_prev = image_gs.permute(1, 2, 0).detach().cpu().numpy()
image_color = model(image_gs.unsqueeze(0)).squeeze().permute(1, 2, 0).detach().cpu().numpy()
col1, col2 = st.columns(2)
col1.image(image_gs_prev)
col2.image(image_color, clamp=True, channels='RGB')