File size: 4,508 Bytes
f787272
 
 
 
 
 
c3784e5
 
cace7f9
c3784e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05761be
e72c238
05761be
e72c238
05761be
e72c238
 
 
 
 
05761be
e72c238
 
c3784e5
f787272
c3784e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e19492
c3784e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f787272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import FileResponse
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from typing import Tuple
import cv2
import fastai
from fastai.vision import *
from fastai.utils.mem import *
from fastai.vision import open_image, load_learner, image, torch
import numpy as np
import urllib.request
import PIL.Image
from io import BytesIO
import torchvision.transforms as T
from PIL import Image
import requests
from io import BytesIO
import fastai
from fastai.vision import *
from fastai.utils.mem import *
from fastai.vision import open_image, load_learner, image, torch
import numpy as np
import urllib.request
import PIL.Image
from PIL import Image
from io import BytesIO
import torchvision.transforms as T
import requests
import model_loader

app = FastAPI()

# Download the model if not already downloaded
MODEL_URL = "https://www.dropbox.com/s/04suaimdpru76h3/ArtLine_920.pkl?dl=1"
MODEL_FILENAME = "ArtLine_920.pkl"
if not os.path.exists(MODEL_FILENAME):
    model_loader.download_model(MODEL_URL, MODEL_FILENAME)

# Load the model
learn = model_loader.load_model(MODEL_FILENAME)


class FeatureLoss(nn.Module):
    def __init__(self, m_feat, layer_ids, layer_wgts):
        super().__init__()
        self.m_feat = m_feat
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))
              ] + [f'gram_{i}' for i in range(len(layer_ids))]

    def make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]
    
    def forward(self, input, target):
        out_feat = self.make_features(target, clone=True)
        in_feat = self.make_features(input)
        self.feat_losses = [base_loss(input,target)]
        self.feat_losses += [base_loss(f_in, f_out)*w
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.feat_losses += [base_loss(gram_matrix(f_in), gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)
    
    def __del__(self): self.hooks.remove()

def add_margin(pil_img, top, right, bottom, left, color):
    width, height = pil_img.size
    new_width = width + right + left
    new_height = height + top + bottom
    result = Image.new(pil_img.mode, (new_width, new_height), color)
    result.paste(pil_img, (left, top))
    return result





import gradio as gr
import cv2


def get_filename(prefix="sketch"):
  from datetime import datetime
  from pytz import timezone
  return datetime.now(timezone('Asia/Seoul')).strftime('sketch__%Y-%m-%d %H:%M:%S.jpg')

def predict(img):
  img = PIL.Image.fromarray(img)
  img = add_margin(img, 250, 250, 250, 250, (255, 255, 255))
  img = np.array(img)

  h, w = img.shape[:-1]
  cv2.imwrite("test.jpg", img)
  img_test = open_image("test.jpg")

  p,img_hr,b = learn.predict(img_test)

  res = (img_hr / img_hr.max()).numpy()
  res = res[0]  # take only first channel as result
  res = cv2.resize(res, (w,h))

  output_file = get_filename()

  cv2.imwrite(output_file, (res * 255).astype(np.uint8), [cv2.IMWRITE_JPEG_QUALITY, 50])

  return res, output_file

@app.post("/predict/")
async def predict(file: UploadFile = File(...)) -> Tuple[str, bytes]:
    contents = await file.read()
    img = cv2.imdecode(np.fromstring(contents, np.uint8), cv2.IMREAD_COLOR)
    img = PIL.Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    img = add_margin(img, 250, 250, 250, 250, (255, 255, 255))
    img = np.array(img)

    h, w = img.shape[:-1]
    cv2.imwrite("test.jpg", img)
    img_test = open_image("test.jpg")

    p,img_hr,b = learn.predict(img_test)

    res = (img_hr / img_hr.max()).numpy()
    res = res[0]  # take only first channel as result
    res = cv2.resize(res, (w,h))

    output_file = get_filename()

    cv2.imwrite(output_file, (res * 255).astype(np.uint8), [cv2.IMWRITE_JPEG_QUALITY, 50])

    return output_file, res.tobytes()

app.mount("/", StaticFiles(directory="static", html=True), name="static")

@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/static/index.html", media_type="text/html")