File size: 2,090 Bytes
ef75ac1
 
 
 
 
 
 
 
 
 
 
 
 
b84be05
 
ef75ac1
b84be05
ef75ac1
 
 
b84be05
 
 
 
 
 
 
 
 
ef75ac1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import flask
import os

app = flask.Flask(__name__, template_folder="./templates/")

import numpy as np
import torch

from diffusers import LDMSuperResolutionPipeline
from diffusers.utils import PIL_INTERPOLATION, load_image, torch_device

from pkg.util import img_binary_data_to_pil, resizePilToMaxSide, pil_to_base64

if False:
    torch_device = 'cpu'

print(f'Running inference on {torch_device}')
torch.backends.cuda.matmul.allow_tf32 = False
ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto")
ldm.to(torch_device)

if False:
    print(f"{ldm.device=}")
    print(f"{type(ldm.components)=}")
    print(f"{ldm.components.keys()=}")
    print(f"{ldm.components['vqvae'].device=}")
    print(f"{ldm.components['unet'].device=}")
    print(f"{ldm.components['scheduler'].config=}")

ldm.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0)

@app.route('/')
def index():
    print('Route: /')
    return flask.render_template('index.html')

@app.route('/superres', methods=['POST'])
def superres():
    print('Route: /superres')
    if flask.request.method != 'POST':
        return flask.jsonify(
            isError=True,
            message=f"This route doesn't support {flask.request.method} method."
        )
    imgBinary = flask.request.data
    maxSideLength = flask.request.args.get('maxSideLength', default=100, type=int)
    numIterations = flask.request.args.get('numIterations', default=20, type=int)


    img = img_binary_data_to_pil(imgBinary)
    # img.show()
    # arr = np.asarray(img)

    img = resizePilToMaxSide(img, maxSideLength=maxSideLength)
    # img.show()

    result = ldm(image=img, generator=generator, num_inference_steps=numIterations, output_type="pil").images[0]
    # result.show()
    resultBinary = pil_to_base64(result)

    return flask.jsonify(
        isError=False,
        message='Success',
        statusCode=200,
        data=resultBinary
    )

if __name__ == '__main__':
    app.run(host='0.0.0.0',  port=int(os.environ.get('PORT', 7860)))