Spaces:
Runtime error
Runtime error
File size: 2,212 Bytes
035d31c fe70fd4 ca17fe1 fe70fd4 035d31c fe70fd4 035d31c fe70fd4 035d31c fe70fd4 035d31c fe70fd4 035d31c fe70fd4 035d31c fe70fd4 035d31c fe70fd4 ca17fe1 fe70fd4 035d31c fe70fd4 035d31c ca17fe1 035d31c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import imp
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch
import torch.nn as nn
import cv2
import numpy as np
import torch
import torch.nn as nnst
import torchvision.transforms.functional as TF
from torchvision import transforms
from model import DoubleConv,UNET
convert_tensor = transforms.ToTensor()
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNET(in_channels=3, out_channels=1).to(device)
model=torch.load("Unet_acc_94.pth",map_location=torch.device('cpu'))
# model=torch.load("src//Unet_acc_94.pth",map_location=device)
def predict(img):
img=cv2.resize(img,(240,160))
test_img=convert_tensor(img).unsqueeze(0)
# print(test_img.shape)
preds=model(test_img.float())
preds=torch.sigmoid(preds)
preds=(preds > 0.5).float()
# print(preds.shape)
im=preds.squeeze(0).permute(1,2,0).detach()
# print(im.shape)
im=im.numpy()
return im
def blurr_image(input_image,preds):
mask=preds
inp=input_image
mask=np.resize(mask,(160,240))
mask=(mask>0.1)*255
mask=np.full((160,240),[mask],np.uint8)
mapping = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
image=cv2.resize(inp,(240,160))
blurred_original_image = cv2.GaussianBlur(image,(25,25),0)
blurred_img = np.where(mapping != (0,0,0),image,blurred_original_image)
blurred_img=cv2.cvtColor(blurred_img,cv2.COLOR_BGR2RGB)
inp=cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
return inp,blurred_img
import streamlit as st
st.title("AI Portrait Mode")
st.markdown("Creator: [Pranav Kushare] (https://github.com/Pranav082001)")
# st.markdown(
# "Source code: [GitHub Repository](git link)")
# )
file=st.file_uploader("Please upload the image",type=["jpg","jpeg","png"])
check=st.checkbox("Dsiplay Mask", value=False)
print(file)
if file is None:
st.text("Please Upload an image")
else:
file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
opencv_image = cv2.imdecode(file_bytes, 1)
pred=predict(opencv_image)
inp_img,blurred=blurr_image(opencv_image,pred)
st.text("Original")
st.image(inp_img)
if check:
st.text("Mask!!")
st.image(pred)
st.text("Blurred")
st.image(blurred) |