Tattoo / app.py
Vijish's picture
Create new file
faaa6c1
raw
history blame
2.69 kB
import gradio as gr
from PIL import Image
import requests
import numpy as np
import urllib.request
from urllib.request import urlretrieve
import PIL.Image
import torchvision.transforms as T
import fastai
from fastai.vision import *
from fastai.utils.mem import *
class FeatureLoss(nn.Module):
def __init__(self, m_feat, layer_ids, layer_wgts):
super().__init__()
self.m_feat = m_feat
self.loss_features = [self.m_feat[i] for i in layer_ids]
self.hooks = hook_outputs(self.loss_features, detach=False)
self.wgts = layer_wgts
self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
] + [f'gram_{i}' for i in range(len(layer_ids))]
def make_features(self, x, clone=False):
self.m_feat(x)
return [(o.clone() if clone else o) for o in self.hooks.stored]
def forward(self, input, target):
out_feat = self.make_features(target, clone=True)
in_feat = self.make_features(input)
self.feat_losses = [base_loss(input,target)]
self.feat_losses += [base_loss(f_in, f_out)*w
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
self.metrics = dict(zip(self.metric_names, self.feat_losses))
return sum(self.feat_losses)
def __del__(self): self.hooks.remove()
MODEL_URL = "https://www.dropbox.com/s/rz9nt35um1agf5y/t10T.pkl?dl=1"
urllib.request.urlretrieve(MODEL_URL, "t10T.pkl")
path = Path(".")
learn=load_learner(path, 't10T.pkl')
urlretrieve("https://s.hdnux.com/photos/01/07/33/71/18726490/5/1200x0.jpg","soccer1.jpg")
urlretrieve("https://media.okmagazine.com/brand-img/IEPXUdkY7/0x0/2015/06/celebrity-tattoos-16-splash.jpg","soccer2.jpg")
urlretrieve("https://newsmeter.in/wp-content/uploads/2020/06/Ajay-Devgn-Tattoo.jpg","baseball.jpg")
urlretrieve("https://www.allkpop.com/upload/2022/08/content/071400/1659895247-tattoozico.jpg","baseball2.jpeg")
sample_images = [["soccer1.jpg"],
["soccer2.jpg"],
["baseball.jpg"],
["baseball2.jpeg"]]
def predict(input):
size = input.size
img_t = T.ToTensor()(input)
img_fast = Image(img_t)
p,img_hr,b = learn.predict(img_fast)
x = np.minimum(np.maximum(image2np(img_hr.data*255), 0), 255).astype(np.uint8)
img = PIL.Image.fromarray(x)
im1 = img.resize(size)
return im1
gr_interface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="image", title='Skin-Deep',examples=sample_images).launch();