Geek7 commited on
Commit
f06461c
·
verified ·
1 Parent(s): 71bdf34

Create myapp.py

Browse files
Files changed (1) hide show
  1. myapp.py +69 -0
myapp.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, send_file
2
+ from flask_cors import CORS
3
+ import os
4
+ from random import randint
5
+ from all_models import models
6
+ from externalmod import gr_Interface_load
7
+ import asyncio
8
+ from threading import RLock
9
+
10
+ # Initialize Flask app and enable CORS
11
+ app = Flask(__name__)
12
+ CORS(app)
13
+
14
+ lock = RLock()
15
+ HF_TOKEN = os.environ.get("HF_TOKEN")
16
+
17
+ # Load models into a global dictionary
18
+ models_load = {}
19
+
20
+ def load_fn(models):
21
+ global models_load
22
+ models_load = {}
23
+
24
+ for model in models:
25
+ if model not in models_load.keys():
26
+ try:
27
+ m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
28
+ models_load[model] = m
29
+ except Exception as error:
30
+ print(error)
31
+ models_load[model] = None # Handle model loading failures
32
+
33
+ load_fn(models)
34
+
35
+ inference_timeout = 600
36
+
37
+ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
38
+ kwargs = {"seed": seed}
39
+ task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
40
+ await asyncio.sleep(0)
41
+
42
+ try:
43
+ result = await asyncio.wait_for(task, timeout=timeout)
44
+ if task.done() and result is not None:
45
+ with lock:
46
+ png_path = "image.png"
47
+ result.save(png_path)
48
+ return png_path # Return the path of the saved image
49
+ except Exception as e:
50
+ print(e)
51
+ return None
52
+
53
+ @app.route('/generate', methods=['POST'])
54
+ def generate():
55
+ data = request.json
56
+ model_str = data.get('model')
57
+ prompt = data.get('prompt')
58
+ seed = data.get('seed', 1)
59
+
60
+ if model_str not in models_load or models_load[model_str] is None:
61
+ return jsonify({"error": "Model not found or not loaded"}), 404
62
+
63
+ image_path = asyncio.run(infer(model_str, prompt, seed, inference_timeout))
64
+ if image_path is not None:
65
+ return send_file(image_path, mimetype='image/png') # Send the generated image file
66
+ return jsonify({"error": "Image generation failed"}), 500
67
+
68
+ if __name__ == '__main__':
69
+ app.run(debug=True)