File size: 2,202 Bytes
3c7030b
e089372
 
 
 
ac39d6f
fbed9e3
e089372
3c7030b
5bc3f97
3c7030b
7151510
ac39d6f
 
 
3c7030b
ac39d6f
 
 
 
 
 
 
 
 
3c7030b
ac39d6f
 
 
 
fbed9e3
 
e089372
 
fbed9e3
3c7030b
5bc3f97
 
 
 
 
 
 
 
 
3c7030b
5bc3f97
3c7030b
5bc3f97
 
 
3c7030b
 
 
5bc3f97
3c7030b
 
 
 
 
 
 
5bc3f97
3c7030b
 
 
 
 
 
 
 
 
e089372
3c7030b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from flask import Flask, request, jsonify, send_file
import gradio as gr
from random import randint
from all_models import models
from externalmod import gr_Interface_load
import asyncio
import os
from threading import RLock
import io

app = Flask(__name__)

lock = RLock()
HF_TOKEN = os.environ.get("HF_TOKEN")

# Load the models
def load_fn(models):
    global models_load
    models_load = {}
    for model in models:
        if model not in models_load.keys():
            try:
                m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
            except Exception as error:
                print(error)
                m = gr.Interface(lambda: None, ['text'], ['image'])  # Fallback
            models_load.update({model: m})

load_fn(models)

num_models = 6
MAX_SEED = 3999999999
default_models = models[:num_models]
inference_timeout = 600

# Async inference function
async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
    kwargs = {"seed": seed}
    task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
    await asyncio.sleep(0)
    try:
        result = await asyncio.wait_for(task, timeout=timeout)
    except (Exception, asyncio.TimeoutError) as e:
        print(e)
        print(f"Task timed out: {model_str}")
        if not task.done():
            task.cancel()
        result = None
    if task.done() and result is not None:
        with lock:
            png_path = "image.png"
            result.save(png_path)
        return png_path
    return None

# Flask API endpoint
@app.route('/async_infer', methods=['POST'])
def async_infer():
    data = request.get_json()
    model_str = data.get('model_str')
    prompt = data.get('prompt')
    seed = data.get('seed', 1)

    # Run the inference
    try:
        image_path = asyncio.run(infer(model_str, prompt, seed))
        if image_path:
            return send_file(image_path, mimetype='image/png')
        else:
            return jsonify({"error": "Image generation failed."}), 500
    except Exception as e:
        return jsonify({"error": str(e)}), 500

# Run the Flask app
if __name__ == '__main__':
    app.run(debug=True)