Geek7 commited on
Commit
3c7030b
·
verified ·
1 Parent(s): 3285b51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -41
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  from random import randint
3
  from all_models import models
@@ -5,23 +6,24 @@ from externalmod import gr_Interface_load
5
  import asyncio
6
  import os
7
  from threading import RLock
8
- from flask import send_file
9
 
 
10
 
11
  lock = RLock()
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
 
 
14
  def load_fn(models):
15
  global models_load
16
  models_load = {}
17
-
18
  for model in models:
19
  if model not in models_load.keys():
20
  try:
21
  m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
22
  except Exception as error:
23
  print(error)
24
- m = gr.Interface(lambda: None, ['text'], ['image'])
25
  models_load.update({model: m})
26
 
27
  load_fn(models)
@@ -31,57 +33,44 @@ MAX_SEED = 3999999999
31
  default_models = models[:num_models]
32
  inference_timeout = 600
33
 
34
- #async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
35
- # kwargs = {"seed": seed}
36
- #task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
37
- # await asyncio.sleep(0)
38
- # try:
39
- # result = await asyncio.wait_for(task, timeout=timeout)
40
- # except (Exception, asyncio.TimeoutError) as e:
41
- # print(e)
42
- # print(f"Task timed out: {model_str}")
43
- # if not task.done():
44
- # task.cancel()
45
- # result = None
46
- # if task.done() and result is not None:
47
- # with lock:
48
- # png_path = "image.png"
49
- # result.save(png_path)
50
- # return png_path
51
- #return None
52
-
53
-
54
-
55
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
56
  kwargs = {"seed": seed}
57
  task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
58
  await asyncio.sleep(0)
59
-
60
  try:
61
  result = await asyncio.wait_for(task, timeout=timeout)
62
  except (Exception, asyncio.TimeoutError) as e:
63
  print(e)
64
  print(f"Task timed out: {model_str}")
65
- if not task.done():
66
  task.cancel()
67
- return None # Return None if the task fails or times out
68
-
69
  if task.done() and result is not None:
70
  with lock:
71
  png_path = "image.png"
72
- result.save(png_path) # Save the image to a file
73
- return send_file(png_path, mimetype='image/png') # Directly use send_file here
74
-
75
- return None # Return None if no result is generated
76
 
 
 
 
 
 
 
 
77
 
78
- # Expose Gradio API
79
- def generate_api(model_str, prompt, seed=1):
80
- result = asyncio.run(infer(model_str, prompt, seed))
81
- if result:
82
- return result # Path to generated image
83
- return None
 
 
 
84
 
85
- # Launch Gradio API without frontend
86
- iface = gr.Interface(fn=generate_api, inputs=["text", "text", "number"], outputs="file")
87
- iface.launch(show_api=True, share=True)
 
1
+ from flask import Flask, request, jsonify, send_file
2
  import gradio as gr
3
  from random import randint
4
  from all_models import models
 
6
  import asyncio
7
  import os
8
  from threading import RLock
9
+ import io
10
 
11
+ app = Flask(__name__)
12
 
13
  lock = RLock()
14
  HF_TOKEN = os.environ.get("HF_TOKEN")
15
 
16
+ # Load the models
17
  def load_fn(models):
18
  global models_load
19
  models_load = {}
 
20
  for model in models:
21
  if model not in models_load.keys():
22
  try:
23
  m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
24
  except Exception as error:
25
  print(error)
26
+ m = gr.Interface(lambda: None, ['text'], ['image']) # Fallback
27
  models_load.update({model: m})
28
 
29
  load_fn(models)
 
33
  default_models = models[:num_models]
34
  inference_timeout = 600
35
 
36
+ # Async inference function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
38
  kwargs = {"seed": seed}
39
  task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
40
  await asyncio.sleep(0)
 
41
  try:
42
  result = await asyncio.wait_for(task, timeout=timeout)
43
  except (Exception, asyncio.TimeoutError) as e:
44
  print(e)
45
  print(f"Task timed out: {model_str}")
46
+ if not task.done():
47
  task.cancel()
48
+ result = None
 
49
  if task.done() and result is not None:
50
  with lock:
51
  png_path = "image.png"
52
+ result.save(png_path)
53
+ return png_path
54
+ return None
 
55
 
56
+ # Flask API endpoint
57
+ @app.route('/async_infer', methods=['POST'])
58
+ def async_infer():
59
+ data = request.get_json()
60
+ model_str = data.get('model_str')
61
+ prompt = data.get('prompt')
62
+ seed = data.get('seed', 1)
63
 
64
+ # Run the inference
65
+ try:
66
+ image_path = asyncio.run(infer(model_str, prompt, seed))
67
+ if image_path:
68
+ return send_file(image_path, mimetype='image/png')
69
+ else:
70
+ return jsonify({"error": "Image generation failed."}), 500
71
+ except Exception as e:
72
+ return jsonify({"error": str(e)}), 500
73
 
74
+ # Run the Flask app
75
+ if __name__ == '__main__':
76
+ app.run(debug=True)