File size: 2,202 Bytes
3c7030b e089372 ac39d6f fbed9e3 e089372 3c7030b 5bc3f97 3c7030b 7151510 ac39d6f 3c7030b ac39d6f 3c7030b ac39d6f fbed9e3 e089372 fbed9e3 3c7030b 5bc3f97 3c7030b 5bc3f97 3c7030b 5bc3f97 3c7030b 5bc3f97 3c7030b 5bc3f97 3c7030b e089372 3c7030b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from flask import Flask, request, jsonify, send_file
import gradio as gr
from random import randint
from all_models import models
from externalmod import gr_Interface_load
import asyncio
import os
from threading import RLock
import io
app = Flask(__name__)
lock = RLock()
HF_TOKEN = os.environ.get("HF_TOKEN")
# Load the models
def load_fn(models):
global models_load
models_load = {}
for model in models:
if model not in models_load.keys():
try:
m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
except Exception as error:
print(error)
m = gr.Interface(lambda: None, ['text'], ['image']) # Fallback
models_load.update({model: m})
load_fn(models)
num_models = 6
MAX_SEED = 3999999999
default_models = models[:num_models]
inference_timeout = 600
# Async inference function
async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
kwargs = {"seed": seed}
task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
await asyncio.sleep(0)
try:
result = await asyncio.wait_for(task, timeout=timeout)
except (Exception, asyncio.TimeoutError) as e:
print(e)
print(f"Task timed out: {model_str}")
if not task.done():
task.cancel()
result = None
if task.done() and result is not None:
with lock:
png_path = "image.png"
result.save(png_path)
return png_path
return None
# Flask API endpoint
@app.route('/async_infer', methods=['POST'])
def async_infer():
data = request.get_json()
model_str = data.get('model_str')
prompt = data.get('prompt')
seed = data.get('seed', 1)
# Run the inference
try:
image_path = asyncio.run(infer(model_str, prompt, seed))
if image_path:
return send_file(image_path, mimetype='image/png')
else:
return jsonify({"error": "Image generation failed."}), 500
except Exception as e:
return jsonify({"error": str(e)}), 500
# Run the Flask app
if __name__ == '__main__':
app.run(debug=True) |