Spaces:
Runtime error
Runtime error
File size: 2,090 Bytes
ef75ac1 b84be05 ef75ac1 b84be05 ef75ac1 b84be05 ef75ac1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import flask
import os
app = flask.Flask(__name__, template_folder="./templates/")
import numpy as np
import torch
from diffusers import LDMSuperResolutionPipeline
from diffusers.utils import PIL_INTERPOLATION, load_image, torch_device
from pkg.util import img_binary_data_to_pil, resizePilToMaxSide, pil_to_base64
if False:
torch_device = 'cpu'
print(f'Running inference on {torch_device}')
torch.backends.cuda.matmul.allow_tf32 = False
ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto")
ldm.to(torch_device)
if False:
print(f"{ldm.device=}")
print(f"{type(ldm.components)=}")
print(f"{ldm.components.keys()=}")
print(f"{ldm.components['vqvae'].device=}")
print(f"{ldm.components['unet'].device=}")
print(f"{ldm.components['scheduler'].config=}")
ldm.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0)
@app.route('/')
def index():
print('Route: /')
return flask.render_template('index.html')
@app.route('/superres', methods=['POST'])
def superres():
print('Route: /superres')
if flask.request.method != 'POST':
return flask.jsonify(
isError=True,
message=f"This route doesn't support {flask.request.method} method."
)
imgBinary = flask.request.data
maxSideLength = flask.request.args.get('maxSideLength', default=100, type=int)
numIterations = flask.request.args.get('numIterations', default=20, type=int)
img = img_binary_data_to_pil(imgBinary)
# img.show()
# arr = np.asarray(img)
img = resizePilToMaxSide(img, maxSideLength=maxSideLength)
# img.show()
result = ldm(image=img, generator=generator, num_inference_steps=numIterations, output_type="pil").images[0]
# result.show()
resultBinary = pil_to_base64(result)
return flask.jsonify(
isError=False,
message='Success',
statusCode=200,
data=resultBinary
)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 7860)))
|