Spaces:
Runtime error
Runtime error
File size: 4,365 Bytes
f787272 c3784e5 cace7f9 c3784e5 f787272 c3784e5 f787272 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import FileResponse
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from typing import Tuple
import cv2
import fastai
from fastai.vision import *
from fastai.utils.mem import *
from fastai.vision import open_image, load_learner, image, torch
import numpy as np
import urllib.request
import PIL.Image
from io import BytesIO
import torchvision.transforms as T
from PIL import Image
import requests
from io import BytesIO
import fastai
from fastai.vision import *
from fastai.utils.mem import *
from fastai.vision import open_image, load_learner, image, torch
import numpy as np
import urllib.request
import PIL.Image
from PIL import Image
from io import BytesIO
import torchvision.transforms as T
app = FastAPI()
class FeatureLoss(nn.Module):
def __init__(self, m_feat, layer_ids, layer_wgts):
super().__init__()
self.m_feat = m_feat
self.loss_features = [self.m_feat[i] for i in layer_ids]
self.hooks = hook_outputs(self.loss_features, detach=False)
self.wgts = layer_wgts
self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
] + [f'gram_{i}' for i in range(len(layer_ids))]
def make_features(self, x, clone=False):
self.m_feat(x)
return [(o.clone() if clone else o) for o in self.hooks.stored]
def forward(self, input, target):
out_feat = self.make_features(target, clone=True)
in_feat = self.make_features(input)
self.feat_losses = [base_loss(input,target)]
self.feat_losses += [base_loss(f_in, f_out)*w
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
self.metrics = dict(zip(self.metric_names, self.feat_losses))
return sum(self.feat_losses)
def __del__(self): self.hooks.remove()
def add_margin(pil_img, top, right, bottom, left, color):
width, height = pil_img.size
new_width = width + right + left
new_height = height + top + bottom
result = Image.new(pil_img.mode, (new_width, new_height), color)
result.paste(pil_img, (left, top))
return result
MODEL_URL = "https://www.dropbox.com/s/04suaimdpru76h3/ArtLine_920.pkl?dl=1 "
urllib.request.urlretrieve(MODEL_URL, "ArtLine_920.pkl")
path = Path(".")
print(os.listdir('.'))
learn=load_learner(path, 'ArtLine_920.pkl')
import gradio as gr
import cv2
def get_filename(prefix="sketch"):
from datetime import datetime
from pytz import timezone
return datetime.now(timezone('Asia/Seoul')).strftime('sketch__%Y-%m-%d %H:%M:%S.jpg')
def predict(img):
img = PIL.Image.fromarray(img)
img = add_margin(img, 250, 250, 250, 250, (255, 255, 255))
img = np.array(img)
h, w = img.shape[:-1]
cv2.imwrite("test.jpg", img)
img_test = open_image("test.jpg")
p,img_hr,b = learn.predict(img_test)
res = (img_hr / img_hr.max()).numpy()
res = res[0] # take only first channel as result
res = cv2.resize(res, (w,h))
output_file = get_filename()
cv2.imwrite(output_file, (res * 255).astype(np.uint8), [cv2.IMWRITE_JPEG_QUALITY, 50])
return res, output_file
@app.post("/predict/")
async def predict(file: UploadFile = File(...)) -> Tuple[str, bytes]:
contents = await file.read()
img = cv2.imdecode(np.fromstring(contents, np.uint8), cv2.IMREAD_COLOR)
img = PIL.Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
img = add_margin(img, 250, 250, 250, 250, (255, 255, 255))
img = np.array(img)
h, w = img.shape[:-1]
cv2.imwrite("test.jpg", img)
img_test = open_image("test.jpg")
p,img_hr,b = learn.predict(img_test)
res = (img_hr / img_hr.max()).numpy()
res = res[0] # take only first channel as result
res = cv2.resize(res, (w,h))
output_file = get_filename()
cv2.imwrite(output_file, (res * 255).astype(np.uint8), [cv2.IMWRITE_JPEG_QUALITY, 50])
return output_file, res.tobytes()
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")
|