Ashrafb commited on
Commit
f787272
1 Parent(s): 009e200

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -7
app.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  import fastai
2
  from fastai.vision import *
3
  from fastai.utils.mem import *
@@ -21,6 +27,8 @@ from PIL import Image
21
  from io import BytesIO
22
  import torchvision.transforms as T
23
 
 
 
24
  class FeatureLoss(nn.Module):
25
  def __init__(self, m_feat, layer_ids, layer_wgts):
26
  super().__init__()
@@ -94,10 +102,33 @@ def predict(img):
94
 
95
  return res, output_file
96
 
97
- gr.Interface(predict,
98
- inputs="image",
99
- outputs=[gr.Image(label="Sketch Image",show_share_button=False), gr.File(label="Result File")],
100
- title = "<span style='color: crimson;'>Aiconvert.online</span>",
101
- css="footer{display:none !important;}",
102
- theme=gr.themes.Base(),
103
- description="").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from fastapi.responses import FileResponse
3
+ from fastapi.responses import HTMLResponse, FileResponse
4
+ from fastapi.staticfiles import StaticFiles
5
+ from typing import Tuple
6
+ import cv2
7
  import fastai
8
  from fastai.vision import *
9
  from fastai.utils.mem import *
 
27
  from io import BytesIO
28
  import torchvision.transforms as T
29
 
30
+ app = FastAPI()
31
+
32
  class FeatureLoss(nn.Module):
33
  def __init__(self, m_feat, layer_ids, layer_wgts):
34
  super().__init__()
 
102
 
103
  return res, output_file
104
 
105
+ @app.post("/predict/")
106
+ async def predict(file: UploadFile = File(...)) -> Tuple[str, bytes]:
107
+ contents = await file.read()
108
+ img = cv2.imdecode(np.fromstring(contents, np.uint8), cv2.IMREAD_COLOR)
109
+ img = PIL.Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
110
+ img = add_margin(img, 250, 250, 250, 250, (255, 255, 255))
111
+ img = np.array(img)
112
+
113
+ h, w = img.shape[:-1]
114
+ cv2.imwrite("test.jpg", img)
115
+ img_test = open_image("test.jpg")
116
+
117
+ p,img_hr,b = learn.predict(img_test)
118
+
119
+ res = (img_hr / img_hr.max()).numpy()
120
+ res = res[0] # take only first channel as result
121
+ res = cv2.resize(res, (w,h))
122
+
123
+ output_file = get_filename()
124
+
125
+ cv2.imwrite(output_file, (res * 255).astype(np.uint8), [cv2.IMWRITE_JPEG_QUALITY, 50])
126
+
127
+ return output_file, res.tobytes()
128
+
129
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")
130
+
131
+ @app.get("/")
132
+ def index() -> FileResponse:
133
+ return FileResponse(path="/app/static/index.html", media_type="text/html")
134
+