Geek7 commited on
Commit
f47aac7
·
verified ·
1 Parent(s): 5689272

Update myapp.py

Browse files
Files changed (1) hide show
  1. myapp.py +43 -39
myapp.py CHANGED
@@ -1,21 +1,19 @@
1
  from flask import Flask, request, jsonify, send_file
2
- from flask_cors import CORS
3
- import os
4
  from all_models import models
5
  from externalmod import gr_Interface_load
6
  import asyncio
 
7
  from threading import RLock
 
8
 
9
- # Initialize Flask app and enable CORS
10
  app = Flask(__name__)
11
- CORS(app)
12
 
13
  lock = RLock()
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
15
 
16
- # Load models into a global dictionary
17
- models_load = {}
18
-
19
  def load_fn(models):
20
  global models_load
21
  models_load = {}
@@ -24,50 +22,56 @@ def load_fn(models):
24
  if model not in models_load.keys():
25
  try:
26
  m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
27
- models_load[model] = m
28
  except Exception as error:
29
- print(f"Error loading model {model}: {error}")
30
- models_load[model] = None # Handle model loading failures
 
31
 
32
  load_fn(models)
33
 
 
 
 
34
  inference_timeout = 600
35
 
 
36
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
37
  kwargs = {"seed": seed}
 
 
38
  try:
39
- task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
40
- await asyncio.sleep(0)
41
-
42
  result = await asyncio.wait_for(task, timeout=timeout)
43
- if task.done() and result is not None:
44
- with lock:
45
- png_path = "image.png"
46
- result.save(png_path)
47
- return png_path # Return the path of the saved image
48
- except Exception as e:
49
- print(f"Inference error for model {model_str}: {e}") # Log the error message
50
- return None
 
 
 
 
51
 
52
- @app.route('/generate', methods=['POST'])
53
- def generate():
54
- data = request.json
55
- model_str = data.get('model')
56
- prompt = data.get('prompt')
 
57
  seed = data.get('seed', 1)
58
 
59
- print(f"Received request for model: '{model_str}', prompt: '{prompt}', seed: {seed}")
60
-
61
- if model_str not in models_load or models_load[model_str] is None:
62
- print(f"Model not found in models_load: {model_str}. Available models: {models_load.keys()}")
63
- return jsonify({"error": "Model not found or not loaded"}), 404
64
-
65
- image_path = asyncio.run(infer(model_str, prompt, seed, inference_timeout))
66
- if image_path is not None:
67
- return send_file(image_path, mimetype='image/png')
68
 
69
- print("Image generation failed for:", model_str) # Log failure reason
70
- return jsonify({"error": "Image generation failed"}), 500
 
 
 
 
71
 
72
- if __name__ == '__main__':
73
- app.run(debug=True)
 
 
1
  from flask import Flask, request, jsonify, send_file
2
+ import gradio as gr
3
+ from random import randint
4
  from all_models import models
5
  from externalmod import gr_Interface_load
6
  import asyncio
7
+ import os
8
  from threading import RLock
9
+ from PIL import Image
10
 
 
11
  app = Flask(__name__)
 
12
 
13
  lock = RLock()
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
15
 
16
+ # Load models
 
 
17
  def load_fn(models):
18
  global models_load
19
  models_load = {}
 
22
  if model not in models_load.keys():
23
  try:
24
  m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
 
25
  except Exception as error:
26
+ print(error)
27
+ m = gr.Interface(lambda: None, ['text'], ['image'])
28
+ models_load.update({model: m})
29
 
30
  load_fn(models)
31
 
32
+ num_models = 6
33
+ MAX_SEED = 3999999999
34
+ default_models = models[:num_models]
35
  inference_timeout = 600
36
 
37
+ # Gradio inference function
38
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
39
  kwargs = {"seed": seed}
40
+ task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
41
+ await asyncio.sleep(0)
42
  try:
 
 
 
43
  result = await asyncio.wait_for(task, timeout=timeout)
44
+ except (Exception, asyncio.TimeoutError) as e:
45
+ print(e)
46
+ print(f"Task timed out: {model_str}")
47
+ if not task.done():
48
+ task.cancel()
49
+ result = None
50
+ if task.done() and result is not None:
51
+ with lock:
52
+ png_path = "generated_image.png"
53
+ result.save(png_path) # Save the result as an image
54
+ return png_path
55
+ return None
56
 
57
+ # API function to perform inference
58
+ @app.route('/generate-image', methods=['POST'])
59
+ def generate_image():
60
+ data = request.get_json()
61
+ model_str = data['model_str']
62
+ prompt = data['prompt']
63
  seed = data.get('seed', 1)
64
 
65
+ # Run Gradio inference
66
+ result_path = asyncio.run(infer(model_str, prompt, seed))
 
 
 
 
 
 
 
67
 
68
+ if result_path:
69
+ # Send back the generated image file
70
+ return send_file(result_path, mimetype='image/png')
71
+ else:
72
+ return jsonify({"error": "Failed to generate image."}), 500
73
+
74
 
75
+ # Add this block to make sure your app runs when called
76
+ if __name__ == "__main__":
77
+ app.run(host='0.0.0.0', port=7860) # Run directly