Geek7 commited on
Commit
e8896f0
·
verified ·
1 Parent(s): 0968c17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -6
app.py CHANGED
@@ -1,9 +1,73 @@
 
 
1
  import os
2
- import subprocess
 
 
 
3
 
4
- if __name__ == "__main__":
5
- # Run awake.py in the background
6
- subprocess.Popen(["python", "wk.py"]) # Start awake.py
7
 
8
- # Run the Flask app using Gunicorn
9
- os.system("gunicorn -w 4 -b 0.0.0.0:7860 myapp:myapp") # 4 worker processes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, send_file
2
+ from flask_cors import CORS
3
  import os
4
+ from all_models import models
5
+ from externalmod import gr_Interface_load
6
+ import asyncio
7
+ from threading import RLock
8
 
9
+ # Initialize Flask app and enable CORS
10
+ app = Flask(__name__)
11
+ CORS(app)
12
 
13
+ lock = RLock()
14
+ HF_TOKEN = os.environ.get("HF_TOKEN")
15
+
16
+ # Load models into a global dictionary
17
+ models_load = {}
18
+
19
+ def load_fn(models):
20
+ global models_load
21
+ models_load = {}
22
+
23
+ for model in models:
24
+ if model not in models_load.keys():
25
+ try:
26
+ m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
27
+ models_load[model] = m
28
+ except Exception as error:
29
+ print(f"Error loading model {model}: {error}")
30
+ models_load[model] = None # Handle model loading failures
31
+
32
+ load_fn(models)
33
+
34
+ inference_timeout = 600
35
+
36
+ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
37
+ kwargs = {"seed": seed}
38
+ try:
39
+ task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
40
+ await asyncio.sleep(0)
41
+
42
+ result = await asyncio.wait_for(task, timeout=timeout)
43
+ if task.done() and result is not None:
44
+ with lock:
45
+ png_path = "image.png"
46
+ result.save(png_path)
47
+ return png_path # Return the path of the saved image
48
+ except Exception as e:
49
+ print(f"Inference error for model {model_str}: {e}") # Log the error message
50
+ return None
51
+
52
+ @app.route('/generate', methods=['POST'])
53
+ def generate():
54
+ data = request.json
55
+ model_str = data.get('model')
56
+ prompt = data.get('prompt')
57
+ seed = data.get('seed', 1)
58
+
59
+ print(f"Received request for model: '{model_str}', prompt: '{prompt}', seed: {seed}")
60
+
61
+ if model_str not in models_load or models_load[model_str] is None:
62
+ print(f"Model not found in models_load: {model_str}. Available models: {models_load.keys()}")
63
+ return jsonify({"error": "Model not found or not loaded"}), 404
64
+
65
+ image_path = asyncio.run(infer(model_str, prompt, seed, inference_timeout))
66
+ if image_path is not None:
67
+ return send_file(image_path, mimetype='image/png')
68
+
69
+ print("Image generation failed for:", model_str) # Log failure reason
70
+ return jsonify({"error": "Image generation failed"}), 500
71
+
72
+ if __name__ == '__main__':
73
+ app.run(debug=True)