Geek7 commited on
Commit
3c6792a
·
verified ·
1 Parent(s): 301808b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -47
app.py CHANGED
@@ -1,21 +1,14 @@
1
- from flask import Flask, request, jsonify, send_file
2
- from flask_cors import CORS
3
- import os
4
  from all_models import models
5
  from externalmod import gr_Interface_load
6
  import asyncio
 
7
  from threading import RLock
8
 
9
- # Initialize Flask app and enable CORS
10
- app = Flask(__name__)
11
- CORS(app)
12
-
13
  lock = RLock()
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
15
 
16
- # Load models into a global dictionary
17
- models_load = {}
18
-
19
  def load_fn(models):
20
  global models_load
21
  models_load = {}
@@ -24,50 +17,44 @@ def load_fn(models):
24
  if model not in models_load.keys():
25
  try:
26
  m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
27
- models_load[model] = m
28
  except Exception as error:
29
- print(f"Error loading model {model}: {error}")
30
- models_load[model] = None # Handle model loading failures
 
31
 
32
  load_fn(models)
33
 
 
 
 
34
  inference_timeout = 600
35
 
36
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
37
  kwargs = {"seed": seed}
 
 
38
  try:
39
- task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
40
- await asyncio.sleep(0)
41
-
42
  result = await asyncio.wait_for(task, timeout=timeout)
43
- if task.done() and result is not None:
44
- with lock:
45
- png_path = "image.png"
46
- result.save(png_path)
47
- return png_path # Return the path of the saved image
48
- except Exception as e:
49
- print(f"Inference error for model {model_str}: {e}") # Log the error message
50
- return None
51
-
52
- @app.route('/generate', methods=['POST'])
53
- def generate():
54
- data = request.json
55
- model_str = data.get('model')
56
- prompt = data.get('prompt')
57
- seed = data.get('seed', 1)
58
-
59
- print(f"Received request for model: '{model_str}', prompt: '{prompt}', seed: {seed}")
60
-
61
- if model_str not in models_load or models_load[model_str] is None:
62
- print(f"Model not found in models_load: {model_str}. Available models: {models_load.keys()}")
63
- return jsonify({"error": "Model not found or not loaded"}), 404
64
-
65
- image_path = asyncio.run(infer(model_str, prompt, seed, inference_timeout))
66
- if image_path is not None:
67
- return send_file(image_path, mimetype='image/png')
68
-
69
- print("Image generation failed for:", model_str) # Log failure reason
70
- return jsonify({"error": "Image generation failed"}), 500
71
-
72
- if __name__ == "__main__":
73
- app.run(host='0.0.0.0', port=7860) # Run directly if needed for testing
 
1
+ import gradio as gr
2
+ from random import randint
 
3
  from all_models import models
4
  from externalmod import gr_Interface_load
5
  import asyncio
6
+ import os
7
  from threading import RLock
8
 
 
 
 
 
9
  lock = RLock()
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
 
 
 
 
12
  def load_fn(models):
13
  global models_load
14
  models_load = {}
 
17
  if model not in models_load.keys():
18
  try:
19
  m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
 
20
  except Exception as error:
21
+ print(error)
22
+ m = gr.Interface(lambda: None, ['text'], ['image'])
23
+ models_load.update({model: m})
24
 
25
  load_fn(models)
26
 
27
+ num_models = 6
28
+ MAX_SEED = 3999999999
29
+ default_models = models[:num_models]
30
  inference_timeout = 600
31
 
32
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
33
  kwargs = {"seed": seed}
34
+ task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
35
+ await asyncio.sleep(0)
36
  try:
 
 
 
37
  result = await asyncio.wait_for(task, timeout=timeout)
38
+ except (Exception, asyncio.TimeoutError) as e:
39
+ print(e)
40
+ print(f"Task timed out: {model_str}")
41
+ if not task.done():
42
+ task.cancel()
43
+ result = None
44
+ if task.done() and result is not None:
45
+ with lock:
46
+ png_path = "image.png"
47
+ result.save(png_path)
48
+ return png_path
49
+ return None
50
+
51
+ # Expose Gradio API
52
+ def generate_api(model_str, prompt, seed=1):
53
+ result = asyncio.run(infer(model_str, prompt, seed))
54
+ if result:
55
+ return result # Path to generated image
56
+ return None
57
+
58
+ # Launch Gradio API without frontend
59
+ iface = gr.Interface(fn=generate_api, inputs=["text", "text", "number"], outputs="file")
60
+ iface.launch(show_api=True, share=True)